diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 1688ac5..ffd18ab 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -33,7 +33,7 @@ SOCKET_ERRORS = ( @dataclass class Packet: type: int - data: Union[bytes, bytearray] + data: bytes class APIFrameHelper(asyncio.Protocol): @@ -192,7 +192,7 @@ class APIPlaintextFrameHelper(APIFrameHelper): if packet_data is None: return - self._callback_packet(msg_type_int, packet_data) + self._callback_packet(msg_type_int, bytes(packet_data)) # If we have more data, continue processing diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 6571cda..ec767a3 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -409,7 +409,9 @@ class APIConnection: def send_message(self, msg: message.Message) -> None: """Send a protobuf message to the remote.""" if not self._is_socket_open: - raise APIConnectionError("Connection isn't established yet") + raise APIConnectionError( + f"Connection isn't established yet ({self._connection_state})" + ) frame_helper = self._frame_helper assert frame_helper is not None @@ -577,7 +579,10 @@ class APIConnection: msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]() try: - msg.ParseFromString(pkt.data) + # MergeFromString instead of ParseFromString since + # ParseFromString will clear the message first and + # the msg is already empty. + msg.MergeFromString(pkt.data) except Exception as e: _LOGGER.info( "%s: Invalid protobuf message: type=%s data=%s: %s", @@ -587,7 +592,11 @@ class APIConnection: e, exc_info=True, ) - self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}")) + self._report_fatal_error( + ProtocolAPIError( + f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}" + ) + ) raise msg_type = type(msg) diff --git a/tests/test_connection.py b/tests/test_connection.py index c483f1c..e62f79c 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,12 +2,12 @@ import asyncio import socket import pytest -from mock import AsyncMock, MagicMock, Mock, patch +from mock import MagicMock, patch -from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet -from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse +from aioesphomeapi._frame_helper import APIPlaintextFrameHelper +from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState -from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError +from aioesphomeapi.core import RequiresEncryptionAPIError from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr @@ -51,14 +51,10 @@ def socket_socket(): yield func -def _get_mock_protocol(): - def _on_packet(pkt: Packet): - pass - - def _on_error(exc: Exception): - raise exc - - protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error) +def _get_mock_protocol(conn: APIConnection): + protocol = APIPlaintextFrameHelper( + on_pkt=conn._process_packet, on_error=conn._report_fatal_error + ) protocol._connected_event.set() protocol._transport = MagicMock() return protocol @@ -67,7 +63,7 @@ def _get_mock_protocol(): @pytest.mark.asyncio async def test_connect(conn, resolve_host, socket_socket, event_loop): loop = asyncio.get_event_loop() - protocol = _get_mock_protocol() + protocol = _get_mock_protocol(conn) with patch.object(event_loop, "sock_connect"), patch.object( loop, "create_connection", return_value=(MagicMock(), protocol) ), patch.object(conn, "_connect_start_ping"), patch.object( @@ -79,16 +75,63 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop): @pytest.mark.asyncio -async def test_requires_encryption_propagates(conn): +async def test_requires_encryption_propagates(conn: APIConnection): loop = asyncio.get_event_loop() - protocol = _get_mock_protocol() + protocol = _get_mock_protocol(conn) with patch.object(loop, "create_connection") as create_connection, patch.object( protocol, "perform_handshake" ): create_connection.return_value = (MagicMock(), protocol) await conn._connect_init_frame_helper() + conn._connection_state = ConnectionState.CONNECTED with pytest.raises(RequiresEncryptionAPIError): + task = asyncio.create_task(conn._connect_hello()) + await asyncio.sleep(0) protocol.data_received(b"\x01\x00\x00") - await conn._connect_hello() + await task + + +@pytest.mark.asyncio +async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket): + """Test that a plaintext connection works.""" + loop = asyncio.get_event_loop() + protocol = _get_mock_protocol(conn) + messages = [] + + def on_msg(msg): + messages.append(msg) + + remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse}) + transport = MagicMock() + + with patch.object(conn, "_connect_hello"), patch.object( + loop, "sock_connect" + ), patch.object(loop, "create_connection") as create_connection, patch.object( + protocol, "perform_handshake" + ): + create_connection.return_value = (transport, protocol) + await conn.connect(login=False) + + protocol.data_received( + b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m' + ) + protocol.data_received(b"5stackatomproxy") + protocol.data_received(b"\x00\x00$") + protocol.data_received(b"\x00\x00\x04") + protocol.data_received( + b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d' + ) + protocol.data_received( + b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif" + ) + await asyncio.sleep(0) + assert conn.is_connected + assert len(messages) == 2 + assert isinstance(messages[0], HelloResponse) + assert isinstance(messages[1], DeviceInfoResponse) + assert messages[1].name == "m5stackatomproxy" + remove() + await conn.force_disconnect() + await asyncio.sleep(0)