mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Fix unhandled exception when handshake times out (#601)
This commit is contained in:
parent
cdca073972
commit
1630816dc8
@ -62,7 +62,7 @@ class APIFrameHelper:
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
||||
|
@ -502,17 +502,23 @@ class APIConnection:
|
||||
await start_connect_task
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
if isinstance(ex, CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
if not isinstance(ex, APIConnectionError):
|
||||
cause: Exception | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
err_str = "Starting connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
cause = self._fatal_exception
|
||||
else:
|
||||
err_str = str(ex) or type(ex).__name__
|
||||
new_exc = APIConnectionError(
|
||||
f"Error while starting connection: {err_str}"
|
||||
)
|
||||
if not start_connect_task.cancelled() and (
|
||||
task_exc := start_connect_task.exception()
|
||||
):
|
||||
raise task_exc
|
||||
raise
|
||||
new_exc.__cause__ = cause or ex
|
||||
raise new_exc
|
||||
raise ex
|
||||
finally:
|
||||
self._start_connect_task = None
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
@ -550,17 +556,24 @@ class APIConnection:
|
||||
await self._finish_connect_task
|
||||
except (Exception, CancelledError) as ex:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
if isinstance(ex, CancelledError):
|
||||
raise self._fatal_exception or APIConnectionError(
|
||||
"Connection cancelled"
|
||||
if not isinstance(ex, APIConnectionError):
|
||||
cause: Exception | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
err_str = "Finishing connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
cause = self._fatal_exception
|
||||
else:
|
||||
err_str = str(ex) or type(ex).__name__
|
||||
cause = ex
|
||||
new_exc = APIConnectionError(
|
||||
f"Error while finishing connection: {err_str}"
|
||||
)
|
||||
if not finish_connect_task.cancelled() and (
|
||||
task_exc := finish_connect_task.exception()
|
||||
):
|
||||
raise task_exc
|
||||
raise
|
||||
new_exc.__cause__ = cause or ex
|
||||
raise new_exc
|
||||
raise ex
|
||||
finally:
|
||||
self._finish_connect_task = None
|
||||
self._set_connection_state(ConnectionState.CONNECTED)
|
||||
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
51
tests/common.py
Normal file
51
tests/common.py
Normal file
@ -0,0 +1,51 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
|
||||
UTC = timezone.utc
|
||||
_MONOTONIC_RESOLUTION = time.get_clock_info("monotonic").resolution
|
||||
# We use a partial here since it is implemented in native code
|
||||
# and avoids the global lookup of UTC
|
||||
utcnow: partial[datetime] = partial(datetime.now, UTC)
|
||||
utcnow.__doc__ = "Get now in UTC time."
|
||||
|
||||
|
||||
def as_utc(dattim: datetime) -> datetime:
|
||||
"""Return a datetime as UTC time."""
|
||||
if dattim.tzinfo == UTC:
|
||||
return dattim
|
||||
return dattim.astimezone(UTC)
|
||||
|
||||
|
||||
def async_fire_time_changed(
|
||||
datetime_: datetime | None = None, fire_all: bool = False
|
||||
) -> None:
|
||||
"""Fire a time changed event at an exact microsecond.
|
||||
|
||||
Consider that it is not possible to actually achieve an exact
|
||||
microsecond in production as the event loop is not precise enough.
|
||||
If your code relies on this level of precision, consider a different
|
||||
approach, as this is only for testing.
|
||||
"""
|
||||
loop = asyncio.get_running_loop()
|
||||
if datetime_ is None:
|
||||
utc_datetime = datetime.now(UTC)
|
||||
else:
|
||||
utc_datetime = as_utc(datetime_)
|
||||
|
||||
timestamp = utc_datetime.timestamp()
|
||||
for task in list(loop._scheduled):
|
||||
if not isinstance(task, asyncio.TimerHandle):
|
||||
continue
|
||||
if task.cancelled():
|
||||
continue
|
||||
|
||||
mock_seconds_into_future = timestamp - time.time()
|
||||
future_seconds = task.when() - (loop.time() + _MONOTONIC_RESOLUTION)
|
||||
|
||||
if fire_all or mock_seconds_into_future >= future_seconds:
|
||||
task._run()
|
||||
task.cancel()
|
@ -1,7 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi import HandshakeAPIError
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||
@ -18,6 +23,8 @@ from aioesphomeapi.core import (
|
||||
SocketAPIError,
|
||||
)
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
|
||||
@ -243,6 +250,43 @@ async def test_noise_incorrect_name():
|
||||
await helper.perform_handshake(30)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_timeout():
|
||||
"""Test we raise on bad name."""
|
||||
outgoing_packets = [
|
||||
"010000", # hello packet
|
||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||
]
|
||||
packets = []
|
||||
|
||||
def _packet(type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = MockAPINoiseFrameHelper(
|
||||
on_pkt=_packet,
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
task = asyncio.create_task(helper.perform_handshake(30))
|
||||
await asyncio.sleep(0)
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=60))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(HandshakeAPIError):
|
||||
await task
|
||||
|
||||
|
||||
VARUINT_TESTCASES = [
|
||||
(0, b"\x00"),
|
||||
(42, b"\x2a"),
|
||||
|
@ -1,16 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import RequiresEncryptionAPIError
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
|
||||
|
||||
async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
@ -180,3 +186,152 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_socket_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of socket error during start connection."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "create_connection", side_effect=OSError("Socket error")):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Socket error"):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_times_out(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection timing out."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
async def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
await asyncio.sleep(500)
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=200))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with pytest.raises(
|
||||
APIConnectionError, match="Error while starting connection: TimeoutError"
|
||||
):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_os_error(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection has an OSError."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=OSError("Socket error")):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Socket error"):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of start connection is cancelled."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "sock_connect", side_effect=asyncio.CancelledError):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Starting connection cancelled"):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_is_cancelled(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of finishing connection being cancelled."""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with patch.object(loop, "create_connection", side_effect=asyncio.CancelledError):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await asyncio.sleep(0)
|
||||
with pytest.raises(APIConnectionError, match="Finishing connection cancelled"):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finish_connection_times_out(
|
||||
conn: APIConnection, resolve_host, socket_socket
|
||||
):
|
||||
"""Test handling of finish connection timing out."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=200))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
with pytest.raises(
|
||||
APIConnectionError, match="Error while finishing connection: TimeoutError"
|
||||
):
|
||||
await connect_task
|
||||
|
||||
async_fire_time_changed(utcnow() + timedelta(seconds=600))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert not conn.is_connected
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
Loading…
Reference in New Issue
Block a user