From a452e738ffceb156080f9aa16d3c7d670e26244c Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Wed, 30 Nov 2022 12:42:15 -1000 Subject: [PATCH] Move message parsing out of the read loop (#323) --- aioesphomeapi/connection.py | 100 ++++++++++++++++++------------------ 1 file changed, 49 insertions(+), 51 deletions(-) diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 97253e3..0688f4a 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -105,7 +105,7 @@ class APIConnection: self._ping_stop_event = asyncio.Event() - self._to_process: asyncio.Queue[message.Message] = asyncio.Queue() + self._to_process: asyncio.Queue[Packet] = asyncio.Queue() self._process_task: Optional[asyncio.Task[None]] = None @@ -511,24 +511,6 @@ class APIConnection: handler(err) await self._cleanup() - async def _read_once(self) -> None: - assert self._frame_helper is not None - pkt = await self._frame_helper.read_packet() - - msg_type = pkt.type - raw_msg = pkt.data - if msg_type not in MESSAGE_TYPE_TO_PROTO: - _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type) - return - - msg = MESSAGE_TYPE_TO_PROTO[msg_type]() - try: - msg.ParseFromString(raw_msg) - except Exception as e: - raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e - _LOGGER.debug("%s: Got message of type %s: %s", self.log_name, type(msg), msg) - self._to_process.put_nowait(msg) - async def _process_loop(self) -> None: while True: if not self._is_socket_open: @@ -536,46 +518,62 @@ class APIConnection: break try: - msg = await self._to_process.get() + pkt = await self._to_process.get() except RuntimeError: break + msg_type = pkt.type + raw_msg = pkt.data + if msg_type not in MESSAGE_TYPE_TO_PROTO: + _LOGGER.debug("%s: Skipping message type %s", self.log_name, msg_type) + continue + + msg = MESSAGE_TYPE_TO_PROTO[msg_type]() + try: + msg.ParseFromString(raw_msg) + except Exception as e: + await self._report_fatal_error( + ProtocolAPIError(f"Invalid protobuf message: {e}") + ) + raise + _LOGGER.debug( + "%s: Got message of type %s: %s", self.log_name, type(msg), msg + ) + for handler in self._message_handlers[:]: handler(msg) await self._handle_internal_messages(msg) async def _read_loop(self) -> None: - while True: - if not self._is_socket_open: - # Socket closed but task isn't cancelled yet - break - try: - await self._read_once() - except SocketClosedAPIError as err: - # don't log with info, if closed the site that closed the connection should log - _LOGGER.debug( - "%s: Socket closed, stopping read loop", - self.log_name, - ) - await self._report_fatal_error(err) - break - except APIConnectionError as err: - _LOGGER.info( - "%s: Error while reading incoming messages: %s", - self.log_name, - err, - ) - await self._report_fatal_error(err) - break - except Exception as err: # pylint: disable=broad-except - _LOGGER.warning( - "%s: Unexpected error while reading incoming messages: %s", - self.log_name, - err, - exc_info=True, - ) - await self._report_fatal_error(err) - break + assert self._frame_helper is not None + try: + while True: + if not self._is_socket_open: + # Socket closed but task isn't cancelled yet + break + self._to_process.put_nowait(await self._frame_helper.read_packet()) + except SocketClosedAPIError as err: + # don't log with info, if closed the site that closed the connection should log + _LOGGER.debug( + "%s: Socket closed, stopping read loop", + self.log_name, + ) + await self._report_fatal_error(err) + except APIConnectionError as err: + _LOGGER.info( + "%s: Error while reading incoming messages: %s", + self.log_name, + err, + ) + await self._report_fatal_error(err) + except Exception as err: # pylint: disable=broad-except + _LOGGER.warning( + "%s: Unexpected error while reading incoming messages: %s", + self.log_name, + err, + exc_info=True, + ) + await self._report_fatal_error(err) async def _handle_internal_messages(self, msg: Any) -> None: if isinstance(msg, DisconnectRequest):