mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-03-02 04:01:56 +01:00
Fix Invalid protobuf message: expected bytes, bytearray found and add coverage (#359)
This commit is contained in:
parent
14a9ffc26b
commit
a83838d025
@ -33,7 +33,7 @@ SOCKET_ERRORS = (
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Packet:
|
class Packet:
|
||||||
type: int
|
type: int
|
||||||
data: Union[bytes, bytearray]
|
data: bytes
|
||||||
|
|
||||||
|
|
||||||
class APIFrameHelper(asyncio.Protocol):
|
class APIFrameHelper(asyncio.Protocol):
|
||||||
@ -192,7 +192,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
if packet_data is None:
|
if packet_data is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._callback_packet(msg_type_int, packet_data)
|
self._callback_packet(msg_type_int, bytes(packet_data))
|
||||||
# If we have more data, continue processing
|
# If we have more data, continue processing
|
||||||
|
|
||||||
|
|
||||||
|
@ -409,7 +409,9 @@ class APIConnection:
|
|||||||
def send_message(self, msg: message.Message) -> None:
|
def send_message(self, msg: message.Message) -> None:
|
||||||
"""Send a protobuf message to the remote."""
|
"""Send a protobuf message to the remote."""
|
||||||
if not self._is_socket_open:
|
if not self._is_socket_open:
|
||||||
raise APIConnectionError("Connection isn't established yet")
|
raise APIConnectionError(
|
||||||
|
f"Connection isn't established yet ({self._connection_state})"
|
||||||
|
)
|
||||||
|
|
||||||
frame_helper = self._frame_helper
|
frame_helper = self._frame_helper
|
||||||
assert frame_helper is not None
|
assert frame_helper is not None
|
||||||
@ -577,7 +579,10 @@ class APIConnection:
|
|||||||
|
|
||||||
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
|
||||||
try:
|
try:
|
||||||
msg.ParseFromString(pkt.data)
|
# MergeFromString instead of ParseFromString since
|
||||||
|
# ParseFromString will clear the message first and
|
||||||
|
# the msg is already empty.
|
||||||
|
msg.MergeFromString(pkt.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_LOGGER.info(
|
_LOGGER.info(
|
||||||
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
"%s: Invalid protobuf message: type=%s data=%s: %s",
|
||||||
@ -587,7 +592,11 @@ class APIConnection:
|
|||||||
e,
|
e,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}"))
|
self._report_fatal_error(
|
||||||
|
ProtocolAPIError(
|
||||||
|
f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}"
|
||||||
|
)
|
||||||
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
msg_type = type(msg)
|
msg_type = type(msg)
|
||||||
|
@ -2,12 +2,12 @@ import asyncio
|
|||||||
import socket
|
import socket
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from mock import AsyncMock, MagicMock, Mock, patch
|
from mock import MagicMock, patch
|
||||||
|
|
||||||
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
|
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||||
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
|
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
|
||||||
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
|
||||||
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
|
from aioesphomeapi.core import RequiresEncryptionAPIError
|
||||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||||
|
|
||||||
|
|
||||||
@ -51,14 +51,10 @@ def socket_socket():
|
|||||||
yield func
|
yield func
|
||||||
|
|
||||||
|
|
||||||
def _get_mock_protocol():
|
def _get_mock_protocol(conn: APIConnection):
|
||||||
def _on_packet(pkt: Packet):
|
protocol = APIPlaintextFrameHelper(
|
||||||
pass
|
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
|
||||||
|
)
|
||||||
def _on_error(exc: Exception):
|
|
||||||
raise exc
|
|
||||||
|
|
||||||
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
|
|
||||||
protocol._connected_event.set()
|
protocol._connected_event.set()
|
||||||
protocol._transport = MagicMock()
|
protocol._transport = MagicMock()
|
||||||
return protocol
|
return protocol
|
||||||
@ -67,7 +63,7 @@ def _get_mock_protocol():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
protocol = _get_mock_protocol()
|
protocol = _get_mock_protocol(conn)
|
||||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||||
loop, "create_connection", return_value=(MagicMock(), protocol)
|
loop, "create_connection", return_value=(MagicMock(), protocol)
|
||||||
), patch.object(conn, "_connect_start_ping"), patch.object(
|
), patch.object(conn, "_connect_start_ping"), patch.object(
|
||||||
@ -79,16 +75,63 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_requires_encryption_propagates(conn):
|
async def test_requires_encryption_propagates(conn: APIConnection):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
protocol = _get_mock_protocol()
|
protocol = _get_mock_protocol(conn)
|
||||||
with patch.object(loop, "create_connection") as create_connection, patch.object(
|
with patch.object(loop, "create_connection") as create_connection, patch.object(
|
||||||
protocol, "perform_handshake"
|
protocol, "perform_handshake"
|
||||||
):
|
):
|
||||||
create_connection.return_value = (MagicMock(), protocol)
|
create_connection.return_value = (MagicMock(), protocol)
|
||||||
|
|
||||||
await conn._connect_init_frame_helper()
|
await conn._connect_init_frame_helper()
|
||||||
|
conn._connection_state = ConnectionState.CONNECTED
|
||||||
|
|
||||||
with pytest.raises(RequiresEncryptionAPIError):
|
with pytest.raises(RequiresEncryptionAPIError):
|
||||||
|
task = asyncio.create_task(conn._connect_hello())
|
||||||
|
await asyncio.sleep(0)
|
||||||
protocol.data_received(b"\x01\x00\x00")
|
protocol.data_received(b"\x01\x00\x00")
|
||||||
await conn._connect_hello()
|
await task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket):
|
||||||
|
"""Test that a plaintext connection works."""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
protocol = _get_mock_protocol(conn)
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
def on_msg(msg):
|
||||||
|
messages.append(msg)
|
||||||
|
|
||||||
|
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
|
||||||
|
transport = MagicMock()
|
||||||
|
|
||||||
|
with patch.object(conn, "_connect_hello"), patch.object(
|
||||||
|
loop, "sock_connect"
|
||||||
|
), patch.object(loop, "create_connection") as create_connection, patch.object(
|
||||||
|
protocol, "perform_handshake"
|
||||||
|
):
|
||||||
|
create_connection.return_value = (transport, protocol)
|
||||||
|
await conn.connect(login=False)
|
||||||
|
|
||||||
|
protocol.data_received(
|
||||||
|
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
|
||||||
|
)
|
||||||
|
protocol.data_received(b"5stackatomproxy")
|
||||||
|
protocol.data_received(b"\x00\x00$")
|
||||||
|
protocol.data_received(b"\x00\x00\x04")
|
||||||
|
protocol.data_received(
|
||||||
|
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
|
||||||
|
)
|
||||||
|
protocol.data_received(
|
||||||
|
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert conn.is_connected
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert isinstance(messages[0], HelloResponse)
|
||||||
|
assert isinstance(messages[1], DeviceInfoResponse)
|
||||||
|
assert messages[1].name == "m5stackatomproxy"
|
||||||
|
remove()
|
||||||
|
await conn.force_disconnect()
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user