Fix `eof_received` not raising SocketClosedAPIError (#651)

This commit is contained in:
J. Nick Koston 2023-11-21 14:56:31 +01:00 committed by GitHub
parent ccf2f1f245
commit f88b15e33b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 97 additions and 54 deletions

View File

@ -537,25 +537,34 @@ class APIConnection:
# If the task was cancelled, we need to clean up the connection # If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError # and raise the CancelledError as APIConnectionError
self._cleanup() self._cleanup()
if not isinstance(ex, APIConnectionError): raise self._wrap_fatal_connection_exception("starting", ex)
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Starting connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
new_exc = APIConnectionError(
f"Error while starting connection: {err_str}"
)
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
finally: finally:
self._start_connect_task = None self._start_connect_task = None
self._set_connection_state(ConnectionState.SOCKET_OPENED) self._set_connection_state(ConnectionState.SOCKET_OPENED)
def _wrap_fatal_connection_exception(
self, action: str, ex: BaseException
) -> APIConnectionError:
"""Ensure a fatal exception is wrapped as as an APIConnectionError."""
if isinstance(ex, APIConnectionError):
return ex
cause: BaseException | None = None
if isinstance(ex, CancelledError):
err_str = f"{action.title()} connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
cause = ex
if isinstance(self._fatal_exception, APIConnectionError):
klass = type(self._fatal_exception)
else:
klass = APIConnectionError
new_exc = klass(f"Error while {action} connection: {err_str}")
new_exc.__cause__ = cause or ex
return new_exc
async def _do_finish_connect(self, login: bool) -> None: async def _do_finish_connect(self, login: bool) -> None:
"""Finish the connection process.""" """Finish the connection process."""
in_do_connect.set(True) in_do_connect.set(True)
@ -585,22 +594,7 @@ class APIConnection:
# If the task was cancelled, we need to clean up the connection # If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError # and raise the CancelledError as APIConnectionError
self._cleanup() self._cleanup()
if not isinstance(ex, APIConnectionError): raise self._wrap_fatal_connection_exception("finishing", ex)
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Finishing connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
cause = ex
new_exc = APIConnectionError(
f"Error while finishing connection: {err_str}"
)
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
finally: finally:
self._finish_connect_task = None self._finish_connect_task = None
self._set_connection_state(ConnectionState.CONNECTED) self._set_connection_state(ConnectionState.CONNECTED)

View File

@ -117,3 +117,14 @@ def send_plaintext_connect_response(
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None: def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
ping_response: message.Message = PingResponse() ping_response: message.Message = PingResponse()
protocol.data_received(generate_plaintext_packet(ping_response)) protocol.data_received(generate_plaintext_packet(ping_response))
def get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
client_info="mock",
log_name="mock_device",
)
transport = MagicMock()
protocol.connection_made(transport)
return protocol

View File

@ -4,7 +4,7 @@ import asyncio
import base64 import base64
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import MagicMock from unittest.mock import MagicMock, patch
import pytest import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped] from noise.connection import NoiseConnection # type: ignore[import-untyped]
@ -20,16 +20,17 @@ from aioesphomeapi._frame_helper.plain_text import (
_cached_varuint_to_bytes as cached_varuint_to_bytes, _cached_varuint_to_bytes as cached_varuint_to_bytes,
) )
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from aioesphomeapi.connection import ConnectionState
from aioesphomeapi.core import ( from aioesphomeapi.core import (
APIConnectionError,
BadNameAPIError, BadNameAPIError,
HandshakeAPIError, HandshakeAPIError,
InvalidEncryptionKeyAPIError, InvalidEncryptionKeyAPIError,
ProtocolAPIError, ProtocolAPIError,
SocketAPIError,
SocketClosedAPIError, SocketClosedAPIError,
) )
from .common import async_fire_time_changed, utcnow from .common import async_fire_time_changed, get_mock_protocol, utcnow
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -385,6 +386,10 @@ def test_bytes_to_varuint(val, encoded):
assert cached_bytes_to_varuint(encoded) == val assert cached_bytes_to_varuint(encoded) == val
def test_bytes_to_varuint_invalid():
assert bytes_to_varuint(b"\xFF") is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_noise_frame_helper_handshake_failure(): async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure.""" """Test the noise frame helper handshake failure."""
@ -568,3 +573,52 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
with pytest.raises(ProtocolAPIError, match="Connection closed"): with pytest.raises(ProtocolAPIError, match="Connection closed"):
helper.data_received(encrypted_header + encrypted_payload) helper.data_received(encrypted_header + encrypted_payload)
@pytest.mark.asyncio
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
loop = asyncio.get_event_loop()
protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection:
create_connection.return_value = (MagicMock(), protocol)
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
task = asyncio.create_task(conn._connect_hello_login(login=True))
await asyncio.sleep(0)
# The preamble should be \x00 but we send \x09
protocol.data_received(b"\x09\x00\x00")
with pytest.raises(ProtocolAPIError):
await task
@pytest.mark.asyncio
async def test_eof_received_closes_connection(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
assert protocol.eof_received() is False
assert conn.is_connected is False
with pytest.raises(SocketClosedAPIError, match="EOF received"):
await connect_task
@pytest.mark.asyncio
async def test_connection_lost_closes_connection_and_logs(
caplog: pytest.LogCaptureFixture,
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
protocol.connection_lost(OSError("original message"))
assert conn.is_connected is False
assert "original message" in caplog.text
with pytest.raises(APIConnectionError, match="original message"):
await connect_task

View File

@ -58,12 +58,7 @@ from aioesphomeapi.model import (
) )
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import ( from .common import Estr, generate_plaintext_packet, get_mock_zeroconf
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
)
@pytest.fixture @pytest.fixture

View File

@ -31,6 +31,7 @@ from .common import (
async_fire_time_changed, async_fire_time_changed,
connect, connect,
generate_plaintext_packet, generate_plaintext_packet,
get_mock_protocol,
send_ping_response, send_ping_response,
send_plaintext_connect_response, send_plaintext_connect_response,
send_plaintext_hello, send_plaintext_hello,
@ -41,17 +42,6 @@ from .conftest import KEEP_ALIVE_INTERVAL
KEEP_ALIVE_TIMEOUT_RATIO = 4.5 KEEP_ALIVE_TIMEOUT_RATIO = 4.5
def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
client_info="mock",
log_name="mock_device",
)
transport = MagicMock()
protocol.connection_made(transport)
return protocol
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_connect( async def test_connect(
plaintext_connect_task_no_login: tuple[ plaintext_connect_task_no_login: tuple[
@ -152,7 +142,7 @@ async def test_disconnect_when_not_fully_connected(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_requires_encryption_propagates(conn: APIConnection): async def test_requires_encryption_propagates(conn: APIConnection):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn) protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection: with patch.object(loop, "create_connection") as create_connection:
create_connection.return_value = (MagicMock(), protocol) create_connection.return_value = (MagicMock(), protocol)
@ -357,7 +347,7 @@ async def test_plaintext_connection_fails_handshake(
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
exception, raised_exception = exception_map exception, raised_exception = exception_map
protocol = _get_mock_protocol(conn) protocol = get_mock_protocol(conn)
messages = [] messages = []
protocol: APIPlaintextFrameHelper | None = None protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock() transport = MagicMock()

View File

@ -14,7 +14,6 @@ from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run from aioesphomeapi.log_runner import async_run
from .common import ( from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr, Estr,
generate_plaintext_packet, generate_plaintext_packet,
get_mock_async_zeroconf, get_mock_async_zeroconf,