Improve error reporting when encryption is disabled on device but client requests it (#464)

This commit is contained in:
J. Nick Koston 2023-07-10 21:15:14 -10:00 committed by GitHub
parent 24be8b0666
commit 7a80e3529b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 11 deletions

View File

@ -61,12 +61,14 @@ class APIFrameHelper(asyncio.Protocol):
"_buffer",
"_buffer_len",
"_pos",
"_client_info",
)
def __init__(
self,
on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None],
client_info: str,
) -> None:
"""Initialize the API frame helper."""
self._on_pkt = on_pkt
@ -76,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
self._buffer = bytearray()
self._buffer_len = 0
self._pos = 0
self._client_info = client_info
def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
@ -124,10 +127,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections."""
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
"""Write a packet to the socket.
The entire packet must be written in a single call to write
to avoid locking.
The entire packet must be written in a single call.
"""
assert self._transport is not None, "Transport should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
@ -273,9 +275,10 @@ class APINoiseFrameHelper(APIFrameHelper):
on_error: Callable[[Exception], None],
noise_psk: str,
expected_name: Optional[str],
client_info: str,
) -> None:
"""Initialize the API frame helper."""
super().__init__(on_pkt, on_error)
super().__init__(on_pkt, on_error, client_info)
self._ready_future = asyncio.get_event_loop().create_future()
self._noise_psk = noise_psk
self._expected_name = expected_name
@ -302,11 +305,25 @@ class APINoiseFrameHelper(APIFrameHelper):
self._set_ready_future_exception(exc)
super()._handle_error_and_close(exc)
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock.
def _handle_error(self, exc: Exception) -> None:
"""Handle an error, and provide a good message when during hello."""
if (
isinstance(exc, ConnectionResetError)
and self._state == NoiseConnectionState.HELLO
):
original_exc = exc
exc = HandshakeAPIError(
"The connection dropped immediately after encrypted hello; "
"Try enabling encryption on the device or turning off "
f"encryption on the client ({self._client_info})."
)
exc.__cause__ = original_exc
super()._handle_error(exc)
The entire packet must be written in a single call to write
to avoid locking.
def _write_frame(self, frame: bytes) -> None:
"""Write a packet to the socket.
The entire packet must be written in a single call to write.
"""
assert self._transport is not None, "Transport is not set"
if _LOGGER.isEnabledFor(logging.DEBUG):
@ -371,7 +388,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._write_frame(b"") # ClientHello
def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock."""
"""Perform the handshake with the server."""
if not server_hello:
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
return

View File

@ -301,6 +301,7 @@ class APIConnection:
lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
client_info=self._params.client_info,
),
sock=self._socket,
)
@ -311,6 +312,7 @@ class APIConnection:
expected_name=self._params.expected_name,
on_pkt=self._process_packet,
on_error=self._report_fatal_error,
client_info=self._params.client_info,
),
sock=self._socket,
)

View File

@ -68,7 +68,9 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
def _on_error(exc: Exception):
raise exc
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
helper = APIPlaintextFrameHelper(
on_pkt=_packet, on_error=_on_error, client_info="my client"
)
helper.data_received(in_bytes)
@ -113,6 +115,7 @@ async def test_noise_frame_helper_incorrect_key():
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
)
helper._transport = MagicMock()
@ -151,6 +154,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest",
client_info="my client",
)
helper._transport = MagicMock()
@ -191,6 +195,7 @@ async def test_noise_incorrect_name():
on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname",
client_info="my client",
)
helper._transport = MagicMock()

View File

@ -54,7 +54,9 @@ def socket_socket():
def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
on_pkt=conn._process_packet,
on_error=conn._report_fatal_error,
client_info="mock",
)
protocol._connected_event.set()
protocol._transport = MagicMock()