Narrow msg_types to only accept tuples (#574)

This commit is contained in:
J. Nick Koston 2023-10-13 18:25:27 -10:00 committed by GitHub
parent 3ad315cae0
commit dc367b67bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 14 deletions

View File

@ -17,6 +17,7 @@ cdef object PING_RESPONSE_MESSAGE
cdef object DisconnectRequest
cdef object PingRequest
cdef object GetTimeRequest
cdef object partial
cdef class APIConnection:
@ -61,9 +62,9 @@ cdef class APIConnection:
cpdef _report_fatal_error(self, Exception err)
@cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types)
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cpdef add_message_callback(self, object on_message, object msg_types)
cpdef add_message_callback(self, object on_message, tuple msg_types)
@cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, object msg_types)
cpdef _remove_message_callback(self, object on_message, tuple msg_types)

View File

@ -7,7 +7,7 @@ import logging
import socket
import sys
import time
from collections.abc import Coroutine, Iterable
from collections.abc import Coroutine
from dataclasses import astuple, dataclass
from functools import partial
from typing import TYPE_CHECKING, Any, Callable
@ -584,7 +584,7 @@ class APIConnection:
raise
def _add_message_callback_without_remove(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> None:
"""Add a message callback without returning a remove callable."""
message_handlers = self._message_handlers
@ -595,14 +595,14 @@ class APIConnection:
handlers.add(on_message)
def add_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> Callable[[], None]:
"""Add a message callback."""
self._add_message_callback_without_remove(on_message, msg_types)
return partial(self._remove_message_callback, on_message, msg_types)
def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]]
self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> None:
"""Remove a message callback."""
message_handlers = self._message_handlers
@ -614,7 +614,7 @@ class APIConnection:
self,
send_msg: message.Message,
on_message: Callable[[Any], None],
msg_types: Iterable[type[Any]],
msg_types: tuple[type[Any], ...],
) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler."""
self.send_message(send_msg)
@ -651,7 +651,7 @@ class APIConnection:
send_msg: message.Message,
do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None,
msg_types: Iterable[type[Any]],
msg_types: tuple[type[Any], ...],
timeout: float = 10.0,
) -> list[message.Message]:
"""Send a message to the remote and build up a list response.
@ -672,9 +672,8 @@ class APIConnection:
# Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future()
responses: list[message.Message] = []
on_message = partial(
self._handle_complex_message, fut, responses, do_append, do_stop
)
handler = self._handle_complex_message
on_message = partial(handler, fut, responses, do_append, do_stop)
read_exception_futures = self._read_exception_futures
self._add_message_callback_without_remove(on_message, msg_types)
@ -710,7 +709,7 @@ class APIConnection:
None, # we will only get responses of `response_type`
None, # we will only get responses of `response_type`
(response_type,),
timeout=timeout,
timeout,
)
return response

View File

@ -143,7 +143,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock()
with patch.object(