Reduce number of calls to readexactly (#324)

This commit is contained in:
J. Nick Koston 2022-11-30 12:47:26 -10:00 committed by GitHub
parent a452e738ff
commit 6273f785f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 32 additions and 14 deletions

View File

@ -74,29 +74,47 @@ class APIPlaintextFrameHelper(APIFrameHelper):
async def read_packet(self) -> Packet: async def read_packet(self) -> Packet:
async with self._read_lock: async with self._read_lock:
try: try:
preamble = await self._reader.readexactly(1) # Read preamble, which should always 0x00
if preamble[0] != 0x00: # Also try to get the length and msg type
if preamble[0] == 0x01: # to avoid multiple calls to readexactly
init_bytes = await self._reader.readexactly(3)
if init_bytes[0] != 0x00:
if init_bytes[0] == 0x01:
raise RequiresEncryptionAPIError( raise RequiresEncryptionAPIError(
"Connection requires encryption" "Connection requires encryption"
) )
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}") raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
length = b"" if init_bytes[1] & 0x80 == 0x80:
while not length or (length[-1] & 0x80) == 0x80: # Length is longer than 1 byte
length = init_bytes[1:3]
msg_type = b""
else:
# This is the most common case with 99% of messages
# needing a single byte for length and type which means
# we avoid 2 calls to readexactly
length = init_bytes[1:2]
msg_type = init_bytes[2:3]
# If the message is long, we need to read the rest of the length
while length[-1] & 0x80 == 0x80:
length += await self._reader.readexactly(1) length += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None # If the message length was longer than 1 byte, we need to read the
msg_type = b"" # message type
while not msg_type or (msg_type[-1] & 0x80) == 0x80: while not msg_type or (msg_type[-1] & 0x80) == 0x80:
msg_type += await self._reader.readexactly(1) msg_type += await self._reader.readexactly(1)
length_int = bytes_to_varuint(length)
assert length_int is not None
msg_type_int = bytes_to_varuint(msg_type) msg_type_int = bytes_to_varuint(msg_type)
assert msg_type_int is not None assert msg_type_int is not None
raw_msg = b"" if length_int == 0:
if length_int != 0: return Packet(type=msg_type_int, data=b"")
raw_msg = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=raw_msg) data = await self._reader.readexactly(length_int)
return Packet(type=msg_type_int, data=data)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
if ( if (
isinstance(err, asyncio.IncompleteReadError) isinstance(err, asyncio.IncompleteReadError)

View File

@ -23,9 +23,9 @@ def bytes_to_varuint(value: bytes) -> Optional[int]:
bitpos = 0 bitpos = 0
for val in value: for val in value:
result |= (val & 0x7F) << bitpos result |= (val & 0x7F) << bitpos
bitpos += 7
if (val & 0x80) == 0: if (val & 0x80) == 0:
return result return result
bitpos += 7
return None return None