Move varuint functions into plain_text frame_helper (#587)

This commit is contained in:
J. Nick Koston 2023-10-16 17:24:03 -10:00 committed by GitHub
parent b059dd6a3a
commit 63897ed680
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 51 deletions

View File

@ -7,6 +7,11 @@ cdef bint TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes
cpdef _varuint_to_bytes(cython.int value)
@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
cpdef _bytes_to_varuint(cython.bytes value)
cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals(

View File

@ -2,14 +2,54 @@ from __future__ import annotations
import asyncio
import logging
from functools import lru_cache
from typing import TYPE_CHECKING
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
from ..util import bytes_to_varuint, varuint_to_bytes
from .base import WRITE_EXCEPTIONS, APIFrameHelper
_LOGGER = logging.getLogger(__name__)
_int = int
_bytes = bytes
def _varuint_to_bytes(value: _int) -> bytes:
"""Convert a varuint to bytes."""
if value <= 0x7F:
return bytes((value,))
result = []
while value:
temp = value & 0x7F
value >>= 7
if value:
result.append(temp | 0x80)
else:
result.append(temp)
return bytes(result)
_cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
varuint_to_bytes = _cached_varuint_to_bytes
def _bytes_to_varuint(value: _bytes) -> _int | None:
"""Convert bytes to a varuint."""
result = 0
bitpos = 0
for val in value:
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return None
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
bytes_to_varuint = _cached_bytes_to_varuint
class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""

View File

@ -1,36 +1,6 @@
from __future__ import annotations
import math
from functools import lru_cache
@lru_cache(maxsize=1024)
def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F:
return bytes([value])
ret = b""
while value:
temp = value & 0x7F
value >>= 7
if value:
ret += bytes([temp | 0x80])
else:
ret += bytes([temp])
return ret
@lru_cache(maxsize=1024)
def bytes_to_varuint(value: bytes) -> int | None:
result = 0
bitpos = 0
for val in value:
result |= (val & 0x7F) << bitpos
if (val & 0x80) == 0:
return result
bitpos += 7
return None
def fix_float_single_double_conversion(value: float) -> float:

View File

@ -4,12 +4,19 @@ import pytest
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
from aioesphomeapi._frame_helper.plain_text import (
_cached_bytes_to_varuint as cached_bytes_to_varuint,
)
from aioesphomeapi._frame_helper.plain_text import (
_cached_varuint_to_bytes as cached_varuint_to_bytes,
)
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from aioesphomeapi.core import (
BadNameAPIError,
InvalidEncryptionKeyAPIError,
SocketAPIError,
)
from aioesphomeapi.util import varuint_to_bytes
PREAMBLE = b"\x00"
@ -234,3 +241,25 @@ async def test_noise_incorrect_name():
with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30)
VARUINT_TESTCASES = [
(0, b"\x00"),
(42, b"\x2a"),
(127, b"\x7f"),
(128, b"\x80\x01"),
(300, b"\xac\x02"),
(65536, b"\x80\x80\x04"),
]
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_varuint_to_bytes(val, encoded):
assert varuint_to_bytes(val) == encoded
assert cached_varuint_to_bytes(val) == encoded
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_bytes_to_varuint(val, encoded):
assert bytes_to_varuint(encoded) == val
assert cached_bytes_to_varuint(encoded) == val

View File

@ -4,25 +4,6 @@ import pytest
from aioesphomeapi import util
VARUINT_TESTCASES = [
(0, b"\x00"),
(42, b"\x2a"),
(127, b"\x7f"),
(128, b"\x80\x01"),
(300, b"\xac\x02"),
(65536, b"\x80\x80\x04"),
]
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_varuint_to_bytes(val, encoded):
assert util.varuint_to_bytes(val) == encoded
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
def test_bytes_to_varuint(val, encoded):
assert util.bytes_to_varuint(encoded) == val
@pytest.mark.parametrize(
"input, output",