Fix error in noise frame helper were we could write when the writer was unset (#685)

This commit is contained in:
J. Nick Koston 2023-11-24 09:42:56 -06:00 committed by GitHub
parent dba6c72735
commit 095ef822f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 149 additions and 121 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import binascii import binascii
import logging import logging
from functools import partial from functools import partial
@ -128,10 +129,10 @@ class APINoiseFrameHelper(APIFrameHelper):
exc.__cause__ = original_exc exc.__cause__ = original_exc
super()._handle_error(exc) super()._handle_error(exc)
async def perform_handshake(self, timeout: float) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Perform the handshake with the server.""" """Handle a new connection."""
super().connection_made(transport)
self._send_hello_handshake() self._send_hello_handshake()
await super().perform_handshake(timeout)
def data_received(self, data: bytes | bytearray | memoryview) -> None: def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data) self._add_to_buffer(data)

View File

@ -10,7 +10,7 @@ from google.protobuf import message
from zeroconf import Zeroconf from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf from zeroconf.asyncio import AsyncZeroconf
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
from aioesphomeapi.api_pb2 import ( from aioesphomeapi.api_pb2 import (
ConnectResponse, ConnectResponse,
@ -31,6 +31,20 @@ utcnow.__doc__ = "Get now in UTC time."
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()} PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
def mock_data_received(
protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes
) -> None:
"""Mock data received on the protocol."""
try:
protocol.data_received(data)
except Exception as err: # pylint: disable=broad-except
loop = asyncio.get_running_loop()
loop.call_soon(
protocol.connection_lost,
err,
)
def get_mock_zeroconf() -> MagicMock: def get_mock_zeroconf() -> MagicMock:
with patch("zeroconf.Zeroconf.start"): with patch("zeroconf.Zeroconf.start"):
zc = Zeroconf() zc = Zeroconf()

View File

@ -46,8 +46,7 @@ def socket_socket():
yield func yield func
@pytest.fixture def get_mock_connection_params() -> ConnectionParams:
def connection_params() -> ConnectionParams:
return ConnectionParams( return ConnectionParams(
address="fake.address", address="fake.address",
port=6052, port=6052,
@ -60,6 +59,11 @@ def connection_params() -> ConnectionParams:
) )
@pytest.fixture
def connection_params() -> ConnectionParams:
return get_mock_connection_params()
@pytest.fixture @pytest.fixture
def noise_connection_params() -> ConnectionParams: def noise_connection_params() -> ConnectionParams:
return ConnectionParams( return ConnectionParams(

View File

@ -4,7 +4,7 @@ import asyncio
import base64 import base64
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from noise.connection import NoiseConnection # type: ignore[import-untyped] from noise.connection import NoiseConnection # type: ignore[import-untyped]
@ -30,7 +30,13 @@ from aioesphomeapi.core import (
SocketClosedAPIError, SocketClosedAPIError,
) )
from .common import async_fire_time_changed, get_mock_protocol, utcnow from .common import (
async_fire_time_changed,
get_mock_protocol,
mock_data_received,
utcnow,
)
from .conftest import get_mock_connection_params
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -42,18 +48,27 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
class MockConnection(APIConnection): class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args.""" """Swallow args."""
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
def process_packet(self, type_: int, data: bytes): def process_packet(self, type_: int, data: bytes):
packets.append((type_, data)) packets.append((type_, data))
def report_fatal_error(self, exc: Exception):
raise exc
connection = MockConnection() connection = MockConnection()
return connection, packets return connection, packets
class MockAPINoiseFrameHelper(APINoiseFrameHelper): class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(*args, **kwargs)
transport = MagicMock()
transport.write = writer or MagicMock()
self.__transport = transport
self.connection_made(transport)
def connection_made(self, transport: Any) -> None:
return super().connection_made(self.__transport)
def mock_write_frame(self, frame: bytes) -> None: def mock_write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket. """Write a packet to the socket.
@ -125,7 +140,7 @@ def test_plaintext_frame_helper(
connection=connection, client_info="my client", log_name="test" connection=connection, client_info="my client", log_name="test"
) )
helper.data_received(in_bytes) mock_data_received(helper, in_bytes)
pkt = packets.pop() pkt = packets.pop()
type_, data = pkt type_, data = pkt
@ -135,7 +150,7 @@ def test_plaintext_frame_helper(
# Make sure we correctly handle fragments # Make sure we correctly handle fragments
for i in range(len(in_bytes)): for i in range(len(in_bytes)):
helper.data_received(in_bytes[i : i + 1]) mock_data_received(helper, in_bytes[i : i + 1])
pkt = packets.pop() pkt = packets.pop()
type_, data = pkt type_, data = pkt
@ -166,7 +181,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4) PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
) )
helper.data_received(in_bytes) mock_data_received(helper, in_bytes)
pkt = packets.pop() pkt = packets.pop()
type_, data = pkt type_, data = pkt
@ -176,7 +191,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
# Make sure we correctly handle fragments # Make sure we correctly handle fragments
for i in range(len(in_bytes)): for i in range(len(in_bytes)):
helper.data_received(in_bytes[i : i + 1]) mock_data_received(helper, in_bytes[i : i + 1])
pkt = packets.pop() pkt = packets.pop()
type_, data = pkt type_, data = pkt
@ -215,15 +230,12 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets: for pkt in outgoing_packets:
helper.mock_write_frame(byte_type(bytes.fromhex(pkt))) helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError): for pkt in incoming_packets:
for pkt in incoming_packets: mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
helper.data_received(byte_type(bytes.fromhex(pkt)))
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.perform_handshake(30)
@ -249,15 +261,12 @@ async def test_noise_frame_helper_incorrect_key():
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets: for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt)) helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError): for pkt in incoming_packets:
for pkt in incoming_packets: mock_data_received(helper, bytes.fromhex(pkt))
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.perform_handshake(30)
@ -283,17 +292,14 @@ async def test_noise_frame_helper_incorrect_key_fragments():
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets: for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt)) helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(InvalidEncryptionKeyAPIError): for pkt in incoming_packets:
for pkt in incoming_packets: in_pkt = bytes.fromhex(pkt)
in_pkt = bytes.fromhex(pkt) for i in range(len(in_pkt)):
for i in range(len(in_pkt)): mock_data_received(helper, in_pkt[i : i + 1])
helper.data_received(in_pkt[i : i + 1])
with pytest.raises(InvalidEncryptionKeyAPIError): with pytest.raises(InvalidEncryptionKeyAPIError):
await helper.perform_handshake(30) await helper.perform_handshake(30)
@ -319,15 +325,12 @@ async def test_noise_incorrect_name():
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets: for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt)) helper.mock_write_frame(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError): for pkt in incoming_packets:
for pkt in incoming_packets: mock_data_received(helper, bytes.fromhex(pkt))
helper.data_received(bytes.fromhex(pkt))
with pytest.raises(BadNameAPIError): with pytest.raises(BadNameAPIError):
await helper.perform_handshake(30) await helper.perform_handshake(30)
@ -350,8 +353,6 @@ async def test_noise_timeout():
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
for pkt in outgoing_packets: for pkt in outgoing_packets:
helper.mock_write_frame(bytes.fromhex(pkt)) helper.mock_write_frame(bytes.fromhex(pkt))
@ -408,9 +409,8 @@ async def test_noise_frame_helper_handshake_failure():
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
log_name="test", log_name="test",
writer=_writer,
) )
helper._transport = MagicMock()
helper._writer = _writer
proto = NoiseConnection.from_name( proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
@ -448,7 +448,7 @@ async def test_noise_frame_helper_handshake_failure():
hello_pkg_length_low = hello_pkg_length & 0xFF hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + hello_pkt hello_pkt_with_header = hello_header + hello_pkt
helper.data_received(hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
error_pkt = b"\x01forced to fail" error_pkt = b"\x01forced to fail"
preamble = 1 preamble = 1
@ -458,8 +458,7 @@ async def test_noise_frame_helper_handshake_failure():
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low)) error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
error_pkt_with_header = error_header + error_pkt error_pkt_with_header = error_header + error_pkt
with pytest.raises(HandshakeAPIError, match="forced to fail"): mock_data_received(helper, error_pkt_with_header)
helper.data_received(error_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="forced to fail"): with pytest.raises(HandshakeAPIError, match="forced to fail"):
await handshake_task await handshake_task
@ -483,9 +482,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
log_name="test", log_name="test",
writer=_writer,
) )
helper._transport = MagicMock()
helper._writer = _writer
proto = NoiseConnection.from_name( proto = NoiseConnection.from_name(
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
@ -523,7 +521,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
hello_pkg_length_low = hello_pkg_length & 0xFF hello_pkg_length_low = hello_pkg_length & 0xFF
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + hello_pkt hello_pkt_with_header = hello_header + hello_pkt
helper.data_received(hello_pkt_with_header) mock_data_received(helper, hello_pkt_with_header)
handshake = proto.write_message(b"") handshake = proto.write_message(b"")
handshake_pkt = b"\x00" + handshake handshake_pkt = b"\x00" + handshake
@ -536,7 +534,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
) )
handshake_with_header = handshake_header + handshake_pkt handshake_with_header = handshake_header + handshake_pkt
helper.data_received(handshake_with_header) mock_data_received(helper, handshake_with_header)
assert not writes assert not writes
@ -566,13 +564,12 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
encrypted_header = bytes( encrypted_header = bytes(
(preamble, encrypted_pkg_length_high, encrypted_pkg_length_low) (preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
) )
helper.data_received(encrypted_header + encrypted_payload) mock_data_received(helper, encrypted_header + encrypted_payload)
assert packets == [(42, b"from device")] assert packets == [(42, b"from device")]
helper.close() helper.close()
with pytest.raises(ProtocolAPIError, match="Connection closed"): mock_data_received(helper, encrypted_header + encrypted_payload)
helper.data_received(encrypted_header + encrypted_payload)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -590,7 +587,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
task = asyncio.create_task(conn._connect_hello_login(login=True)) task = asyncio.create_task(conn._connect_hello_login(login=True))
await asyncio.sleep(0) await asyncio.sleep(0)
# The preamble should be \x00 but we send \x09 # The preamble should be \x00 but we send \x09
protocol.data_received(b"\x09\x00\x00") mock_data_received(protocol, b"\x09\x00\x00")
with pytest.raises(ProtocolAPIError): with pytest.raises(ProtocolAPIError):
await task await task
@ -615,7 +612,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
assert protocol is not None assert protocol is not None
assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper) assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper)
protocol.data_received(b"\x00\x00\x00") mock_data_received(protocol, b"\x00\x00\x00")
with pytest.raises(ProtocolAPIError, match="Marker byte invalid"): with pytest.raises(ProtocolAPIError, match="Marker byte invalid"):
await task await task
@ -632,8 +629,6 @@ async def test_noise_frame_helper_empty_hello():
client_info="my client", client_info="my client",
log_name="test", log_name="test",
) )
helper._transport = MagicMock()
helper._writer = MagicMock()
handshake_task = asyncio.create_task(helper.perform_handshake(30)) handshake_task = asyncio.create_task(helper.perform_handshake(30))
empty_hello_pkt = b"" empty_hello_pkt = b""
@ -644,8 +639,7 @@ async def test_noise_frame_helper_empty_hello():
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low)) hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
hello_pkt_with_header = hello_header + empty_hello_pkt hello_pkt_with_header = hello_header + empty_hello_pkt
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): mock_data_received(helper, hello_pkt_with_header)
helper.data_received(hello_pkt_with_header)
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"): with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
await handshake_task await handshake_task

View File

@ -88,7 +88,12 @@ from aioesphomeapi.model import (
) )
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import Estr, generate_plaintext_packet, get_mock_zeroconf from .common import (
Estr,
generate_plaintext_packet,
get_mock_zeroconf,
mock_data_received,
)
@pytest.fixture @pytest.fixture
@ -849,7 +854,7 @@ async def test_bluetooth_disconnect(
response: message.Message = BluetoothDeviceConnectionResponse( response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False address=1234, connected=False
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await disconnect_task await disconnect_task
@ -864,7 +869,7 @@ async def test_bluetooth_pair(
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234)) pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDevicePairingResponse(address=1234) response: message.Message = BluetoothDevicePairingResponse(address=1234)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await pair_task await pair_task
@ -879,7 +884,7 @@ async def test_bluetooth_unpair(
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234)) unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234) response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await unpair_task await unpair_task
@ -894,7 +899,7 @@ async def test_bluetooth_clear_cache(
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234)) clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234) response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await clear_task await clear_task
@ -914,7 +919,7 @@ async def test_device_info(
friendly_name="My Device", friendly_name="My Device",
has_deep_sleep=True, has_deep_sleep=True,
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
device_info = await device_info_task device_info = await device_info_task
assert device_info.name == "realname" assert device_info.name == "realname"
assert device_info.friendly_name == "My Device" assert device_info.friendly_name == "My Device"
@ -923,7 +928,7 @@ async def test_device_info(
disconnect_task = asyncio.create_task(client.disconnect()) disconnect_task = asyncio.create_task(client.disconnect())
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = DisconnectResponse() response: message.Message = DisconnectResponse()
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await disconnect_task await disconnect_task
with pytest.raises(APIConnectionError, match="CLOSED"): with pytest.raises(APIConnectionError, match="CLOSED"):
await client.device_info() await client.device_info()
@ -943,12 +948,12 @@ async def test_bluetooth_gatt_read(
other_response: message.Message = BluetoothGATTReadResponse( other_response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=4567, data=b"4567" address=1234, handle=4567, data=b"4567"
) )
protocol.data_received(generate_plaintext_packet(other_response)) mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTReadResponse( response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=1234, data=b"1234" address=1234, handle=1234, data=b"1234"
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert await read_task == b"1234" assert await read_task == b"1234"
@ -966,12 +971,12 @@ async def test_bluetooth_gatt_read_descriptor(
other_response: message.Message = BluetoothGATTReadResponse( other_response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=4567, data=b"4567" address=1234, handle=4567, data=b"4567"
) )
protocol.data_received(generate_plaintext_packet(other_response)) mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTReadResponse( response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=1234, data=b"1234" address=1234, handle=1234, data=b"1234"
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert await read_task == b"1234" assert await read_task == b"1234"
@ -991,10 +996,10 @@ async def test_bluetooth_gatt_write(
other_response: message.Message = BluetoothGATTWriteResponse( other_response: message.Message = BluetoothGATTWriteResponse(
address=1234, handle=4567 address=1234, handle=4567
) )
protocol.data_received(generate_plaintext_packet(other_response)) mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234) response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await write_task await write_task
@ -1034,10 +1039,10 @@ async def test_bluetooth_gatt_write_descriptor(
other_response: message.Message = BluetoothGATTWriteResponse( other_response: message.Message = BluetoothGATTWriteResponse(
address=1234, handle=4567 address=1234, handle=4567
) )
protocol.data_received(generate_plaintext_packet(other_response)) mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234) response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
await write_task await write_task
@ -1077,12 +1082,12 @@ async def test_bluetooth_gatt_read_descriptor(
other_response: message.Message = BluetoothGATTReadResponse( other_response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=4567, data=b"4567" address=1234, handle=4567, data=b"4567"
) )
protocol.data_received(generate_plaintext_packet(other_response)) mock_data_received(protocol, generate_plaintext_packet(other_response))
response: message.Message = BluetoothGATTReadResponse( response: message.Message = BluetoothGATTReadResponse(
address=1234, handle=1234, data=b"1234" address=1234, handle=1234, data=b"1234"
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert await read_task == b"1234" assert await read_task == b"1234"
@ -1102,9 +1107,9 @@ async def test_bluetooth_gatt_get_services(
response: message.Message = BluetoothGATTGetServicesResponse( response: message.Message = BluetoothGATTGetServicesResponse(
address=1234, services=[service1] address=1234, services=[service1]
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234) done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234)
protocol.data_received(generate_plaintext_packet(done_response)) mock_data_received(protocol, generate_plaintext_packet(done_response))
services = await services_task services = await services_task
assert services == ESPHomeBluetoothGATTServices( assert services == ESPHomeBluetoothGATTServices(
@ -1129,9 +1134,9 @@ async def test_bluetooth_gatt_get_services_errors(
response: message.Message = BluetoothGATTGetServicesResponse( response: message.Message = BluetoothGATTGetServicesResponse(
address=1234, services=[service1] address=1234, services=[service1]
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
done_response: message.Message = BluetoothGATTErrorResponse(address=1234) done_response: message.Message = BluetoothGATTErrorResponse(address=1234)
protocol.data_received(generate_plaintext_packet(done_response)) mock_data_received(protocol, generate_plaintext_packet(done_response))
with pytest.raises(BluetoothGATTAPIError): with pytest.raises(BluetoothGATTAPIError):
await services_task await services_task
@ -1164,9 +1169,10 @@ async def test_bluetooth_gatt_start_notify(
data_response: message.Message = BluetoothGATTNotifyDataResponse( data_response: message.Message = BluetoothGATTNotifyDataResponse(
address=1234, handle=1, data=b"gotit" address=1234, handle=1, data=b"gotit"
) )
protocol.data_received( mock_data_received(
protocol,
generate_plaintext_packet(notify_response) generate_plaintext_packet(notify_response)
+ generate_plaintext_packet(data_response) + generate_plaintext_packet(data_response),
) )
cancel_cb, abort_cb = await notify_task cancel_cb, abort_cb = await notify_task
@ -1175,7 +1181,7 @@ async def test_bluetooth_gatt_start_notify(
second_data_response: message.Message = BluetoothGATTNotifyDataResponse( second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
address=1234, handle=1, data=b"after finished" address=1234, handle=1, data=b"after finished"
) )
protocol.data_received(generate_plaintext_packet(second_data_response)) mock_data_received(protocol, generate_plaintext_packet(second_data_response))
assert notifies == [(1, b"gotit"), (1, b"after finished")] assert notifies == [(1, b"gotit"), (1, b"after finished")]
await cancel_cb() await cancel_cb()
@ -1244,7 +1250,7 @@ async def test_subscribe_bluetooth_le_advertisements(
manufacturer_data={}, manufacturer_data={},
address_type=1, address_type=1,
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert advs == [ assert advs == [
BluetoothLEAdvertisement( BluetoothLEAdvertisement(
@ -1290,7 +1296,7 @@ async def test_subscribe_bluetooth_le_raw_advertisements(
) )
] ]
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert len(adv_groups) == 1 assert len(adv_groups) == 1
first_adv = adv_groups[0][0] first_adv = adv_groups[0][0]
assert first_adv.address == 1234 assert first_adv.address == 1234
@ -1318,7 +1324,7 @@ async def test_subscribe_bluetooth_connections_free(
) )
await asyncio.sleep(0) await asyncio.sleep(0)
response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3) response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3)
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert connections == [(2, 3)] assert connections == [(2, 3)]
unsub() unsub()
@ -1345,7 +1351,7 @@ async def test_subscribe_home_assistant_states(
response: message.Message = SubscribeHomeAssistantStateResponse( response: message.Message = SubscribeHomeAssistantStateResponse(
entity_id="sensor.red", attribute="any" entity_id="sensor.red", attribute="any"
) )
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert states == [("sensor.red", "any")] assert states == [("sensor.red", "any")]

View File

@ -33,6 +33,7 @@ from .common import (
connect, connect,
generate_plaintext_packet, generate_plaintext_packet,
get_mock_protocol, get_mock_protocol,
mock_data_received,
send_ping_request, send_ping_request,
send_ping_response, send_ping_response,
send_plaintext_connect_response, send_plaintext_connect_response,
@ -52,20 +53,22 @@ async def test_connect(
) -> None: ) -> None:
"""Test that a plaintext connection works.""" """Test that a plaintext connection works."""
conn, transport, protocol, connect_task = plaintext_connect_task_no_login conn, transport, protocol, connect_task = plaintext_connect_task_no_login
protocol.data_received( mock_data_received(
protocol,
bytes.fromhex( bytes.fromhex(
"003602080110091a216d6173746572617672656c61792028657" "003602080110091a216d6173746572617672656c61792028657"
"370686f6d652076323032332e362e3329220d6d617374657261" "370686f6d652076323032332e362e3329220d6d617374657261"
"7672656c6179" "7672656c6179"
) ),
) )
protocol.data_received( mock_data_received(
protocol,
bytes.fromhex( bytes.fromhex(
"005b0a120d6d6173746572617672656c61791a1130383a33413a" "005b0a120d6d6173746572617672656c61791a1130383a33413a"
"46323a33453a35453a36302208323032332e362e332a154a756e" "46323a33453a35453a36302208323032332e362e332a154a756e"
"20323820323032332c2031383a31323a3236320965737033322d" "20323820323032332c2031383a31323a3236320965737033322d"
"65766250506209457370726573736966" "65766250506209457370726573736966"
) ),
) )
await connect_task await connect_task
assert conn.is_connected assert conn.is_connected
@ -80,13 +83,14 @@ async def test_timeout_sending_message(
) -> None: ) -> None:
conn, transport, protocol, connect_task = plaintext_connect_task_no_login conn, transport, protocol, connect_task = plaintext_connect_task_no_login
protocol.data_received( mock_data_received(
protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
b"5stackatomproxy" b"5stackatomproxy"
b"\x00\x00$" b"\x00\x00$"
b"\x00\x00\x04" b"\x00\x00\x04"
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif",
) )
await connect_task await connect_task
@ -117,8 +121,9 @@ async def test_disconnect_when_not_fully_connected(
# Only send the first part of the handshake # Only send the first part of the handshake
# so we are stuck in the middle of the connection process # so we are stuck in the middle of the connection process
protocol.data_received( mock_data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
) )
await asyncio.sleep(0) await asyncio.sleep(0)
@ -156,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
with pytest.raises(RequiresEncryptionAPIError): with pytest.raises(RequiresEncryptionAPIError):
task = asyncio.create_task(conn._connect_hello_login(login=True)) task = asyncio.create_task(conn._connect_hello_login(login=True))
await asyncio.sleep(0) await asyncio.sleep(0)
protocol.data_received(b"\x01\x00\x00") mock_data_received(protocol, b"\x01\x00\x00")
await task await task
@ -175,17 +180,19 @@ async def test_plaintext_connection(
messages.append(msg) messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
protocol.data_received( mock_data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
) )
protocol.data_received(b"5stackatomproxy") mock_data_received(protocol, b"5stackatomproxy")
protocol.data_received(b"\x00\x00$") mock_data_received(protocol, b"\x00\x00$")
protocol.data_received(b"\x00\x00\x04") mock_data_received(protocol, b"\x00\x00\x04")
protocol.data_received( mock_data_received(
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' protocol,
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
) )
protocol.data_received( mock_data_received(
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
) )
await asyncio.sleep(0) await asyncio.sleep(0)
await connect_task await connect_task
@ -308,8 +315,9 @@ async def test_finish_connection_times_out(
messages.append(msg) messages.append(msg)
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse)) remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
protocol.data_received( mock_data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
) )
await asyncio.sleep(0) await asyncio.sleep(0)
@ -386,17 +394,19 @@ async def test_plaintext_connection_fails_handshake(
assert conn._socket is not None assert conn._socket is not None
assert conn._frame_helper is not None assert conn._frame_helper is not None
protocol.data_received( mock_data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' protocol,
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
) )
protocol.data_received(b"5stackatomproxy") mock_data_received(protocol, b"5stackatomproxy")
protocol.data_received(b"\x00\x00$") mock_data_received(protocol, b"\x00\x00$")
protocol.data_received(b"\x00\x00\x04") mock_data_received(protocol, b"\x00\x00\x04")
protocol.data_received( mock_data_received(
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' protocol,
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
) )
protocol.data_received( mock_data_received(
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
) )
call_order = [] call_order = []
@ -530,11 +540,9 @@ async def test_disconnect_fails_to_send_response(
await connect_task await connect_task
assert conn.is_connected assert conn.is_connected
with pytest.raises(SocketAPIError), patch.object( with patch.object(protocol, "_writer", side_effect=OSError):
protocol, "_writer", side_effect=OSError
):
disconnect_request = DisconnectRequest() disconnect_request = DisconnectRequest()
protocol.data_received(generate_plaintext_packet(disconnect_request)) mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
# Wait one loop iteration for the disconnect to be processed # Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0) await asyncio.sleep(0)
@ -589,7 +597,7 @@ async def test_disconnect_success_case(
assert conn.is_connected assert conn.is_connected
disconnect_request = DisconnectRequest() disconnect_request = DisconnectRequest()
protocol.data_received(generate_plaintext_packet(disconnect_request)) mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
# Wait one loop iteration for the disconnect to be processed # Wait one loop iteration for the disconnect to be processed
await asyncio.sleep(0) await asyncio.sleep(0)

View File

@ -17,6 +17,7 @@ from .common import (
Estr, Estr,
generate_plaintext_packet, generate_plaintext_packet,
get_mock_async_zeroconf, get_mock_async_zeroconf,
mock_data_received,
send_plaintext_connect_response, send_plaintext_connect_response,
send_plaintext_hello, send_plaintext_hello,
) )
@ -74,11 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
response: message.Message = SubscribeLogsResponse() response: message.Message = SubscribeLogsResponse()
response.message = b"Hello world" response.message = b"Hello world"
protocol.data_received(generate_plaintext_packet(response)) mock_data_received(protocol, generate_plaintext_packet(response))
assert len(messages) == 1 assert len(messages) == 1
assert messages[0].message == b"Hello world" assert messages[0].message == b"Hello world"
stop_task = asyncio.create_task(stop()) stop_task = asyncio.create_task(stop())
await asyncio.sleep(0) await asyncio.sleep(0)
disconnect_response = DisconnectResponse() disconnect_response = DisconnectResponse()
protocol.data_received(generate_plaintext_packet(disconnect_response)) mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
await stop_task await stop_task