Fix error in noise frame helper were we could write when the writer was unset (#685)
This commit is contained in:
parent
dba6c72735
commit
095ef822f1
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import binascii
|
import binascii
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -128,10 +129,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||||
exc.__cause__ = original_exc
|
exc.__cause__ = original_exc
|
||||||
super()._handle_error(exc)
|
super()._handle_error(exc)
|
||||||
|
|
||||||
async def perform_handshake(self, timeout: float) -> None:
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
"""Perform the handshake with the server."""
|
"""Handle a new connection."""
|
||||||
|
super().connection_made(transport)
|
||||||
self._send_hello_handshake()
|
self._send_hello_handshake()
|
||||||
await super().perform_handshake(timeout)
|
|
||||||
|
|
||||||
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
self._add_to_buffer(data)
|
self._add_to_buffer(data)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from google.protobuf import message
|
||||||
from zeroconf import Zeroconf
|
from zeroconf import Zeroconf
|
||||||
from zeroconf.asyncio import AsyncZeroconf
|
from zeroconf.asyncio import AsyncZeroconf
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||||
from aioesphomeapi.api_pb2 import (
|
from aioesphomeapi.api_pb2 import (
|
||||||
ConnectResponse,
|
ConnectResponse,
|
||||||
|
@ -31,6 +31,20 @@ utcnow.__doc__ = "Get now in UTC time."
|
||||||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_data_received(
|
||||||
|
protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes
|
||||||
|
) -> None:
|
||||||
|
"""Mock data received on the protocol."""
|
||||||
|
try:
|
||||||
|
protocol.data_received(data)
|
||||||
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
loop.call_soon(
|
||||||
|
protocol.connection_lost,
|
||||||
|
err,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_mock_zeroconf() -> MagicMock:
|
def get_mock_zeroconf() -> MagicMock:
|
||||||
with patch("zeroconf.Zeroconf.start"):
|
with patch("zeroconf.Zeroconf.start"):
|
||||||
zc = Zeroconf()
|
zc = Zeroconf()
|
||||||
|
|
|
@ -46,8 +46,7 @@ def socket_socket():
|
||||||
yield func
|
yield func
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def get_mock_connection_params() -> ConnectionParams:
|
||||||
def connection_params() -> ConnectionParams:
|
|
||||||
return ConnectionParams(
|
return ConnectionParams(
|
||||||
address="fake.address",
|
address="fake.address",
|
||||||
port=6052,
|
port=6052,
|
||||||
|
@ -60,6 +59,11 @@ def connection_params() -> ConnectionParams:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def connection_params() -> ConnectionParams:
|
||||||
|
return get_mock_connection_params()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def noise_connection_params() -> ConnectionParams:
|
def noise_connection_params() -> ConnectionParams:
|
||||||
return ConnectionParams(
|
return ConnectionParams(
|
||||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
||||||
import base64
|
import base64
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||||
|
@ -30,7 +30,13 @@ from aioesphomeapi.core import (
|
||||||
SocketClosedAPIError,
|
SocketClosedAPIError,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import async_fire_time_changed, get_mock_protocol, utcnow
|
from .common import (
|
||||||
|
async_fire_time_changed,
|
||||||
|
get_mock_protocol,
|
||||||
|
mock_data_received,
|
||||||
|
utcnow,
|
||||||
|
)
|
||||||
|
from .conftest import get_mock_connection_params
|
||||||
|
|
||||||
PREAMBLE = b"\x00"
|
PREAMBLE = b"\x00"
|
||||||
|
|
||||||
|
@ -42,18 +48,27 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
||||||
class MockConnection(APIConnection):
|
class MockConnection(APIConnection):
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
"""Swallow args."""
|
"""Swallow args."""
|
||||||
|
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
|
||||||
|
|
||||||
def process_packet(self, type_: int, data: bytes):
|
def process_packet(self, type_: int, data: bytes):
|
||||||
packets.append((type_, data))
|
packets.append((type_, data))
|
||||||
|
|
||||||
def report_fatal_error(self, exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
connection = MockConnection()
|
connection = MockConnection()
|
||||||
return connection, packets
|
return connection, packets
|
||||||
|
|
||||||
|
|
||||||
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||||
|
def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
|
||||||
|
"""Swallow args."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
transport = MagicMock()
|
||||||
|
transport.write = writer or MagicMock()
|
||||||
|
self.__transport = transport
|
||||||
|
self.connection_made(transport)
|
||||||
|
|
||||||
|
def connection_made(self, transport: Any) -> None:
|
||||||
|
return super().connection_made(self.__transport)
|
||||||
|
|
||||||
def mock_write_frame(self, frame: bytes) -> None:
|
def mock_write_frame(self, frame: bytes) -> None:
|
||||||
"""Write a packet to the socket.
|
"""Write a packet to the socket.
|
||||||
|
|
||||||
|
@ -125,7 +140,7 @@ def test_plaintext_frame_helper(
|
||||||
connection=connection, client_info="my client", log_name="test"
|
connection=connection, client_info="my client", log_name="test"
|
||||||
)
|
)
|
||||||
|
|
||||||
helper.data_received(in_bytes)
|
mock_data_received(helper, in_bytes)
|
||||||
|
|
||||||
pkt = packets.pop()
|
pkt = packets.pop()
|
||||||
type_, data = pkt
|
type_, data = pkt
|
||||||
|
@ -135,7 +150,7 @@ def test_plaintext_frame_helper(
|
||||||
|
|
||||||
# Make sure we correctly handle fragments
|
# Make sure we correctly handle fragments
|
||||||
for i in range(len(in_bytes)):
|
for i in range(len(in_bytes)):
|
||||||
helper.data_received(in_bytes[i : i + 1])
|
mock_data_received(helper, in_bytes[i : i + 1])
|
||||||
|
|
||||||
pkt = packets.pop()
|
pkt = packets.pop()
|
||||||
type_, data = pkt
|
type_, data = pkt
|
||||||
|
@ -166,7 +181,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
|
||||||
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
|
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
|
||||||
)
|
)
|
||||||
|
|
||||||
helper.data_received(in_bytes)
|
mock_data_received(helper, in_bytes)
|
||||||
|
|
||||||
pkt = packets.pop()
|
pkt = packets.pop()
|
||||||
type_, data = pkt
|
type_, data = pkt
|
||||||
|
@ -176,7 +191,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
|
||||||
|
|
||||||
# Make sure we correctly handle fragments
|
# Make sure we correctly handle fragments
|
||||||
for i in range(len(in_bytes)):
|
for i in range(len(in_bytes)):
|
||||||
helper.data_received(in_bytes[i : i + 1])
|
mock_data_received(helper, in_bytes[i : i + 1])
|
||||||
|
|
||||||
pkt = packets.pop()
|
pkt = packets.pop()
|
||||||
type_, data = pkt
|
type_, data = pkt
|
||||||
|
@ -215,15 +230,12 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
for pkt in outgoing_packets:
|
||||||
helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
|
helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
for pkt in incoming_packets:
|
||||||
for pkt in incoming_packets:
|
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
||||||
helper.data_received(byte_type(bytes.fromhex(pkt)))
|
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.perform_handshake(30)
|
||||||
|
@ -249,15 +261,12 @@ async def test_noise_frame_helper_incorrect_key():
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
for pkt in outgoing_packets:
|
||||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
for pkt in incoming_packets:
|
||||||
for pkt in incoming_packets:
|
mock_data_received(helper, bytes.fromhex(pkt))
|
||||||
helper.data_received(bytes.fromhex(pkt))
|
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.perform_handshake(30)
|
||||||
|
@ -283,17 +292,14 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
for pkt in outgoing_packets:
|
||||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
for pkt in incoming_packets:
|
||||||
for pkt in incoming_packets:
|
in_pkt = bytes.fromhex(pkt)
|
||||||
in_pkt = bytes.fromhex(pkt)
|
for i in range(len(in_pkt)):
|
||||||
for i in range(len(in_pkt)):
|
mock_data_received(helper, in_pkt[i : i + 1])
|
||||||
helper.data_received(in_pkt[i : i + 1])
|
|
||||||
|
|
||||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.perform_handshake(30)
|
||||||
|
@ -319,15 +325,12 @@ async def test_noise_incorrect_name():
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
for pkt in outgoing_packets:
|
||||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||||
|
|
||||||
with pytest.raises(BadNameAPIError):
|
for pkt in incoming_packets:
|
||||||
for pkt in incoming_packets:
|
mock_data_received(helper, bytes.fromhex(pkt))
|
||||||
helper.data_received(bytes.fromhex(pkt))
|
|
||||||
|
|
||||||
with pytest.raises(BadNameAPIError):
|
with pytest.raises(BadNameAPIError):
|
||||||
await helper.perform_handshake(30)
|
await helper.perform_handshake(30)
|
||||||
|
@ -350,8 +353,6 @@ async def test_noise_timeout():
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
for pkt in outgoing_packets:
|
for pkt in outgoing_packets:
|
||||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||||
|
@ -408,9 +409,8 @@ async def test_noise_frame_helper_handshake_failure():
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
|
writer=_writer,
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = _writer
|
|
||||||
|
|
||||||
proto = NoiseConnection.from_name(
|
proto = NoiseConnection.from_name(
|
||||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||||
|
@ -448,7 +448,7 @@ async def test_noise_frame_helper_handshake_failure():
|
||||||
hello_pkg_length_low = hello_pkg_length & 0xFF
|
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||||
hello_pkt_with_header = hello_header + hello_pkt
|
hello_pkt_with_header = hello_header + hello_pkt
|
||||||
helper.data_received(hello_pkt_with_header)
|
mock_data_received(helper, hello_pkt_with_header)
|
||||||
|
|
||||||
error_pkt = b"\x01forced to fail"
|
error_pkt = b"\x01forced to fail"
|
||||||
preamble = 1
|
preamble = 1
|
||||||
|
@ -458,8 +458,7 @@ async def test_noise_frame_helper_handshake_failure():
|
||||||
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
|
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
|
||||||
error_pkt_with_header = error_header + error_pkt
|
error_pkt_with_header = error_header + error_pkt
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
mock_data_received(helper, error_pkt_with_header)
|
||||||
helper.data_received(error_pkt_with_header)
|
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||||
await handshake_task
|
await handshake_task
|
||||||
|
@ -483,9 +482,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
|
writer=_writer,
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = _writer
|
|
||||||
|
|
||||||
proto = NoiseConnection.from_name(
|
proto = NoiseConnection.from_name(
|
||||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||||
|
@ -523,7 +521,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
hello_pkg_length_low = hello_pkg_length & 0xFF
|
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||||
hello_pkt_with_header = hello_header + hello_pkt
|
hello_pkt_with_header = hello_header + hello_pkt
|
||||||
helper.data_received(hello_pkt_with_header)
|
mock_data_received(helper, hello_pkt_with_header)
|
||||||
|
|
||||||
handshake = proto.write_message(b"")
|
handshake = proto.write_message(b"")
|
||||||
handshake_pkt = b"\x00" + handshake
|
handshake_pkt = b"\x00" + handshake
|
||||||
|
@ -536,7 +534,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
)
|
)
|
||||||
handshake_with_header = handshake_header + handshake_pkt
|
handshake_with_header = handshake_header + handshake_pkt
|
||||||
|
|
||||||
helper.data_received(handshake_with_header)
|
mock_data_received(helper, handshake_with_header)
|
||||||
|
|
||||||
assert not writes
|
assert not writes
|
||||||
|
|
||||||
|
@ -566,13 +564,12 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||||
encrypted_header = bytes(
|
encrypted_header = bytes(
|
||||||
(preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
|
(preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
|
||||||
)
|
)
|
||||||
helper.data_received(encrypted_header + encrypted_payload)
|
mock_data_received(helper, encrypted_header + encrypted_payload)
|
||||||
|
|
||||||
assert packets == [(42, b"from device")]
|
assert packets == [(42, b"from device")]
|
||||||
helper.close()
|
helper.close()
|
||||||
|
|
||||||
with pytest.raises(ProtocolAPIError, match="Connection closed"):
|
mock_data_received(helper, encrypted_header + encrypted_payload)
|
||||||
helper.data_received(encrypted_header + encrypted_payload)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
@ -590,7 +587,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
||||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
# The preamble should be \x00 but we send \x09
|
# The preamble should be \x00 but we send \x09
|
||||||
protocol.data_received(b"\x09\x00\x00")
|
mock_data_received(protocol, b"\x09\x00\x00")
|
||||||
|
|
||||||
with pytest.raises(ProtocolAPIError):
|
with pytest.raises(ProtocolAPIError):
|
||||||
await task
|
await task
|
||||||
|
@ -615,7 +612,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
|
||||||
assert protocol is not None
|
assert protocol is not None
|
||||||
assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper)
|
assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper)
|
||||||
|
|
||||||
protocol.data_received(b"\x00\x00\x00")
|
mock_data_received(protocol, b"\x00\x00\x00")
|
||||||
|
|
||||||
with pytest.raises(ProtocolAPIError, match="Marker byte invalid"):
|
with pytest.raises(ProtocolAPIError, match="Marker byte invalid"):
|
||||||
await task
|
await task
|
||||||
|
@ -632,8 +629,6 @@ async def test_noise_frame_helper_empty_hello():
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
|
||||||
helper._writer = MagicMock()
|
|
||||||
|
|
||||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||||
empty_hello_pkt = b""
|
empty_hello_pkt = b""
|
||||||
|
@ -644,8 +639,7 @@ async def test_noise_frame_helper_empty_hello():
|
||||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||||
hello_pkt_with_header = hello_header + empty_hello_pkt
|
hello_pkt_with_header = hello_header + empty_hello_pkt
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
mock_data_received(helper, hello_pkt_with_header)
|
||||||
helper.data_received(hello_pkt_with_header)
|
|
||||||
|
|
||||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
||||||
await handshake_task
|
await handshake_task
|
||||||
|
|
|
@ -88,7 +88,12 @@ from aioesphomeapi.model import (
|
||||||
)
|
)
|
||||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||||
|
|
||||||
from .common import Estr, generate_plaintext_packet, get_mock_zeroconf
|
from .common import (
|
||||||
|
Estr,
|
||||||
|
generate_plaintext_packet,
|
||||||
|
get_mock_zeroconf,
|
||||||
|
mock_data_received,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -849,7 +854,7 @@ async def test_bluetooth_disconnect(
|
||||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||||
address=1234, connected=False
|
address=1234, connected=False
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await disconnect_task
|
await disconnect_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -864,7 +869,7 @@ async def test_bluetooth_pair(
|
||||||
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
|
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
response: message.Message = BluetoothDevicePairingResponse(address=1234)
|
response: message.Message = BluetoothDevicePairingResponse(address=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await pair_task
|
await pair_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -879,7 +884,7 @@ async def test_bluetooth_unpair(
|
||||||
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
|
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await unpair_task
|
await unpair_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -894,7 +899,7 @@ async def test_bluetooth_clear_cache(
|
||||||
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
|
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await clear_task
|
await clear_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -914,7 +919,7 @@ async def test_device_info(
|
||||||
friendly_name="My Device",
|
friendly_name="My Device",
|
||||||
has_deep_sleep=True,
|
has_deep_sleep=True,
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
device_info = await device_info_task
|
device_info = await device_info_task
|
||||||
assert device_info.name == "realname"
|
assert device_info.name == "realname"
|
||||||
assert device_info.friendly_name == "My Device"
|
assert device_info.friendly_name == "My Device"
|
||||||
|
@ -923,7 +928,7 @@ async def test_device_info(
|
||||||
disconnect_task = asyncio.create_task(client.disconnect())
|
disconnect_task = asyncio.create_task(client.disconnect())
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
response: message.Message = DisconnectResponse()
|
response: message.Message = DisconnectResponse()
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await disconnect_task
|
await disconnect_task
|
||||||
with pytest.raises(APIConnectionError, match="CLOSED"):
|
with pytest.raises(APIConnectionError, match="CLOSED"):
|
||||||
await client.device_info()
|
await client.device_info()
|
||||||
|
@ -943,12 +948,12 @@ async def test_bluetooth_gatt_read(
|
||||||
other_response: message.Message = BluetoothGATTReadResponse(
|
other_response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=4567, data=b"4567"
|
address=1234, handle=4567, data=b"4567"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(other_response))
|
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTReadResponse(
|
response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=1234, data=b"1234"
|
address=1234, handle=1234, data=b"1234"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
assert await read_task == b"1234"
|
assert await read_task == b"1234"
|
||||||
|
|
||||||
|
|
||||||
|
@ -966,12 +971,12 @@ async def test_bluetooth_gatt_read_descriptor(
|
||||||
other_response: message.Message = BluetoothGATTReadResponse(
|
other_response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=4567, data=b"4567"
|
address=1234, handle=4567, data=b"4567"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(other_response))
|
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTReadResponse(
|
response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=1234, data=b"1234"
|
address=1234, handle=1234, data=b"1234"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
assert await read_task == b"1234"
|
assert await read_task == b"1234"
|
||||||
|
|
||||||
|
|
||||||
|
@ -991,10 +996,10 @@ async def test_bluetooth_gatt_write(
|
||||||
other_response: message.Message = BluetoothGATTWriteResponse(
|
other_response: message.Message = BluetoothGATTWriteResponse(
|
||||||
address=1234, handle=4567
|
address=1234, handle=4567
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(other_response))
|
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await write_task
|
await write_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -1034,10 +1039,10 @@ async def test_bluetooth_gatt_write_descriptor(
|
||||||
other_response: message.Message = BluetoothGATTWriteResponse(
|
other_response: message.Message = BluetoothGATTWriteResponse(
|
||||||
address=1234, handle=4567
|
address=1234, handle=4567
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(other_response))
|
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
await write_task
|
await write_task
|
||||||
|
|
||||||
|
|
||||||
|
@ -1077,12 +1082,12 @@ async def test_bluetooth_gatt_read_descriptor(
|
||||||
other_response: message.Message = BluetoothGATTReadResponse(
|
other_response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=4567, data=b"4567"
|
address=1234, handle=4567, data=b"4567"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(other_response))
|
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||||
|
|
||||||
response: message.Message = BluetoothGATTReadResponse(
|
response: message.Message = BluetoothGATTReadResponse(
|
||||||
address=1234, handle=1234, data=b"1234"
|
address=1234, handle=1234, data=b"1234"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
assert await read_task == b"1234"
|
assert await read_task == b"1234"
|
||||||
|
|
||||||
|
|
||||||
|
@ -1102,9 +1107,9 @@ async def test_bluetooth_gatt_get_services(
|
||||||
response: message.Message = BluetoothGATTGetServicesResponse(
|
response: message.Message = BluetoothGATTGetServicesResponse(
|
||||||
address=1234, services=[service1]
|
address=1234, services=[service1]
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234)
|
done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(done_response))
|
mock_data_received(protocol, generate_plaintext_packet(done_response))
|
||||||
|
|
||||||
services = await services_task
|
services = await services_task
|
||||||
assert services == ESPHomeBluetoothGATTServices(
|
assert services == ESPHomeBluetoothGATTServices(
|
||||||
|
@ -1129,9 +1134,9 @@ async def test_bluetooth_gatt_get_services_errors(
|
||||||
response: message.Message = BluetoothGATTGetServicesResponse(
|
response: message.Message = BluetoothGATTGetServicesResponse(
|
||||||
address=1234, services=[service1]
|
address=1234, services=[service1]
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
done_response: message.Message = BluetoothGATTErrorResponse(address=1234)
|
done_response: message.Message = BluetoothGATTErrorResponse(address=1234)
|
||||||
protocol.data_received(generate_plaintext_packet(done_response))
|
mock_data_received(protocol, generate_plaintext_packet(done_response))
|
||||||
|
|
||||||
with pytest.raises(BluetoothGATTAPIError):
|
with pytest.raises(BluetoothGATTAPIError):
|
||||||
await services_task
|
await services_task
|
||||||
|
@ -1164,9 +1169,10 @@ async def test_bluetooth_gatt_start_notify(
|
||||||
data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||||
address=1234, handle=1, data=b"gotit"
|
address=1234, handle=1, data=b"gotit"
|
||||||
)
|
)
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
|
protocol,
|
||||||
generate_plaintext_packet(notify_response)
|
generate_plaintext_packet(notify_response)
|
||||||
+ generate_plaintext_packet(data_response)
|
+ generate_plaintext_packet(data_response),
|
||||||
)
|
)
|
||||||
|
|
||||||
cancel_cb, abort_cb = await notify_task
|
cancel_cb, abort_cb = await notify_task
|
||||||
|
@ -1175,7 +1181,7 @@ async def test_bluetooth_gatt_start_notify(
|
||||||
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||||
address=1234, handle=1, data=b"after finished"
|
address=1234, handle=1, data=b"after finished"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(second_data_response))
|
mock_data_received(protocol, generate_plaintext_packet(second_data_response))
|
||||||
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
||||||
await cancel_cb()
|
await cancel_cb()
|
||||||
|
|
||||||
|
@ -1244,7 +1250,7 @@ async def test_subscribe_bluetooth_le_advertisements(
|
||||||
manufacturer_data={},
|
manufacturer_data={},
|
||||||
address_type=1,
|
address_type=1,
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
|
||||||
assert advs == [
|
assert advs == [
|
||||||
BluetoothLEAdvertisement(
|
BluetoothLEAdvertisement(
|
||||||
|
@ -1290,7 +1296,7 @@ async def test_subscribe_bluetooth_le_raw_advertisements(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
assert len(adv_groups) == 1
|
assert len(adv_groups) == 1
|
||||||
first_adv = adv_groups[0][0]
|
first_adv = adv_groups[0][0]
|
||||||
assert first_adv.address == 1234
|
assert first_adv.address == 1234
|
||||||
|
@ -1318,7 +1324,7 @@ async def test_subscribe_bluetooth_connections_free(
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3)
|
response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
|
||||||
assert connections == [(2, 3)]
|
assert connections == [(2, 3)]
|
||||||
unsub()
|
unsub()
|
||||||
|
@ -1345,7 +1351,7 @@ async def test_subscribe_home_assistant_states(
|
||||||
response: message.Message = SubscribeHomeAssistantStateResponse(
|
response: message.Message = SubscribeHomeAssistantStateResponse(
|
||||||
entity_id="sensor.red", attribute="any"
|
entity_id="sensor.red", attribute="any"
|
||||||
)
|
)
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
|
|
||||||
assert states == [("sensor.red", "any")]
|
assert states == [("sensor.red", "any")]
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ from .common import (
|
||||||
connect,
|
connect,
|
||||||
generate_plaintext_packet,
|
generate_plaintext_packet,
|
||||||
get_mock_protocol,
|
get_mock_protocol,
|
||||||
|
mock_data_received,
|
||||||
send_ping_request,
|
send_ping_request,
|
||||||
send_ping_response,
|
send_ping_response,
|
||||||
send_plaintext_connect_response,
|
send_plaintext_connect_response,
|
||||||
|
@ -52,20 +53,22 @@ async def test_connect(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that a plaintext connection works."""
|
"""Test that a plaintext connection works."""
|
||||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
|
protocol,
|
||||||
bytes.fromhex(
|
bytes.fromhex(
|
||||||
"003602080110091a216d6173746572617672656c61792028657"
|
"003602080110091a216d6173746572617672656c61792028657"
|
||||||
"370686f6d652076323032332e362e3329220d6d617374657261"
|
"370686f6d652076323032332e362e3329220d6d617374657261"
|
||||||
"7672656c6179"
|
"7672656c6179"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
|
protocol,
|
||||||
bytes.fromhex(
|
bytes.fromhex(
|
||||||
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
|
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
|
||||||
"46323a33453a35453a36302208323032332e362e332a154a756e"
|
"46323a33453a35453a36302208323032332e362e332a154a756e"
|
||||||
"20323820323032332c2031383a31323a3236320965737033322d"
|
"20323820323032332c2031383a31323a3236320965737033322d"
|
||||||
"65766250506209457370726573736966"
|
"65766250506209457370726573736966"
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
await connect_task
|
await connect_task
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
@ -80,13 +83,14 @@ async def test_timeout_sending_message(
|
||||||
) -> None:
|
) -> None:
|
||||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||||
|
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
|
protocol,
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||||
b"5stackatomproxy"
|
b"5stackatomproxy"
|
||||||
b"\x00\x00$"
|
b"\x00\x00$"
|
||||||
b"\x00\x00\x04"
|
b"\x00\x00\x04"
|
||||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif",
|
||||||
)
|
)
|
||||||
|
|
||||||
await connect_task
|
await connect_task
|
||||||
|
@ -117,8 +121,9 @@ async def test_disconnect_when_not_fully_connected(
|
||||||
|
|
||||||
# Only send the first part of the handshake
|
# Only send the first part of the handshake
|
||||||
# so we are stuck in the middle of the connection process
|
# so we are stuck in the middle of the connection process
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
protocol,
|
||||||
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||||
)
|
)
|
||||||
|
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
@ -156,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
||||||
with pytest.raises(RequiresEncryptionAPIError):
|
with pytest.raises(RequiresEncryptionAPIError):
|
||||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
protocol.data_received(b"\x01\x00\x00")
|
mock_data_received(protocol, b"\x01\x00\x00")
|
||||||
await task
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
@ -175,17 +180,19 @@ async def test_plaintext_connection(
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
protocol,
|
||||||
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||||
)
|
)
|
||||||
protocol.data_received(b"5stackatomproxy")
|
mock_data_received(protocol, b"5stackatomproxy")
|
||||||
protocol.data_received(b"\x00\x00$")
|
mock_data_received(protocol, b"\x00\x00$")
|
||||||
protocol.data_received(b"\x00\x00\x04")
|
mock_data_received(protocol, b"\x00\x00\x04")
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
protocol,
|
||||||
|
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
||||||
)
|
)
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await connect_task
|
await connect_task
|
||||||
|
@ -308,8 +315,9 @@ async def test_finish_connection_times_out(
|
||||||
messages.append(msg)
|
messages.append(msg)
|
||||||
|
|
||||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
protocol,
|
||||||
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
@ -386,17 +394,19 @@ async def test_plaintext_connection_fails_handshake(
|
||||||
assert conn._socket is not None
|
assert conn._socket is not None
|
||||||
assert conn._frame_helper is not None
|
assert conn._frame_helper is not None
|
||||||
|
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
protocol,
|
||||||
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||||
)
|
)
|
||||||
protocol.data_received(b"5stackatomproxy")
|
mock_data_received(protocol, b"5stackatomproxy")
|
||||||
protocol.data_received(b"\x00\x00$")
|
mock_data_received(protocol, b"\x00\x00$")
|
||||||
protocol.data_received(b"\x00\x00\x04")
|
mock_data_received(protocol, b"\x00\x00\x04")
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
protocol,
|
||||||
|
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
||||||
)
|
)
|
||||||
protocol.data_received(
|
mock_data_received(
|
||||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||||
)
|
)
|
||||||
|
|
||||||
call_order = []
|
call_order = []
|
||||||
|
@ -530,11 +540,9 @@ async def test_disconnect_fails_to_send_response(
|
||||||
await connect_task
|
await connect_task
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
|
||||||
with pytest.raises(SocketAPIError), patch.object(
|
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||||
protocol, "_writer", side_effect=OSError
|
|
||||||
):
|
|
||||||
disconnect_request = DisconnectRequest()
|
disconnect_request = DisconnectRequest()
|
||||||
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||||
|
|
||||||
# Wait one loop iteration for the disconnect to be processed
|
# Wait one loop iteration for the disconnect to be processed
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
@ -589,7 +597,7 @@ async def test_disconnect_success_case(
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
|
||||||
disconnect_request = DisconnectRequest()
|
disconnect_request = DisconnectRequest()
|
||||||
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||||
|
|
||||||
# Wait one loop iteration for the disconnect to be processed
|
# Wait one loop iteration for the disconnect to be processed
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
|
@ -17,6 +17,7 @@ from .common import (
|
||||||
Estr,
|
Estr,
|
||||||
generate_plaintext_packet,
|
generate_plaintext_packet,
|
||||||
get_mock_async_zeroconf,
|
get_mock_async_zeroconf,
|
||||||
|
mock_data_received,
|
||||||
send_plaintext_connect_response,
|
send_plaintext_connect_response,
|
||||||
send_plaintext_hello,
|
send_plaintext_hello,
|
||||||
)
|
)
|
||||||
|
@ -74,11 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
||||||
|
|
||||||
response: message.Message = SubscribeLogsResponse()
|
response: message.Message = SubscribeLogsResponse()
|
||||||
response.message = b"Hello world"
|
response.message = b"Hello world"
|
||||||
protocol.data_received(generate_plaintext_packet(response))
|
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||||
assert len(messages) == 1
|
assert len(messages) == 1
|
||||||
assert messages[0].message == b"Hello world"
|
assert messages[0].message == b"Hello world"
|
||||||
stop_task = asyncio.create_task(stop())
|
stop_task = asyncio.create_task(stop())
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
disconnect_response = DisconnectResponse()
|
disconnect_response = DisconnectResponse()
|
||||||
protocol.data_received(generate_plaintext_packet(disconnect_response))
|
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
|
||||||
await stop_task
|
await stop_task
|
||||||
|
|
Loading…
Reference in New Issue