From 7a80e3529b767f7b0f91a625d1d9eebd7de1caaa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 10 Jul 2023 21:15:14 -1000 Subject: [PATCH] Improve error reporting when encryption is disabled on device but client requests it (#464) --- aioesphomeapi/_frame_helper.py | 35 +++++++++++++++++++++++++--------- aioesphomeapi/connection.py | 2 ++ tests/test__frame_helper.py | 7 ++++++- tests/test_connection.py | 4 +++- 4 files changed, 37 insertions(+), 11 deletions(-) diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index c9452c1..ca9bac5 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -61,12 +61,14 @@ class APIFrameHelper(asyncio.Protocol): "_buffer", "_buffer_len", "_pos", + "_client_info", ) def __init__( self, on_pkt: Callable[[int, bytes], None], on_error: Callable[[Exception], None], + client_info: str, ) -> None: """Initialize the API frame helper.""" self._on_pkt = on_pkt @@ -76,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol): self._buffer = bytearray() self._buffer_len = 0 self._pos = 0 + self._client_info = client_info def _read_exactly(self, length: int) -> Optional[bytearray]: """Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" @@ -124,10 +127,9 @@ class APIPlaintextFrameHelper(APIFrameHelper): """Frame helper for plaintext API connections.""" def write_packet(self, type_: int, data: bytes) -> None: - """Write a packet to the socket, the caller should not have the lock. + """Write a packet to the socket. - The entire packet must be written in a single call to write - to avoid locking. + The entire packet must be written in a single call. """ assert self._transport is not None, "Transport should be set" data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data @@ -273,9 +275,10 @@ class APINoiseFrameHelper(APIFrameHelper): on_error: Callable[[Exception], None], noise_psk: str, expected_name: Optional[str], + client_info: str, ) -> None: """Initialize the API frame helper.""" - super().__init__(on_pkt, on_error) + super().__init__(on_pkt, on_error, client_info) self._ready_future = asyncio.get_event_loop().create_future() self._noise_psk = noise_psk self._expected_name = expected_name @@ -302,11 +305,25 @@ class APINoiseFrameHelper(APIFrameHelper): self._set_ready_future_exception(exc) super()._handle_error_and_close(exc) - def _write_frame(self, frame: bytes) -> None: - """Write a packet to the socket, the caller should not have the lock. + def _handle_error(self, exc: Exception) -> None: + """Handle an error, and provide a good message when during hello.""" + if ( + isinstance(exc, ConnectionResetError) + and self._state == NoiseConnectionState.HELLO + ): + original_exc = exc + exc = HandshakeAPIError( + "The connection dropped immediately after encrypted hello; " + "Try enabling encryption on the device or turning off " + f"encryption on the client ({self._client_info})." + ) + exc.__cause__ = original_exc + super()._handle_error(exc) - The entire packet must be written in a single call to write - to avoid locking. + def _write_frame(self, frame: bytes) -> None: + """Write a packet to the socket. + + The entire packet must be written in a single call to write. """ assert self._transport is not None, "Transport is not set" if _LOGGER.isEnabledFor(logging.DEBUG): @@ -371,7 +388,7 @@ class APINoiseFrameHelper(APIFrameHelper): self._write_frame(b"") # ClientHello def _handle_hello(self, server_hello: bytearray) -> None: - """Perform the handshake with the server, the caller is responsible for having the lock.""" + """Perform the handshake with the server.""" if not server_hello: self._handle_error_and_close(HandshakeAPIError("ServerHello is empty")) return diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 720dafa..5c8c573 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -301,6 +301,7 @@ class APIConnection: lambda: APIPlaintextFrameHelper( on_pkt=self._process_packet, on_error=self._report_fatal_error, + client_info=self._params.client_info, ), sock=self._socket, ) @@ -311,6 +312,7 @@ class APIConnection: expected_name=self._params.expected_name, on_pkt=self._process_packet, on_error=self._report_fatal_error, + client_info=self._params.client_info, ), sock=self._socket, ) diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index dac08d2..7b32900 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -68,7 +68,9 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type): def _on_error(exc: Exception): raise exc - helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error) + helper = APIPlaintextFrameHelper( + on_pkt=_packet, on_error=_on_error, client_info="my client" + ) helper.data_received(in_bytes) @@ -113,6 +115,7 @@ async def test_noise_frame_helper_incorrect_key(): on_error=_on_error, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="servicetest", + client_info="my client", ) helper._transport = MagicMock() @@ -151,6 +154,7 @@ async def test_noise_frame_helper_incorrect_key_fragments(): on_error=_on_error, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="servicetest", + client_info="my client", ) helper._transport = MagicMock() @@ -191,6 +195,7 @@ async def test_noise_incorrect_name(): on_error=_on_error, noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", expected_name="wrongname", + client_info="my client", ) helper._transport = MagicMock() diff --git a/tests/test_connection.py b/tests/test_connection.py index 1dacdcd..47db689 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -54,7 +54,9 @@ def socket_socket(): def _get_mock_protocol(conn: APIConnection): protocol = APIPlaintextFrameHelper( - on_pkt=conn._process_packet, on_error=conn._report_fatal_error + on_pkt=conn._process_packet, + on_error=conn._report_fatal_error, + client_info="mock", ) protocol._connected_event.set() protocol._transport = MagicMock()