diff --git a/aioesphomeapi/_frame_helper/base.py b/aioesphomeapi/_frame_helper/base.py index 2ad4b36..c405888 100644 --- a/aioesphomeapi/_frame_helper/base.py +++ b/aioesphomeapi/_frame_helper/base.py @@ -154,6 +154,7 @@ class APIFrameHelper: self.close() def _handle_error(self, exc: Exception) -> None: + self._set_ready_future_exception(exc) self._connection.report_fatal_error(exc) def connection_lost(self, exc: Exception | None) -> None: diff --git a/aioesphomeapi/_frame_helper/noise.py b/aioesphomeapi/_frame_helper/noise.py index 1416df5..75066ed 100644 --- a/aioesphomeapi/_frame_helper/noise.py +++ b/aioesphomeapi/_frame_helper/noise.py @@ -110,10 +110,6 @@ class APINoiseFrameHelper(APIFrameHelper): self._state = NOISE_STATE_CLOSED super().close() - def _handle_error_and_close(self, exc: Exception) -> None: - self._set_ready_future_exception(exc) - super()._handle_error_and_close(exc) - def _handle_error(self, exc: Exception) -> None: """Handle an error, and provide a good message when during hello.""" if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError): diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 4956b54..3f699e7 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -621,6 +621,35 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N await task +@pytest.mark.asyncio +async def test_init_noise_attempted_when_esp_uses_plaintext( + noise_conn: APIConnection, +) -> None: + loop = asyncio.get_event_loop() + transport = MagicMock() + protocol: APINoiseFrameHelper | None = None + + async def _create_connection(create, sock, *args, **kwargs): + nonlocal protocol + protocol = create() + protocol.connection_made(transport) + return transport, protocol + + with patch.object(loop, "create_connection", side_effect=_create_connection): + task = asyncio.create_task(noise_conn._connect_init_frame_helper()) + await asyncio.sleep(0) + + assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper) + protocol = noise_conn._frame_helper + + protocol.connection_lost(ConnectionResetError()) + + with pytest.raises( + APIConnectionError, match="The connection dropped immediately" + ): + await task + + @pytest.mark.asyncio async def test_eof_received_closes_connection( plaintext_connect_task_with_login: tuple[