Fix error in noise frame helper were we could write when the writer was unset (#685)
This commit is contained in:
parent
dba6c72735
commit
095ef822f1
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import logging
|
||||
from functools import partial
|
||||
|
@ -128,10 +129,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||
exc.__cause__ = original_exc
|
||||
super()._handle_error(exc)
|
||||
|
||||
async def perform_handshake(self, timeout: float) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
super().connection_made(transport)
|
||||
self._send_hello_handshake()
|
||||
await super().perform_handshake(timeout)
|
||||
|
||||
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self._add_to_buffer(data)
|
||||
|
|
|
@ -10,7 +10,7 @@ from google.protobuf import message
|
|||
from zeroconf import Zeroconf
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.plain_text import _cached_varuint_to_bytes
|
||||
from aioesphomeapi.api_pb2 import (
|
||||
ConnectResponse,
|
||||
|
@ -31,6 +31,20 @@ utcnow.__doc__ = "Get now in UTC time."
|
|||
PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
def mock_data_received(
|
||||
protocol: APINoiseFrameHelper | APIPlaintextFrameHelper, data: bytes
|
||||
) -> None:
|
||||
"""Mock data received on the protocol."""
|
||||
try:
|
||||
protocol.data_received(data)
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.call_soon(
|
||||
protocol.connection_lost,
|
||||
err,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_zeroconf() -> MagicMock:
|
||||
with patch("zeroconf.Zeroconf.start"):
|
||||
zc = Zeroconf()
|
||||
|
|
|
@ -46,8 +46,7 @@ def socket_socket():
|
|||
yield func
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_params() -> ConnectionParams:
|
||||
def get_mock_connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
address="fake.address",
|
||||
port=6052,
|
||||
|
@ -60,6 +59,11 @@ def connection_params() -> ConnectionParams:
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def connection_params() -> ConnectionParams:
|
||||
return get_mock_connection_params()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def noise_connection_params() -> ConnectionParams:
|
||||
return ConnectionParams(
|
||||
|
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import base64
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from noise.connection import NoiseConnection # type: ignore[import-untyped]
|
||||
|
@ -30,7 +30,13 @@ from aioesphomeapi.core import (
|
|||
SocketClosedAPIError,
|
||||
)
|
||||
|
||||
from .common import async_fire_time_changed, get_mock_protocol, utcnow
|
||||
from .common import (
|
||||
async_fire_time_changed,
|
||||
get_mock_protocol,
|
||||
mock_data_received,
|
||||
utcnow,
|
||||
)
|
||||
from .conftest import get_mock_connection_params
|
||||
|
||||
PREAMBLE = b"\x00"
|
||||
|
||||
|
@ -42,18 +48,27 @@ def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
|
|||
class MockConnection(APIConnection):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
"""Swallow args."""
|
||||
super().__init__(get_mock_connection_params(), AsyncMock(), *args, **kwargs)
|
||||
|
||||
def process_packet(self, type_: int, data: bytes):
|
||||
packets.append((type_, data))
|
||||
|
||||
def report_fatal_error(self, exc: Exception):
|
||||
raise exc
|
||||
|
||||
connection = MockConnection()
|
||||
return connection, packets
|
||||
|
||||
|
||||
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
|
||||
"""Swallow args."""
|
||||
super().__init__(*args, **kwargs)
|
||||
transport = MagicMock()
|
||||
transport.write = writer or MagicMock()
|
||||
self.__transport = transport
|
||||
self.connection_made(transport)
|
||||
|
||||
def connection_made(self, transport: Any) -> None:
|
||||
return super().connection_made(self.__transport)
|
||||
|
||||
def mock_write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
|
@ -125,7 +140,7 @@ def test_plaintext_frame_helper(
|
|||
connection=connection, client_info="my client", log_name="test"
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
mock_data_received(helper, in_bytes)
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
@ -135,7 +150,7 @@ def test_plaintext_frame_helper(
|
|||
|
||||
# Make sure we correctly handle fragments
|
||||
for i in range(len(in_bytes)):
|
||||
helper.data_received(in_bytes[i : i + 1])
|
||||
mock_data_received(helper, in_bytes[i : i + 1])
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
@ -166,7 +181,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
|
|||
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4)
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
mock_data_received(helper, in_bytes)
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
@ -176,7 +191,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
|
|||
|
||||
# Make sure we correctly handle fragments
|
||||
for i in range(len(in_bytes)):
|
||||
helper.data_received(in_bytes[i : i + 1])
|
||||
mock_data_received(helper, in_bytes[i : i + 1])
|
||||
|
||||
pkt = packets.pop()
|
||||
type_, data = pkt
|
||||
|
@ -215,15 +230,12 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(byte_type(bytes.fromhex(pkt)))
|
||||
for pkt in incoming_packets:
|
||||
mock_data_received(helper, byte_type(bytes.fromhex(pkt)))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
@ -249,15 +261,12 @@ async def test_noise_frame_helper_incorrect_key():
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
for pkt in incoming_packets:
|
||||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
@ -283,17 +292,14 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
for pkt in incoming_packets:
|
||||
in_pkt = bytes.fromhex(pkt)
|
||||
for i in range(len(in_pkt)):
|
||||
helper.data_received(in_pkt[i : i + 1])
|
||||
for pkt in incoming_packets:
|
||||
in_pkt = bytes.fromhex(pkt)
|
||||
for i in range(len(in_pkt)):
|
||||
mock_data_received(helper, in_pkt[i : i + 1])
|
||||
|
||||
with pytest.raises(InvalidEncryptionKeyAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
@ -319,15 +325,12 @@ async def test_noise_incorrect_name():
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
for pkt in incoming_packets:
|
||||
helper.data_received(bytes.fromhex(pkt))
|
||||
for pkt in incoming_packets:
|
||||
mock_data_received(helper, bytes.fromhex(pkt))
|
||||
|
||||
with pytest.raises(BadNameAPIError):
|
||||
await helper.perform_handshake(30)
|
||||
|
@ -350,8 +353,6 @@ async def test_noise_timeout():
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper.mock_write_frame(bytes.fromhex(pkt))
|
||||
|
@ -408,9 +409,8 @@ async def test_noise_frame_helper_handshake_failure():
|
|||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = _writer
|
||||
|
||||
proto = NoiseConnection.from_name(
|
||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||
|
@ -448,7 +448,7 @@ async def test_noise_frame_helper_handshake_failure():
|
|||
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||
hello_pkt_with_header = hello_header + hello_pkt
|
||||
helper.data_received(hello_pkt_with_header)
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
||||
error_pkt = b"\x01forced to fail"
|
||||
preamble = 1
|
||||
|
@ -458,8 +458,7 @@ async def test_noise_frame_helper_handshake_failure():
|
|||
error_header = bytes((preamble, error_pkg_length_high, error_pkg_length_low))
|
||||
error_pkt_with_header = error_header + error_pkt
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||
helper.data_received(error_pkt_with_header)
|
||||
mock_data_received(helper, error_pkt_with_header)
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="forced to fail"):
|
||||
await handshake_task
|
||||
|
@ -483,9 +482,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = _writer
|
||||
|
||||
proto = NoiseConnection.from_name(
|
||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||
|
@ -523,7 +521,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||
hello_pkg_length_low = hello_pkg_length & 0xFF
|
||||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||
hello_pkt_with_header = hello_header + hello_pkt
|
||||
helper.data_received(hello_pkt_with_header)
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
||||
handshake = proto.write_message(b"")
|
||||
handshake_pkt = b"\x00" + handshake
|
||||
|
@ -536,7 +534,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||
)
|
||||
handshake_with_header = handshake_header + handshake_pkt
|
||||
|
||||
helper.data_received(handshake_with_header)
|
||||
mock_data_received(helper, handshake_with_header)
|
||||
|
||||
assert not writes
|
||||
|
||||
|
@ -566,13 +564,12 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||
encrypted_header = bytes(
|
||||
(preamble, encrypted_pkg_length_high, encrypted_pkg_length_low)
|
||||
)
|
||||
helper.data_received(encrypted_header + encrypted_payload)
|
||||
mock_data_received(helper, encrypted_header + encrypted_payload)
|
||||
|
||||
assert packets == [(42, b"from device")]
|
||||
helper.close()
|
||||
|
||||
with pytest.raises(ProtocolAPIError, match="Connection closed"):
|
||||
helper.data_received(encrypted_header + encrypted_payload)
|
||||
mock_data_received(helper, encrypted_header + encrypted_payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -590,7 +587,7 @@ async def test_init_plaintext_with_wrong_preamble(conn: APIConnection):
|
|||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
await asyncio.sleep(0)
|
||||
# The preamble should be \x00 but we send \x09
|
||||
protocol.data_received(b"\x09\x00\x00")
|
||||
mock_data_received(protocol, b"\x09\x00\x00")
|
||||
|
||||
with pytest.raises(ProtocolAPIError):
|
||||
await task
|
||||
|
@ -615,7 +612,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
|
|||
assert protocol is not None
|
||||
assert isinstance(noise_conn._frame_helper, APINoiseFrameHelper)
|
||||
|
||||
protocol.data_received(b"\x00\x00\x00")
|
||||
mock_data_received(protocol, b"\x00\x00\x00")
|
||||
|
||||
with pytest.raises(ProtocolAPIError, match="Marker byte invalid"):
|
||||
await task
|
||||
|
@ -632,8 +629,6 @@ async def test_noise_frame_helper_empty_hello():
|
|||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
handshake_task = asyncio.create_task(helper.perform_handshake(30))
|
||||
empty_hello_pkt = b""
|
||||
|
@ -644,8 +639,7 @@ async def test_noise_frame_helper_empty_hello():
|
|||
hello_header = bytes((preamble, hello_pkg_length_high, hello_pkg_length_low))
|
||||
hello_pkt_with_header = hello_header + empty_hello_pkt
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
||||
helper.data_received(hello_pkt_with_header)
|
||||
mock_data_received(helper, hello_pkt_with_header)
|
||||
|
||||
with pytest.raises(HandshakeAPIError, match="ServerHello is empty"):
|
||||
await handshake_task
|
||||
|
|
|
@ -88,7 +88,12 @@ from aioesphomeapi.model import (
|
|||
)
|
||||
from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
|
||||
from .common import Estr, generate_plaintext_packet, get_mock_zeroconf
|
||||
from .common import (
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_zeroconf,
|
||||
mock_data_received,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -849,7 +854,7 @@ async def test_bluetooth_disconnect(
|
|||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await disconnect_task
|
||||
|
||||
|
||||
|
@ -864,7 +869,7 @@ async def test_bluetooth_pair(
|
|||
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDevicePairingResponse(address=1234)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await pair_task
|
||||
|
||||
|
||||
|
@ -879,7 +884,7 @@ async def test_bluetooth_unpair(
|
|||
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await unpair_task
|
||||
|
||||
|
||||
|
@ -894,7 +899,7 @@ async def test_bluetooth_clear_cache(
|
|||
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await clear_task
|
||||
|
||||
|
||||
|
@ -914,7 +919,7 @@ async def test_device_info(
|
|||
friendly_name="My Device",
|
||||
has_deep_sleep=True,
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
device_info = await device_info_task
|
||||
assert device_info.name == "realname"
|
||||
assert device_info.friendly_name == "My Device"
|
||||
|
@ -923,7 +928,7 @@ async def test_device_info(
|
|||
disconnect_task = asyncio.create_task(client.disconnect())
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = DisconnectResponse()
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await disconnect_task
|
||||
with pytest.raises(APIConnectionError, match="CLOSED"):
|
||||
await client.device_info()
|
||||
|
@ -943,12 +948,12 @@ async def test_bluetooth_gatt_read(
|
|||
other_response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=4567, data=b"4567"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(other_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=1234, data=b"1234"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert await read_task == b"1234"
|
||||
|
||||
|
||||
|
@ -966,12 +971,12 @@ async def test_bluetooth_gatt_read_descriptor(
|
|||
other_response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=4567, data=b"4567"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(other_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=1234, data=b"1234"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert await read_task == b"1234"
|
||||
|
||||
|
||||
|
@ -991,10 +996,10 @@ async def test_bluetooth_gatt_write(
|
|||
other_response: message.Message = BluetoothGATTWriteResponse(
|
||||
address=1234, handle=4567
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(other_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await write_task
|
||||
|
||||
|
||||
|
@ -1034,10 +1039,10 @@ async def test_bluetooth_gatt_write_descriptor(
|
|||
other_response: message.Message = BluetoothGATTWriteResponse(
|
||||
address=1234, handle=4567
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(other_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTWriteResponse(address=1234, handle=1234)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
await write_task
|
||||
|
||||
|
||||
|
@ -1077,12 +1082,12 @@ async def test_bluetooth_gatt_read_descriptor(
|
|||
other_response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=4567, data=b"4567"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(other_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(other_response))
|
||||
|
||||
response: message.Message = BluetoothGATTReadResponse(
|
||||
address=1234, handle=1234, data=b"1234"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert await read_task == b"1234"
|
||||
|
||||
|
||||
|
@ -1102,9 +1107,9 @@ async def test_bluetooth_gatt_get_services(
|
|||
response: message.Message = BluetoothGATTGetServicesResponse(
|
||||
address=1234, services=[service1]
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
done_response: message.Message = BluetoothGATTGetServicesDoneResponse(address=1234)
|
||||
protocol.data_received(generate_plaintext_packet(done_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(done_response))
|
||||
|
||||
services = await services_task
|
||||
assert services == ESPHomeBluetoothGATTServices(
|
||||
|
@ -1129,9 +1134,9 @@ async def test_bluetooth_gatt_get_services_errors(
|
|||
response: message.Message = BluetoothGATTGetServicesResponse(
|
||||
address=1234, services=[service1]
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
done_response: message.Message = BluetoothGATTErrorResponse(address=1234)
|
||||
protocol.data_received(generate_plaintext_packet(done_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(done_response))
|
||||
|
||||
with pytest.raises(BluetoothGATTAPIError):
|
||||
await services_task
|
||||
|
@ -1164,9 +1169,10 @@ async def test_bluetooth_gatt_start_notify(
|
|||
data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||
address=1234, handle=1, data=b"gotit"
|
||||
)
|
||||
protocol.data_received(
|
||||
mock_data_received(
|
||||
protocol,
|
||||
generate_plaintext_packet(notify_response)
|
||||
+ generate_plaintext_packet(data_response)
|
||||
+ generate_plaintext_packet(data_response),
|
||||
)
|
||||
|
||||
cancel_cb, abort_cb = await notify_task
|
||||
|
@ -1175,7 +1181,7 @@ async def test_bluetooth_gatt_start_notify(
|
|||
second_data_response: message.Message = BluetoothGATTNotifyDataResponse(
|
||||
address=1234, handle=1, data=b"after finished"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(second_data_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(second_data_response))
|
||||
assert notifies == [(1, b"gotit"), (1, b"after finished")]
|
||||
await cancel_cb()
|
||||
|
||||
|
@ -1244,7 +1250,7 @@ async def test_subscribe_bluetooth_le_advertisements(
|
|||
manufacturer_data={},
|
||||
address_type=1,
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
|
||||
assert advs == [
|
||||
BluetoothLEAdvertisement(
|
||||
|
@ -1290,7 +1296,7 @@ async def test_subscribe_bluetooth_le_raw_advertisements(
|
|||
)
|
||||
]
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert len(adv_groups) == 1
|
||||
first_adv = adv_groups[0][0]
|
||||
assert first_adv.address == 1234
|
||||
|
@ -1318,7 +1324,7 @@ async def test_subscribe_bluetooth_connections_free(
|
|||
)
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothConnectionsFreeResponse(free=2, limit=3)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
|
||||
assert connections == [(2, 3)]
|
||||
unsub()
|
||||
|
@ -1345,7 +1351,7 @@ async def test_subscribe_home_assistant_states(
|
|||
response: message.Message = SubscribeHomeAssistantStateResponse(
|
||||
entity_id="sensor.red", attribute="any"
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
|
||||
assert states == [("sensor.red", "any")]
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from .common import (
|
|||
connect,
|
||||
generate_plaintext_packet,
|
||||
get_mock_protocol,
|
||||
mock_data_received,
|
||||
send_ping_request,
|
||||
send_ping_response,
|
||||
send_plaintext_connect_response,
|
||||
|
@ -52,20 +53,22 @@ async def test_connect(
|
|||
) -> None:
|
||||
"""Test that a plaintext connection works."""
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
protocol.data_received(
|
||||
mock_data_received(
|
||||
protocol,
|
||||
bytes.fromhex(
|
||||
"003602080110091a216d6173746572617672656c61792028657"
|
||||
"370686f6d652076323032332e362e3329220d6d617374657261"
|
||||
"7672656c6179"
|
||||
)
|
||||
),
|
||||
)
|
||||
protocol.data_received(
|
||||
mock_data_received(
|
||||
protocol,
|
||||
bytes.fromhex(
|
||||
"005b0a120d6d6173746572617672656c61791a1130383a33413a"
|
||||
"46323a33453a35453a36302208323032332e362e332a154a756e"
|
||||
"20323820323032332c2031383a31323a3236320965737033322d"
|
||||
"65766250506209457370726573736966"
|
||||
)
|
||||
),
|
||||
)
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
@ -80,13 +83,14 @@ async def test_timeout_sending_message(
|
|||
) -> None:
|
||||
conn, transport, protocol, connect_task = plaintext_connect_task_no_login
|
||||
|
||||
protocol.data_received(
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
b"5stackatomproxy"
|
||||
b"\x00\x00$"
|
||||
b"\x00\x00\x04"
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif",
|
||||
)
|
||||
|
||||
await connect_task
|
||||
|
@ -117,8 +121,9 @@ async def test_disconnect_when_not_fully_connected(
|
|||
|
||||
# Only send the first part of the handshake
|
||||
# so we are stuck in the middle of the connection process
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||
)
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
@ -156,7 +161,7 @@ async def test_requires_encryption_propagates(conn: APIConnection):
|
|||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
task = asyncio.create_task(conn._connect_hello_login(login=True))
|
||||
await asyncio.sleep(0)
|
||||
protocol.data_received(b"\x01\x00\x00")
|
||||
mock_data_received(protocol, b"\x01\x00\x00")
|
||||
await task
|
||||
|
||||
|
||||
|
@ -175,17 +180,19 @@ async def test_plaintext_connection(
|
|||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||
)
|
||||
protocol.data_received(b"5stackatomproxy")
|
||||
protocol.data_received(b"\x00\x00$")
|
||||
protocol.data_received(b"\x00\x00\x04")
|
||||
protocol.data_received(
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
mock_data_received(protocol, b"5stackatomproxy")
|
||||
mock_data_received(protocol, b"\x00\x00$")
|
||||
mock_data_received(protocol, b"\x00\x00\x04")
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
||||
)
|
||||
protocol.data_received(
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
mock_data_received(
|
||||
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await connect_task
|
||||
|
@ -308,8 +315,9 @@ async def test_finish_connection_times_out(
|
|||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, (HelloResponse, DeviceInfoResponse))
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
@ -386,17 +394,19 @@ async def test_plaintext_connection_fails_handshake(
|
|||
assert conn._socket is not None
|
||||
assert conn._frame_helper is not None
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m',
|
||||
)
|
||||
protocol.data_received(b"5stackatomproxy")
|
||||
protocol.data_received(b"\x00\x00$")
|
||||
protocol.data_received(b"\x00\x00\x04")
|
||||
protocol.data_received(
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
mock_data_received(protocol, b"5stackatomproxy")
|
||||
mock_data_received(protocol, b"\x00\x00$")
|
||||
mock_data_received(protocol, b"\x00\x00\x04")
|
||||
mock_data_received(
|
||||
protocol,
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d',
|
||||
)
|
||||
protocol.data_received(
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
mock_data_received(
|
||||
protocol, b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
|
||||
call_order = []
|
||||
|
@ -530,11 +540,9 @@ async def test_disconnect_fails_to_send_response(
|
|||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
||||
with pytest.raises(SocketAPIError), patch.object(
|
||||
protocol, "_writer", side_effect=OSError
|
||||
):
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
disconnect_request = DisconnectRequest()
|
||||
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||
|
||||
# Wait one loop iteration for the disconnect to be processed
|
||||
await asyncio.sleep(0)
|
||||
|
@ -589,7 +597,7 @@ async def test_disconnect_success_case(
|
|||
assert conn.is_connected
|
||||
|
||||
disconnect_request = DisconnectRequest()
|
||||
protocol.data_received(generate_plaintext_packet(disconnect_request))
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||
|
||||
# Wait one loop iteration for the disconnect to be processed
|
||||
await asyncio.sleep(0)
|
||||
|
|
|
@ -17,6 +17,7 @@ from .common import (
|
|||
Estr,
|
||||
generate_plaintext_packet,
|
||||
get_mock_async_zeroconf,
|
||||
mock_data_received,
|
||||
send_plaintext_connect_response,
|
||||
send_plaintext_hello,
|
||||
)
|
||||
|
@ -74,11 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
|||
|
||||
response: message.Message = SubscribeLogsResponse()
|
||||
response.message = b"Hello world"
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(response))
|
||||
assert len(messages) == 1
|
||||
assert messages[0].message == b"Hello world"
|
||||
stop_task = asyncio.create_task(stop())
|
||||
await asyncio.sleep(0)
|
||||
disconnect_response = DisconnectResponse()
|
||||
protocol.data_received(generate_plaintext_packet(disconnect_response))
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_response))
|
||||
await stop_task
|
||||
|
|
Loading…
Reference in New Issue