mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-09-28 04:27:27 +02:00
Refactor frame helper to avoid py conversions when processing packets (#641)
This commit is contained in:
parent
371143d383
commit
3ccb36b6fc
@ -1,30 +1,31 @@
|
|||||||
|
|
||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
from ..connection cimport APIConnection
|
||||||
|
|
||||||
|
|
||||||
cdef bint TYPE_CHECKING
|
cdef bint TYPE_CHECKING
|
||||||
|
|
||||||
cdef class APIFrameHelper:
|
cdef class APIFrameHelper:
|
||||||
|
|
||||||
cdef object _loop
|
cdef object _loop
|
||||||
cdef object _on_pkt
|
cdef APIConnection _connection
|
||||||
cdef object _on_error
|
|
||||||
cdef object _transport
|
cdef object _transport
|
||||||
cdef public object _writer
|
cdef public object _writer
|
||||||
cdef public object _ready_future
|
cdef public object _ready_future
|
||||||
cdef bytes _buffer
|
cdef bytes _buffer
|
||||||
cdef cython.uint _buffer_len
|
cdef unsigned int _buffer_len
|
||||||
cdef cython.uint _pos
|
cdef unsigned int _pos
|
||||||
cdef object _client_info
|
cdef object _client_info
|
||||||
cdef str _log_name
|
cdef str _log_name
|
||||||
cdef object _debug_enabled
|
cdef object _debug_enabled
|
||||||
|
|
||||||
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
|
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||||
cdef bytes _read_exactly(self, int length)
|
cdef bytes _read_exactly(self, int length)
|
||||||
|
|
||||||
cdef _add_to_buffer(self, bytes data)
|
cdef _add_to_buffer(self, bytes data)
|
||||||
|
|
||||||
@cython.locals(end_of_frame_pos=cython.uint)
|
@cython.locals(end_of_frame_pos="unsigned int")
|
||||||
cdef _remove_from_buffer(self)
|
cdef _remove_from_buffer(self)
|
||||||
|
|
||||||
cpdef write_packets(self, list packets)
|
cpdef write_packets(self, list packets)
|
||||||
|
@ -8,6 +8,9 @@ from typing import TYPE_CHECKING, Callable, cast
|
|||||||
|
|
||||||
from ..core import HandshakeAPIError, SocketClosedAPIError
|
from ..core import HandshakeAPIError, SocketClosedAPIError
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..connection import APIConnection
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
SOCKET_ERRORS = (
|
SOCKET_ERRORS = (
|
||||||
@ -27,8 +30,7 @@ class APIFrameHelper:
|
|||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"_loop",
|
"_loop",
|
||||||
"_on_pkt",
|
"_connection",
|
||||||
"_on_error",
|
|
||||||
"_transport",
|
"_transport",
|
||||||
"_writer",
|
"_writer",
|
||||||
"_ready_future",
|
"_ready_future",
|
||||||
@ -42,16 +44,14 @@ class APIFrameHelper:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_pkt: Callable[[int, bytes], None],
|
connection: "APIConnection",
|
||||||
on_error: Callable[[Exception], None],
|
|
||||||
client_info: str,
|
client_info: str,
|
||||||
log_name: str,
|
log_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._on_pkt = on_pkt
|
self._connection = connection
|
||||||
self._on_error = on_error
|
|
||||||
self._transport: asyncio.Transport | None = None
|
self._transport: asyncio.Transport | None = None
|
||||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||||
self._ready_future = self._loop.create_future()
|
self._ready_future = self._loop.create_future()
|
||||||
@ -143,7 +143,7 @@ class APIFrameHelper:
|
|||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
def _handle_error(self, exc: Exception) -> None:
|
def _handle_error(self, exc: Exception) -> None:
|
||||||
self._on_error(exc)
|
self._connection.report_fatal_error(exc)
|
||||||
|
|
||||||
def connection_lost(self, exc: Exception | None) -> None:
|
def connection_lost(self, exc: Exception | None) -> None:
|
||||||
"""Handle the connection being lost."""
|
"""Handle the connection being lost."""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
from ..connection cimport APIConnection
|
||||||
from .base cimport APIFrameHelper
|
from .base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,6 +25,9 @@ from ..core import (
|
|||||||
)
|
)
|
||||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..connection import APIConnection
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -81,15 +84,14 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_pkt: Callable[[int, bytes], None],
|
connection: "APIConnection",
|
||||||
on_error: Callable[[Exception], None],
|
|
||||||
noise_psk: str,
|
noise_psk: str,
|
||||||
expected_name: str | None,
|
expected_name: str | None,
|
||||||
client_info: str,
|
client_info: str,
|
||||||
log_name: str,
|
log_name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
super().__init__(on_pkt, on_error, client_info, log_name)
|
super().__init__(connection, client_info, log_name)
|
||||||
self._noise_psk = noise_psk
|
self._noise_psk = noise_psk
|
||||||
self._expected_name = expected_name
|
self._expected_name = expected_name
|
||||||
self._set_state(NoiseConnectionState.HELLO)
|
self._set_state(NoiseConnectionState.HELLO)
|
||||||
@ -364,7 +366,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
# N bytes: message data
|
# N bytes: message data
|
||||||
type_high = msg[0]
|
type_high = msg[0]
|
||||||
type_low = msg[1]
|
type_low = msg[1]
|
||||||
self._on_pkt((type_high << 8) | type_low, msg[4:])
|
self._connection.process_packet((type_high << 8) | type_low, msg[4:])
|
||||||
|
|
||||||
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
def _handle_closed(self, frame: bytes) -> None: # pylint: disable=unused-argument
|
||||||
"""Handle a closed frame."""
|
"""Handle a closed frame."""
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import cython
|
import cython
|
||||||
|
|
||||||
|
from ..connection cimport APIConnection
|
||||||
from .base cimport APIFrameHelper
|
from .base cimport APIFrameHelper
|
||||||
|
|
||||||
|
|
||||||
|
@ -166,7 +166,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
packet_data = maybe_packet_data
|
packet_data = maybe_packet_data
|
||||||
|
|
||||||
self._remove_from_buffer()
|
self._remove_from_buffer()
|
||||||
self._on_pkt(msg_type_int, packet_data)
|
self._connection.process_packet(msg_type_int, packet_data)
|
||||||
# If we have more data, continue processing
|
# If we have more data, continue processing
|
||||||
|
|
||||||
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
def _error_on_incorrect_preamble(self, preamble: _int) -> None:
|
||||||
|
@ -74,7 +74,7 @@ cdef class APIConnection:
|
|||||||
cdef send_messages(self, tuple messages)
|
cdef send_messages(self, tuple messages)
|
||||||
|
|
||||||
@cython.locals(handlers=set, handlers_copy=set)
|
@cython.locals(handlers=set, handlers_copy=set)
|
||||||
cpdef _process_packet(self, object msg_type_proto, object data)
|
cpdef process_packet(self, object msg_type_proto, object data)
|
||||||
|
|
||||||
cpdef _async_cancel_pong_timer(self)
|
cpdef _async_cancel_pong_timer(self)
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ cdef class APIConnection:
|
|||||||
|
|
||||||
cpdef _set_connection_state(self, object state)
|
cpdef _set_connection_state(self, object state)
|
||||||
|
|
||||||
cpdef _report_fatal_error(self, Exception err)
|
cpdef report_fatal_error(self, Exception err)
|
||||||
|
|
||||||
@cython.locals(handlers=set)
|
@cython.locals(handlers=set)
|
||||||
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
cpdef _add_message_callback_without_remove(self, object on_message, tuple msg_types)
|
||||||
|
@ -337,8 +337,7 @@ class APIConnection:
|
|||||||
if (noise_psk := self._params.noise_psk) is None:
|
if (noise_psk := self._params.noise_psk) is None:
|
||||||
_, fh = await loop.create_connection( # type: ignore[type-var]
|
_, fh = await loop.create_connection( # type: ignore[type-var]
|
||||||
lambda: APIPlaintextFrameHelper(
|
lambda: APIPlaintextFrameHelper(
|
||||||
on_pkt=self._process_packet,
|
connection=self,
|
||||||
on_error=self._report_fatal_error,
|
|
||||||
client_info=self._params.client_info,
|
client_info=self._params.client_info,
|
||||||
log_name=self.log_name,
|
log_name=self.log_name,
|
||||||
),
|
),
|
||||||
@ -349,8 +348,7 @@ class APIConnection:
|
|||||||
lambda: APINoiseFrameHelper(
|
lambda: APINoiseFrameHelper(
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
expected_name=self._params.expected_name,
|
expected_name=self._params.expected_name,
|
||||||
on_pkt=self._process_packet,
|
connection=self,
|
||||||
on_error=self._report_fatal_error,
|
|
||||||
client_info=self._params.client_info,
|
client_info=self._params.client_info,
|
||||||
log_name=self.log_name,
|
log_name=self.log_name,
|
||||||
),
|
),
|
||||||
@ -395,7 +393,7 @@ class APIConnection:
|
|||||||
CONNECT_REQUEST_TIMEOUT,
|
CONNECT_REQUEST_TIMEOUT,
|
||||||
)
|
)
|
||||||
except TimeoutAPIError as err:
|
except TimeoutAPIError as err:
|
||||||
self._report_fatal_error(err)
|
self.report_fatal_error(err)
|
||||||
raise TimeoutAPIError("Hello timed out") from err
|
raise TimeoutAPIError("Hello timed out") from err
|
||||||
|
|
||||||
resp = responses.pop(0)
|
resp = responses.pop(0)
|
||||||
@ -499,7 +497,7 @@ class APIConnection:
|
|||||||
self.log_name,
|
self.log_name,
|
||||||
self._keep_alive_timeout,
|
self._keep_alive_timeout,
|
||||||
)
|
)
|
||||||
self._report_fatal_error(
|
self.report_fatal_error(
|
||||||
PingFailedAPIError(
|
PingFailedAPIError(
|
||||||
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
f"Ping response not received after {self._keep_alive_timeout} seconds"
|
||||||
)
|
)
|
||||||
@ -653,7 +651,7 @@ class APIConnection:
|
|||||||
# If writing packet fails, we don't know what state the frames
|
# If writing packet fails, we don't know what state the frames
|
||||||
# are in anymore and we have to close the connection
|
# are in anymore and we have to close the connection
|
||||||
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
|
_LOGGER.info("%s: Error writing packets: %s", self.log_name, err)
|
||||||
self._report_fatal_error(err)
|
self.report_fatal_error(err)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _add_message_callback_without_remove(
|
def _add_message_callback_without_remove(
|
||||||
@ -786,7 +784,7 @@ class APIConnection:
|
|||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _report_fatal_error(self, err: Exception) -> None:
|
def report_fatal_error(self, err: Exception) -> None:
|
||||||
"""Report a fatal error that occurred during an operation.
|
"""Report a fatal error that occurred during an operation.
|
||||||
|
|
||||||
This should only be called for errors that mean the connection
|
This should only be called for errors that mean the connection
|
||||||
@ -806,8 +804,8 @@ class APIConnection:
|
|||||||
self._fatal_exception = err
|
self._fatal_exception = err
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
def process_packet(self, msg_type_proto: _int, data: _bytes) -> None:
|
||||||
"""Factory to make a packet processor."""
|
"""Process an incoming packet."""
|
||||||
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
if (klass := MESSAGE_TYPE_TO_PROTO.get(msg_type_proto)) is None:
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Skipping message type %s",
|
"%s: Skipping message type %s",
|
||||||
@ -831,7 +829,7 @@ class APIConnection:
|
|||||||
e,
|
e,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
self._report_fatal_error(
|
self.report_fatal_error(
|
||||||
ProtocolAPIError(
|
ProtocolAPIError(
|
||||||
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
|
f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}"
|
||||||
)
|
)
|
||||||
|
@ -3,11 +3,13 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from typing import Any
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
from aioesphomeapi import APIConnection
|
||||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||||
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
from aioesphomeapi._frame_helper.noise import ESPHOME_NOISE_BACKEND, NOISE_HELLO
|
||||||
@ -31,6 +33,24 @@ from .common import async_fire_time_changed, utcnow
|
|||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
||||||
|
"""Make a mock connection."""
|
||||||
|
packets: list[tuple[int, bytes]] = []
|
||||||
|
|
||||||
|
class MockConnection(APIConnection):
|
||||||
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
"""Swallow args."""
|
||||||
|
|
||||||
|
def process_packet(self, type_: int, data: bytes):
|
||||||
|
packets.append((type_, data))
|
||||||
|
|
||||||
|
def report_fatal_error(self, exc: Exception):
|
||||||
|
raise exc
|
||||||
|
|
||||||
|
connection = MockConnection()
|
||||||
|
return connection, packets
|
||||||
|
|
||||||
|
|
||||||
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||||
def mock_write_frame(self, frame: bytes) -> None:
|
def mock_write_frame(self, frame: bytes) -> None:
|
||||||
"""Write a packet to the socket.
|
"""Write a packet to the socket.
|
||||||
@ -97,16 +117,9 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
|||||||
)
|
)
|
||||||
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
packets = []
|
connection, packets = _make_mock_connection()
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = APIPlaintextFrameHelper(
|
helper = APIPlaintextFrameHelper(
|
||||||
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test"
|
connection=connection, client_info="my client", log_name="test"
|
||||||
)
|
)
|
||||||
|
|
||||||
helper.data_received(in_bytes)
|
helper.data_received(in_bytes)
|
||||||
@ -139,17 +152,10 @@ async def test_noise_frame_helper_incorrect_key():
|
|||||||
"01000d01736572766963657465737400",
|
"01000d01736572766963657465737400",
|
||||||
"0100160148616e647368616b65204d4143206661696c757265",
|
"0100160148616e647368616b65204d4143206661696c757265",
|
||||||
]
|
]
|
||||||
packets = []
|
connection, _ = _make_mock_connection()
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
@ -180,17 +186,10 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||||||
"01000d01736572766963657465737400",
|
"01000d01736572766963657465737400",
|
||||||
"0100160148616e647368616b65204d4143206661696c757265",
|
"0100160148616e647368616b65204d4143206661696c757265",
|
||||||
]
|
]
|
||||||
packets = []
|
connection, _ = _make_mock_connection()
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
@ -223,17 +222,10 @@ async def test_noise_incorrect_name():
|
|||||||
"01000d01736572766963657465737400",
|
"01000d01736572766963657465737400",
|
||||||
"0100160148616e647368616b65204d4143206661696c757265",
|
"0100160148616e647368616b65204d4143206661696c757265",
|
||||||
]
|
]
|
||||||
packets = []
|
connection, _ = _make_mock_connection()
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="wrongname",
|
expected_name="wrongname",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
@ -260,17 +252,11 @@ async def test_noise_timeout():
|
|||||||
"010000", # hello packet
|
"010000", # hello packet
|
||||||
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
"010031001ed7f7bb0b74085418258ed5928931bc36ade7cf06937fcff089044d4ab142643f1b2c9935bb77696f23d930836737a4",
|
||||||
]
|
]
|
||||||
packets = []
|
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
connection, _ = _make_mock_connection()
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="wrongname",
|
expected_name="wrongname",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
@ -317,21 +303,15 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
"""Test the noise frame helper handshake failure."""
|
"""Test the noise frame helper handshake failure."""
|
||||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||||
psk_bytes = base64.b64decode(noise_psk)
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
packets = []
|
|
||||||
writes = []
|
writes = []
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _writer(data: bytes):
|
def _writer(data: bytes):
|
||||||
writes.append(data)
|
writes.append(data)
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
connection, _ = _make_mock_connection()
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
@ -398,21 +378,15 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
"""Test the noise frame helper handshake success with a single packet."""
|
"""Test the noise frame helper handshake success with a single packet."""
|
||||||
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
noise_psk = "QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc="
|
||||||
psk_bytes = base64.b64decode(noise_psk)
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
packets = []
|
|
||||||
writes = []
|
writes = []
|
||||||
|
|
||||||
def _packet(type_: int, data: bytes):
|
|
||||||
packets.append((type_, data))
|
|
||||||
|
|
||||||
def _writer(data: bytes):
|
def _writer(data: bytes):
|
||||||
writes.append(data)
|
writes.append(data)
|
||||||
|
|
||||||
def _on_error(exc: Exception):
|
connection, packets = _make_mock_connection()
|
||||||
raise exc
|
|
||||||
|
|
||||||
helper = MockAPINoiseFrameHelper(
|
helper = MockAPINoiseFrameHelper(
|
||||||
on_pkt=_packet,
|
connection=connection,
|
||||||
on_error=_on_error,
|
|
||||||
noise_psk=noise_psk,
|
noise_psk=noise_psk,
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
|
@ -35,8 +35,7 @@ from .common import (
|
|||||||
|
|
||||||
def _get_mock_protocol(conn: APIConnection):
|
def _get_mock_protocol(conn: APIConnection):
|
||||||
protocol = APIPlaintextFrameHelper(
|
protocol = APIPlaintextFrameHelper(
|
||||||
on_pkt=conn._process_packet,
|
connection=conn,
|
||||||
on_error=conn._report_fatal_error,
|
|
||||||
client_info="mock",
|
client_info="mock",
|
||||||
log_name="mock_device",
|
log_name="mock_device",
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user