diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index bdf2290..a98a7a2 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -815,14 +815,14 @@ class APIConnection: _LOGGER.error( "%s: Invalid protobuf message: type=%s data=%s: %s", self.log_name, - msg_type_proto, + klass.__name__, data, e, exc_info=True, ) self.report_fatal_error( ProtocolAPIError( - f"Invalid protobuf message: type={msg_type_proto} data={data!r}: {e}" + f"Invalid protobuf message: type={klass.__name__} data={data!r}: {e}" ) ) raise diff --git a/tests/test_connection.py b/tests/test_connection.py index 190fd42..270fc5f 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -19,6 +19,7 @@ from aioesphomeapi.api_pb2 import ( HelloResponse, PingRequest, PingResponse, + TextSensorStateResponse, ) from aioesphomeapi.connection import APIConnection, ConnectionParams, ConnectionState from aioesphomeapi.core import ( @@ -738,3 +739,33 @@ async def test_unknown_protobuf_message_type_logged( assert connection.is_connected await connection.force_disconnect() await asyncio.sleep(0) + + +@pytest.mark.asyncio +async def test_bad_protobuf_message_drops_connection( + api_client: tuple[ + APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper + ], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test ad bad protobuf messages is logged and causes the connection to collapse.""" + client, connection, transport, protocol = api_client + msg: message.Message = TextSensorStateResponse( + key=1, state="invalid", missing_state=False + ) + caplog.clear() + caplog.set_level(logging.DEBUG) + client.set_debug(True) + bytes_ = msg.SerializeToString() + # Replace the bytes with invalid UTF-8 + bytes_ = bytes.replace(bytes_, b"invalid", b"inval\xe9 ") + + message_with_bad_protobuf_data = ( + b"\0" + + _cached_varuint_to_bytes(len(bytes_)) + + _cached_varuint_to_bytes(27) + + bytes_ + ) + mock_data_received(protocol, message_with_bad_protobuf_data) + assert "Invalid protobuf message: type=TextSensorStateResponse" in caplog.text + assert connection.is_connected is False