Fix not backing off when connection requires encryption (#762)

This commit is contained in:
J. Nick Koston 2023-11-27 18:39:22 -06:00 committed by GitHub
parent 31c6e4abc6
commit ab5834ca0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 120 additions and 45 deletions

View File

@ -72,7 +72,7 @@ cdef class APIConnection:
cdef public APIFrameHelper _frame_helper cdef public APIFrameHelper _frame_helper
cdef public object api_version cdef public object api_version
cdef public object connection_state cdef public object connection_state
cdef dict _message_handlers cdef public dict _message_handlers
cdef public str log_name cdef public str log_name
cdef set _read_exception_futures cdef set _read_exception_futures
cdef object _ping_timer cdef object _ping_timer
@ -81,7 +81,7 @@ cdef class APIConnection:
cdef float _keep_alive_timeout cdef float _keep_alive_timeout
cdef object _start_connect_task cdef object _start_connect_task
cdef object _finish_connect_task cdef object _finish_connect_task
cdef object _fatal_exception cdef public Exception _fatal_exception
cdef bint _expected_disconnect cdef bint _expected_disconnect
cdef object _loop cdef object _loop
cdef bint _send_pending_ping cdef bint _send_pending_ping
@ -102,7 +102,7 @@ cdef class APIConnection:
cpdef _async_schedule_keep_alive(self, object now) cpdef _async_schedule_keep_alive(self, object now)
cpdef _cleanup(self) cdef _cleanup(self)
cpdef set_log_name(self, str name) cpdef set_log_name(self, str name)
@ -112,7 +112,7 @@ cdef class APIConnection:
cdef _process_login_response(self, object hello_response) cdef _process_login_response(self, object hello_response)
cpdef _set_connection_state(self, object state) cdef _set_connection_state(self, object state)
cpdef report_fatal_error(self, Exception err) cpdef report_fatal_error(self, Exception err)
@ -129,3 +129,5 @@ cdef class APIConnection:
cpdef _handle_ping_request_internal(self, object msg) cpdef _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg) cpdef _handle_get_time_request_internal(self, object msg)
cdef _set_fatal_exception_if_unset(self, Exception err)

View File

@ -120,8 +120,8 @@ class ConnectionState(enum.Enum):
# The handshake has been completed, messages can be exchanged # The handshake has been completed, messages can be exchanged
HANDSHAKE_COMPLETE = 2 HANDSHAKE_COMPLETE = 2
# The connection has been established, authenticated data can be exchanged # The connection has been established, authenticated data can be exchanged
CONNECTED = 2 CONNECTED = 3
CLOSED = 3 CLOSED = 4
CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED
@ -593,7 +593,10 @@ class APIConnection:
"""Set the connection state and log the change.""" """Set the connection state and log the change."""
self.connection_state = state self.connection_state = state
self.is_connected = state is CONNECTION_STATE_CONNECTED self.is_connected = state is CONNECTION_STATE_CONNECTED
self._handshake_complete = state is CONNECTION_STATE_HANDSHAKE_COMPLETE self._handshake_complete = (
state is CONNECTION_STATE_HANDSHAKE_COMPLETE
or state is CONNECTION_STATE_CONNECTED
)
def _make_connect_request(self) -> ConnectRequest: def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest.""" """Make a ConnectRequest."""
@ -772,17 +775,28 @@ class APIConnection:
The connection will be closed, all exception handlers notified. The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so. This method does not log the error, the call site should do so.
""" """
if self._expected_disconnect is False and not self._fatal_exception: if not self._fatal_exception:
# Only log the first error if self._expected_disconnect is False:
_LOGGER.warning( # Only log the first error
"%s: Connection error occurred: %s", _LOGGER.warning(
self.log_name, "%s: Connection error occurred: %s",
err or type(err), self.log_name,
exc_info=not str(err), # Log the full stack on empty error string err or type(err),
) exc_info=not str(err), # Log the full stack on empty error string
self._fatal_exception = err )
# Only set the first error since otherwise the original
# error will be lost (ie RequiresEncryptionAPIError) and than
# SocketClosedAPIError will be raised instead
self._set_fatal_exception_if_unset(err)
self._cleanup() self._cleanup()
def _set_fatal_exception_if_unset(self, err: Exception) -> None:
"""Set the fatal exception if it hasn't been set yet."""
if self._fatal_exception is None:
self._fatal_exception = err
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet.""" """Process an incoming packet."""
debug_enabled = self._debug_enabled debug_enabled = self._debug_enabled
@ -889,8 +903,10 @@ class APIConnection:
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT [self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
) )
if pending: if pending:
self._fatal_exception = TimeoutAPIError( self._set_fatal_exception_if_unset(
"Timed out waiting to finish connect before disconnecting" TimeoutAPIError(
"Timed out waiting to finish connect before disconnecting"
)
) )
if self._debug_enabled: if self._debug_enabled:
_LOGGER.debug( _LOGGER.debug(
@ -931,15 +947,3 @@ class APIConnection:
) )
self._cleanup() self._cleanup()
def _get_message_handlers(
self,
) -> dict[Any, set[Callable[[message.Message], None]]]:
"""Get the message handlers.
This function is only used for testing for leaks.
It has to be bound to the real instance to work since
_message_handlers is not a public attribute.
"""
return self._message_handlers

View File

@ -1246,9 +1246,7 @@ async def test_bluetooth_gatt_start_notify(
client, connection, transport, protocol = api_client client, connection, transport, protocol = api_client
notifies = [] notifies = []
handlers_before = len( handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data)) notifies.append((handle, data))
@ -1280,7 +1278,7 @@ async def test_bluetooth_gatt_start_notify(
await cancel_cb() await cancel_cb()
assert ( assert (
len(list(itertools.chain(*connection._get_message_handlers().values()))) len(list(itertools.chain(*connection._message_handlers.values())))
== handlers_before == handlers_before
) )
# Ensure abort callback is a no-op after cancel # Ensure abort callback is a no-op after cancel
@ -1305,9 +1303,7 @@ async def test_bluetooth_gatt_start_notify_fails(
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None: def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
notifies.append((handle, data)) notifies.append((handle, data))
handlers_before = len( handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
list(itertools.chain(*connection._get_message_handlers().values()))
)
with patch.object( with patch.object(
connection, connection,
@ -1317,7 +1313,7 @@ async def test_bluetooth_gatt_start_notify_fails(
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify) await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
assert ( assert (
len(list(itertools.chain(*connection._get_message_handlers().values()))) len(list(itertools.chain(*connection._message_handlers.values())))
== handlers_before == handlers_before
) )
@ -1769,9 +1765,7 @@ async def test_bluetooth_device_connect_cancelled(
client, connection, transport, protocol = api_client client, connection, transport, protocol = api_client
states = [] states = []
handlers_before = len( handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
list(itertools.chain(*connection._get_message_handlers().values()))
)
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None: def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
states.append((connected, mtu, error)) states.append((connected, mtu, error))
@ -1795,9 +1789,7 @@ async def test_bluetooth_device_connect_cancelled(
await connect_task await connect_task
assert states == [] assert states == []
handlers_after = len( handlers_after = len(list(itertools.chain(*connection._message_handlers.values())))
list(itertools.chain(*connection._get_message_handlers().values()))
)
# Make sure we do not leak message handlers # Make sure we do not leak message handlers
assert handlers_after == handlers_before assert handlers_after == handlers_before

View File

@ -170,6 +170,14 @@ async def test_requires_encryption_propagates(conn: APIConnection):
mock_data_received(protocol, b"\x01\x00\x00") mock_data_received(protocol, b"\x01\x00\x00")
await task await task
await asyncio.sleep(0)
await asyncio.sleep(0)
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
conn.force_disconnect()
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
conn.report_fatal_error(Exception("test"))
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_plaintext_connection( async def test_plaintext_connection(

View File

@ -30,6 +30,7 @@ from aioesphomeapi.reconnect_logic import (
from .common import ( from .common import (
get_mock_async_zeroconf, get_mock_async_zeroconf,
get_mock_zeroconf, get_mock_zeroconf,
mock_data_received,
send_plaintext_connect_response, send_plaintext_connect_response,
send_plaintext_hello, send_plaintext_hello,
) )
@ -719,6 +720,8 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
protocol = cli._connection._frame_helper protocol = cli._connection._frame_helper
send_plaintext_hello(protocol) send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False) send_plaintext_connect_response(protocol, False)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert cli._connection.is_connected is True assert cli._connection.is_connected is True
await asyncio.sleep(0) await asyncio.sleep(0)
@ -736,5 +739,71 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
assert mock_create_connection.call_count == 0 assert mock_create_connection.call_count == 0
assert len(on_disconnect_calls) == 1 assert len(on_disconnect_calls) == 1
assert on_disconnect_calls[0] is False expected_disconnect = on_disconnect_calls[-1]
assert expected_disconnect is False
await logic.stop()
@pytest.mark.asyncio
async def test_backoff_on_encryption_error(
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
) -> None:
"""Test we backoff on encryption error."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
noise_psk="",
expected_name="fake",
zeroconf_instance=async_zeroconf.zeroconf,
)
connected = asyncio.Event()
on_disconnect_calls = []
async def on_disconnect(expected_disconnect: bool) -> None:
on_disconnect_calls.append(expected_disconnect)
async def on_connect() -> None:
connected.set()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=async_zeroconf,
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
mock_data_received(protocol, b"\x01\x00\x00")
assert cli._connection.is_connected is False
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_calls) == 0
assert "Scheduling new connect attempt in 60.00 seconds" in caplog.text
assert "Connection requires encryption (RequiresEncryptionAPIError)" in caplog.text
now = loop.time()
assert logic._connect_timer.when() - now == pytest.approx(60, 1)
assert logic._tries == MAXIMUM_BACKOFF_TRIES
await logic.stop() await logic.stop()