diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 20cbbf8..b5f35ca 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -216,27 +216,32 @@ class APIFrameHelper: async def _read_packet_plaintext(self) -> Packet: async with self._read_lock: - preamble = await self._reader.readexactly(1) - if preamble[0] != 0x00: - if preamble[0] == 0x01: - raise RequiresEncryptionAPIError("Connection requires encryption") - raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") + try: + preamble = await self._reader.readexactly(1) + if preamble[0] != 0x00: + if preamble[0] == 0x01: + raise RequiresEncryptionAPIError( + "Connection requires encryption" + ) + raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") - length = b"" - while not length or (length[-1] & 0x80) == 0x80: - length += await self._reader.readexactly(1) - length_int = bytes_to_varuint(length) - assert length_int is not None - msg_type = b"" - while not msg_type or (msg_type[-1] & 0x80) == 0x80: - msg_type += await self._reader.readexactly(1) - msg_type_int = bytes_to_varuint(msg_type) - assert msg_type_int is not None + length = b"" + while not length or (length[-1] & 0x80) == 0x80: + length += await self._reader.readexactly(1) + length_int = bytes_to_varuint(length) + assert length_int is not None + msg_type = b"" + while not msg_type or (msg_type[-1] & 0x80) == 0x80: + msg_type += await self._reader.readexactly(1) + msg_type_int = bytes_to_varuint(msg_type) + assert msg_type_int is not None - raw_msg = b"" - if length_int != 0: - raw_msg = await self._reader.readexactly(length_int) - return Packet(type=msg_type_int, data=raw_msg) + raw_msg = b"" + if length_int != 0: + raw_msg = await self._reader.readexactly(length_int) + return Packet(type=msg_type_int, data=raw_msg) + except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: + raise SocketAPIError(f"Error while reading data: {err}") from err async def read_packet(self) -> Packet: if self._params.noise_psk is None: @@ -531,6 +536,9 @@ class APIConnection: async def run_forever(self) -> None: while True: + if self._frame_helper is None: + # Socket closed + break try: await self._run_once() except APIConnectionError as err: