Narrow msg_types to only accept tuples (#574)

This commit is contained in:
J. Nick Koston 2023-10-13 18:25:27 -10:00 committed by GitHub
parent 3ad315cae0
commit dc367b67bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 14 deletions

View File

@ -17,6 +17,7 @@ cdef object PING_RESPONSE_MESSAGE
cdef object DisconnectRequest cdef object DisconnectRequest
cdef object PingRequest cdef object PingRequest
cdef object GetTimeRequest cdef object GetTimeRequest
cdef object partial
cdef class APIConnection: cdef class APIConnection:
@ -61,9 +62,9 @@ cdef class APIConnection:
cpdef _report_fatal_error(self, Exception err) cpdef _report_fatal_error(self, Exception err)
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _add_message_callback_without_remove(self, object on_message, object msg_types) cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
cpdef add_message_callback(self, object on_message, object msg_types) cpdef add_message_callback(self, object on_message, tuple msg_types)
@cython.locals(handlers=set) @cython.locals(handlers=set)
cpdef _remove_message_callback(self, object on_message, object msg_types) cpdef _remove_message_callback(self, object on_message, tuple msg_types)

View File

@ -7,7 +7,7 @@ import logging
import socket import socket
import sys import sys
import time import time
from collections.abc import Coroutine, Iterable from collections.abc import Coroutine
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable
@ -584,7 +584,7 @@ class APIConnection:
raise raise
def _add_message_callback_without_remove( def _add_message_callback_without_remove(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> None: ) -> None:
"""Add a message callback without returning a remove callable.""" """Add a message callback without returning a remove callable."""
message_handlers = self._message_handlers message_handlers = self._message_handlers
@ -595,14 +595,14 @@ class APIConnection:
handlers.add(on_message) handlers.add(on_message)
def add_message_callback( def add_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Add a message callback.""" """Add a message callback."""
self._add_message_callback_without_remove(on_message, msg_types) self._add_message_callback_without_remove(on_message, msg_types)
return partial(self._remove_message_callback, on_message, msg_types) return partial(self._remove_message_callback, on_message, msg_types)
def _remove_message_callback( def _remove_message_callback(
self, on_message: Callable[[Any], None], msg_types: Iterable[type[Any]] self, on_message: Callable[[Any], None], msg_types: tuple[type[Any], ...]
) -> None: ) -> None:
"""Remove a message callback.""" """Remove a message callback."""
message_handlers = self._message_handlers message_handlers = self._message_handlers
@ -614,7 +614,7 @@ class APIConnection:
self, self,
send_msg: message.Message, send_msg: message.Message,
on_message: Callable[[Any], None], on_message: Callable[[Any], None],
msg_types: Iterable[type[Any]], msg_types: tuple[type[Any], ...],
) -> Callable[[], None]: ) -> Callable[[], None]:
"""Send a message to the remote and register the given message handler.""" """Send a message to the remote and register the given message handler."""
self.send_message(send_msg) self.send_message(send_msg)
@ -651,7 +651,7 @@ class APIConnection:
send_msg: message.Message, send_msg: message.Message,
do_append: Callable[[message.Message], bool] | None, do_append: Callable[[message.Message], bool] | None,
do_stop: Callable[[message.Message], bool] | None, do_stop: Callable[[message.Message], bool] | None,
msg_types: Iterable[type[Any]], msg_types: tuple[type[Any], ...],
timeout: float = 10.0, timeout: float = 10.0,
) -> list[message.Message]: ) -> list[message.Message]:
"""Send a message to the remote and build up a list response. """Send a message to the remote and build up a list response.
@ -672,9 +672,8 @@ class APIConnection:
# Unsafe to await between sending the message and registering the handler # Unsafe to await between sending the message and registering the handler
fut: asyncio.Future[None] = loop.create_future() fut: asyncio.Future[None] = loop.create_future()
responses: list[message.Message] = [] responses: list[message.Message] = []
on_message = partial( handler = self._handle_complex_message
self._handle_complex_message, fut, responses, do_append, do_stop on_message = partial(handler, fut, responses, do_append, do_stop)
)
read_exception_futures = self._read_exception_futures read_exception_futures = self._read_exception_futures
self._add_message_callback_without_remove(on_message, msg_types) self._add_message_callback_without_remove(on_message, msg_types)
@ -710,7 +709,7 @@ class APIConnection:
None, # we will only get responses of `response_type` None, # we will only get responses of `response_type`
None, # we will only get responses of `response_type` None, # we will only get responses of `response_type`
(response_type,), (response_type,),
timeout=timeout, timeout,
) )
return response return response

View File

@ -143,7 +143,7 @@ async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_so
def on_msg(msg): def on_msg(msg):
messages.append(msg) messages.append(msg)
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse}) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
transport = MagicMock() transport = MagicMock()
with patch.object( with patch.object(