mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-15 13:59:17 +01:00
Fix eof_received
not raising SocketClosedAPIError (#651)
This commit is contained in:
parent
ccf2f1f245
commit
f88b15e33b
@ -537,25 +537,34 @@ class APIConnection:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
if not isinstance(ex, APIConnectionError):
|
||||
cause: Exception | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
err_str = "Starting connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
cause = self._fatal_exception
|
||||
else:
|
||||
err_str = str(ex) or type(ex).__name__
|
||||
new_exc = APIConnectionError(
|
||||
f"Error while starting connection: {err_str}"
|
||||
)
|
||||
new_exc.__cause__ = cause or ex
|
||||
raise new_exc
|
||||
raise ex
|
||||
raise self._wrap_fatal_connection_exception("starting", ex)
|
||||
finally:
|
||||
self._start_connect_task = None
|
||||
self._set_connection_state(ConnectionState.SOCKET_OPENED)
|
||||
|
||||
def _wrap_fatal_connection_exception(
|
||||
self, action: str, ex: BaseException
|
||||
) -> APIConnectionError:
|
||||
"""Ensure a fatal exception is wrapped as as an APIConnectionError."""
|
||||
if isinstance(ex, APIConnectionError):
|
||||
return ex
|
||||
cause: BaseException | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
err_str = f"{action.title()} connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
cause = self._fatal_exception
|
||||
else:
|
||||
err_str = str(ex) or type(ex).__name__
|
||||
cause = ex
|
||||
if isinstance(self._fatal_exception, APIConnectionError):
|
||||
klass = type(self._fatal_exception)
|
||||
else:
|
||||
klass = APIConnectionError
|
||||
new_exc = klass(f"Error while {action} connection: {err_str}")
|
||||
new_exc.__cause__ = cause or ex
|
||||
return new_exc
|
||||
|
||||
async def _do_finish_connect(self, login: bool) -> None:
|
||||
"""Finish the connection process."""
|
||||
in_do_connect.set(True)
|
||||
@ -585,22 +594,7 @@ class APIConnection:
|
||||
# If the task was cancelled, we need to clean up the connection
|
||||
# and raise the CancelledError as APIConnectionError
|
||||
self._cleanup()
|
||||
if not isinstance(ex, APIConnectionError):
|
||||
cause: Exception | None = None
|
||||
if isinstance(ex, CancelledError):
|
||||
err_str = "Finishing connection cancelled"
|
||||
if self._fatal_exception:
|
||||
err_str += f" due to fatal exception: {self._fatal_exception}"
|
||||
cause = self._fatal_exception
|
||||
else:
|
||||
err_str = str(ex) or type(ex).__name__
|
||||
cause = ex
|
||||
new_exc = APIConnectionError(
|
||||
f"Error while finishing connection: {err_str}"
|
||||
)
|
||||
new_exc.__cause__ = cause or ex
|
||||
raise new_exc
|
||||
raise ex
|
||||
raise self._wrap_fatal_connection_exception("finishing", ex)
|
||||
finally:
|
||||
self._finish_connect_task = None
|
||||
self._set_connection_state(ConnectionState.CONNECTED)
|
||||
|
@ -117,3 +117,14 @@ def send_plaintext_connect_response(
|
||||
def send_ping_response(protocol: APIPlaintextFrameHelper) -> None:
|
||||
ping_response: message.Message = PingResponse()
|
||||
protocol.data_received(generate_plaintext_packet(ping_response))
|
||||
|
||||
|
||||
def get_mock_protocol(conn: APIConnection):
|
||||
protocol = APIPlaintextFrameHelper(
|
||||
connection=conn,
|
||||
client_info="mock",
|
||||
log_name="mock_device",
|
||||
)
|
||||
transport = MagicMock()
|
||||
protocol.connection_made(transport)
|
||||
return protocol
|
||||
|
@ -4,7 +4,7 @@ import asyncio
|
||||
import base64
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
@ -20,16 +20,17 @@ from aioesphomeapi._frame_helper.plain_text import (
|
||||
_cached_varuint_to_bytes as cached_varuint_to_bytes,
|
||||
)
|
||||
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
|
||||
from aioesphomeapi.connection import ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
from .common import async_fire_time_changed, get_mock_protocol, utcnow
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
@ -385,6 +386,10 @@ def test_bytes_to_varuint(val, encoded):
|
||||
assert cached_bytes_to_varuint(encoded) == val
|
||||
|
||||
|
||||
def test_bytes_to_varuint_invalid():
|
||||
assert bytes_to_varuint(b"\xFF") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_noise_frame_helper_handshake_failure():
|
||||
"""Test the noise frame helper handshake failure."""
|
||||
@ -568,3 +573,52 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
|
||||
with pytest.raises(ProtocolAPIError, match="Connection closed"):
|
||||
helper.data_received(encrypted_header + encrypted_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = get_mock_protocol(conn)
|
||||
with patch.object(loop, "create_connection") as create_connection:
|
||||
create_connection.return_value = (MagicMock(), protocol)
|
||||
|
||||
conn._socket = MagicMock()
|
||||
await conn._connect_init_frame_helper()
|
||||
loop.call_soon(conn._frame_helper._ready_future.set_result, None)
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
await asyncio.sleep(0)
|
||||
# The preamble should be \x00 but we send \x09
|
||||
protocol.data_received(b"\x09\x00\x00")
|
||||
|
||||
with pytest.raises(ProtocolAPIError):
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_eof_received_closes_connection(
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
assert protocol.eof_received() is False
|
||||
assert conn.is_connected is False
|
||||
with pytest.raises(SocketClosedAPIError, match="EOF received"):
|
||||
await connect_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_lost_closes_connection_and_logs(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
plaintext_connect_task_with_login: tuple[
|
||||
APIConnection, asyncio.Transport, APIPlaintextFrameHelper, asyncio.Task
|
||||
],
|
||||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_with_login
|
||||
protocol.connection_lost(OSError("original message"))
|
||||
assert conn.is_connected is False
|
||||
assert "original message" in caplog.text
|
||||
with pytest.raises(APIConnectionError, match="original message"):
|
||||
await connect_task
|
||||
|
@ -58,12 +58,7 @@ from aioesphomeapi.model import (
|
||||
)
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_zeroconf,
|
||||
)
|
||||
from .common import Estr, generate_plaintext_packet, get_mock_zeroconf
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -31,6 +31,7 @@ from .common import (
|
||||
async_fire_time_changed,
|
||||
connect,
|
||||
generate_plaintext_packet,
|
||||
get_mock_protocol,
|
||||
send_ping_response,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
@ -41,17 +42,6 @@ from .conftest import KEEP_ALIVE_INTERVAL
|
||||
KEEP_ALIVE_TIMEOUT_RATIO = 4.5
|
||||
|
||||
|
||||
def _get_mock_protocol(conn: APIConnection):
|
||||
protocol = APIPlaintextFrameHelper(
|
||||
connection=conn,
|
||||
client_info="mock",
|
||||
log_name="mock_device",
|
||||
)
|
||||
transport = MagicMock()
|
||||
protocol.connection_made(transport)
|
||||
return protocol
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(
|
||||
plaintext_connect_task_no_login: tuple[
|
||||
@ -152,7 +142,7 @@ async def test_disconnect_when_not_fully_connected(
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
protocol = get_mock_protocol(conn)
|
||||
with patch.object(loop, "create_connection") as create_connection:
|
||||
create_connection.return_value = (MagicMock(), protocol)
|
||||
|
||||
@ -357,7 +347,7 @@ async def test_plaintext_connection_fails_handshake(
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
exception, raised_exception = exception_map
|
||||
protocol = _get_mock_protocol(conn)
|
||||
protocol = get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: APIPlaintextFrameHelper | None = None
|
||||
transport = MagicMock()
|
||||
|
@ -14,7 +14,6 @@ from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.log_runner import async_run
|
||||
|
||||
from .common import (
|
||||
PROTO_TO_MESSAGE_TYPE,
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_async_zeroconf,
|
||||
|
Loading…
Reference in New Issue
Block a user