mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-12 10:33:57 +01:00
Improve error reporting when encryption is disabled on device but client requests it (#464)
This commit is contained in:
parent
24be8b0666
commit
7a80e3529b
@ -61,12 +61,14 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
"_client_info",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
client_info: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
self._on_pkt = on_pkt
|
||||
@ -76,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
@ -124,10 +127,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
assert self._transport is not None, "Transport should be set"
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
@ -273,9 +275,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
on_error: Callable[[Exception], None],
|
||||
noise_psk: str,
|
||||
expected_name: Optional[str],
|
||||
client_info: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error)
|
||||
super().__init__(on_pkt, on_error, client_info)
|
||||
self._ready_future = asyncio.get_event_loop().create_future()
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
@ -302,11 +305,25 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._set_ready_future_exception(exc)
|
||||
super()._handle_error_and_close(exc)
|
||||
|
||||
def _write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket, the caller should not have the lock.
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
"""Handle an error, and provide a good message when during hello."""
|
||||
if (
|
||||
isinstance(exc, ConnectionResetError)
|
||||
and self._state == NoiseConnectionState.HELLO
|
||||
):
|
||||
original_exc = exc
|
||||
exc = HandshakeAPIError(
|
||||
"The connection dropped immediately after encrypted hello; "
|
||||
"Try enabling encryption on the device or turning off "
|
||||
f"encryption on the client ({self._client_info})."
|
||||
)
|
||||
exc.__cause__ = original_exc
|
||||
super()._handle_error(exc)
|
||||
|
||||
The entire packet must be written in a single call to write
|
||||
to avoid locking.
|
||||
def _write_frame(self, frame: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call to write.
|
||||
"""
|
||||
assert self._transport is not None, "Transport is not set"
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
@ -371,7 +388,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._write_frame(b"") # ClientHello
|
||||
|
||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
||||
"""Perform the handshake with the server."""
|
||||
if not server_hello:
|
||||
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
||||
return
|
||||
|
@ -301,6 +301,7 @@ class APIConnection:
|
||||
lambda: APIPlaintextFrameHelper(
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
client_info=self._params.client_info,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
@ -311,6 +312,7 @@ class APIConnection:
|
||||
expected_name=self._params.expected_name,
|
||||
on_pkt=self._process_packet,
|
||||
on_error=self._report_fatal_error,
|
||||
client_info=self._params.client_info,
|
||||
),
|
||||
sock=self._socket,
|
||||
)
|
||||
|
@ -68,7 +68,9 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
||||
def _on_error(exc: Exception):
|
||||
raise exc
|
||||
|
||||
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
|
||||
helper = APIPlaintextFrameHelper(
|
||||
on_pkt=_packet, on_error=_on_error, client_info="my client"
|
||||
)
|
||||
|
||||
helper.data_received(in_bytes)
|
||||
|
||||
@ -113,6 +115,7 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
@ -151,6 +154,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
@ -191,6 +195,7 @@ async def test_noise_incorrect_name():
|
||||
on_error=_on_error,
|
||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||
expected_name="wrongname",
|
||||
client_info="my client",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
|
||||
|
@ -54,7 +54,9 @@ def socket_socket():
|
||||
|
||||
def _get_mock_protocol(conn: APIConnection):
|
||||
protocol = APIPlaintextFrameHelper(
|
||||
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
|
||||
on_pkt=conn._process_packet,
|
||||
on_error=conn._report_fatal_error,
|
||||
client_info="mock",
|
||||
)
|
||||
protocol._connected_event.set()
|
||||
protocol._transport = MagicMock()
|
||||
|
Loading…
Reference in New Issue
Block a user