Fix handshake getting the wrong exception when the ESP drops the connection because its not using noise (#681)

This commit is contained in:
J. Nick Koston 2023-11-24 08:26:12 -06:00 committed by GitHub
parent 304379ff48
commit c21e32fda7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 4 deletions

View File

@ -154,6 +154,7 @@ class APIFrameHelper:
self.close()
def _handle_error(self, exc: Exception) -> None:
self._set_ready_future_exception(exc)
self._connection.report_fatal_error(exc)
def connection_lost(self, exc: Exception | None) -> None:

View File

@ -110,10 +110,6 @@ class APINoiseFrameHelper(APIFrameHelper):
self._state = NOISE_STATE_CLOSED
super().close()
def _handle_error_and_close(self, exc: Exception) -> None:
self._set_ready_future_exception(exc)
super()._handle_error_and_close(exc)
def _handle_error(self, exc: Exception) -> None:
"""Handle an error, and provide a good message when during hello."""
if self._state == NOISE_STATE_HELLO and isinstance(exc, ConnectionResetError):

View File

@ -621,6 +621,35 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
await task
@pytest.mark.asyncio
async def test_init_noise_attempted_when_esp_uses_plaintext(
noise_conn: APIConnection,
) -> None:
loop = asyncio.get_event_loop()
transport = MagicMock()
protocol: APINoiseFrameHelper | None = None
async def _create_connection(create, sock, *args, **kwargs):
nonlocal protocol
protocol = create()
protocol.connection_made(transport)
return transport, protocol
with patch.object(loop, "create_connection", side_effect=_create_connection):
task = asyncio.create_task(noise_conn._connect_init_frame_helper())
await asyncio.sleep(0)
assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper)
protocol = noise_conn._frame_helper
protocol.connection_lost(ConnectionResetError())
with pytest.raises(
APIConnectionError, match="The connection dropped immediately"
):
await task
@pytest.mark.asyncio
async def test_eof_received_closes_connection(
plaintext_connect_task_with_login: tuple[