Small cleanups to connection (#648)

This commit is contained in:
J. Nick Koston 2023-11-21 13:08:48 +01:00 committed by GitHub
parent b40d34c154
commit 298aa01b00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 36 additions and 16 deletions

View File

@ -451,7 +451,7 @@ class APIClient:
assert self._connection is not None assert self._connection is not None
resp = await self._connection.send_messages_await_response_complex( resp = await self._connection.send_messages_await_response_complex(
(ListEntitiesRequest(),), do_append, do_stop, msg_types, timeout=60 (ListEntitiesRequest(),), do_append, do_stop, msg_types, 60
) )
entities: list[EntityInfo] = [] entities: list[EntityInfo] = []
services: list[UserService] = [] services: list[UserService] = []
@ -567,7 +567,7 @@ class APIClient:
message_filter = partial(self._filter_bluetooth_message, address, handle) message_filter = partial(self._filter_bluetooth_message, address, handle)
resp = await self._connection.send_messages_await_response_complex( resp = await self._connection.send_messages_await_response_complex(
(request,), message_filter, message_filter, msg_types, timeout=timeout (request,), message_filter, message_filter, msg_types, timeout
) )
if isinstance(resp[0], BluetoothGATTErrorResponse): if isinstance(resp[0], BluetoothGATTErrorResponse):
@ -893,7 +893,7 @@ class APIClient:
do_append, do_append,
do_stop, do_stop,
msg_types, msg_types,
timeout=DEFAULT_BLE_TIMEOUT, DEFAULT_BLE_TIMEOUT,
) )
services = [] services = []
for msg in resp: for msg in resp:

View File

@ -33,7 +33,7 @@ cdef object partial
cdef object hr cdef object hr
cdef object RESOLVE_TIMEOUT cdef object RESOLVE_TIMEOUT
cdef object CONNECT_AND_SETUP_TIMEOUT cdef object CONNECT_AND_SETUP_TIMEOUT, CONNECT_REQUEST_TIMEOUT
cdef object APIConnectionError cdef object APIConnectionError
cdef object BadNameAPIError cdef object BadNameAPIError
@ -42,10 +42,23 @@ cdef object PingFailedAPIError
cdef object ReadFailedAPIError cdef object ReadFailedAPIError
cdef object TimeoutAPIError cdef object TimeoutAPIError
cdef object in_do_connect, astuple
@cython.dataclasses.dataclass
cdef class ConnectionParams:
cdef public str address
cdef public object port
cdef public object password
cdef public object client_info
cdef public object keepalive
cdef public object zeroconf_manager
cdef public object noise_psk
cdef public object expected_name
cdef class APIConnection: cdef class APIConnection:
cdef object _params cdef ConnectionParams _params
cdef public object on_stop cdef public object on_stop
cdef object _on_stop_task cdef object _on_stop_task
cdef public object _socket cdef public object _socket
@ -95,3 +108,9 @@ cdef class APIConnection:
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, tuple msg_types) cpdef _remove_message_callback(self, object on_message, tuple msg_types)
cpdef _handle_disconnect_request_internal(self, object msg)
cpdef _handle_ping_request_internal(self, object msg)
cpdef _handle_get_time_request_internal(self, object msg)

View File

@ -339,7 +339,8 @@ class APIConnection:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
fh: APIPlaintextFrameHelper | APINoiseFrameHelper fh: APIPlaintextFrameHelper | APINoiseFrameHelper
loop = self._loop loop = self._loop
assert self._socket is not None if TYPE_CHECKING:
assert self._socket is not None
if (noise_psk := self._params.noise_psk) is None: if (noise_psk := self._params.noise_psk) is None:
_, fh = await loop.create_connection( # type: ignore[type-var] _, fh = await loop.create_connection( # type: ignore[type-var]
@ -461,7 +462,7 @@ class APIConnection:
now = loop.time() now = loop.time()
if self._send_pending_ping: if self._send_pending_ping:
self.send_message(PING_REQUEST_MESSAGE) self.send_messages((PING_REQUEST_MESSAGE,))
if self._pong_timer is None: if self._pong_timer is None:
# Do not reset the timer if it's already set # Do not reset the timer if it's already set
# since the only thing we want to reset the timer # since the only thing we want to reset the timer
@ -694,7 +695,7 @@ class APIConnection:
msg_types: tuple[type[Any], ...], msg_types: tuple[type[Any], ...],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler.""" """Send a message to the remote and register the given message handler."""
self.send_message(send_msg) self.send_messages((send_msg,))
# Since we do not return control to the event loop (no awaits) # Since we do not return control to the event loop (no awaits)
# between sending the message and registering the handler # between sending the message and registering the handler
# we can be sure that we will not miss any messages even though # we can be sure that we will not miss any messages even though
@ -729,7 +730,7 @@ class APIConnection:
do_append: Callable[[message.Message], bool] | None, do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None, do_stop: Callable[[message.Message], bool] | None,
msg_types: tuple[type[Any], ...], msg_types: tuple[type[Any], ...],
timeout: float = 10.0, timeout: _float,
) -> list[message.Message]: ) -> list[message.Message]:
"""Send a message to the remote and build up a list response. """Send a message to the remote and build up a list response.
@ -779,7 +780,7 @@ class APIConnection:
return responses return responses
async def send_message_await_response( async def send_message_await_response(
self, send_msg: message.Message, response_type: Any, timeout: float = 10.0 self, send_msg: message.Message, response_type: Any, timeout: _float = 10.0
) -> Any: ) -> Any:
[response] = await self.send_messages_await_response_complex( [response] = await self.send_messages_await_response_complex(
(send_msg,), (send_msg,),
@ -887,14 +888,14 @@ class APIConnection:
# the response if for some reason sending the response # the response if for some reason sending the response
# fails we will still mark the disconnect as expected # fails we will still mark the disconnect as expected
self._expected_disconnect = True self._expected_disconnect = True
self.send_message(DISCONNECT_RESPONSE_MESSAGE) self.send_messages((DISCONNECT_RESPONSE_MESSAGE,))
self._cleanup() self._cleanup()
def _handle_ping_request_internal( # pylint: disable=unused-argument def _handle_ping_request_internal( # pylint: disable=unused-argument
self, _msg: PingRequest self, _msg: PingRequest
) -> None: ) -> None:
"""Handle a PingRequest.""" """Handle a PingRequest."""
self.send_message(PING_RESPONSE_MESSAGE) self.send_messages((PING_RESPONSE_MESSAGE,))
def _handle_get_time_request_internal( # pylint: disable=unused-argument def _handle_get_time_request_internal( # pylint: disable=unused-argument
self, _msg: GetTimeRequest self, _msg: GetTimeRequest
@ -902,7 +903,7 @@ class APIConnection:
"""Handle a GetTimeRequest.""" """Handle a GetTimeRequest."""
resp = GetTimeResponse() resp = GetTimeResponse()
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
self.send_message(resp) self.send_messages((resp,))
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from the API.""" """Disconnect from the API."""
@ -946,7 +947,7 @@ class APIConnection:
# Still try to tell the esp to disconnect gracefully # Still try to tell the esp to disconnect gracefully
# but don't wait for it to finish # but don't wait for it to finish
try: try:
self.send_message(DISCONNECT_REQUEST_MESSAGE) self.send_messages((DISCONNECT_REQUEST_MESSAGE,))
except APIConnectionError as err: except APIConnectionError as err:
_LOGGER.error( _LOGGER.error(
"%s: Failed to send (forced) disconnect request: %s", "%s: Failed to send (forced) disconnect request: %s",

View File

@ -79,7 +79,7 @@ def auth_client():
def patch_response_complex(client: APIClient, messages): def patch_response_complex(client: APIClient, messages):
async def patched(req, app, stop, msg_types, timeout=5.0): async def patched(req, app, stop, msg_types, timeout):
resp = [] resp = []
for msg in messages: for msg in messages:
if app(msg): if app(msg):

View File

@ -97,7 +97,7 @@ async def test_timeout_sending_message(
with pytest.raises(TimeoutAPIError): with pytest.raises(TimeoutAPIError):
await conn.send_messages_await_response_complex( await conn.send_messages_await_response_complex(
(PingRequest(),), None, None, (PingResponse,), timeout=0 (PingRequest(),), None, None, (PingResponse,), 0
) )
transport.reset_mock() transport.reset_mock()