Avoid a dict lookup and int conversion to process every packet (#946)

This commit is contained in:
J. Nick Koston 2024-09-03 12:35:18 -10:00 committed by GitHub
parent 88a256c5e8
commit 7d112a82e9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 37 additions and 12 deletions

View File

@ -49,7 +49,9 @@ cdef class APINoiseFrameHelper(APIFrameHelper):
@cython.locals( @cython.locals(
msg=bytes, msg=bytes,
type_high="unsigned char", type_high="unsigned char",
type_low="unsigned char" type_low="unsigned char",
msg_type="unsigned int",
payload=bytes
) )
cdef void _handle_frame(self, bytes frame) cdef void _handle_frame(self, bytes frame)

View File

@ -359,7 +359,9 @@ class APINoiseFrameHelper(APIFrameHelper):
# N bytes: message data # N bytes: message data
type_high = msg[0] type_high = msg[0]
type_low = msg[1] type_low = msg[1]
self._connection.process_packet((type_high << 8) | type_low, msg[4:]) msg_type = (type_high << 8) | type_low
payload = msg[4:]
self._connection.process_packet(msg_type, payload)
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
"""Handle a closed frame.""" """Handle a closed frame."""

View File

@ -73,6 +73,8 @@ cpdef void handle_complex_message(
cdef object _handle_timeout cdef object _handle_timeout
cdef object _handle_complex_message cdef object _handle_complex_message
cdef tuple MESSAGE_NUMBER_TO_PROTO
@cython.dataclasses.dataclass @cython.dataclasses.dataclass
cdef class ConnectionParams: cdef class ConnectionParams:
@ -119,7 +121,7 @@ cdef class APIConnection:
cdef void send_messages(self, tuple messages) cdef void send_messages(self, tuple messages)
@cython.locals(handlers=set, handlers_copy=set) @cython.locals(handlers=set, handlers_copy=set)
cpdef void process_packet(self, object msg_type_proto, object data) cpdef void process_packet(self, unsigned int msg_type_proto, object data)
cdef void _async_cancel_pong_timer(self) cdef void _async_cancel_pong_timer(self)

View File

@ -63,6 +63,9 @@ else:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values())
PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB PREFERRED_BUFFER_SIZE = 2097152 # Set buffer limit to 2MB
MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use MIN_BUFFER_SIZE = 131072 # Minimum buffer size to use
@ -888,22 +891,27 @@ class APIConnection:
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None: def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
"""Process an incoming packet.""" """Process an incoming packet."""
debug_enabled = self._debug_enabled debug_enabled = self._debug_enabled
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
if debug_enabled:
_LOGGER.debug(
"%s: Skipping unknown message type %s",
self.log_name,
msg_type_proto,
)
return
try: try:
# MESSAGE_NUMBER_TO_PROTO is 0-indexed
# but the message type is 1-indexed
klass = MESSAGE_NUMBER_TO_PROTO[msg_type_proto - 1]
msg: message.Message = klass() msg: message.Message = klass()
# MergeFromString instead of ParseFromString since # MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and # ParseFromString will clear the message first and
# the msg is already empty. # the msg is already empty.
msg.MergeFromString(data) msg.MergeFromString(data)
except Exception as e: except Exception as e:
# IndexError will be very rare so we check for it
# after the broad exception catch to avoid having
# to check the exception type twice for the common case
if isinstance(e, IndexError):
if debug_enabled:
_LOGGER.debug(
"%s: Skipping unknown message type %s",
self.log_name,
msg_type_proto,
)
return
_LOGGER.error( _LOGGER.error(
"%s: Invalid protobuf message: type=%s data=%s: %s", "%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name, self.log_name,

View File

@ -393,3 +393,5 @@ MESSAGE_TYPE_TO_PROTO = {
117: UpdateStateResponse, 117: UpdateStateResponse,
118: UpdateCommandRequest, 118: UpdateCommandRequest,
} }
MESSAGE_NUMBER_TO_PROTO = tuple(MESSAGE_TYPE_TO_PROTO.values())

9
tests/test_core.py Normal file
View File

@ -0,0 +1,9 @@
from __future__ import annotations
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO
def test_order_and_no_missing_numbers_in_message_type_to_proto():
"""Test that MESSAGE_TYPE_TO_PROTO has no missing numbers."""
for idx, (k, v) in enumerate(MESSAGE_TYPE_TO_PROTO.items()):
assert idx + 1 == k