diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 1b393d8..9d420ac 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -72,7 +72,7 @@ cdef class APIConnection: cdef public APIFrameHelper _frame_helper cdef public object api_version cdef public object connection_state - cdef dict _message_handlers + cdef public dict _message_handlers cdef public str log_name cdef set _read_exception_futures cdef object _ping_timer @@ -81,7 +81,7 @@ cdef class APIConnection: cdef float _keep_alive_timeout cdef object _start_connect_task cdef object _finish_connect_task - cdef object _fatal_exception + cdef public Exception _fatal_exception cdef bint _expected_disconnect cdef object _loop cdef bint _send_pending_ping @@ -102,7 +102,7 @@ cdef class APIConnection: cpdef _async_schedule_keep_alive(self, object now) - cpdef _cleanup(self) + cdef _cleanup(self) cpdef set_log_name(self, str name) @@ -112,7 +112,7 @@ cdef class APIConnection: cdef _process_login_response(self, object hello_response) - cpdef _set_connection_state(self, object state) + cdef _set_connection_state(self, object state) cpdef report_fatal_error(self, Exception err) @@ -129,3 +129,5 @@ cdef class APIConnection: cpdef _handle_ping_request_internal(self, object msg) cpdef _handle_get_time_request_internal(self, object msg) + + cdef _set_fatal_exception_if_unset(self, Exception err) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index e074afd..521c2a3 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -120,8 +120,8 @@ class ConnectionState(enum.Enum): # The handshake has been completed, messages can be exchanged HANDSHAKE_COMPLETE = 2 # The connection has been established, authenticated data can be exchanged - CONNECTED = 2 - CLOSED = 3 + CONNECTED = 3 + CLOSED = 4 CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED @@ -593,7 +593,10 @@ class APIConnection: """Set the connection state and log the change.""" self.connection_state = state self.is_connected = state is CONNECTION_STATE_CONNECTED - self._handshake_complete = state is CONNECTION_STATE_HANDSHAKE_COMPLETE + self._handshake_complete = ( + state is CONNECTION_STATE_HANDSHAKE_COMPLETE + or state is CONNECTION_STATE_CONNECTED + ) def _make_connect_request(self) -> ConnectRequest: """Make a ConnectRequest.""" @@ -772,17 +775,28 @@ class APIConnection: The connection will be closed, all exception handlers notified. This method does not log the error, the call site should do so. """ - if self._expected_disconnect is False and not self._fatal_exception: - # Only log the first error - _LOGGER.warning( - "%s: Connection error occurred: %s", - self.log_name, - err or type(err), - exc_info=not str(err), # Log the full stack on empty error string - ) - self._fatal_exception = err + if not self._fatal_exception: + if self._expected_disconnect is False: + # Only log the first error + _LOGGER.warning( + "%s: Connection error occurred: %s", + self.log_name, + err or type(err), + exc_info=not str(err), # Log the full stack on empty error string + ) + + # Only set the first error since otherwise the original + # error will be lost (ie RequiresEncryptionAPIError) and than + # SocketClosedAPIError will be raised instead + self._set_fatal_exception_if_unset(err) + self._cleanup() + def _set_fatal_exception_if_unset(self, err: Exception) -> None: + """Set the fatal exception if it hasn't been set yet.""" + if self._fatal_exception is None: + self._fatal_exception = err + def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: """Process an incoming packet.""" debug_enabled = self._debug_enabled @@ -889,8 +903,10 @@ class APIConnection: [self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT ) if pending: - self._fatal_exception = TimeoutAPIError( - "Timed out waiting to finish connect before disconnecting" + self._set_fatal_exception_if_unset( + TimeoutAPIError( + "Timed out waiting to finish connect before disconnecting" + ) ) if self._debug_enabled: _LOGGER.debug( @@ -931,15 +947,3 @@ class APIConnection: ) self._cleanup() - - def _get_message_handlers( - self, - ) -> dict[Any, set[Callable[[message.Message], None]]]: - """Get the message handlers. - - This function is only used for testing for leaks. - - It has to be bound to the real instance to work since - _message_handlers is not a public attribute. - """ - return self._message_handlers diff --git a/tests/test_client.py b/tests/test_client.py index ccf3be2..161b326 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1246,9 +1246,7 @@ async def test_bluetooth_gatt_start_notify( client, connection, transport, protocol = api_client notifies = [] - handlers_before = len( - list(itertools.chain(*connection._get_message_handlers().values())) - ) + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: notifies.append((handle, data)) @@ -1280,7 +1278,7 @@ async def test_bluetooth_gatt_start_notify( await cancel_cb() assert ( - len(list(itertools.chain(*connection._get_message_handlers().values()))) + len(list(itertools.chain(*connection._message_handlers.values()))) == handlers_before ) # Ensure abort callback is a no-op after cancel @@ -1305,9 +1303,7 @@ async def test_bluetooth_gatt_start_notify_fails( def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: notifies.append((handle, data)) - handlers_before = len( - list(itertools.chain(*connection._get_message_handlers().values())) - ) + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) with patch.object( connection, @@ -1317,7 +1313,7 @@ async def test_bluetooth_gatt_start_notify_fails( await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) assert ( - len(list(itertools.chain(*connection._get_message_handlers().values()))) + len(list(itertools.chain(*connection._message_handlers.values()))) == handlers_before ) @@ -1769,9 +1765,7 @@ async def test_bluetooth_device_connect_cancelled( client, connection, transport, protocol = api_client states = [] - handlers_before = len( - list(itertools.chain(*connection._get_message_handlers().values())) - ) + handlers_before = len(list(itertools.chain(*connection._message_handlers.values()))) def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: states.append((connected, mtu, error)) @@ -1795,9 +1789,7 @@ async def test_bluetooth_device_connect_cancelled( await connect_task assert states == [] - handlers_after = len( - list(itertools.chain(*connection._get_message_handlers().values())) - ) + handlers_after = len(list(itertools.chain(*connection._message_handlers.values()))) # Make sure we do not leak message handlers assert handlers_after == handlers_before diff --git a/tests/test_connection.py b/tests/test_connection.py index 3652b4b..4f95593 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -170,6 +170,14 @@ async def test_requires_encryption_propagates(conn: APIConnection): mock_data_received(protocol, b"\x01\x00\x00") await task + await asyncio.sleep(0) + await asyncio.sleep(0) + assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError) + conn.force_disconnect() + assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError) + conn.report_fatal_error(Exception("test")) + assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError) + @pytest.mark.asyncio async def test_plaintext_connection( diff --git a/tests/test_reconnect_logic.py b/tests/test_reconnect_logic.py index 0b605a0..c9dbb9f 100644 --- a/tests/test_reconnect_logic.py +++ b/tests/test_reconnect_logic.py @@ -30,6 +30,7 @@ from aioesphomeapi.reconnect_logic import ( from .common import ( get_mock_async_zeroconf, get_mock_zeroconf, + mock_data_received, send_plaintext_connect_response, send_plaintext_hello, ) @@ -719,6 +720,8 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL protocol = cli._connection._frame_helper send_plaintext_hello(protocol) send_plaintext_connect_response(protocol, False) + await asyncio.sleep(0) + await asyncio.sleep(0) assert cli._connection.is_connected is True await asyncio.sleep(0) @@ -736,5 +739,71 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL assert mock_create_connection.call_count == 0 assert len(on_disconnect_calls) == 1 - assert on_disconnect_calls[0] is False + expected_disconnect = on_disconnect_calls[-1] + assert expected_disconnect is False + await logic.stop() + + +@pytest.mark.asyncio +async def test_backoff_on_encryption_error( + event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture +) -> None: + """Test we backoff on encryption error.""" + loop = asyncio.get_event_loop() + protocol: APIPlaintextFrameHelper | None = None + transport = MagicMock() + connected = asyncio.Event() + + class PatchableAPIClient(APIClient): + pass + + async_zeroconf = get_mock_async_zeroconf() + + cli = PatchableAPIClient( + address="1.2.3.4", + port=6052, + password=None, + noise_psk="", + expected_name="fake", + zeroconf_instance=async_zeroconf.zeroconf, + ) + + connected = asyncio.Event() + on_disconnect_calls = [] + + async def on_disconnect(expected_disconnect: bool) -> None: + on_disconnect_calls.append(expected_disconnect) + + async def on_connect() -> None: + connected.set() + + logic = ReconnectLogic( + client=cli, + on_connect=on_connect, + on_disconnect=on_disconnect, + zeroconf_instance=async_zeroconf, + name="fake", + ) + + with patch.object(event_loop, "sock_connect"), patch.object( + loop, + "create_connection", + side_effect=partial(_create_mock_transport_protocol, transport, connected), + ): + await logic.start() + await connected.wait() + protocol = cli._connection._frame_helper + mock_data_received(protocol, b"\x01\x00\x00") + + assert cli._connection.is_connected is False + await asyncio.sleep(0) + await asyncio.sleep(0) + + assert len(on_disconnect_calls) == 0 + + assert "Scheduling new connect attempt in 60.00 seconds" in caplog.text + assert "Connection requires encryption (RequiresEncryptionAPIError)" in caplog.text + now = loop.time() + assert logic._connect_timer.when() - now == pytest.approx(60, 1) + assert logic._tries == MAXIMUM_BACKOFF_TRIES await logic.stop()