Fix Invalid protobuf message: expected bytes, bytearray found and add coverage (#359)

This commit is contained in:
J. Nick Koston 2023-01-07 14:24:24 -10:00 committed by GitHub
parent 14a9ffc26b
commit a83838d025
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 21 deletions

View File

@ -33,7 +33,7 @@ SOCKET_ERRORS = (
@dataclass
class Packet:
type: int
data: Union[bytes, bytearray]
data: bytes
class APIFrameHelper(asyncio.Protocol):
@ -192,7 +192,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if packet_data is None:
return
self._callback_packet(msg_type_int, packet_data)
self._callback_packet(msg_type_int, bytes(packet_data))
# If we have more data, continue processing

View File

@ -409,7 +409,9 @@ class APIConnection:
def send_message(self, msg: message.Message) -> None:
"""Send a protobuf message to the remote."""
if not self._is_socket_open:
raise APIConnectionError("Connection isn't established yet")
raise APIConnectionError(
f"Connection isn't established yet ({self._connection_state})"
)
frame_helper = self._frame_helper
assert frame_helper is not None
@ -577,7 +579,10 @@ class APIConnection:
msg = MESSAGE_TYPE_TO_PROTO[msg_type_proto]()
try:
msg.ParseFromString(pkt.data)
# MergeFromString instead of ParseFromString since
# ParseFromString will clear the message first and
# the msg is already empty.
msg.MergeFromString(pkt.data)
except Exception as e:
_LOGGER.info(
"%s: Invalid protobuf message: type=%s data=%s: %s",
@ -587,7 +592,11 @@ class APIConnection:
e,
exc_info=True,
)
self._report_fatal_error(ProtocolAPIError(f"Invalid protobuf message: {e}"))
self._report_fatal_error(
ProtocolAPIError(
f"Invalid protobuf message: type={pkt.type} data={pkt.data!r}: {e}"
)
)
raise
msg_type = type(msg)

View File

@ -2,12 +2,12 @@ import asyncio
import socket
import pytest
from mock import AsyncMock, MagicMock, Mock, patch
from mock import MagicMock, patch
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper, Packet
from aioesphomeapi.api_pb2 import ConnectResponse, HelloResponse
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi.api_pb2 import DeviceInfoResponse, HelloResponse
from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState
from aioesphomeapi.core import APIConnectionError, RequiresEncryptionAPIError
from aioesphomeapi.core import RequiresEncryptionAPIError
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
@ -51,14 +51,10 @@ def socket_socket():
yield func
def _get_mock_protocol():
def _on_packet(pkt: Packet):
pass
def _on_error(exc: Exception):
raise exc
protocol = APIPlaintextFrameHelper(on_pkt=_on_packet, on_error=_on_error)
def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
)
protocol._connected_event.set()
protocol._transport = MagicMock()
return protocol
@ -67,7 +63,7 @@ def _get_mock_protocol():
@pytest.mark.asyncio
async def test_connect(conn, resolve_host, socket_socket, event_loop):
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol()
protocol = _get_mock_protocol(conn)
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", return_value=(MagicMock(), protocol)
), patch.object(conn, "_connect_start_ping"), patch.object(
@ -79,16 +75,63 @@ async def test_connect(conn, resolve_host, socket_socket, event_loop):
@pytest.mark.asyncio
async def test_requires_encryption_propagates(conn):
async def test_requires_encryption_propagates(conn: APIConnection):
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol()
protocol = _get_mock_protocol(conn)
with patch.object(loop, "create_connection") as create_connection, patch.object(
protocol, "perform_handshake"
):
create_connection.return_value = (MagicMock(), protocol)
await conn._connect_init_frame_helper()
conn._connection_state = ConnectionState.CONNECTED
with pytest.raises(RequiresEncryptionAPIError):
task = asyncio.create_task(conn._connect_hello())
await asyncio.sleep(0)
protocol.data_received(b"\x01\x00\x00")
await conn._connect_hello()
await task
@pytest.mark.asyncio
async def test_plaintext_connection(conn: APIConnection, resolve_host, socket_socket):
"""Test that a plaintext connection works."""
loop = asyncio.get_event_loop()
protocol = _get_mock_protocol(conn)
messages = []
def on_msg(msg):
messages.append(msg)
remove = conn.add_message_callback(on_msg, {HelloResponse, DeviceInfoResponse})
transport = MagicMock()
with patch.object(conn, "_connect_hello"), patch.object(
loop, "sock_connect"
), patch.object(loop, "create_connection") as create_connection, patch.object(
protocol, "perform_handshake"
):
create_connection.return_value = (transport, protocol)
await conn.connect(login=False)
protocol.data_received(
b'\x00@\x02\x08\x01\x10\x07\x1a(m5stackatomproxy (esphome v2023.1.0-dev)"\x10m'
)
protocol.data_received(b"5stackatomproxy")
protocol.data_received(b"\x00\x00$")
protocol.data_received(b"\x00\x00\x04")
protocol.data_received(
b'\x00e\n\x12\x10m5stackatomproxy\x1a\x11E8:9F:6D:0A:68:E0"\x0c2023.1.0-d'
)
protocol.data_received(
b"ev*\x15Jan 7 2023, 13:19:532\x0cm5stack-atomX\x03b\tEspressif"
)
await asyncio.sleep(0)
assert conn.is_connected
assert len(messages) == 2
assert isinstance(messages[0], HelloResponse)
assert isinstance(messages[1], DeviceInfoResponse)
assert messages[1].name == "m5stackatomproxy"
remove()
await conn.force_disconnect()
await asyncio.sleep(0)