mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-04 18:58:05 +01:00
Reduce number of calls to readexactly (#324)
This commit is contained in:
parent
a452e738ff
commit
6273f785f4
@ -74,29 +74,47 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
async def read_packet(self) -> Packet:
|
async def read_packet(self) -> Packet:
|
||||||
async with self._read_lock:
|
async with self._read_lock:
|
||||||
try:
|
try:
|
||||||
preamble = await self._reader.readexactly(1)
|
# Read preamble, which should always 0x00
|
||||||
if preamble[0] != 0x00:
|
# Also try to get the length and msg type
|
||||||
if preamble[0] == 0x01:
|
# to avoid multiple calls to readexactly
|
||||||
|
init_bytes = await self._reader.readexactly(3)
|
||||||
|
if init_bytes[0] != 0x00:
|
||||||
|
if init_bytes[0] == 0x01:
|
||||||
raise RequiresEncryptionAPIError(
|
raise RequiresEncryptionAPIError(
|
||||||
"Connection requires encryption"
|
"Connection requires encryption"
|
||||||
)
|
)
|
||||||
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
||||||
|
|
||||||
length = b""
|
if init_bytes[1] & 0x80 == 0x80:
|
||||||
while not length or (length[-1] & 0x80) == 0x80:
|
# Length is longer than 1 byte
|
||||||
length += await self._reader.readexactly(1)
|
length = init_bytes[1:3]
|
||||||
length_int = bytes_to_varuint(length)
|
|
||||||
assert length_int is not None
|
|
||||||
msg_type = b""
|
msg_type = b""
|
||||||
|
else:
|
||||||
|
# This is the most common case with 99% of messages
|
||||||
|
# needing a single byte for length and type which means
|
||||||
|
# we avoid 2 calls to readexactly
|
||||||
|
length = init_bytes[1:2]
|
||||||
|
msg_type = init_bytes[2:3]
|
||||||
|
|
||||||
|
# If the message is long, we need to read the rest of the length
|
||||||
|
while length[-1] & 0x80 == 0x80:
|
||||||
|
length += await self._reader.readexactly(1)
|
||||||
|
|
||||||
|
# If the message length was longer than 1 byte, we need to read the
|
||||||
|
# message type
|
||||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||||
msg_type += await self._reader.readexactly(1)
|
msg_type += await self._reader.readexactly(1)
|
||||||
|
|
||||||
|
length_int = bytes_to_varuint(length)
|
||||||
|
assert length_int is not None
|
||||||
msg_type_int = bytes_to_varuint(msg_type)
|
msg_type_int = bytes_to_varuint(msg_type)
|
||||||
assert msg_type_int is not None
|
assert msg_type_int is not None
|
||||||
|
|
||||||
raw_msg = b""
|
if length_int == 0:
|
||||||
if length_int != 0:
|
return Packet(type=msg_type_int, data=b"")
|
||||||
raw_msg = await self._reader.readexactly(length_int)
|
|
||||||
return Packet(type=msg_type_int, data=raw_msg)
|
data = await self._reader.readexactly(length_int)
|
||||||
|
return Packet(type=msg_type_int, data=data)
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
if (
|
if (
|
||||||
isinstance(err, asyncio.IncompleteReadError)
|
isinstance(err, asyncio.IncompleteReadError)
|
||||||
|
@ -23,9 +23,9 @@ def bytes_to_varuint(value: bytes) -> Optional[int]:
|
|||||||
bitpos = 0
|
bitpos = 0
|
||||||
for val in value:
|
for val in value:
|
||||||
result |= (val & 0x7F) << bitpos
|
result |= (val & 0x7F) << bitpos
|
||||||
bitpos += 7
|
|
||||||
if (val & 0x80) == 0:
|
if (val & 0x80) == 0:
|
||||||
return result
|
return result
|
||||||
|
bitpos += 7
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user