Ensure frame_helper is always closed before the underlying socket (#602)

This commit is contained in:
J. Nick Koston 2023-10-23 19:22:08 -05:00 committed by GitHub
parent 9ecf2fc2b7
commit e1c42e95bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 5 deletions

View File

@ -288,6 +288,7 @@ class APIConnection:
"""Step 2 in connect process: connect the socket."""
debug_enable = self._debug_enabled()
sock = socket.socket(family=addr.family, type=addr.type, proto=addr.proto)
self._socket = sock
sock.setblocking(False)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# Try to reduce the pressure on esphome device as it measures
@ -319,7 +320,6 @@ class APIConnection:
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
self._socket = sock
if debug_enable is True:
_LOGGER.debug(
"%s: Opened socket to %s:%s (%s)",
@ -359,6 +359,10 @@ class APIConnection:
sock=self._socket,
)
# Set the frame helper right away to ensure
# the socket gets closed if we fail to handshake
self._frame_helper = fh
try:
await fh.perform_handshake(HANDSHAKE_TIMEOUT)
except asyncio_TimeoutError as err:
@ -366,7 +370,6 @@ class APIConnection:
except OSError as err:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
self._frame_helper = fh
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""

View File

@ -3,16 +3,21 @@ from __future__ import annotations
import asyncio
import socket
from datetime import timedelta
from typing import Optional
from typing import Any, Coroutine, Optional
from unittest.mock import AsyncMock
import pytest
from mock import MagicMock, patch
from aioesphomeapi import APIConnectionError
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import RequiresEncryptionAPIError
from aioesphomeapi.core import (
APIConnectionError,
HandshakeAPIError,
RequiresEncryptionAPIError,
TimeoutAPIError,
)
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import async_fire_time_changed, utcnow
@ -335,3 +340,105 @@ async def test_finish_connection_times_out(
remove()
await conn.force_disconnect()
await asyncio.sleep(0)
@pytest.mark.parametrize(
("exception_map"),
[
(OSError("Socket error"), HandshakeAPIError),
(asyncio.TimeoutError, TimeoutAPIError),
(asyncio.CancelledError, APIConnectionError),
],
)
@pytest.mark.asyncio
async def test_plaintext_connection_fails_handshake(
conn: APIConnection,
resolve_host: AsyncMock,
socket_socket: MagicMock,
exception_map: tuple[Exception, Exception],
) -> None:
"""Test that the frame helper is closed before the underlying socket.
If we don't do this, asyncio will get confused and not release the socket.
"""
loop = asyncio.get_event_loop()
exception, raised_exception = exception_map
protocol = _get_mock_protocol(conn)
messages = []
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
class APIPlaintextFrameHelperHandshakeException(APIPlaintextFrameHelper):
"""Plaintext frame helper that raises exception on handshake."""
def perform_handshake(self, timeout: float) -> Coroutine[Any, Any, None]:
raise exception
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()
with patch(
"aioesphomeapi.connection.APIPlaintextFrameHelper",
APIPlaintextFrameHelperHandshakeException,
), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
assert conn._socket is not None
assert conn._frame_helper is not None
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
protocol.data_received(b"5stackatomproxy")
protocol.data_received(b"\x00\x00$")
protocol.data_received(b"\x00\x00\x04")
protocol.data_received(
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
)
protocol.data_received(
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
call_order = []
def _socket_close_call():
call_order.append("socket_close")
def _frame_helper_close_call():
call_order.append("frame_helper_close")
with patch.object(
conn._socket, "close", side_effect=_socket_close_call
), patch.object(
conn._frame_helper, "close", side_effect=_frame_helper_close_call
), pytest.raises(
raised_exception
):
await asyncio.sleep(0)
await connect_task
# Ensure the frame helper is closed before the socket
# so asyncio releases the socket
assert call_order == ["frame_helper_close", "socket_close"]
assert not conn.is_connected
assert len(messages) == 2
assert isinstance(messages[0], HelloResponse)
assert isinstance(messages[1], DeviceInfoResponse)
assert messages[1].name == "m5stackatomproxy"
remove()
await conn.force_disconnect()
await asyncio.sleep(0)