mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-21 16:37:41 +01:00
Fix Invalid protobuf message: expected bytes, bytearray found and add coverage (#359)
This commit is contained in:
parent
14a9ffc26b
commit
a83838d025
@ -33,7 +33,7 @@ SOCKET_ERRORS = (
|
||||
@dataclass
|
||||
class Packet:
|
||||
type: int
|
||||
data: Union[bytes, bytearray]
|
||||
data: bytes
|
||||
|
||||
|
||||
class APIFrameHelper(asyncio.Protocol):
|
||||
@ -192,7 +192,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
if packet_data is None:
|
||||
return
|
||||
|
||||
self._callback_packet(msg_type_int, packet_data)
|
||||
self._callback_packet(msg_type_int, bytes(packet_data))
|
||||
# If we have more data, continue processing
|
||||
|
||||
|
||||
|
@ -409,7 +409,9 @@ class APIConnection:
|
||||
def send_message(self, msg: message.Message) -> None:
|
||||
"""Send a protobuf message to the remote."""
|
||||
if not self._is_socket_open:
|
||||
raise APIConnectionError("Connection isn't established yet")
|
||||
raise APIConnectionError(
|
||||
f"Connection isn't established yet ({self._connection_state})"
|
||||
)
|
||||
|
||||
frame_helper = self._frame_helper
|
||||
assert frame_helper is not None
|
||||
@ -577,7 +579,10 @@ class APIConnection:
|
||||
|
||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
||||
try:
|
||||
msg.ParseFromString(pkt.data)
|
||||
# MergeFromString instead of ParseFromString since
|
||||
# ParseFromString will clear the message first and
|
||||
# the msg is already empty.
|
||||
msg.MergeFromString(pkt.data)
|
||||
except Exception as e:
|
||||
_LOGGER.info(
|
||||
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
||||
@ -587,7 +592,11 @@ class APIConnection:
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}"))
|
||||
self._report_fatal_error(
|
||||
ProtocolAPIError(
|
||||
f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}"
|
||||
)
|
||||
)
|
||||
raise
|
||||
|
||||
msg_type = type(msg)
|
||||
|
@ -2,12 +2,12 @@ import asyncio
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
from mock import AsyncMock, MagicMock, Mock, patch
|
||||
from mock import MagicMock, patch
|
||||
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
|
||||
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
|
||||
from aioesphomeapi.core import RequiresEncryptionAPIError
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
|
||||
|
||||
@ -51,14 +51,10 @@ def socket_socket():
|
||||
yield func
|
||||
|
||||
|
||||
def _get_mock_protocol():
|
||||
def _on_packet(pkt: Packet):
|
||||
pass
|
||||
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
|
||||
def _get_mock_protocol(conn: APIConnection):
|
||||
protocol = APIPlaintextFrameHelper(
|
||||
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
|
||||
)
|
||||
protocol._connected_event.set()
|
||||
protocol._transport = MagicMock()
|
||||
return protocol
|
||||
@ -67,7 +63,7 @@ def _get_mock_protocol():
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", return_value=(MagicMock(), protocol)
|
||||
), patch.object(conn, "_connect_start_ping"), patch.object(
|
||||
@ -79,16 +75,63 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_requires_encryption_propagates(conn):
|
||||
async def test_requires_encryption_propagates(conn: APIConnection):
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
with patch.object(loop, "create_connection") as create_connection, patch.object(
|
||||
protocol, "perform_handshake"
|
||||
):
|
||||
create_connection.return_value = (MagicMock(), protocol)
|
||||
|
||||
await conn._connect_init_frame_helper()
|
||||
conn._connection_state = ConnectionState.CONNECTED
|
||||
|
||||
with pytest.raises(RequiresEncryptionAPIError):
|
||||
task = asyncio.create_task(conn._connect_hello())
|
||||
await asyncio.sleep(0)
|
||||
protocol.data_received(b"\x01\x00\x00")
|
||||
await conn._connect_hello()
|
||||
await task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket):
|
||||
"""Test that a plaintext connection works."""
|
||||
loop = asyncio.get_event_loop()
|
||||
protocol = _get_mock_protocol(conn)
|
||||
messages = []
|
||||
|
||||
def on_msg(msg):
|
||||
messages.append(msg)
|
||||
|
||||
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
|
||||
transport = MagicMock()
|
||||
|
||||
with patch.object(conn, "_connect_hello"), patch.object(
|
||||
loop, "sock_connect"
|
||||
), patch.object(loop, "create_connection") as create_connection, patch.object(
|
||||
protocol, "perform_handshake"
|
||||
):
|
||||
create_connection.return_value = (transport, protocol)
|
||||
await conn.connect(login=False)
|
||||
|
||||
protocol.data_received(
|
||||
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||
)
|
||||
protocol.data_received(b"5stackatomproxy")
|
||||
protocol.data_received(b"\x00\x00$")
|
||||
protocol.data_received(b"\x00\x00\x04")
|
||||
protocol.data_received(
|
||||
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||
)
|
||||
protocol.data_received(
|
||||
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
assert conn.is_connected
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], HelloResponse)
|
||||
assert isinstance(messages[1], DeviceInfoResponse)
|
||||
assert messages[1].name == "m5stackatomproxy"
|
||||
remove()
|
||||
await conn.force_disconnect()
|
||||
await asyncio.sleep(0)
|
||||
|
Loading…
Reference in New Issue
Block a user