mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-02-17 01:51:23 +01:00
Ensure frame_helper is always closed before the underlying socket (#602)
This commit is contained in:
parent
9ecf2fc2b7
commit
e1c42e95bf
@ -288,6 +288,7 @@ class APIConnection:
|
||||
"""Step 2 in connect process: connect the socket."""
|
||||
debug_enable = self._debug_enabled()
|
||||
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
|
||||
self._socket = sock
|
||||
sock.setblocking(False)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# Try to reduce the pressure on esphome device as it measures
|
||||
@ -319,7 +320,6 @@ class APIConnection:
|
||||
except OSError as err:
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
|
||||
|
||||
self._socket = sock
|
||||
if debug_enable is True:
|
||||
_LOGGER.debug(
|
||||
"%s: Opened socket to %s:%s (%s)",
|
||||
@ -359,6 +359,10 @@ class APIConnection:
|
||||
sock=self._socket,
|
||||
)
|
||||
|
||||
# Set the frame helper right away to ensure
|
||||
# the socket gets closed if we fail to handshake
|
||||
self._frame_helper = fh
|
||||
|
||||
try:
|
||||
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
|
||||
except asyncio_TimeoutError as err:
|
||||
@ -366,7 +370,6 @@ class APIConnection:
|
||||
except OSError as err:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
|
||||
self._frame_helper = fh
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
|
@ -3,16 +3,21 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import socket
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from typing import Any, Coroutine, Optional
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi import APIConnectionError
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import RequiresEncryptionAPIError
|
||||
from aioesphomeapi.core import (
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
@ -335,3 +340,105 @@ async def test_finish_connection_times_out(
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception_map"),
|
||||
[
|
||||
(OSError("Socket error"), HandshakeAPIError),
|
||||
(asyncio.TimeoutError, TimeoutAPIError),
|
||||
(asyncio.CancelledError, APIConnectionError),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_connection_fails_handshake(
|
||||
conn: APIConnection,
|
||||
resolve_host: AsyncMock,
|
||||
socket_socket: MagicMock,
|
||||
exception_map: tuple[Exception, Exception],
|
||||
) -> None:
|
||||
"""Test that the frame helper is closed before the underlying socket.
|
||||
|
||||
If we don't do this, asyncio will get confused and not release the socket.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
exception, raised_exception = exception_map
|
||||
protocol = _get_mock_protocol(conn)
|
||||
messages = []
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
|
||||
"""Plaintext frame helper that raises exception on handshake."""
|
||||
|
||||
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
|
||||
raise exception
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
with patch(
|
||||
"aioesphomeapi.connection.APIPlaintextFrameHelper",
|
||||
APIPlaintextFrameHelperHandshakeException,
|
||||
), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
assert conn._socket is not None
|
||||
assert conn._frame_helper is not None
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
)
|
||||
protocol.data_received(b"5stackatomproxy")
|
||||
protocol.data_received(b"\x00\x00$")
|
||||
protocol.data_received(b"\x00\x00\x04")
|
||||
protocol.data_received(
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
)
|
||||
protocol.data_received(
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
|
||||
call_order = []
|
||||
|
||||
def _socket_close_call():
|
||||
call_order.append("socket_close")
|
||||
|
||||
def _frame_helper_close_call():
|
||||
call_order.append("frame_helper_close")
|
||||
|
||||
with patch.object(
|
||||
conn._socket, "close", side_effect=_socket_close_call
|
||||
), patch.object(
|
||||
conn._frame_helper, "close", side_effect=_frame_helper_close_call
|
||||
), pytest.raises(
|
||||
raised_exception
|
||||
):
|
||||
await asyncio.sleep(0)
|
||||
await connect_task
|
||||
|
||||
# Ensure the frame helper is closed before the socket
|
||||
# so asyncio releases the socket
|
||||
assert call_order == ["frame_helper_close", "socket_close"]
|
||||
assert not conn.is_connected
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], HelloResponse)
|
||||
assert isinstance(messages[1], DeviceInfoResponse)
|
||||
assert messages[1].name == "m5stackatomproxy"
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
Loading…
Reference in New Issue
Block a user