Improve performance of processing incoming packets (#475)

This commit is contained in:
J. Nick Koston 2023-07-15 08:48:47 -10:00 committed by GitHub
parent 0dbab1ebac
commit 8306058703
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 59 deletions

View File

@ -306,11 +306,12 @@ class APIConnection:
"""Step 3 in connect process: initialize the frame helper and init read loop.""" """Step 3 in connect process: initialize the frame helper and init read loop."""
fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper] fh: Union[APIPlaintextFrameHelper, APINoiseFrameHelper]
loop = self._loop loop = self._loop
process_packet = self._process_packet_factory()
if self._params.noise_psk is None: if self._params.noise_psk is None:
_, fh = await loop.create_connection( _, fh = await loop.create_connection(
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet, on_pkt=process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
), ),
@ -321,7 +322,7 @@ class APIConnection:
lambda: APINoiseFrameHelper( lambda: APINoiseFrameHelper(
noise_psk=self._params.noise_psk, noise_psk=self._params.noise_psk,
expected_name=self._params.expected_name, expected_name=self._params.expected_name,
on_pkt=self._process_packet, on_pkt=process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info, client_info=self._params.client_info,
), ),
@ -699,74 +700,88 @@ class APIConnection:
self._read_exception_futures.clear() self._read_exception_futures.clear()
self._cleanup() self._cleanup()
def _process_packet(self, msg_type_proto: int, data: bytes) -> None: def _process_packet_factory(self) -> Callable[[int, bytes], None]:
"""Process a packet from the socket.""" """Factory to make a packet processor."""
debug = _LOGGER.isEnabledFor(logging.DEBUG) message_type_to_proto = MESSAGE_TYPE_TO_PROTO
if not (class_ := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)): is_enabled_for = _LOGGER.isEnabledFor
if debug: logging_debug = logging.DEBUG
message_handlers = self._message_handlers
def _process_packet(msg_type_proto: int, data: bytes) -> None:
"""Process a packet from the socket."""
try:
# python 3.11 has near zero cost exception handling
# if we do not raise which is almost never expected
# so we can just use a try/except here
class_ = message_type_to_proto[msg_type_proto]
except KeyError:
_LOGGER.debug( _LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto "%s: Skipping message type %s", self.log_name, msg_type_proto
) )
return return
msg = class_() msg = class_()
try: try:
# MergeFromString instead of ParseFromString since # MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and # ParseFromString will clear the message first and
# the msg is already empty. # the msg is already empty.
msg.MergeFromString(data) msg.MergeFromString(data)
except Exception as e: except Exception as e:
_LOGGER.info( _LOGGER.info(
"%s: Invalid protobuf message: type=%s data=%s: %s", "%s: Invalid protobuf message: type=%s data=%s: %s",
self.log_name, self.log_name,
msg_type_proto, msg_type_proto,
data, data,
e, e,
exc_info=True, exc_info=True,
)
self._report_fatal_error(
ProtocolAPIError(
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
) )
) self._report_fatal_error(
raise ProtocolAPIError(
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
)
)
raise
msg_type = type(msg) msg_type = type(msg)
if debug: if is_enabled_for(logging_debug):
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, msg_type, msg "%s: Got message of type %s: %s", self.log_name, msg_type, msg
) )
if self._pong_timer: if self._pong_timer:
# Any valid message from the remote cancels the pong timer # Any valid message from the remote cancels the pong timer
# as we know the connection is still alive # as we know the connection is still alive
self._async_cancel_pong_timer() self._async_cancel_pong_timer()
if self._send_pending_ping: if self._send_pending_ping:
# Any valid message from the remove cancels the pending ping # Any valid message from the remove cancels the pending ping
# since we know the connection is still alive # since we know the connection is still alive
self._send_pending_ping = False self._send_pending_ping = False
for handler in self._message_handlers.get(msg_type, [])[:]: handlers = message_handlers.get(msg_type)
handler(msg) if handlers is not None:
for handler in handlers[:]:
handler(msg)
# Pre-check the message type to avoid awaiting # Pre-check the message type to avoid awaiting
# since most messages are not internal messages # since most messages are not internal messages
if msg_type not in INTERNAL_MESSAGE_TYPES: if msg_type not in INTERNAL_MESSAGE_TYPES:
return return
if msg_type is DisconnectRequest: if msg_type is DisconnectRequest:
self.send_message(DisconnectResponse()) self.send_message(DisconnectResponse())
self._connection_state = ConnectionState.CLOSED self._connection_state = ConnectionState.CLOSED
self._expected_disconnect = True self._expected_disconnect = True
self._cleanup() self._cleanup()
elif msg_type is PingRequest: elif msg_type is PingRequest:
self.send_message(PING_RESPONSE_MESSAGE) self.send_message(PING_RESPONSE_MESSAGE)
elif msg_type is GetTimeRequest: elif msg_type is GetTimeRequest:
resp = GetTimeResponse() resp = GetTimeResponse()
resp.epoch_seconds = int(time.time()) resp.epoch_seconds = int(time.time())
self.send_message(resp) self.send_message(resp)
return _process_packet
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from the API.""" """Disconnect from the API."""

View File

@ -54,7 +54,7 @@ def socket_socket():
def _get_mock_protocol(conn: APIConnection): def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper( protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet, on_pkt=conn._process_packet_factory(),
on_error=conn._report_fatal_error, on_error=conn._report_fatal_error,
client_info="mock", client_info="mock",
) )