Fix `eof_received` not raising SocketClosedAPIError (#651)

This commit is contained in:
J. Nick Koston 2023-11-21 14:56:31 +01:00 committed by GitHub
parent ccf2f1f245
commit f88b15e33b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 97 additions and 54 deletions

View File

@ -537,25 +537,34 @@ class APIConnection:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
if not isinstance(ex, APIConnectionError):
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Starting connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
new_exc = APIConnectionError(
f"Error while starting connection: {err_str}"
)
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
raise self._wrap_fatal_connection_exception("starting", ex)
finally:
self._start_connect_task = None
self._set_connection_state(ConnectionState.SOCKET_OPENED)
def _wrap_fatal_connection_exception(
self, action: str, ex: BaseException
) -> APIConnectionError:
"""Ensure a fatal exception is wrapped as as an APIConnectionError."""
if isinstance(ex, APIConnectionError):
return ex
cause: BaseException | None = None
if isinstance(ex, CancelledError):
err_str = f"{action.title()} connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
cause = ex
if isinstance(self._fatal_exception, APIConnectionError):
klass = type(self._fatal_exception)
else:
klass = APIConnectionError
new_exc = klass(f"Error while {action} connection: {err_str}")
new_exc.__cause__ = cause or ex
return new_exc
async def _do_finish_connect(self, login: bool) -> None:
"""Finish the connection process."""
in_do_connect.set(True)
@ -585,22 +594,7 @@ class APIConnection:
# If the task was cancelled, we need to clean up the connection
# and raise the CancelledError as APIConnectionError
self._cleanup()
if not isinstance(ex, APIConnectionError):
cause: Exception | None = None
if isinstance(ex, CancelledError):
err_str = "Finishing connection cancelled"
if self._fatal_exception:
err_str += f" due to fatal exception: {self._fatal_exception}"
cause = self._fatal_exception
else:
err_str = str(ex) or type(ex).__name__
cause = ex
new_exc = APIConnectionError(
f"Error while finishing connection: {err_str}"
)
new_exc.__cause__ = cause or ex
raise new_exc
raise ex
raise self._wrap_fatal_connection_exception("finishing", ex)
finally:
self._finish_connect_task = None
self._set_connection_state(ConnectionState.CONNECTED)

View File

@ -117,3 +117,14 @@ def send_plaintext_connect_response(
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
ping_response: message.Message = PingResponse()
protocol.data_received(generate_plaintext_packet(ping_response))
def get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
client_info="mock",
log_name="mock_device",
)
transport = MagicMock()
protocol.connection_made(transport)
return protocol

View File

@ -4,7 +4,7 @@ import asyncio
import base64
from datetime import timedelta
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch
import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped]
@ -20,16 +20,17 @@ from aioesphomeapi._frame_helper.plain_text import (
_cached_varuint_to_bytes as cached_varuint_to_bytes,
)
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from aioesphomeapi.connection import ConnectionState
from aioesphomeapi.core import (
APIConnectionError,
BadNameAPIError,
HandshakeAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
SocketAPIError,
SocketClosedAPIError,
)
from .common import async_fire_time_changed, utcnow
from .common import async_fire_time_changed, get_mock_protocol, utcnow
PREAMBLE = b"\x00"
@ -385,6 +386,10 @@ def test_bytes_to_varuint(val, encoded):
assert cached_bytes_to_varuint(encoded) == val
def test_bytes_to_varuint_invalid():
assert bytes_to_varuint(b"\xFF") is None
@pytest.mark.asyncio
async def test_noise_frame_helper_handshake_failure():
"""Test the noise frame helper handshake failure."""
@ -568,3 +573,52 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
with pytest.raises(ProtocolAPIError, match="Connection closed"):
helper.data_received(encrypted_header + encrypted_payload)
@pytest.mark.asyncio
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
loop = asyncio.get_event_loop()
protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection:
create_connection.return_value = (MagicMock(), protocol)
conn._socket = MagicMock()
await conn._connect_init_frame_helper()
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
conn.connection_state = ConnectionState.CONNECTED
task = asyncio.create_task(conn._connect_hello_login(login=True))
await asyncio.sleep(0)
# The preamble should be \x00 but we send \x09
protocol.data_received(b"\x09\x00\x00")
with pytest.raises(ProtocolAPIError):
await task
@pytest.mark.asyncio
async def test_eof_received_closes_connection(
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
assert protocol.eof_received() is False
assert conn.is_connected is False
with pytest.raises(SocketClosedAPIError, match="EOF received"):
await connect_task
@pytest.mark.asyncio
async def test_connection_lost_closes_connection_and_logs(
caplog: pytest.LogCaptureFixture,
plaintext_connect_task_with_login: tuple[
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
],
) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
protocol.connection_lost(OSError("original message"))
assert conn.is_connected is False
assert "original message" in caplog.text
with pytest.raises(APIConnectionError, match="original message"):
await connect_task

View File

@ -58,12 +58,7 @@ from aioesphomeapi.model import (
)
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
)
from .common import Estr, generate_plaintext_packet, get_mock_zeroconf
@pytest.fixture

View File

@ -31,6 +31,7 @@ from .common import (
async_fire_time_changed,
connect,
generate_plaintext_packet,
get_mock_protocol,
send_ping_response,
send_plaintext_connect_response,
send_plaintext_hello,
@ -41,17 +42,6 @@ from .conftest import KEEP_ALIVE_INTERVAL
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
connection=conn,
client_info="mock",
log_name="mock_device",
)
transport = MagicMock()
protocol.connection_made(transport)
return protocol
@pytest.mark.asyncio
async def test_connect(
plaintext_connect_task_no_login: tuple[
@ -152,7 +142,7 @@ async def test_disconnect_when_not_fully_connected(
@pytest.mark.asyncio
async def test_requires_encryption_propagates(conn: APIConnection):
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
protocol = get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection:
create_connection.return_value = (MagicMock(), protocol)
@ -357,7 +347,7 @@ async def test_plaintext_connection_fails_handshake(
"""
loop = asyncio.get_event_loop()
exception, raised_exception = exception_map
protocol = _get_mock_protocol(conn)
protocol = get_mock_protocol(conn)
messages = []
protocol: APIPlaintextFrameHelper | None = None
transport = MagicMock()

View File

@ -14,7 +14,6 @@ from aioesphomeapi.connection import APIConnection
from aioesphomeapi.log_runner import async_run
from .common import (
PROTO_TO_MESSAGE_TYPE,
Estr,
generate_plaintext_packet,
get_mock_async_zeroconf,