Fix not backing off when connection requires encryption

issue https://github.com/home-assistant/core/issues/104624
This commit is contained in:
J. Nick Koston 2023-11-27 18:08:06 -06:00
parent 31c6e4abc6
commit b1682b286d
No known key found for this signature in database
2 changed files with 90 additions and 15 deletions

View File

@ -120,8 +120,8 @@ class ConnectionState(enum.Enum):
# The handshake has been completed, messages can be exchanged
HANDSHAKE_COMPLETE = 2
# The connection has been established, authenticated data can be exchanged
CONNECTED = 2
CLOSED = 3
CONNECTED = 3
CLOSED = 4
CONNECTION_STATE_INITIALIZED = ConnectionState.INITIALIZED
@ -593,7 +593,10 @@ class APIConnection:
"""Set the connection state and log the change."""
self.connection_state = state
self.is_connected = state is CONNECTION_STATE_CONNECTED
self._handshake_complete = state is CONNECTION_STATE_HANDSHAKE_COMPLETE
self._handshake_complete = (
state is CONNECTION_STATE_HANDSHAKE_COMPLETE
or state is CONNECTION_STATE_CONNECTED
)
def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest."""
@ -772,15 +775,21 @@ class APIConnection:
The connection will be closed, all exception handlers notified.
This method does not log the error, the call site should do so.
"""
if self._expected_disconnect is False and not self._fatal_exception:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
self.log_name,
err or type(err),
exc_info=not str(err), # Log the full stack on empty error string
)
self._fatal_exception = err
if not self._fatal_exception:
if self._expected_disconnect is False:
# Only log the first error
_LOGGER.warning(
"%s: Connection error occurred: %s",
self.log_name,
err or type(err),
exc_info=not str(err), # Log the full stack on empty error string
)
# Only set the first error since otherwise the original
# error will be lost (ie RequiresEncryptionAPIError) and than
# SocketClosedAPIError will be raised instead
self._fatal_exception = err
self._cleanup()
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:

View File

@ -30,6 +30,7 @@ from aioesphomeapi.reconnect_logic import (
from .common import (
get_mock_async_zeroconf,
get_mock_zeroconf,
mock_data_received,
send_plaintext_connect_response,
send_plaintext_hello,
)
@ -720,7 +721,7 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
send_plaintext_hello(protocol)
send_plaintext_connect_response(protocol, False)
assert cli._connection.is_connected is True
assert cli._connection.is_connected is False
await asyncio.sleep(0)
with patch.object(event_loop, "sock_connect"), patch.object(
@ -735,6 +736,71 @@ async def test_handling_unexpected_disconnect(event_loop: asyncio.AbstractEventL
# since its an unexpected disconnect
assert mock_create_connection.call_count == 0
assert len(on_disconnect_calls) == 1
assert on_disconnect_calls[0] is False
# We never actually finished the connection
assert len(on_disconnect_calls) == 0
await logic.stop()
@pytest.mark.asyncio
async def test_backoff_on_encryption_error(
event_loop: asyncio.AbstractEventLoop, caplog: pytest.LogCaptureFixture
) -> None:
"""Test we backoff on encryption error."""
loop = asyncio.get_event_loop()
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()
connected = asyncio.Event()
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address="1.2.3.4",
port=6052,
password=None,
noise_psk="",
expected_name="fake",
zeroconf_instance=async_zeroconf.zeroconf,
)
connected = asyncio.Event()
on_disconnect_calls = []
async def on_disconnect(expected_disconnect: bool) -> None:
on_disconnect_calls.append(expected_disconnect)
async def on_connect() -> None:
connected.set()
logic = ReconnectLogic(
client=cli,
on_connect=on_connect,
on_disconnect=on_disconnect,
zeroconf_instance=async_zeroconf,
name="fake",
)
with patch.object(event_loop, "sock_connect"), patch.object(
loop,
"create_connection",
side_effect=partial(_create_mock_transport_protocol, transport, connected),
):
await logic.start()
await connected.wait()
protocol = cli._connection._frame_helper
mock_data_received(protocol, b"\x01\x00\x00")
assert cli._connection.is_connected is False
await asyncio.sleep(0)
await asyncio.sleep(0)
assert len(on_disconnect_calls) == 0
assert "Scheduling new connect attempt in 60.00 seconds" in caplog.text
assert "Connection requires encryption (RequiresEncryptionAPIError)" in caplog.text
now = loop.time()
assert logic._connect_timer.when() - now == pytest.approx(60, 1)
assert logic._tries == MAXIMUM_BACKOFF_TRIES
await logic.stop()