Fix not backing off when connection requires encryption (#762)

This commit is contained in:
J. Nick Koston 2023-11-27 18:39:22 -06:00 committed by GitHub
parent 31c6e4abc6
commit ab5834ca0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 45 deletions

View File

@ -72,7 +72,7 @@ cdef class APIConnection:
cdef public APIFrameHelper _frame_helper
cdef public object api_version
cdef public object connection_state
cdef dict _message_handlers
cdef public dict _message_handlers
cdef public str log_name
cdef set _read_exception_futures
cdef object _ping_timer
@ -81,7 +81,7 @@ cdef class APIConnection:
cdef float _keep_alive_timeout
cdef object _start_connect_task
cdef object _finish_connect_task
cdef object _fatal_exception
cdef public Exception _fatal_exception
cdef bint _expected_disconnect
cdef object _loop
cdef bint _send_pending_ping
@ -102,7 +102,7 @@ cdef class APIConnection:
cpdef _async_schedule_keep_alive(self, object now)
cpdef _cleanup(self)
cdef _cleanup(self)
cpdef set_log_name(self, str name)
@ -112,7 +112,7 @@ cdef class APIConnection:
cdef _process_login_response(self, object hello_response)
cpdef _set_connection_state(self, object state)
cdef _set_connection_state(self, object state)
cpdef report_fatal_error(self, Exception err)
@ -129,3 +129,5 @@ cdef class APIConnection:
cpdef _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg)
cdef _set_fatal_exception_if_unset(self, Exception err)

View File

@ -120,8 +120,8 @@ class ConnectionState(enum.Enum):
# The handshake has been completed, messages can be exchanged
HANDSHAKE_COMPLETE = 2
# The connection has been established, authenticated data can be exchanged
CONNECTED = 2
CLOSED = 3
CONNECTED = 3
CLOSED = 4
CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED
@ -593,7 +593,10 @@ class APIConnection:
"""Set the connection state and log the change."""
self.connection_state = state
self.is_connected = state is CONNECTION_STATE_CONNECTED
self._handshake_complete = state is CONNECTION_STATE_HANDSHAKE_COMPLETE
self._handshake_complete = (
state is CONNECTION_STATE_HANDSHAKE_COMPLETE
or state is CONNECTION_STATE_CONNECTED
)
def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest."""
@ -772,17 +775,28 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if self._expected_disconnect is False and not self._fatal_exception:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
self.log_name,
err or type(err),
exc_info=not str(err), # Log the full stack on empty error string
)
self._fatal_exception = err
if not self._fatal_exception:
if self._expected_disconnect is False:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
self.log_name,
err or type(err),
exc_info=not str(err), # Log the full stack on empty error string
)
# Only set the first error since otherwise the original
# error will be lost (ie RequiresEncryptionAPIError) and than
# SocketClosedAPIError will be raised instead
self._set_fatal_exception_if_unset(err)
self._cleanup()
def _set_fatal_exception_if_unset(self, err: Exception) -> None:
"""Set the fatal exception if it hasn't been set yet."""
if self._fatal_exception is None:
self._fatal_exception = err
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet."""
debug_enabled = self._debug_enabled
@ -889,8 +903,10 @@ class APIConnection:
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
)
if pending:
self._fatal_exception = TimeoutAPIError(
"Timed out waiting to finish connect before disconnecting"
self._set_fatal_exception_if_unset(
TimeoutAPIError(
"Timed out waiting to finish connect before disconnecting"
)
)
if self._debug_enabled:
_LOGGER.debug(
@ -931,15 +947,3 @@ class APIConnection:
)
self._cleanup()
def _get_message_handlers(
self,
) -> dict[Any, set[Callable[[message.Message], None]]]:
"""Get the message handlers.
This function is only used for testing for leaks.
It has to be bound to the real instance to work since
_message_handlers is not a public attribute.
"""
return self._message_handlers

View File

@ -1246,9 +1246,7 @@ async def test_bluetooth_gatt_start_notify(
client, connection, transport, protocol = api_client
notifies = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data))
@ -1280,7 +1278,7 @@ async def test_bluetooth_gatt_start_notify(
await cancel_cb()
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
len(list(itertools.chain(*connection._message_handlers.values())))
== handlers_before
)
# Ensure abort callback is a no-op after cancel
@ -1305,9 +1303,7 @@ async def test_bluetooth_gatt_start_notify_fails(
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data))
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
with patch.object(
connection,
@ -1317,7 +1313,7 @@ async def test_bluetooth_gatt_start_notify_fails(
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
assert (
len(list(itertools.chain(*connection._get_message_handlers().values())))
len(list(itertools.chain(*connection._message_handlers.values())))
== handlers_before
)
@ -1769,9 +1765,7 @@ async def test_bluetooth_device_connect_cancelled(
client, connection, transport, protocol = api_client
states = []
handlers_before = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error))
@ -1795,9 +1789,7 @@ async def test_bluetooth_device_connect_cancelled(
await connect_task
assert states == []
handlers_after = len(
list(itertools.chain(*connection._get_message_handlers().values()))
)
handlers_after = len(list(itertools.chain(*connection._message_handlers.values())))
# Make sure we do not leak message handlers
assert handlers_after == handlers_before

View File

@ -170,6 +170,14 @@ async def test_requires_encryption_propagates(conn: APIConnection):
mock_data_received(protocol, b"\x01\x00\x00")
await task
await asyncio.sleep(0)
await asyncio.sleep(0)
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
conn.force_disconnect()
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
conn.report_fatal_error(Exception("test"))
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
@pytest.mark.asyncio
async def test_plaintext_connection(

View File

@ -30,6 +30,7 @@ from aioesphomeapi.reconnect_logic import (
from .common import (
get_mock_async_zeroconf,
get_mock_zeroconf,
mock_data_received,
send_plaintext_connect_response,
send_plaintext_hello,
)
@ -719,6 +720,8 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
protocol = cli._connection._frame_helper
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert cli._connection.is_connected is True
await asyncio.sleep(0)
@ -736,5 +739,71 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
assert mock_create_connection.call_count == 0
assert len(on_disconnect_calls) == 1
assert on_disconnect_calls[0] is False
expected_disconnect = on_disconnect_calls[-1]
assert expected_disconnect is False
await logic.stop()
@pytest.mark.asyncio
async def test_backoff_on_encryption_error(
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
) -> None:
"""Test we backoff on encryption error."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
noise_psk="",
expected_name="fake",
zeroconf_instance=async_zeroconf.zeroconf,
)
connected = asyncio.Event()
on_disconnect_calls = []
async def on_disconnect(expected_disconnect: bool) -> None:
on_disconnect_calls.append(expected_disconnect)
async def on_connect() -> None:
connected.set()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=async_zeroconf,
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
mock_data_received(protocol, b"\x01\x00\x00")
assert cli._connection.is_connected is False
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_calls) == 0
assert "Scheduling new connect attempt in 60.00 seconds" in caplog.text
assert "Connection requires encryption (RequiresEncryptionAPIError)" in caplog.text
now = loop.time()
assert logic._connect_timer.when() - now == pytest.approx(60, 1)
assert logic._tries == MAXIMUM_BACKOFF_TRIES
await logic.stop()