mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Fix pong timer warning when pending ping is skipped (#483)
This commit is contained in:
parent
6aeea79884
commit
e909891ebe
@ -70,6 +70,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
"_client_info",
|
||||
"_log_name",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -77,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
self._on_pkt = on_pkt
|
||||
@ -87,6 +89,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
self._log_name = log_name
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
@ -118,11 +121,13 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._on_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
|
||||
self._handle_error(
|
||||
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
|
||||
)
|
||||
return super().connection_lost(exc)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
self._handle_error(SocketClosedAPIError("EOF received"))
|
||||
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
|
||||
return super().eof_received()
|
||||
|
||||
def close(self) -> None:
|
||||
@ -142,12 +147,14 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
assert self._transport is not None, "Transport should be set"
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
|
||||
try:
|
||||
self._transport.write(data)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
@ -170,11 +177,15 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
if preamble != 0x00:
|
||||
if preamble == 0x01:
|
||||
self._handle_error_and_close(
|
||||
RequiresEncryptionAPIError("Connection requires encryption")
|
||||
RequiresEncryptionAPIError(
|
||||
f"{self._log_name}: Connection requires encryption"
|
||||
)
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Invalid preamble {preamble:02x}")
|
||||
ProtocolAPIError(
|
||||
f"{self._log_name}: Invalid preamble {preamble:02x}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@ -285,9 +296,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
noise_psk: str,
|
||||
expected_name: Optional[str],
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error, client_info)
|
||||
super().__init__(on_pkt, on_error, client_info, log_name)
|
||||
self._ready_future = asyncio.get_event_loop().create_future()
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
@ -311,7 +323,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
# Make sure we set the ready event if its not already set
|
||||
# so that we don't block forever on the ready event if we
|
||||
# are waiting for the handshake to complete.
|
||||
self._set_ready_future_exception(APIConnectionError("Connection closed"))
|
||||
self._set_ready_future_exception(
|
||||
APIConnectionError(f"{self._log_name}: Connection closed")
|
||||
)
|
||||
self._set_state(NoiseConnectionState.CLOSED)
|
||||
super().close()
|
||||
|
||||
@ -327,7 +341,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
):
|
||||
original_exc = exc
|
||||
exc = HandshakeAPIError(
|
||||
"The connection dropped immediately after encrypted hello; "
|
||||
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
|
||||
"Try enabling encryption on the device or turning off "
|
||||
f"encryption on the client ({self._client_info})."
|
||||
)
|
||||
@ -341,7 +355,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"""
|
||||
assert self._transport is not None, "Transport is not set"
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("Sending frame: [%s]", frame.hex())
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
try:
|
||||
@ -354,7 +368,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
)
|
||||
self._transport.write(header + frame)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
@ -363,7 +379,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
async with async_timeout.timeout(60.0):
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError("Timeout during handshake") from err
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
@ -376,7 +394,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
preamble, msg_size_high, msg_size_low = header
|
||||
if preamble != 0x01:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
ProtocolAPIError(
|
||||
f"{self._log_name}: Marker byte invalid: {header[0]}"
|
||||
)
|
||||
)
|
||||
return
|
||||
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
|
||||
@ -404,7 +424,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
if not server_hello:
|
||||
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"{self._log_name}: ServerHello is empty")
|
||||
)
|
||||
return
|
||||
|
||||
# First byte of server hello is the protocol the server chose
|
||||
@ -413,7 +435,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
|
||||
HandshakeAPIError(
|
||||
f"{self._log_name}: Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@ -429,7 +453,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if self._expected_name is not None and self._expected_name != server_name:
|
||||
self._handle_error_and_close(
|
||||
BadNameAPIError(
|
||||
f"Server sent a different name '{server_name}'", server_name
|
||||
f"{self._log_name}: Server sent a different name '{server_name}'",
|
||||
server_name,
|
||||
)
|
||||
)
|
||||
return
|
||||
@ -458,19 +483,19 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
if explanation == "Handshake MAC failure":
|
||||
self._handle_error_and_close(
|
||||
InvalidEncryptionKeyAPIError(
|
||||
"Invalid encryption key", self._server_name
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
|
||||
)
|
||||
return
|
||||
try:
|
||||
self._proto.read_message(msg[1:])
|
||||
except InvalidTag as invalid_tag_exc:
|
||||
ex = InvalidEncryptionKeyAPIError(
|
||||
"Invalid encryption key", self._server_name
|
||||
f"{self._log_name}: Invalid encryption key", self._server_name
|
||||
)
|
||||
ex.__cause__ = invalid_tag_exc
|
||||
self._handle_error_and_close(ex)
|
||||
@ -491,7 +516,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
if self._state != NoiseConnectionState.READY:
|
||||
raise HandshakeAPIError("Noise connection is not ready")
|
||||
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
data_len = len(data)
|
||||
@ -517,7 +542,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
msg = self._decrypt(bytes(frame))
|
||||
except InvalidTag as ex:
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
|
||||
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
|
||||
)
|
||||
return
|
||||
# Message layout is
|
||||
@ -530,7 +555,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self, frame: bytearray
|
||||
) -> None:
|
||||
"""Handle a closed frame."""
|
||||
self._handle_error(ProtocolAPIError("Connection closed"))
|
||||
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
|
||||
|
||||
STATE_TO_CALLABLE = {
|
||||
NoiseConnectionState.HELLO: _handle_hello,
|
||||
|
@ -316,6 +316,7 @@ class APIConnection:
|
||||
on_pkt=process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
client_info=self._params.client_info,
|
||||
log_name=self.log_name,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
@ -327,6 +328,7 @@ class APIConnection:
|
||||
on_pkt=process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
client_info=self._params.client_info,
|
||||
log_name=self.log_name,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
@ -391,7 +393,6 @@ class APIConnection:
|
||||
|
||||
if self._send_pending_ping:
|
||||
self.send_message(PING_REQUEST_MESSAGE)
|
||||
|
||||
if self._pong_timer is None:
|
||||
# Do not reset the timer if it's already set
|
||||
# since the only thing we want to reset the timer
|
||||
@ -540,15 +541,14 @@ class APIConnection:
|
||||
|
||||
frame_helper = self._frame_helper
|
||||
assert frame_helper is not None
|
||||
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
|
||||
_msg_type = type(msg)
|
||||
message_type = PROTO_TO_MESSAGE_TYPE.get(_msg_type)
|
||||
if not message_type:
|
||||
raise ValueError(f"Message type id not found for type {type(msg)}")
|
||||
raise ValueError(f"Message type id not found for type {_msg_type}")
|
||||
encoded = msg.SerializeToString()
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"%s: Sending %s: %s", self._params.address, type(msg), str(msg)
|
||||
)
|
||||
_LOGGER.debug("%s: Sending %s: %s", self.log_name, _msg_type.__name__, msg)
|
||||
|
||||
try:
|
||||
frame_helper.write_packet(message_type, encoded)
|
||||
@ -730,7 +730,9 @@ class APIConnection:
|
||||
class_ = message_type_to_proto[msg_type_proto]
|
||||
except KeyError:
|
||||
_LOGGER.debug(
|
||||
"%s: Skipping message type %s", self.log_name, msg_type_proto
|
||||
"%s: Skipping message type %s",
|
||||
self.log_name,
|
||||
msg_type_proto,
|
||||
)
|
||||
return
|
||||
|
||||
@ -760,7 +762,10 @@ class APIConnection:
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
|
||||
"%s: Got message of type %s: %s",
|
||||
self.log_name,
|
||||
msg_type.__name__,
|
||||
msg,
|
||||
)
|
||||
|
||||
if self._pong_timer:
|
||||
|
@ -69,7 +69,7 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
raise exc
|
||||
|
||||
helper = APIPlaintextFrameHelper(
|
||||
on_pkt=_packet, on_error=_on_error, client_info="my client"
|
||||
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test"
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
@ -116,6 +116,7 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
@ -155,6 +156,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
@ -196,6 +198,7 @@ async def test_noise_incorrect_name():
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
|
@ -57,6 +57,7 @@ def _get_mock_protocol(conn: APIConnection):
|
||||
on_pkt=conn._process_packet_factory(),
|
||||
on_error=conn._report_fatal_error,
|
||||
client_info="mock",
|
||||
log_name="mock_device",
|
||||
)
|
||||
protocol._connected_event.set()
|
||||
protocol._transport = MagicMock()
|
||||
|
Loading…
Reference in New Issue
Block a user