diff --git a/aioesphomeapi/_frame_helper/plain_text.pxd b/aioesphomeapi/_frame_helper/plain_text.pxd index b0bd3a1..994804f 100644 --- a/aioesphomeapi/_frame_helper/plain_text.pxd +++ b/aioesphomeapi/_frame_helper/plain_text.pxd @@ -7,6 +7,11 @@ cdef bint TYPE_CHECKING cdef object WRITE_EXCEPTIONS cdef object bytes_to_varuint, varuint_to_bytes +cpdef _varuint_to_bytes(cython.int value) + +@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int) +cpdef _bytes_to_varuint(cython.bytes value) + cdef class APIPlaintextFrameHelper(APIFrameHelper): @cython.locals( diff --git a/aioesphomeapi/_frame_helper/plain_text.py b/aioesphomeapi/_frame_helper/plain_text.py index b9b149b..b629b3e 100644 --- a/aioesphomeapi/_frame_helper/plain_text.py +++ b/aioesphomeapi/_frame_helper/plain_text.py @@ -2,14 +2,54 @@ from __future__ import annotations import asyncio import logging +from functools import lru_cache from typing import TYPE_CHECKING from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError -from ..util import bytes_to_varuint, varuint_to_bytes from .base import WRITE_EXCEPTIONS, APIFrameHelper _LOGGER = logging.getLogger(__name__) +_int = int +_bytes = bytes + + +def _varuint_to_bytes(value: _int) -> bytes: + """Convert a varuint to bytes.""" + if value <= 0x7F: + return bytes((value,)) + + result = [] + while value: + temp = value & 0x7F + value >>= 7 + if value: + result.append(temp | 0x80) + else: + result.append(temp) + + return bytes(result) + + +_cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes) +varuint_to_bytes = _cached_varuint_to_bytes + + +def _bytes_to_varuint(value: _bytes) -> _int | None: + """Convert bytes to a varuint.""" + result = 0 + bitpos = 0 + for val in value: + result |= (val & 0x7F) << bitpos + if (val & 0x80) == 0: + return result + bitpos += 7 + return None + + +_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint) +bytes_to_varuint = _cached_bytes_to_varuint + class APIPlaintextFrameHelper(APIFrameHelper): """Frame helper for plaintext API connections.""" diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index 4d9c6d2..2f87f21 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -1,36 +1,6 @@ from __future__ import annotations import math -from functools import lru_cache - - -@lru_cache(maxsize=1024) -def varuint_to_bytes(value: int) -> bytes: - if value <= 0x7F: - return bytes([value]) - - ret = b"" - while value: - temp = value & 0x7F - value >>= 7 - if value: - ret += bytes([temp | 0x80]) - else: - ret += bytes([temp]) - - return ret - - -@lru_cache(maxsize=1024) -def bytes_to_varuint(value: bytes) -> int | None: - result = 0 - bitpos = 0 - for val in value: - result |= (val & 0x7F) << bitpos - if (val & 0x80) == 0: - return result - bitpos += 7 - return None def fix_float_single_double_conversion(value: float) -> float: diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 6b2b25c..d333447 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -4,12 +4,19 @@ import pytest from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS +from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint +from aioesphomeapi._frame_helper.plain_text import ( + _cached_bytes_to_varuint as cached_bytes_to_varuint, +) +from aioesphomeapi._frame_helper.plain_text import ( + _cached_varuint_to_bytes as cached_varuint_to_bytes, +) +from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes from aioesphomeapi.core import ( BadNameAPIError, InvalidEncryptionKeyAPIError, SocketAPIError, ) -from aioesphomeapi.util import varuint_to_bytes PREAMBLE = b"\x00" @@ -234,3 +241,25 @@ async def test_noise_incorrect_name(): with pytest.raises(BadNameAPIError): await helper.perform_handshake(30) + + +VARUINT_TESTCASES = [ + (0, b"\x00"), + (42, b"\x2a"), + (127, b"\x7f"), + (128, b"\x80\x01"), + (300, b"\xac\x02"), + (65536, b"\x80\x80\x04"), +] + + +@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES) +def test_varuint_to_bytes(val, encoded): + assert varuint_to_bytes(val) == encoded + assert cached_varuint_to_bytes(val) == encoded + + +@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES) +def test_bytes_to_varuint(val, encoded): + assert bytes_to_varuint(encoded) == val + assert cached_bytes_to_varuint(encoded) == val diff --git a/tests/test_util.py b/tests/test_util.py index 897a9da..b4ba380 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,25 +4,6 @@ import pytest from aioesphomeapi import util -VARUINT_TESTCASES = [ - (0, b"\x00"), - (42, b"\x2a"), - (127, b"\x7f"), - (128, b"\x80\x01"), - (300, b"\xac\x02"), - (65536, b"\x80\x80\x04"), -] - - -@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES) -def test_varuint_to_bytes(val, encoded): - assert util.varuint_to_bytes(val) == encoded - - -@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES) -def test_bytes_to_varuint(val, encoded): - assert util.bytes_to_varuint(encoded) == val - @pytest.mark.parametrize( "input, output",