Send hello and login asynchronously to speed up connecting (#628)

This commit is contained in:
J. Nick Koston 2023-11-09 19:17:53 -06:00 committed by GitHub
parent 8678fa9ebc
commit a15b96c76b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 208 additions and 59 deletions

View File

@ -441,8 +441,8 @@ class APIClient:
return isinstance(msg, ListEntitiesDoneResponse)
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex(
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
resp = await self._connection.send_messages_await_response_complex(
(ListEntitiesRequest(),), do_append, do_stop, msg_types, timeout=60
)
entities: list[EntityInfo] = []
services: list[UserService] = []
@ -557,8 +557,8 @@ class APIClient:
assert self._connection is not None
message_filter = partial(self._filter_bluetooth_message, address, handle)
resp = await self._connection.send_message_await_response_complex(
request, message_filter, message_filter, msg_types, timeout=timeout
resp = await self._connection.send_messages_await_response_complex(
(request,), message_filter, message_filter, msg_types, timeout=timeout
)
if isinstance(resp[0], BluetoothGATTErrorResponse):
@ -791,9 +791,11 @@ class APIClient:
)
return True
[response] = await self._connection.send_message_await_response_complex(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.PAIR
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.PAIR
),
),
predicate_func,
predicate_func,
@ -812,9 +814,11 @@ class APIClient:
def predicate_func(msg: BluetoothDeviceUnpairingResponse) -> bool:
return bool(msg.address == address)
[response] = await self._connection.send_message_await_response_complex(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
),
),
predicate_func,
predicate_func,
@ -833,9 +837,11 @@ class APIClient:
def predicate_func(msg: BluetoothDeviceClearCacheResponse) -> bool:
return bool(msg.address == address)
[response] = await self._connection.send_message_await_response_complex(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
[response] = await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
),
),
predicate_func,
predicate_func,
@ -853,10 +859,12 @@ class APIClient:
return bool(msg.address == address and not msg.connected)
assert self._connection is not None
await self._connection.send_message_await_response_complex(
BluetoothDeviceRequest(
address=address,
request_type=BluetoothDeviceRequestType.DISCONNECT,
await self._connection.send_messages_await_response_complex(
(
BluetoothDeviceRequest(
address=address,
request_type=BluetoothDeviceRequestType.DISCONNECT,
),
),
predicate_func,
predicate_func,
@ -883,8 +891,8 @@ class APIClient:
return isinstance(msg, stop_types) and msg.address == address
assert self._connection is not None
resp = await self._connection.send_message_await_response_complex(
BluetoothGATTGetServicesRequest(address=address),
resp = await self._connection.send_messages_await_response_complex(
(BluetoothGATTGetServicesRequest(address=address),),
do_append,
do_stop,
msg_types,

View File

@ -89,3 +89,5 @@ cdef class APIConnection:
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
cdef _send_messages(self, tuple messages)

View File

@ -369,17 +369,48 @@ class APIConnection:
raise HandshakeAPIError(f"Handshake failed: {err}") from err
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
async def _connect_hello(self) -> None:
"""Step 4 in connect process: send hello and get api version."""
def _make_hello_request(self) -> HelloRequest:
"""Make a HelloRequest."""
hello = HelloRequest()
hello.client_info = self._params.client_info
hello.api_version_major = 1
hello.api_version_minor = 9
return hello
async def _connect_hello_login(self, login: bool) -> None:
"""Step 4 in connect process: send hello and login and get api version."""
messages = [self._make_hello_request()]
msg_types = [HelloResponse]
if login:
messages.append(self._make_connect_request())
msg_types.append(ConnectResponse)
try:
resp = await self.send_message_await_response(hello, HelloResponse)
responses = await self.send_messages_await_response_complex(
tuple(messages),
None,
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
is msg_types[-1],
tuple(msg_types),
CONNECT_REQUEST_TIMEOUT,
)
except TimeoutAPIError as err:
self._report_fatal_error(err)
raise TimeoutAPIError("Hello timed out") from err
resp = responses.pop(0)
self._process_hello_resp(resp)
if login:
login_response = responses.pop(0)
self._process_login_response(login_response)
def _process_login_response(self, login_response: ConnectResponse) -> None:
"""Process a ConnectResponse."""
if login_response.invalid_password:
raise InvalidAuthAPIError("Invalid password!")
def _process_hello_resp(self, resp: HelloResponse) -> None:
"""Process a HelloResponse."""
_LOGGER.debug(
"%s: Successfully connected ('%s' API=%s.%s)",
self.log_name,
@ -525,9 +556,7 @@ class APIConnection:
in_do_connect.set(True)
await self._connect_init_frame_helper()
self._register_internal_message_handlers()
await self._connect_hello()
if login:
await self._login()
await self._connect_hello_login(login)
self._async_schedule_keep_alive(self._loop.time())
async def finish_connection(self, *, login: bool) -> None:
@ -577,25 +606,22 @@ class APIConnection:
self.is_connected = state is ConnectionState.CONNECTED
self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE
async def _login(self) -> None:
"""Send a login (ConnectRequest) and await the response."""
def _make_connect_request(self) -> ConnectRequest:
"""Make a ConnectRequest."""
connect = ConnectRequest()
if self._params.password is not None:
connect.password = self._params.password
try:
resp = await self.send_message_await_response(
connect, ConnectResponse, timeout=CONNECT_REQUEST_TIMEOUT
)
except TimeoutAPIError as err:
# After a timeout for connect the connection can no longer be used
# We don't know what state the device may be in after ConnectRequest
# was already sent
_LOGGER.debug("%s: Login timed out", self.log_name)
self._report_fatal_error(err)
raise
return connect
if resp.invalid_password:
raise InvalidAuthAPIError("Invalid password!")
def _send_messages(self, messages: tuple[message.Message, ...]) -> None:
"""Send a message to the remote.
Currently this is a wrapper around send_message
but may be changed in the future to batch messages
together.
"""
for msg in messages:
self.send_message(msg)
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
@ -692,9 +718,9 @@ class APIConnection:
if do_stop is None or do_stop(resp):
fut.set_result(None)
async def send_message_await_response_complex( # pylint: disable=too-many-locals
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
self,
send_msg: message.Message,
messages: tuple[message.Message, ...],
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
msg_types: tuple[type[Any], ...],
@ -712,8 +738,7 @@ class APIConnection:
# Send the message right away to reduce latency.
# This is safe because we are not awaiting between
# sending the message and registering the handler
self.send_message(send_msg)
self._send_messages(messages)
loop = self._loop
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()
@ -736,8 +761,9 @@ class APIConnection:
await fut
except asyncio_TimeoutError as err:
timeout_expired = True
response_names = ", ".join(t.__name__ for t in msg_types)
raise TimeoutAPIError(
f"Timeout waiting for response to {type(send_msg).__name__} after {timeout}s"
f"Timeout waiting for {response_names} after {timeout}s"
) from err
finally:
if not timeout_expired:
@ -750,8 +776,8 @@ class APIConnection:
async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
) -> Any:
[response] = await self.send_message_await_response_complex(
send_msg,
[response] = await self.send_messages_await_response_complex(
(send_msg,),
None, # we will only get responses of `response_type`
None, # we will only get responses of `response_type`
(response_type,),

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Awaitable
from enum import Enum
from typing import Callable
@ -181,6 +182,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
async def _try_connect(self) -> bool:
"""Try connecting to the API client."""
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
start_connect_time = time.perf_counter()
try:
await self._cli.start_connection(on_stop=self._on_disconnect)
except Exception as err: # pylint: disable=broad-except
@ -192,7 +194,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._async_log_connection_error(err)
self._tries += 1
return False
_LOGGER.info("Successfully connected to %s", self._log_name)
finish_connect_time = time.perf_counter()
connect_time = finish_connect_time - start_connect_time
_LOGGER.info(
"Successfully connected to %s in %0.3fs", self._log_name, connect_time
)
self._stop_zc_listen()
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
try:
@ -212,7 +218,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
self._tries += 1
return False
self._tries = 0
_LOGGER.info("Successful handshake with %s", self._log_name)
finish_handshake_time = time.perf_counter()
handshake_time = finish_handshake_time - finish_connect_time
_LOGGER.info(
"Successful handshake with %s in %0.3fs", self._log_name, handshake_time
)
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
await self._on_connect_cb()
return True

View File

@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
raise ValueError("Response never stopped")
return resp
client._connection.send_message_await_response_complex = patched
client._connection.send_messages_await_response_complex = patched
def patch_response_callback(client: APIClient):

View File

@ -1,16 +1,20 @@
from __future__ import annotations
import asyncio
import logging
import socket
from datetime import timedelta
from typing import Any, Coroutine, Generator, Optional
from unittest.mock import AsyncMock
import pytest
from google.protobuf import message
from mock import MagicMock, patch
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import (
ConnectResponse,
DeviceInfoResponse,
HelloResponse,
PingRequest,
@ -18,8 +22,10 @@ from aioesphomeapi.api_pb2 import (
)
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
RequiresEncryptionAPIError,
TimeoutAPIError,
)
@ -27,6 +33,11 @@ from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
from .common import async_fire_time_changed, utcnow
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
async def connect(conn: APIConnection, login: bool = True):
"""Wrapper for connection logic to do both parts."""
@ -165,8 +176,8 @@ async def test_timeout_sending_message(
await connect_task
with pytest.raises(TimeoutAPIError):
await conn.send_message_await_response_complex(
PingRequest(), None, None, (PingResponse,), timeout=0
await conn.send_messages_await_response_complex(
(PingRequest(),), None, None, (PingResponse,), timeout=0
)
transport.reset_mock()
@ -176,9 +187,7 @@ async def test_timeout_sending_message(
transport.write.assert_called_with(b"\x00\x00\x05")
assert "disconnect request failed" in caplog.text
assert (
" Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text
)
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@pytest.mark.asyncio
@ -232,9 +241,7 @@ async def test_disconnect_when_not_fully_connected(
transport.write.assert_called_with(b"\x00\x00\x05")
assert "disconnect request failed" in caplog.text
assert (
" Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text
)
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@pytest.mark.asyncio
@ -250,7 +257,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
conn.connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError):
task = asyncio.create_task(conn._connect_hello())
task = asyncio.create_task(conn._connect_hello_login(login=True))
await asyncio.sleep(0)
protocol.data_received(b"\x01\x00\x00")
await task
@ -554,3 +561,99 @@ async def test_plaintext_connection_fails_handshake(
remove()
await conn.force_disconnect()
await asyncio.sleep(0)
def _generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
return (
b"\0"
+ _cached_varuint_to_bytes(len(msg))
+ _cached_varuint_to_bytes(type_)
+ msg
)
@pytest.mark.asyncio
async def test_connect_wrong_password(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=True))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = True
connect_msg = connect_response.SerializeToString()
protocol.data_received(
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
_generate_plaintext_packet(
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
)
)
with pytest.raises(InvalidAuthAPIError):
await connect_task
assert not conn.is_connected
@pytest.mark.asyncio
async def test_connect_correct_password(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol: Optional[APIPlaintextFrameHelper] = None
transport = MagicMock()
connected = asyncio.Event()
def _create_mock_transport_protocol(create_func, **kwargs):
nonlocal protocol
protocol = create_func()
protocol.connection_made(transport)
connected.set()
return transport, protocol
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
):
connect_task = asyncio.create_task(connect(conn, login=True))
await connected.wait()
hello_response: message.Message = HelloResponse()
hello_response.api_version_major = 1
hello_response.api_version_minor = 9
hello_response.name = "fake"
hello_msg = hello_response.SerializeToString()
connect_response: message.Message = ConnectResponse()
connect_response.invalid_password = False
connect_msg = connect_response.SerializeToString()
protocol.data_received(
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
)
protocol.data_received(
_generate_plaintext_packet(
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
)
)
await connect_task
assert conn.is_connected