Improve error reporting when encryption is disabled on device but client requests it (#464)

This commit is contained in:
J. Nick Koston 2023-07-10 21:15:14 -10:00 committed by GitHub
parent 24be8b0666
commit 7a80e3529b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 37 additions and 11 deletions

View File

@ -61,12 +61,14 @@ class APIFrameHelper(asyncio.Protocol):
"_buffer", "_buffer",
"_buffer_len", "_buffer_len",
"_pos", "_pos",
"_client_info",
) )
def __init__( def __init__(
self, self,
on_pkt: Callable[[int, bytes], None], on_pkt: Callable[[int, bytes], None],
on_error: Callable[[Exception], None], on_error: Callable[[Exception], None],
client_info: str,
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
self._on_pkt = on_pkt self._on_pkt = on_pkt
@ -76,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
self._buffer = bytearray() self._buffer = bytearray()
self._buffer_len = 0 self._buffer_len = 0
self._pos = 0 self._pos = 0
self._client_info = client_info
def _read_exactly(self, length: int) -> Optional[bytearray]: def _read_exactly(self, length: int) -> Optional[bytearray]:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
@ -124,10 +127,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
"""Frame helper for plaintext API connections.""" """Frame helper for plaintext API connections."""
def write_packet(self, type_: int, data: bytes) -> None: def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket, the caller should not have the lock. """Write a packet to the socket.
The entire packet must be written in a single call to write The entire packet must be written in a single call.
to avoid locking.
""" """
assert self._transport is not None, "Transport should be set" assert self._transport is not None, "Transport should be set"
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
@ -273,9 +275,10 @@ class APINoiseFrameHelper(APIFrameHelper):
on_error: Callable[[Exception], None], on_error: Callable[[Exception], None],
noise_psk: str, noise_psk: str,
expected_name: Optional[str], expected_name: Optional[str],
client_info: str,
) -> None: ) -> None:
"""Initialize the API frame helper.""" """Initialize the API frame helper."""
super().__init__(on_pkt, on_error) super().__init__(on_pkt, on_error, client_info)
self._ready_future = asyncio.get_event_loop().create_future() self._ready_future = asyncio.get_event_loop().create_future()
self._noise_psk = noise_psk self._noise_psk = noise_psk
self._expected_name = expected_name self._expected_name = expected_name
@ -302,11 +305,25 @@ class APINoiseFrameHelper(APIFrameHelper):
self._set_ready_future_exception(exc) self._set_ready_future_exception(exc)
super()._handle_error_and_close(exc) super()._handle_error_and_close(exc)
def _write_frame(self, frame: bytes) -> None: def _handle_error(self, exc: Exception) -> None:
"""Write a packet to the socket, the caller should not have the lock. """Handle an error, and provide a good message when during hello."""
if (
isinstance(exc, ConnectionResetError)
and self._state == NoiseConnectionState.HELLO
):
original_exc = exc
exc = HandshakeAPIError(
"The connection dropped immediately after encrypted hello; "
"Try enabling encryption on the device or turning off "
f"encryption on the client ({self._client_info})."
)
exc.__cause__ = original_exc
super()._handle_error(exc)
The entire packet must be written in a single call to write def _write_frame(self, frame: bytes) -> None:
to avoid locking. """Write a packet to the socket.
The entire packet must be written in a single call to write.
""" """
assert self._transport is not None, "Transport is not set" assert self._transport is not None, "Transport is not set"
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
@ -371,7 +388,7 @@ class APINoiseFrameHelper(APIFrameHelper):
self._write_frame(b"") # ClientHello self._write_frame(b"") # ClientHello
def _handle_hello(self, server_hello: bytearray) -> None: def _handle_hello(self, server_hello: bytearray) -> None:
"""Perform the handshake with the server, the caller is responsible for having the lock.""" """Perform the handshake with the server."""
if not server_hello: if not server_hello:
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty")) self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
return return

View File

@ -301,6 +301,7 @@ class APIConnection:
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
on_pkt=self._process_packet, on_pkt=self._process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info,
), ),
sock=self._socket, sock=self._socket,
) )
@ -311,6 +312,7 @@ class APIConnection:
expected_name=self._params.expected_name, expected_name=self._params.expected_name,
on_pkt=self._process_packet, on_pkt=self._process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
client_info=self._params.client_info,
), ),
sock=self._socket, sock=self._socket,
) )

View File

@ -68,7 +68,9 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
def _on_error(exc: Exception): def _on_error(exc: Exception):
raise exc raise exc
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error) helper = APIPlaintextFrameHelper(
on_pkt=_packet, on_error=_on_error, client_info="my client"
)
helper.data_received(in_bytes) helper.data_received(in_bytes)
@ -113,6 +115,7 @@ async def test_noise_frame_helper_incorrect_key():
on_error=_on_error, on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest", expected_name="servicetest",
client_info="my client",
) )
helper._transport = MagicMock() helper._transport = MagicMock()
@ -151,6 +154,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
on_error=_on_error, on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="servicetest", expected_name="servicetest",
client_info="my client",
) )
helper._transport = MagicMock() helper._transport = MagicMock()
@ -191,6 +195,7 @@ async def test_noise_incorrect_name():
on_error=_on_error, on_error=_on_error,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
expected_name="wrongname", expected_name="wrongname",
client_info="my client",
) )
helper._transport = MagicMock() helper._transport = MagicMock()

View File

@ -54,7 +54,9 @@ def socket_socket():
def _get_mock_protocol(conn: APIConnection): def _get_mock_protocol(conn: APIConnection):
protocol = APIPlaintextFrameHelper( protocol = APIPlaintextFrameHelper(
on_pkt=conn._process_packet, on_error=conn._report_fatal_error on_pkt=conn._process_packet,
on_error=conn._report_fatal_error,
client_info="mock",
) )
protocol._connected_event.set() protocol._connected_event.set()
protocol._transport = MagicMock() protocol._transport = MagicMock()