From dc367b67bb9864fecedd2e65a7e91ec1147afb5e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Fri, 13 Oct 2023 18:25:27 -1000 Subject: [PATCH] Narrow msg_types to only accept tuples (#574) --- aioesphomeapi/connection.pxd | 7 ++++--- aioesphomeapi/connection.py | 19 +++++++++---------- tests/test_connection.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/aioesphomeapi/connection.pxd b/aioesphomeapi/connection.pxd index 7dc314d..3eeb137 100644 --- a/aioesphomeapi/connection.pxd +++ b/aioesphomeapi/connection.pxd @@ -17,6 +17,7 @@ cdef object PING_RESPONSE_MESSAGE cdef object DisconnectRequest cdef object PingRequest cdef object GetTimeRequest +cdef object partial cdef class APIConnection: @@ -61,9 +62,9 @@ cdef class APIConnection: cpdef _report_fatal_error(self, Exception err) @cython.locals(handlers=set) - cpdef _add_message_callback_without_remove(self, object on_message, object msg_types) + cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types) - cpdef add_message_callback(self, object on_message, object msg_types) + cpdef add_message_callback(self, object on_message, tuple msg_types) @cython.locals(handlers=set) - cpdef _remove_message_callback(self, object on_message, object msg_types) \ No newline at end of file + cpdef _remove_message_callback(self, object on_message, tuple msg_types) \ No newline at end of file diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 3e0f4d1..2141490 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -7,7 +7,7 @@ import logging import socket import sys import time -from collections.abc import Coroutine, Iterable +from collections.abc import Coroutine from dataclasses import astuple, dataclass from functools import partial from typing import TYPE_CHECKING, Any, Callable @@ -584,7 +584,7 @@ class APIConnection: raise def _add_message_callback_without_remove( - self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] + self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...] ) -> None: """Add a message callback without returning a remove callable.""" message_handlers = self._message_handlers @@ -595,14 +595,14 @@ class APIConnection: handlers.add(on_message) def add_message_callback( - self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] + self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...] ) -> Callable[[], None]: """Add a message callback.""" self._add_message_callback_without_remove(on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types) def _remove_message_callback( - self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] + self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...] ) -> None: """Remove a message callback.""" message_handlers = self._message_handlers @@ -614,7 +614,7 @@ class APIConnection: self, send_msg: message.Message, on_message: Callable[[Any], None], - msg_types: Iterable[type[Any]], + msg_types: tuple[type[Any], ...], ) -> Callable[[], None]: """Send a message to the remote and register the given message handler.""" self.send_message(send_msg) @@ -651,7 +651,7 @@ class APIConnection: send_msg: message.Message, do_append: Callable[[message.Message], bool] | None, do_stop: Callable[[message.Message], bool] | None, - msg_types: Iterable[type[Any]], + msg_types: tuple[type[Any], ...], timeout: float = 10.0, ) -> list[message.Message]: """Send a message to the remote and build up a list response. @@ -672,9 +672,8 @@ class APIConnection: # Unsafe to await between sending the message and registering the handler fut: asyncio.Future[None] = loop.create_future() responses: list[message.Message] = [] - on_message = partial( - self._handle_complex_message, fut, responses, do_append, do_stop - ) + handler = self._handle_complex_message + on_message = partial(handler, fut, responses, do_append, do_stop) read_exception_futures = self._read_exception_futures self._add_message_callback_without_remove(on_message, msg_types) @@ -710,7 +709,7 @@ class APIConnection: None, # we will only get responses of `response_type` None, # we will only get responses of `response_type` (response_type,), - timeout=timeout, + timeout, ) return response diff --git a/tests/test_connection.py b/tests/test_connection.py index 8005fee..49f184f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -143,7 +143,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so def on_msg(msg): messages.append(msg) - remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse}) + remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) transport = MagicMock() with patch.object(