Move varuint functions into plain_text frame_helper (#587)
This commit is contained in:
parent
b059dd6a3a
commit
63897ed680
|
@ -7,6 +7,11 @@ cdef bint TYPE_CHECKING
|
||||||
cdef object WRITE_EXCEPTIONS
|
cdef object WRITE_EXCEPTIONS
|
||||||
cdef object bytes_to_varuint, varuint_to_bytes
|
cdef object bytes_to_varuint, varuint_to_bytes
|
||||||
|
|
||||||
|
cpdef _varuint_to_bytes(cython.int value)
|
||||||
|
|
||||||
|
@cython.locals(result=cython.int, bitpos=cython.int, val=cython.int)
|
||||||
|
cpdef _bytes_to_varuint(cython.bytes value)
|
||||||
|
|
||||||
cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
cdef class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
|
|
||||||
@cython.locals(
|
@cython.locals(
|
||||||
|
|
|
@ -2,14 +2,54 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
|
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
|
||||||
from ..util import bytes_to_varuint, varuint_to_bytes
|
|
||||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_int = int
|
||||||
|
_bytes = bytes
|
||||||
|
|
||||||
|
|
||||||
|
def _varuint_to_bytes(value: _int) -> bytes:
|
||||||
|
"""Convert a varuint to bytes."""
|
||||||
|
if value <= 0x7F:
|
||||||
|
return bytes((value,))
|
||||||
|
|
||||||
|
result = []
|
||||||
|
while value:
|
||||||
|
temp = value & 0x7F
|
||||||
|
value >>= 7
|
||||||
|
if value:
|
||||||
|
result.append(temp | 0x80)
|
||||||
|
else:
|
||||||
|
result.append(temp)
|
||||||
|
|
||||||
|
return bytes(result)
|
||||||
|
|
||||||
|
|
||||||
|
_cached_varuint_to_bytes = lru_cache(maxsize=1024)(_varuint_to_bytes)
|
||||||
|
varuint_to_bytes = _cached_varuint_to_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def _bytes_to_varuint(value: _bytes) -> _int | None:
|
||||||
|
"""Convert bytes to a varuint."""
|
||||||
|
result = 0
|
||||||
|
bitpos = 0
|
||||||
|
for val in value:
|
||||||
|
result |= (val & 0x7F) << bitpos
|
||||||
|
if (val & 0x80) == 0:
|
||||||
|
return result
|
||||||
|
bitpos += 7
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
_cached_bytes_to_varuint = lru_cache(maxsize=1024)(_bytes_to_varuint)
|
||||||
|
bytes_to_varuint = _cached_bytes_to_varuint
|
||||||
|
|
||||||
|
|
||||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
"""Frame helper for plaintext API connections."""
|
"""Frame helper for plaintext API connections."""
|
||||||
|
|
|
@ -1,36 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1024)
|
|
||||||
def varuint_to_bytes(value: int) -> bytes:
|
|
||||||
if value <= 0x7F:
|
|
||||||
return bytes([value])
|
|
||||||
|
|
||||||
ret = b""
|
|
||||||
while value:
|
|
||||||
temp = value & 0x7F
|
|
||||||
value >>= 7
|
|
||||||
if value:
|
|
||||||
ret += bytes([temp | 0x80])
|
|
||||||
else:
|
|
||||||
ret += bytes([temp])
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1024)
|
|
||||||
def bytes_to_varuint(value: bytes) -> int | None:
|
|
||||||
result = 0
|
|
||||||
bitpos = 0
|
|
||||||
for val in value:
|
|
||||||
result |= (val & 0x7F) << bitpos
|
|
||||||
if (val & 0x80) == 0:
|
|
||||||
return result
|
|
||||||
bitpos += 7
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def fix_float_single_double_conversion(value: float) -> float:
|
def fix_float_single_double_conversion(value: float) -> float:
|
||||||
|
|
|
@ -4,12 +4,19 @@ import pytest
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import _bytes_to_varuint as bytes_to_varuint
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import (
|
||||||
|
_cached_bytes_to_varuint as cached_bytes_to_varuint,
|
||||||
|
)
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import (
|
||||||
|
_cached_varuint_to_bytes as cached_varuint_to_bytes,
|
||||||
|
)
|
||||||
|
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
|
||||||
from aioesphomeapi.core import (
|
from aioesphomeapi.core import (
|
||||||
BadNameAPIError,
|
BadNameAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
SocketAPIError,
|
SocketAPIError,
|
||||||
)
|
)
|
||||||
from aioesphomeapi.util import varuint_to_bytes
|
|
||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
|
|
||||||
|
@ -234,3 +241,25 @@ async def test_noise_incorrect_name():
|
||||||
|
|
||||||
with pytest.raises(BadNameAPIError):
|
with pytest.raises(BadNameAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.perform_handshake(30)
|
||||||
|
|
||||||
|
|
||||||
|
VARUINT_TESTCASES = [
|
||||||
|
(0, b"\x00"),
|
||||||
|
(42, b"\x2a"),
|
||||||
|
(127, b"\x7f"),
|
||||||
|
(128, b"\x80\x01"),
|
||||||
|
(300, b"\xac\x02"),
|
||||||
|
(65536, b"\x80\x80\x04"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
||||||
|
def test_varuint_to_bytes(val, encoded):
|
||||||
|
assert varuint_to_bytes(val) == encoded
|
||||||
|
assert cached_varuint_to_bytes(val) == encoded
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
||||||
|
def test_bytes_to_varuint(val, encoded):
|
||||||
|
assert bytes_to_varuint(encoded) == val
|
||||||
|
assert cached_bytes_to_varuint(encoded) == val
|
||||||
|
|
|
@ -4,25 +4,6 @@ import pytest
|
||||||
|
|
||||||
from aioesphomeapi import util
|
from aioesphomeapi import util
|
||||||
|
|
||||||
VARUINT_TESTCASES = [
|
|
||||||
(0, b"\x00"),
|
|
||||||
(42, b"\x2a"),
|
|
||||||
(127, b"\x7f"),
|
|
||||||
(128, b"\x80\x01"),
|
|
||||||
(300, b"\xac\x02"),
|
|
||||||
(65536, b"\x80\x80\x04"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
|
||||||
def test_varuint_to_bytes(val, encoded):
|
|
||||||
assert util.varuint_to_bytes(val) == encoded
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("val, encoded", VARUINT_TESTCASES)
|
|
||||||
def test_bytes_to_varuint(val, encoded):
|
|
||||||
assert util.bytes_to_varuint(encoded) == val
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"input, output",
|
"input, output",
|
||||||
|
|
Loading…
Reference in New Issue