mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-27 04:22:46 +02:00
Narrow msg_types to only accept tuples (#574)
This commit is contained in:
parent
3ad315cae0
commit
dc367b67bb
@ -17,6 +17,7 @@ cdef object PING_RESPONSE_MESSAGE
|
|||||||
cdef object DisconnectRequest
|
cdef object DisconnectRequest
|
||||||
cdef object PingRequest
|
cdef object PingRequest
|
||||||
cdef object GetTimeRequest
|
cdef object GetTimeRequest
|
||||||
|
cdef object partial
|
||||||
|
|
||||||
cdef class APIConnection:
|
cdef class APIConnection:
|
||||||
|
|
||||||
@ -61,9 +62,9 @@ cdef class APIConnection:
|
|||||||
cpdef _report_fatal_error(self, Exception err)
|
cpdef _report_fatal_error(self, Exception err)
|
||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types)
|
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
cpdef add_message_callback(self, object on_message, object msg_types)
|
cpdef add_message_callback(self, object on_message, tuple msg_types)
|
||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _remove_message_callback(self, object on_message, object msg_types)
|
cpdef _remove_message_callback(self, object on_message, tuple msg_types)
|
@ -7,7 +7,7 @@ import logging
|
|||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Coroutine, Iterable
|
from collections.abc import Coroutine
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Callable
|
from typing import TYPE_CHECKING, Any, Callable
|
||||||
@ -584,7 +584,7 @@ class APIConnection:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _add_message_callback_without_remove(
|
def _add_message_callback_without_remove(
|
||||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add a message callback without returning a remove callable."""
|
"""Add a message callback without returning a remove callable."""
|
||||||
message_handlers = self._message_handlers
|
message_handlers = self._message_handlers
|
||||||
@ -595,14 +595,14 @@ class APIConnection:
|
|||||||
handlers.add(on_message)
|
handlers.add(on_message)
|
||||||
|
|
||||||
def add_message_callback(
|
def add_message_callback(
|
||||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Add a message callback."""
|
"""Add a message callback."""
|
||||||
self._add_message_callback_without_remove(on_message, msg_types)
|
self._add_message_callback_without_remove(on_message, msg_types)
|
||||||
return partial(self._remove_message_callback, on_message, msg_types)
|
return partial(self._remove_message_callback, on_message, msg_types)
|
||||||
|
|
||||||
def _remove_message_callback(
|
def _remove_message_callback(
|
||||||
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
|
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Remove a message callback."""
|
"""Remove a message callback."""
|
||||||
message_handlers = self._message_handlers
|
message_handlers = self._message_handlers
|
||||||
@ -614,7 +614,7 @@ class APIConnection:
|
|||||||
self,
|
self,
|
||||||
send_msg: message.Message,
|
send_msg: message.Message,
|
||||||
on_message: Callable[[Any], None],
|
on_message: Callable[[Any], None],
|
||||||
msg_types: Iterable[type[Any]],
|
msg_types: tuple[type[Any], ...],
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Send a message to the remote and register the given message handler."""
|
"""Send a message to the remote and register the given message handler."""
|
||||||
self.send_message(send_msg)
|
self.send_message(send_msg)
|
||||||
@ -651,7 +651,7 @@ class APIConnection:
|
|||||||
send_msg: message.Message,
|
send_msg: message.Message,
|
||||||
do_append: Callable[[message.Message], bool] | None,
|
do_append: Callable[[message.Message], bool] | None,
|
||||||
do_stop: Callable[[message.Message], bool] | None,
|
do_stop: Callable[[message.Message], bool] | None,
|
||||||
msg_types: Iterable[type[Any]],
|
msg_types: tuple[type[Any], ...],
|
||||||
timeout: float = 10.0,
|
timeout: float = 10.0,
|
||||||
) -> list[message.Message]:
|
) -> list[message.Message]:
|
||||||
"""Send a message to the remote and build up a list response.
|
"""Send a message to the remote and build up a list response.
|
||||||
@ -672,9 +672,8 @@ class APIConnection:
|
|||||||
# Unsafe to await between sending the message and registering the handler
|
# Unsafe to await between sending the message and registering the handler
|
||||||
fut: asyncio.Future[None] = loop.create_future()
|
fut: asyncio.Future[None] = loop.create_future()
|
||||||
responses: list[message.Message] = []
|
responses: list[message.Message] = []
|
||||||
on_message = partial(
|
handler = self._handle_complex_message
|
||||||
self._handle_complex_message, fut, responses, do_append, do_stop
|
on_message = partial(handler, fut, responses, do_append, do_stop)
|
||||||
)
|
|
||||||
|
|
||||||
read_exception_futures = self._read_exception_futures
|
read_exception_futures = self._read_exception_futures
|
||||||
self._add_message_callback_without_remove(on_message, msg_types)
|
self._add_message_callback_without_remove(on_message, msg_types)
|
||||||
@ -710,7 +709,7 @@ class APIConnection:
|
|||||||
None, # we will only get responses of `response_type`
|
None, # we will only get responses of `response_type`
|
||||||
None, # we will only get responses of `response_type`
|
None, # we will only get responses of `response_type`
|
||||||
(response_type,),
|
(response_type,),
|
||||||
timeout=timeout,
|
timeout,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@ -143,7 +143,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
|
|||||||
def on_msg(msg):
|
def on_msg(msg):
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
|
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
|
Loading…
Reference in New Issue
Block a user