Fix pong timer warning when pending ping is skipped (#483)

This commit is contained in:
J. Nick Koston 2023-07-17 09:27:59 -10:00 committed by GitHub
parent 6aeea79884
commit e909891ebe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 88 additions and 54 deletions

View File

@ -70,6 +70,7 @@ class APIFrameHelper(asyncio.Protocol):
"_buffer_len",
"_pos",
"_client_info",
"_log_name",
)
def __init__(
@ -77,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
client_info: str,
log_name: str,
) -> None:
"""Initialize the API frame helper."""
self._on_pkt = on_pkt
@ -87,6 +89,7 @@ class APIFrameHelper(asyncio.Protocol):
self._buffer_len = 0
self._pos = 0
self._client_info = client_info
self._log_name = log_name
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
@ -118,11 +121,13 @@ class APIFrameHelper(asyncio.Protocol):
self._on_error(exc)
def connection_lost(self, exc: Optional[Exception]) -> None:
self._handle_error(exc or SocketClosedAPIError("Connection lost"))
self._handle_error(
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
)
return super().connection_lost(exc)
def eof_received(self) -> Optional[bool]:
self._handle_error(SocketClosedAPIError("EOF received"))
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
return super().eof_received()
def close(self) -> None:
@ -142,12 +147,14 @@ class APIPlaintextFrameHelper(APIFrameHelper):
assert self._transport is not None, "Transport should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending plaintext frame %s", data.hex())
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
try:
self._transport.write(data)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
async def perform_handshake(self) -> None:
"""Perform the handshake."""
@ -170,11 +177,15 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if preamble != 0x00:
if preamble == 0x01:
self._handle_error_and_close(
RequiresEncryptionAPIError("Connection requires encryption")
RequiresEncryptionAPIError(
f"{self._log_name}: Connection requires encryption"
)
)
return
self._handle_error_and_close(
ProtocolAPIError(f"Invalid preamble {preamble:02x}")
ProtocolAPIError(
f"{self._log_name}: Invalid preamble {preamble:02x}"
)
)
return
@ -285,9 +296,10 @@ class APINoiseFrameHelper(APIFrameHelper):
noise_psk: str,
expected_name: Optional[str],
client_info: str,
log_name: str,
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error, client_info)
super().__init__(on_pkt, on_error, client_info, log_name)
self._ready_future = asyncio.get_event_loop().create_future()
self._noise_psk = noise_psk
self._expected_name = expected_name
@ -311,7 +323,9 @@ class APINoiseFrameHelper(APIFrameHelper):
# Make sure we set the ready event if its not already set
# so that we don't block forever on the ready event if we
# are waiting for the handshake to complete.
self._set_ready_future_exception(APIConnectionError("Connection closed"))
self._set_ready_future_exception(
APIConnectionError(f"{self._log_name}: Connection closed")
)
self._set_state(NoiseConnectionState.CLOSED)
super().close()
@ -327,7 +341,7 @@ class APINoiseFrameHelper(APIFrameHelper):
):
original_exc = exc
exc = HandshakeAPIError(
"The connection dropped immediately after encrypted hello; "
f"{self._log_name}: The connection dropped immediately after encrypted hello; "
"Try enabling encryption on the device or turning off "
f"encryption on the client ({self._client_info})."
)
@ -341,7 +355,7 @@ class APINoiseFrameHelper(APIFrameHelper):
"""
assert self._transport is not None, "Transport is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug("Sending frame: [%s]", frame.hex())
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
frame_len = len(frame)
try:
@ -354,7 +368,9 @@ class APINoiseFrameHelper(APIFrameHelper):
)
self._transport.write(header + frame)
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketAPIError(f"Error while writing data: {err}") from err
raise SocketAPIError(
f"{self._log_name}: Error while writing data: {err}"
) from err
async def perform_handshake(self) -> None:
"""Perform the handshake with the server."""
@ -363,7 +379,9 @@ class APINoiseFrameHelper(APIFrameHelper):
async with async_timeout.timeout(60.0):
await self._ready_future
except asyncio.TimeoutError as err:
raise HandshakeAPIError("Timeout during handshake") from err
raise HandshakeAPIError(
f"{self._log_name}: Timeout during handshake"
) from err
def data_received(self, data: bytes) -> None:
self._buffer += data
@ -376,7 +394,9 @@ class APINoiseFrameHelper(APIFrameHelper):
preamble, msg_size_high, msg_size_low = header
if preamble != 0x01:
self._handle_error_and_close(
ProtocolAPIError(f"Marker byte invalid: {header[0]}")
ProtocolAPIError(
f"{self._log_name}: Marker byte invalid: {header[0]}"
)
)
return
frame = self._read_exactly((msg_size_high << 8) | msg_size_low)
@ -404,7 +424,9 @@ class APINoiseFrameHelper(APIFrameHelper):
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server."""
if not server_hello:
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
self._handle_error_and_close(
HandshakeAPIError(f"{self._log_name}: ServerHello is empty")
)
return
# First byte of server hello is the protocol the server chose
@ -413,7 +435,9 @@ class APINoiseFrameHelper(APIFrameHelper):
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
self._handle_error_and_close(
HandshakeAPIError(f"Unknown protocol selected by client {chosen_proto}")
HandshakeAPIError(
f"{self._log_name}: Unknown protocol selected by client {chosen_proto}"
)
)
return
@ -429,7 +453,8 @@ class APINoiseFrameHelper(APIFrameHelper):
if self._expected_name is not None and self._expected_name != server_name:
self._handle_error_and_close(
BadNameAPIError(
f"Server sent a different name '{server_name}'", server_name
f"{self._log_name}: Server sent a different name '{server_name}'",
server_name,
)
)
return
@ -458,19 +483,19 @@ class APINoiseFrameHelper(APIFrameHelper):
if explanation == "Handshake MAC failure":
self._handle_error_and_close(
InvalidEncryptionKeyAPIError(
"Invalid encryption key", self._server_name
f"{self._log_name}: Invalid encryption key", self._server_name
)
)
return
self._handle_error_and_close(
HandshakeAPIError(f"Handshake failure: {explanation}")
HandshakeAPIError(f"{self._log_name}: Handshake failure: {explanation}")
)
return
try:
self._proto.read_message(msg[1:])
except InvalidTag as invalid_tag_exc:
ex = InvalidEncryptionKeyAPIError(
"Invalid encryption key", self._server_name
f"{self._log_name}: Invalid encryption key", self._server_name
)
ex.__cause__ = invalid_tag_exc
self._handle_error_and_close(ex)
@ -491,7 +516,7 @@ class APINoiseFrameHelper(APIFrameHelper):
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
if self._state != NoiseConnectionState.READY:
raise HandshakeAPIError("Noise connection is not ready")
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
if TYPE_CHECKING:
assert self._encrypt is not None, "Handshake should be complete"
data_len = len(data)
@ -517,7 +542,7 @@ class APINoiseFrameHelper(APIFrameHelper):
msg = self._decrypt(bytes(frame))
except InvalidTag as ex:
self._handle_error_and_close(
ProtocolAPIError(f"Bad encryption frame: {ex!r}")
ProtocolAPIError(f"{self._log_name}: Bad encryption frame: {ex!r}")
)
return
# Message layout is
@ -530,7 +555,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self, frame: bytearray
) -> None:
"""Handle a closed frame."""
self._handle_error(ProtocolAPIError("Connection closed"))
self._handle_error(ProtocolAPIError(f"{self._log_name}: Connection closed"))
STATE_TO_CALLABLE = {
NoiseConnectionState.HELLO: _handle_hello,

View File

@ -316,6 +316,7 @@ class APIConnection:
on_pkt=process_packet,
on_error=self._report_fatal_error,
client_info=self._params.client_info,
log_name=self.log_name,
),
sock=self._socket,
)
@ -327,6 +328,7 @@ class APIConnection:
on_pkt=process_packet,
on_error=self._report_fatal_error,
client_info=self._params.client_info,
log_name=self.log_name,
),
sock=self._socket,
)
@ -391,30 +393,29 @@ class APIConnection:
if self._send_pending_ping:
self.send_message(PING_REQUEST_MESSAGE)
if self._pong_timer is None:
# Do not reset the timer if it's already set
# since the only thing we want to reset the timer
# is if we receive a pong.
self._pong_timer = self._loop.call_later(
self._keep_alive_timeout, self._async_pong_not_received
)
else:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
#
# We send another ping in case the device has
# rebooted and dropped the connection without telling
# us to force a TCP RST aka connection reset by peer.
#
_LOGGER.debug(
"%s: PingResponse (pong) was not received "
"since last keep alive after %s seconds; "
"rescheduling keep alive",
self.log_name,
self._keep_alive_interval,
)
if self._pong_timer is None:
# Do not reset the timer if it's already set
# since the only thing we want to reset the timer
# is if we receive a pong.
self._pong_timer = self._loop.call_later(
self._keep_alive_timeout, self._async_pong_not_received
)
else:
#
# We haven't reached the ping response (pong) timeout yet
# and we haven't seen a response to the last ping
#
# We send another ping in case the device has
# rebooted and dropped the connection without telling
# us to force a TCP RST aka connection reset by peer.
#
_LOGGER.debug(
"%s: PingResponse (pong) was not received "
"since last keep alive after %s seconds; "
"rescheduling keep alive",
self.log_name,
self._keep_alive_interval,
)
self._async_schedule_keep_alive()
@ -540,15 +541,14 @@ class APIConnection:
frame_helper = self._frame_helper
assert frame_helper is not None
message_type = PROTO_TO_MESSAGE_TYPE.get(type(msg))
_msg_type = type(msg)
message_type = PROTO_TO_MESSAGE_TYPE.get(_msg_type)
if not message_type:
raise ValueError(f"Message type id not found for type {type(msg)}")
raise ValueError(f"Message type id not found for type {_msg_type}")
encoded = msg.SerializeToString()
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"%s: Sending %s: %s", self._params.address, type(msg), str(msg)
)
_LOGGER.debug("%s: Sending %s: %s", self.log_name, _msg_type.__name__, msg)
try:
frame_helper.write_packet(message_type, encoded)
@ -730,7 +730,9 @@ class APIConnection:
class_ = message_type_to_proto[msg_type_proto]
except KeyError:
_LOGGER.debug(
"%s: Skipping message type %s", self.log_name, msg_type_proto
"%s: Skipping message type %s",
self.log_name,
msg_type_proto,
)
return
@ -760,7 +762,10 @@ class APIConnection:
if is_enabled_for(logging_debug):
_LOGGER.debug(
"%s: Got message of type %s: %s", self.log_name, msg_type, msg
"%s: Got message of type %s: %s",
self.log_name,
msg_type.__name__,
msg,
)
if self._pong_timer:

View File

@ -69,7 +69,7 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
raise exc
helper = APIPlaintextFrameHelper(
on_pkt=_packet, on_error=_on_error, client_info="my client"
on_pkt=_packet, on_error=_on_error, client_info="my client", log_name="test"
)
helper.data_received(in_bytes)
@ -116,6 +116,7 @@ async def test_noise_frame_helper_incorrect_key():
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()
@ -155,6 +156,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()
@ -196,6 +198,7 @@ async def test_noise_incorrect_name():
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
log_name="test",
)
helper._transport = MagicMock()

View File

@ -57,6 +57,7 @@ def _get_mock_protocol(conn: APIConnection):
on_pkt=conn._process_packet_factory(),
on_error=conn._report_fatal_error,
client_info="mock",
log_name="mock_device",
)
protocol._connected_event.set()
protocol._transport = MagicMock()