mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-11 13:21:25 +01:00
Narrow msg_types to only accept tuples (#574)
This commit is contained in:
parent
3ad315cae0
commit
dc367b67bb
@ -17,6 +17,7 @@ cdef object PING_RESPONSE_MESSAGE
|
||||
cdef object DisconnectRequest
|
||||
cdef object PingRequest
|
||||
cdef object GetTimeRequest
|
||||
cdef object partial
|
||||
|
||||
cdef class APIConnection:
|
||||
|
||||
@ -61,9 +62,9 @@ cdef class APIConnection:
|
||||
cpdef _report_fatal_error(self, Exception err)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types)
|
||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||
|
||||
cpdef add_message_callback(self, object on_message, object msg_types)
|
||||
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
||||
|
||||
@cython.locals(handlers=set)
|
||||
cpdef _remove_message_callback(self, object on_message, object msg_types)
|
||||
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
@ -7,7 +7,7 @@ import logging
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Coroutine, Iterable
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
@ -584,7 +584,7 @@ class APIConnection:
|
||||
raise
|
||||
|
||||
def _add_message_callback_without_remove(
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
||||
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||
) -> None:
|
||||
"""Add a message callback without returning a remove callable."""
|
||||
message_handlers = self._message_handlers
|
||||
@ -595,14 +595,14 @@ class APIConnection:
|
||||
handlers.add(on_message)
|
||||
|
||||
def add_message_callback(
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
||||
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||
) -> Callable[[], None]:
|
||||
"""Add a message callback."""
|
||||
self._add_message_callback_without_remove(on_message, msg_types)
|
||||
return partial(self._remove_message_callback, on_message, msg_types)
|
||||
|
||||
def _remove_message_callback(
|
||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
||||
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||
) -> None:
|
||||
"""Remove a message callback."""
|
||||
message_handlers = self._message_handlers
|
||||
@ -614,7 +614,7 @@ class APIConnection:
|
||||
self,
|
||||
send_msg: message.Message,
|
||||
on_message: Callable[[Any], None],
|
||||
msg_types: Iterable[type[Any]],
|
||||
msg_types: tuple[type[Any], ...],
|
||||
) -> Callable[[], None]:
|
||||
"""Send a message to the remote and register the given message handler."""
|
||||
self.send_message(send_msg)
|
||||
@ -651,7 +651,7 @@ class APIConnection:
|
||||
send_msg: message.Message,
|
||||
do_append: Callable[[message.Message], bool] | None,
|
||||
do_stop: Callable[[message.Message], bool] | None,
|
||||
msg_types: Iterable[type[Any]],
|
||||
msg_types: tuple[type[Any], ...],
|
||||
timeout: float = 10.0,
|
||||
) -> list[message.Message]:
|
||||
"""Send a message to the remote and build up a list response.
|
||||
@ -672,9 +672,8 @@ class APIConnection:
|
||||
# Unsafe to await between sending the message and registering the handler
|
||||
fut: asyncio.Future[None] = loop.create_future()
|
||||
responses: list[message.Message] = []
|
||||
on_message = partial(
|
||||
self._handle_complex_message, fut, responses, do_append, do_stop
|
||||
)
|
||||
handler = self._handle_complex_message
|
||||
on_message = partial(handler, fut, responses, do_append, do_stop)
|
||||
|
||||
read_exception_futures = self._read_exception_futures
|
||||
self._add_message_callback_without_remove(on_message, msg_types)
|
||||
@ -710,7 +709,7 @@ class APIConnection:
|
||||
None, # we will only get responses of `response_type`
|
||||
None, # we will only get responses of `response_type`
|
||||
(response_type,),
|
||||
timeout=timeout,
|
||||
timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
|
@ -143,7 +143,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(
|
||||
|
Loading…
Reference in New Issue
Block a user