mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-12 10:33:57 +01:00
Send hello and login asynchronously to speed up connecting (#628)
This commit is contained in:
parent
8678fa9ebc
commit
a15b96c76b
@ -441,8 +441,8 @@ class APIClient:
|
||||
return isinstance(msg, ListEntitiesDoneResponse)
|
||||
|
||||
assert self._connection is not None
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
ListEntitiesRequest(), do_append, do_stop, msg_types, timeout=60
|
||||
resp = await self._connection.send_messages_await_response_complex(
|
||||
(ListEntitiesRequest(),), do_append, do_stop, msg_types, timeout=60
|
||||
)
|
||||
entities: list[EntityInfo] = []
|
||||
services: list[UserService] = []
|
||||
@ -557,8 +557,8 @@ class APIClient:
|
||||
assert self._connection is not None
|
||||
|
||||
message_filter = partial(self._filter_bluetooth_message, address, handle)
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
request, message_filter, message_filter, msg_types, timeout=timeout
|
||||
resp = await self._connection.send_messages_await_response_complex(
|
||||
(request,), message_filter, message_filter, msg_types, timeout=timeout
|
||||
)
|
||||
|
||||
if isinstance(resp[0], BluetoothGATTErrorResponse):
|
||||
@ -791,9 +791,11 @@ class APIClient:
|
||||
)
|
||||
return True
|
||||
|
||||
[response] = await self._connection.send_message_await_response_complex(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.PAIR
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.PAIR
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
@ -812,9 +814,11 @@ class APIClient:
|
||||
def predicate_func(msg: BluetoothDeviceUnpairingResponse) -> bool:
|
||||
return bool(msg.address == address)
|
||||
|
||||
[response] = await self._connection.send_message_await_response_complex(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.UNPAIR
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
@ -833,9 +837,11 @@ class APIClient:
|
||||
def predicate_func(msg: BluetoothDeviceClearCacheResponse) -> bool:
|
||||
return bool(msg.address == address)
|
||||
|
||||
[response] = await self._connection.send_message_await_response_complex(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
|
||||
[response] = await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address, request_type=BluetoothDeviceRequestType.CLEAR_CACHE
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
@ -853,10 +859,12 @@ class APIClient:
|
||||
return bool(msg.address == address and not msg.connected)
|
||||
|
||||
assert self._connection is not None
|
||||
await self._connection.send_message_await_response_complex(
|
||||
BluetoothDeviceRequest(
|
||||
address=address,
|
||||
request_type=BluetoothDeviceRequestType.DISCONNECT,
|
||||
await self._connection.send_messages_await_response_complex(
|
||||
(
|
||||
BluetoothDeviceRequest(
|
||||
address=address,
|
||||
request_type=BluetoothDeviceRequestType.DISCONNECT,
|
||||
),
|
||||
),
|
||||
predicate_func,
|
||||
predicate_func,
|
||||
@ -883,8 +891,8 @@ class APIClient:
|
||||
return isinstance(msg, stop_types) and msg.address == address
|
||||
|
||||
assert self._connection is not None
|
||||
resp = await self._connection.send_message_await_response_complex(
|
||||
BluetoothGATTGetServicesRequest(address=address),
|
||||
resp = await self._connection.send_messages_await_response_complex(
|
||||
(BluetoothGATTGetServicesRequest(address=address),),
|
||||
do_append,
|
||||
do_stop,
|
||||
msg_types,
|
||||
|
@ -89,3 +89,5 @@ cdef class APIConnection:
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
cdef _send_messages(self, tuple messages)
|
||||
|
@ -369,17 +369,48 @@ class APIConnection:
|
||||
raise HandshakeAPIError(f"Handshake failed: {err}") from err
|
||||
self._set_connection_state(ConnectionState.HANDSHAKE_COMPLETE)
|
||||
|
||||
async def _connect_hello(self) -> None:
|
||||
"""Step 4 in connect process: send hello and get api version."""
|
||||
def _make_hello_request(self) -> HelloRequest:
|
||||
"""Make a HelloRequest."""
|
||||
hello = HelloRequest()
|
||||
hello.client_info = self._params.client_info
|
||||
hello.api_version_major = 1
|
||||
hello.api_version_minor = 9
|
||||
return hello
|
||||
|
||||
async def _connect_hello_login(self, login: bool) -> None:
|
||||
"""Step 4 in connect process: send hello and login and get api version."""
|
||||
messages = [self._make_hello_request()]
|
||||
msg_types = [HelloResponse]
|
||||
if login:
|
||||
messages.append(self._make_connect_request())
|
||||
msg_types.append(ConnectResponse)
|
||||
|
||||
try:
|
||||
resp = await self.send_message_await_response(hello, HelloResponse)
|
||||
responses = await self.send_messages_await_response_complex(
|
||||
tuple(messages),
|
||||
None,
|
||||
lambda resp: type(resp) # pylint: disable=unidiomatic-typecheck
|
||||
is msg_types[-1],
|
||||
tuple(msg_types),
|
||||
CONNECT_REQUEST_TIMEOUT,
|
||||
)
|
||||
except TimeoutAPIError as err:
|
||||
self._report_fatal_error(err)
|
||||
raise TimeoutAPIError("Hello timed out") from err
|
||||
|
||||
resp = responses.pop(0)
|
||||
self._process_hello_resp(resp)
|
||||
if login:
|
||||
login_response = responses.pop(0)
|
||||
self._process_login_response(login_response)
|
||||
|
||||
def _process_login_response(self, login_response: ConnectResponse) -> None:
|
||||
"""Process a ConnectResponse."""
|
||||
if login_response.invalid_password:
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
|
||||
def _process_hello_resp(self, resp: HelloResponse) -> None:
|
||||
"""Process a HelloResponse."""
|
||||
_LOGGER.debug(
|
||||
"%s: Successfully connected ('%s' API=%s.%s)",
|
||||
self.log_name,
|
||||
@ -525,9 +556,7 @@ class APIConnection:
|
||||
in_do_connect.set(True)
|
||||
await self._connect_init_frame_helper()
|
||||
self._register_internal_message_handlers()
|
||||
await self._connect_hello()
|
||||
if login:
|
||||
await self._login()
|
||||
await self._connect_hello_login(login)
|
||||
self._async_schedule_keep_alive(self._loop.time())
|
||||
|
||||
async def finish_connection(self, *, login: bool) -> None:
|
||||
@ -577,25 +606,22 @@ class APIConnection:
|
||||
self.is_connected = state is ConnectionState.CONNECTED
|
||||
self._handshake_complete = state is ConnectionState.HANDSHAKE_COMPLETE
|
||||
|
||||
async def _login(self) -> None:
|
||||
"""Send a login (ConnectRequest) and await the response."""
|
||||
def _make_connect_request(self) -> ConnectRequest:
|
||||
"""Make a ConnectRequest."""
|
||||
connect = ConnectRequest()
|
||||
if self._params.password is not None:
|
||||
connect.password = self._params.password
|
||||
try:
|
||||
resp = await self.send_message_await_response(
|
||||
connect, ConnectResponse, timeout=CONNECT_REQUEST_TIMEOUT
|
||||
)
|
||||
except TimeoutAPIError as err:
|
||||
# After a timeout for connect the connection can no longer be used
|
||||
# We don't know what state the device may be in after ConnectRequest
|
||||
# was already sent
|
||||
_LOGGER.debug("%s: Login timed out", self.log_name)
|
||||
self._report_fatal_error(err)
|
||||
raise
|
||||
return connect
|
||||
|
||||
if resp.invalid_password:
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
def _send_messages(self, messages: tuple[message.Message, ...]) -> None:
|
||||
"""Send a message to the remote.
|
||||
|
||||
Currently this is a wrapper around send_message
|
||||
but may be changed in the future to batch messages
|
||||
together.
|
||||
"""
|
||||
for msg in messages:
|
||||
self.send_message(msg)
|
||||
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
@ -692,9 +718,9 @@ class APIConnection:
|
||||
if do_stop is None or do_stop(resp):
|
||||
fut.set_result(None)
|
||||
|
||||
async def send_message_await_response_complex( # pylint: disable=too-many-locals
|
||||
async def send_messages_await_response_complex( # pylint: disable=too-many-locals
|
||||
self,
|
||||
send_msg: message.Message,
|
||||
messages: tuple[message.Message, ...],
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
msg_types: tuple[type[Any], ...],
|
||||
@ -712,8 +738,7 @@ class APIConnection:
|
||||
# Send the message right away to reduce latency.
|
||||
# This is safe because we are not awaiting between
|
||||
# sending the message and registering the handler
|
||||
|
||||
self.send_message(send_msg)
|
||||
self._send_messages(messages)
|
||||
loop = self._loop
|
||||
# Unsafe to await between sending the message and registering the handler
|
||||
fut: asyncio.Future[None] = loop.create_future()
|
||||
@ -736,8 +761,9 @@ class APIConnection:
|
||||
await fut
|
||||
except asyncio_TimeoutError as err:
|
||||
timeout_expired = True
|
||||
response_names = ", ".join(t.__name__ for t in msg_types)
|
||||
raise TimeoutAPIError(
|
||||
f"Timeout waiting for response to {type(send_msg).__name__} after {timeout}s"
|
||||
f"Timeout waiting for {response_names} after {timeout}s"
|
||||
) from err
|
||||
finally:
|
||||
if not timeout_expired:
|
||||
@ -750,8 +776,8 @@ class APIConnection:
|
||||
async def send_message_await_response(
|
||||
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0
|
||||
) -> Any:
|
||||
[response] = await self.send_message_await_response_complex(
|
||||
send_msg,
|
||||
[response] = await self.send_messages_await_response_complex(
|
||||
(send_msg,),
|
||||
None, # we will only get responses of `response_type`
|
||||
None, # we will only get responses of `response_type`
|
||||
(response_type,),
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Awaitable
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
@ -181,6 +182,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
async def _try_connect(self) -> bool:
|
||||
"""Try connecting to the API client."""
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.CONNECTING)
|
||||
start_connect_time = time.perf_counter()
|
||||
try:
|
||||
await self._cli.start_connection(on_stop=self._on_disconnect)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
@ -192,7 +194,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._async_log_connection_error(err)
|
||||
self._tries += 1
|
||||
return False
|
||||
_LOGGER.info("Successfully connected to %s", self._log_name)
|
||||
finish_connect_time = time.perf_counter()
|
||||
connect_time = finish_connect_time - start_connect_time
|
||||
_LOGGER.info(
|
||||
"Successfully connected to %s in %0.3fs", self._log_name, connect_time
|
||||
)
|
||||
self._stop_zc_listen()
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
|
||||
try:
|
||||
@ -212,7 +218,11 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
self._tries += 1
|
||||
return False
|
||||
self._tries = 0
|
||||
_LOGGER.info("Successful handshake with %s", self._log_name)
|
||||
finish_handshake_time = time.perf_counter()
|
||||
handshake_time = finish_handshake_time - finish_connect_time
|
||||
_LOGGER.info(
|
||||
"Successful handshake with %s in %0.3fs", self._log_name, handshake_time
|
||||
)
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
|
||||
await self._on_connect_cb()
|
||||
return True
|
||||
|
@ -74,7 +74,7 @@ def patch_response_complex(client: APIClient, messages):
|
||||
raise ValueError("Response never stopped")
|
||||
return resp
|
||||
|
||||
client._connection.send_message_await_response_complex = patched
|
||||
client._connection.send_messages_await_response_complex = patched
|
||||
|
||||
|
||||
def patch_response_callback(client: APIClient):
|
||||
|
@ -1,16 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
from datetime import timedelta
|
||||
from typing import Any, Coroutine, Generator, Optional
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from google.protobuf import message
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
ConnectResponse,
|
||||
DeviceInfoResponse,
|
||||
HelloResponse,
|
||||
PingRequest,
|
||||
@ -18,8 +22,10 @@ from aioesphomeapi.api_pb2 import (
|
||||
)
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
TimeoutAPIError,
|
||||
)
|
||||
@ -27,6 +33,11 @@ from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
from .common import async_fire_time_changed, utcnow
|
||||
|
||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
logging.getLogger("aioesphomeapi").setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
async def connect(conn: APIConnection, login: bool = True):
|
||||
"""Wrapper for connection logic to do both parts."""
|
||||
@ -165,8 +176,8 @@ async def test_timeout_sending_message(
|
||||
await connect_task
|
||||
|
||||
with pytest.raises(TimeoutAPIError):
|
||||
await conn.send_message_await_response_complex(
|
||||
PingRequest(), None, None, (PingResponse,), timeout=0
|
||||
await conn.send_messages_await_response_complex(
|
||||
(PingRequest(),), None, None, (PingResponse,), timeout=0
|
||||
)
|
||||
|
||||
transport.reset_mock()
|
||||
@ -176,9 +187,7 @@ async def test_timeout_sending_message(
|
||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
||||
|
||||
assert "disconnect request failed" in caplog.text
|
||||
assert (
|
||||
" Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text
|
||||
)
|
||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -232,9 +241,7 @@ async def test_disconnect_when_not_fully_connected(
|
||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
||||
|
||||
assert "disconnect request failed" in caplog.text
|
||||
assert (
|
||||
" Timeout waiting for response to DisconnectRequest after 0.0s" in caplog.text
|
||||
)
|
||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -250,7 +257,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
conn.connection_state = ConnectionState.CONNECTED
|
||||
|
||||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
task = asyncio.create_task(conn._connect_hello())
|
||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
await asyncio.sleep(0)
|
||||
protocol.data_received(b"\x01\x00\x00")
|
||||
await task
|
||||
@ -554,3 +561,99 @@ async def test_plaintext_connection_fails_handshake(
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
def _generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
|
||||
return (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(msg))
|
||||
+ _cached_varuint_to_bytes(type_)
|
||||
+ msg
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_wrong_password(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=True))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = True
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(
|
||||
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(InvalidAuthAPIError):
|
||||
await connect_task
|
||||
|
||||
assert not conn.is_connected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_correct_password(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol: Optional[APIPlaintextFrameHelper] = None
|
||||
transport = MagicMock()
|
||||
connected = asyncio.Event()
|
||||
|
||||
def _create_mock_transport_protocol(create_func, **kwargs):
|
||||
nonlocal protocol
|
||||
protocol = create_func()
|
||||
protocol.connection_made(transport)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
):
|
||||
connect_task = asyncio.create_task(connect(conn, login=True))
|
||||
await connected.wait()
|
||||
hello_response: message.Message = HelloResponse()
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = False
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(
|
||||
_generate_plaintext_packet(
|
||||
connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse]
|
||||
)
|
||||
)
|
||||
|
||||
await connect_task
|
||||
|
||||
assert conn.is_connected
|
||||
|
Loading…
Reference in New Issue
Block a user