Fix unhandled exception when handshake times out (#601)

This commit is contained in:
J. Nick Koston 2023-10-23 12:32:20 -05:00 committed by GitHub
parent cdca073972
commit 1630816dc8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 282 additions and 19 deletions

View File

@ -62,7 +62,7 @@ class APIFrameHelper:
self._log_name = log_name
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
def _set_ready_future_exception(self, exc: Exception) -> None:
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
if not self._ready_future.done():
self._ready_future.set_exception(exc)

View File

@ -502,17 +502,23 @@ class APIConnection:
await start_connect_task
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
# and raise the CancelledError as APIConnectionError
self._cleanup()
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
if not isinstance(ex, APIConnectionError):
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Starting connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
new_exc = APIConnectionError(
f"Error while starting connection: {err_str}"
)
if not start_connect_task.cancelled() and (
task_exc := start_connect_task.exception()
):
raise task_exc
raise
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
finally:
self._start_connect_task = None
self._set_connection_state(ConnectionState.SOCKET_OPENED)
@ -550,17 +556,24 @@ class APIConnection:
await self._finish_connect_task
except (Exception, CancelledError) as ex:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError
# and raise the CancelledError as APIConnectionError
self._cleanup()
if isinstance(ex, CancelledError):
raise self._fatal_exception or APIConnectionError(
"Connection cancelled"
if not isinstance(ex, APIConnectionError):
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Finishing connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
cause = ex
new_exc = APIConnectionError(
f"Error while finishing connection: {err_str}"
)
if not finish_connect_task.cancelled() and (
task_exc := finish_connect_task.exception()
):
raise task_exc
raise
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
finally:
self._finish_connect_task = None
self._set_connection_state(ConnectionState.CONNECTED)

0
tests/__init__.py Normal file
View File

51
tests/common.py Normal file
View File

@ -0,0 +1,51 @@
from __future__ import annotations
import asyncio
import time
from datetime import datetime, timezone
from functools import partial
UTC = timezone.utc
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
# We use a partial here since it is implemented in native code
# and avoids the global lookup of UTC
utcnow: partial[datetime] = partial(datetime.now, UTC)
utcnow.__doc__ = "Get now in UTC time."
def as_utc(dattim: datetime) -> datetime:
"""Return a datetime as UTC time."""
if dattim.tzinfo == UTC:
return dattim
return dattim.astimezone(UTC)
def async_fire_time_changed(
datetime_: datetime | None = None, fire_all: bool = False
) -> None:
"""Fire a time changed event at an exact microsecond.
Consider that it is not possible to actually achieve an exact
microsecond in production as the event loop is not precise enough.
If your code relies on this level of precision, consider a different
approach, as this is only for testing.
"""
loop = asyncio.get_running_loop()
if datetime_ is None:
utc_datetime = datetime.now(UTC)
else:
utc_datetime = as_utc(datetime_)
timestamp = utc_datetime.timestamp()
for task in list(loop._scheduled):
if not isinstance(task, asyncio.TimerHandle):
continue
if task.cancelled():
continue
mock_seconds_into_future = timestamp - time.time()
future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)
if fire_all or mock_seconds_into_future >= future_seconds:
task._run()
task.cancel()

View File

@ -1,7 +1,12 @@
from __future__ import annotations
import asyncio
from datetime import timedelta
from unittest.mock import MagicMock
import pytest
from aioesphomeapi import HandshakeAPIError
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
@ -18,6 +23,8 @@ from aioesphomeapi.core import (
SocketAPIError,
)
from .common import async_fire_time_changed, utcnow
PREAMBLE = b"\x00"
@ -243,6 +250,43 @@ async def test_noise_incorrect_name():
await helper.perform_handshake(30)
@pytest.mark.asyncio
async def test_noise_timeout():
"""Test we raise on bad name."""
outgoing_packets = [
"010000", # hello packet
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
]
packets = []
def _packet(type_: int, data: bytes):
packets.append((type_, data))
def _on_error(exc: Exception):
raise exc
helper = MockAPINoiseFrameHelper(
on_pkt=_packet,
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt))
task = asyncio.create_task(helper.perform_handshake(30))
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=60))
await asyncio.sleep(0)
with pytest.raises(HandshakeAPIError):
await task
VARUINT_TESTCASES = [
(0, b"\x00"),
(42, b"\x2a"),

View File

@ -1,16 +1,22 @@
from __future__ import annotations
import asyncio
import socket
from datetime import timedelta
from typing import Optional
import pytest
from mock import MagicMock, patch
from aioesphomeapi import APIConnectionError
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import RequiresEncryptionAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import async_fire_time_changed, utcnow
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
@ -180,3 +186,152 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
remove()
await conn.force_disconnect()
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_start_connection_socket_error(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of socket error during start connection."""
loop = asyncio.get_event_loop()
with patch.object(loop, "create_connection", side_effect=OSError("Socket error")):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Socket error"):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_start_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection timing out."""
loop = asyncio.get_event_loop()
async def _create_mock_transport_protocol(create_func, **kwargs):
await asyncio.sleep(500)
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=200))
await asyncio.sleep(0)
with pytest.raises(
APIConnectionError, match="Error while starting connection: TimeoutError"
):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_start_connection_os_error(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection has an OSError."""
loop = asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=OSError("Socket error")):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Socket error"):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_start_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of start connection is cancelled."""
loop = asyncio.get_event_loop()
with patch.object(loop, "sock_connect", side_effect=asyncio.CancelledError):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Starting connection cancelled"):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_finish_connection_is_cancelled(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of finishing connection being cancelled."""
loop = asyncio.get_event_loop()
with patch.object(loop, "create_connection", side_effect=asyncio.CancelledError):
connect_task = asyncio.create_task(connect(conn, login=False))
await asyncio.sleep(0)
with pytest.raises(APIConnectionError, match="Finishing connection cancelled"):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
@pytest.mark.asyncio
async def test_finish_connection_times_out(
conn: APIConnection, resolve_host, socket_socket
):
"""Test handling of finish connection timing out."""
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()
with patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
await asyncio.sleep(0)
async_fire_time_changed(utcnow() + timedelta(seconds=200))
await asyncio.sleep(0)
with pytest.raises(
APIConnectionError, match="Error while finishing connection: TimeoutError"
):
await connect_task
async_fire_time_changed(utcnow() + timedelta(seconds=600))
await asyncio.sleep(0)
assert not conn.is_connected
remove()
await conn.force_disconnect()
await asyncio.sleep(0)