mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Fix not backing off when connection requires encryption (#762)
This commit is contained in:
parent
31c6e4abc6
commit
ab5834ca0d
@ -72,7 +72,7 @@ cdef class APIConnection:
|
||||
cdef public APIFrameHelper _frame_helper
|
||||
cdef public object api_version
|
||||
cdef public object connection_state
|
||||
cdef dict _message_handlers
|
||||
cdef public dict _message_handlers
|
||||
cdef public str log_name
|
||||
cdef set _read_exception_futures
|
||||
cdef object _ping_timer
|
||||
@ -81,7 +81,7 @@ cdef class APIConnection:
|
||||
cdef float _keep_alive_timeout
|
||||
cdef object _start_connect_task
|
||||
cdef object _finish_connect_task
|
||||
cdef object _fatal_exception
|
||||
cdef public Exception _fatal_exception
|
||||
cdef bint _expected_disconnect
|
||||
cdef object _loop
|
||||
cdef bint _send_pending_ping
|
||||
@ -102,7 +102,7 @@ cdef class APIConnection:
|
||||
|
||||
cpdef _async_schedule_keep_alive(self, object now)
|
||||
|
||||
cpdef _cleanup(self)
|
||||
cdef _cleanup(self)
|
||||
|
||||
cpdef set_log_name(self, str name)
|
||||
|
||||
@ -112,7 +112,7 @@ cdef class APIConnection:
|
||||
|
||||
cdef _process_login_response(self, object hello_response)
|
||||
|
||||
cpdef _set_connection_state(self, object state)
|
||||
cdef _set_connection_state(self, object state)
|
||||
|
||||
cpdef report_fatal_error(self, Exception err)
|
||||
|
||||
@ -129,3 +129,5 @@ cdef class APIConnection:
|
||||
cpdef _handle_ping_request_internal(self, object msg)
|
||||
|
||||
cpdef _handle_get_time_request_internal(self, object msg)
|
||||
|
||||
cdef _set_fatal_exception_if_unset(self, Exception err)
|
||||
|
@ -120,8 +120,8 @@ class ConnectionState(enum.Enum):
|
||||
# The handshake has been completed, messages can be exchanged
|
||||
HANDSHAKE_COMPLETE = 2
|
||||
# The connection has been established, authenticated data can be exchanged
|
||||
CONNECTED = 2
|
||||
CLOSED = 3
|
||||
CONNECTED = 3
|
||||
CLOSED = 4
|
||||
|
||||
|
||||
CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED
|
||||
@ -593,7 +593,10 @@ class APIConnection:
|
||||
"""Set the connection state and log the change."""
|
||||
self.connection_state = state
|
||||
self.is_connected = state is CONNECTION_STATE_CONNECTED
|
||||
self._handshake_complete = state is CONNECTION_STATE_HANDSHAKE_COMPLETE
|
||||
self._handshake_complete = (
|
||||
state is CONNECTION_STATE_HANDSHAKE_COMPLETE
|
||||
or state is CONNECTION_STATE_CONNECTED
|
||||
)
|
||||
|
||||
def _make_connect_request(self) -> ConnectRequest:
|
||||
"""Make a ConnectRequest."""
|
||||
@ -772,17 +775,28 @@ class APIConnection:
|
||||
The connection will be closed, all exception handlers notified.
|
||||
This method does not log the error, the call site should do so.
|
||||
"""
|
||||
if self._expected_disconnect is False and not self._fatal_exception:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
"%s: Connection error occurred: %s",
|
||||
self.log_name,
|
||||
err or type(err),
|
||||
exc_info=not str(err), # Log the full stack on empty error string
|
||||
)
|
||||
self._fatal_exception = err
|
||||
if not self._fatal_exception:
|
||||
if self._expected_disconnect is False:
|
||||
# Only log the first error
|
||||
_LOGGER.warning(
|
||||
"%s: Connection error occurred: %s",
|
||||
self.log_name,
|
||||
err or type(err),
|
||||
exc_info=not str(err), # Log the full stack on empty error string
|
||||
)
|
||||
|
||||
# Only set the first error since otherwise the original
|
||||
# error will be lost (ie RequiresEncryptionAPIError) and than
|
||||
# SocketClosedAPIError will be raised instead
|
||||
self._set_fatal_exception_if_unset(err)
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def _set_fatal_exception_if_unset(self, err: Exception) -> None:
|
||||
"""Set the fatal exception if it hasn't been set yet."""
|
||||
if self._fatal_exception is None:
|
||||
self._fatal_exception = err
|
||||
|
||||
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||
"""Process an incoming packet."""
|
||||
debug_enabled = self._debug_enabled
|
||||
@ -889,8 +903,10 @@ class APIConnection:
|
||||
[self._finish_connect_task], timeout=DISCONNECT_CONNECT_TIMEOUT
|
||||
)
|
||||
if pending:
|
||||
self._fatal_exception = TimeoutAPIError(
|
||||
"Timed out waiting to finish connect before disconnecting"
|
||||
self._set_fatal_exception_if_unset(
|
||||
TimeoutAPIError(
|
||||
"Timed out waiting to finish connect before disconnecting"
|
||||
)
|
||||
)
|
||||
if self._debug_enabled:
|
||||
_LOGGER.debug(
|
||||
@ -931,15 +947,3 @@ class APIConnection:
|
||||
)
|
||||
|
||||
self._cleanup()
|
||||
|
||||
def _get_message_handlers(
|
||||
self,
|
||||
) -> dict[Any, set[Callable[[message.Message], None]]]:
|
||||
"""Get the message handlers.
|
||||
|
||||
This function is only used for testing for leaks.
|
||||
|
||||
It has to be bound to the real instance to work since
|
||||
_message_handlers is not a public attribute.
|
||||
"""
|
||||
return self._message_handlers
|
||||
|
@ -1246,9 +1246,7 @@ async def test_bluetooth_gatt_start_notify(
|
||||
client, connection, transport, protocol = api_client
|
||||
notifies = []
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
|
||||
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||
notifies.append((handle, data))
|
||||
@ -1280,7 +1278,7 @@ async def test_bluetooth_gatt_start_notify(
|
||||
await cancel_cb()
|
||||
|
||||
assert (
|
||||
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||
len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
== handlers_before
|
||||
)
|
||||
# Ensure abort callback is a no-op after cancel
|
||||
@ -1305,9 +1303,7 @@ async def test_bluetooth_gatt_start_notify_fails(
|
||||
def on_bluetooth_gatt_notify(handle: int, data: bytearray) -> None:
|
||||
notifies.append((handle, data))
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
|
||||
with patch.object(
|
||||
connection,
|
||||
@ -1317,7 +1313,7 @@ async def test_bluetooth_gatt_start_notify_fails(
|
||||
await client.bluetooth_gatt_start_notify(1234, 1, on_bluetooth_gatt_notify)
|
||||
|
||||
assert (
|
||||
len(list(itertools.chain(*connection._get_message_handlers().values())))
|
||||
len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
== handlers_before
|
||||
)
|
||||
|
||||
@ -1769,9 +1765,7 @@ async def test_bluetooth_device_connect_cancelled(
|
||||
client, connection, transport, protocol = api_client
|
||||
states = []
|
||||
|
||||
handlers_before = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
handlers_before = len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
|
||||
def on_bluetooth_connection_state(connected: bool, mtu: int, error: int) -> None:
|
||||
states.append((connected, mtu, error))
|
||||
@ -1795,9 +1789,7 @@ async def test_bluetooth_device_connect_cancelled(
|
||||
await connect_task
|
||||
assert states == []
|
||||
|
||||
handlers_after = len(
|
||||
list(itertools.chain(*connection._get_message_handlers().values()))
|
||||
)
|
||||
handlers_after = len(list(itertools.chain(*connection._message_handlers.values())))
|
||||
# Make sure we do not leak message handlers
|
||||
assert handlers_after == handlers_before
|
||||
|
||||
|
@ -170,6 +170,14 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
mock_data_received(protocol, b"\x01\x00\x00")
|
||||
await task
|
||||
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
|
||||
conn.force_disconnect()
|
||||
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
|
||||
conn.report_fatal_error(Exception("test"))
|
||||
assert isinstance(conn._fatal_exception, RequiresEncryptionAPIError)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_connection(
|
||||
|
@ -30,6 +30,7 @@ from aioesphomeapi.reconnect_logic import (
|
||||
from .common import (
|
||||
get_mock_async_zeroconf,
|
||||
get_mock_zeroconf,
|
||||
mock_data_received,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
)
|
||||
@ -719,6 +720,8 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
||||
protocol = cli._connection._frame_helper
|
||||
send_plaintext_hello(protocol)
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert cli._connection.is_connected is True
|
||||
await asyncio.sleep(0)
|
||||
@ -736,5 +739,71 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
|
||||
assert mock_create_connection.call_count == 0
|
||||
|
||||
assert len(on_disconnect_calls) == 1
|
||||
assert on_disconnect_calls[0] is False
|
||||
expected_disconnect = on_disconnect_calls[-1]
|
||||
assert expected_disconnect is False
|
||||
await logic.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backoff_on_encryption_error(
|
||||
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
|
||||
) -> None:
|
||||
"""Test we backoff on encryption error."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
async_zeroconf = get_mock_async_zeroconf()
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address="1.2.3.4",
|
||||
port=6052,
|
||||
password=None,
|
||||
noise_psk="",
|
||||
expected_name="fake",
|
||||
zeroconf_instance=async_zeroconf.zeroconf,
|
||||
)
|
||||
|
||||
connected = asyncio.Event()
|
||||
on_disconnect_calls = []
|
||||
|
||||
async def on_disconnect(expected_disconnect: bool) -> None:
|
||||
on_disconnect_calls.append(expected_disconnect)
|
||||
|
||||
async def on_connect() -> None:
|
||||
connected.set()
|
||||
|
||||
logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
zeroconf_instance=async_zeroconf,
|
||||
name="fake",
|
||||
)
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop,
|
||||
"create_connection",
|
||||
side_effect=partial(_create_mock_transport_protocol, transport, connected),
|
||||
):
|
||||
await logic.start()
|
||||
await connected.wait()
|
||||
protocol = cli._connection._frame_helper
|
||||
mock_data_received(protocol, b"\x01\x00\x00")
|
||||
|
||||
assert cli._connection.is_connected is False
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert len(on_disconnect_calls) == 0
|
||||
|
||||
assert "Scheduling new connect attempt in 60.00 seconds" in caplog.text
|
||||
assert "Connection requires encryption (RequiresEncryptionAPIError)" in caplog.text
|
||||
now = loop.time()
|
||||
assert logic._connect_timer.when() - now == pytest.approx(60, 1)
|
||||
assert logic._tries == MAXIMUM_BACKOFF_TRIES
|
||||
await logic.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user