mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
Improve error reporting when encryption is disabled on device but client requests it (#464)
This commit is contained in:
parent
24be8b0666
commit
7a80e3529b
@ -61,12 +61,14 @@ class APIFrameHelper(asyncio.Protocol):
|
|||||||
"_buffer",
|
"_buffer",
|
||||||
"_buffer_len",
|
"_buffer_len",
|
||||||
"_pos",
|
"_pos",
|
||||||
|
"_client_info",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
on_pkt: Callable[[int, bytes], None],
|
on_pkt: Callable[[int, bytes], None],
|
||||||
on_error: Callable[[Exception], None],
|
on_error: Callable[[Exception], None],
|
||||||
|
client_info: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
self._on_pkt = on_pkt
|
self._on_pkt = on_pkt
|
||||||
@ -76,6 +78,7 @@ class APIFrameHelper(asyncio.Protocol):
|
|||||||
self._buffer = bytearray()
|
self._buffer = bytearray()
|
||||||
self._buffer_len = 0
|
self._buffer_len = 0
|
||||||
self._pos = 0
|
self._pos = 0
|
||||||
|
self._client_info = client_info
|
||||||
|
|
||||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||||
@ -124,10 +127,9 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
"""Frame helper for plaintext API connections."""
|
"""Frame helper for plaintext API connections."""
|
||||||
|
|
||||||
def write_packet(self, type_: int, data: bytes) -> None:
|
def write_packet(self, type_: int, data: bytes) -> None:
|
||||||
"""Write a packet to the socket, the caller should not have the lock.
|
"""Write a packet to the socket.
|
||||||
|
|
||||||
The entire packet must be written in a single call to write
|
The entire packet must be written in a single call.
|
||||||
to avoid locking.
|
|
||||||
"""
|
"""
|
||||||
assert self._transport is not None, "Transport should be set"
|
assert self._transport is not None, "Transport should be set"
|
||||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||||
@ -273,9 +275,10 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
on_error: Callable[[Exception], None],
|
on_error: Callable[[Exception], None],
|
||||||
noise_psk: str,
|
noise_psk: str,
|
||||||
expected_name: Optional[str],
|
expected_name: Optional[str],
|
||||||
|
client_info: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the API frame helper."""
|
"""Initialize the API frame helper."""
|
||||||
super().__init__(on_pkt, on_error)
|
super().__init__(on_pkt, on_error, client_info)
|
||||||
self._ready_future = asyncio.get_event_loop().create_future()
|
self._ready_future = asyncio.get_event_loop().create_future()
|
||||||
self._noise_psk = noise_psk
|
self._noise_psk = noise_psk
|
||||||
self._expected_name = expected_name
|
self._expected_name = expected_name
|
||||||
@ -302,11 +305,25 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._set_ready_future_exception(exc)
|
self._set_ready_future_exception(exc)
|
||||||
super()._handle_error_and_close(exc)
|
super()._handle_error_and_close(exc)
|
||||||
|
|
||||||
def _write_frame(self, frame: bytes) -> None:
|
def _handle_error(self, exc: Exception) -> None:
|
||||||
"""Write a packet to the socket, the caller should not have the lock.
|
"""Handle an error, and provide a good message when during hello."""
|
||||||
|
if (
|
||||||
|
isinstance(exc, ConnectionResetError)
|
||||||
|
and self._state == NoiseConnectionState.HELLO
|
||||||
|
):
|
||||||
|
original_exc = exc
|
||||||
|
exc = HandshakeAPIError(
|
||||||
|
"The connection dropped immediately after encrypted hello; "
|
||||||
|
"Try enabling encryption on the device or turning off "
|
||||||
|
f"encryption on the client ({self._client_info})."
|
||||||
|
)
|
||||||
|
exc.__cause__ = original_exc
|
||||||
|
super()._handle_error(exc)
|
||||||
|
|
||||||
The entire packet must be written in a single call to write
|
def _write_frame(self, frame: bytes) -> None:
|
||||||
to avoid locking.
|
"""Write a packet to the socket.
|
||||||
|
|
||||||
|
The entire packet must be written in a single call to write.
|
||||||
"""
|
"""
|
||||||
assert self._transport is not None, "Transport is not set"
|
assert self._transport is not None, "Transport is not set"
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
@ -371,7 +388,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
self._write_frame(b"") # ClientHello
|
self._write_frame(b"") # ClientHello
|
||||||
|
|
||||||
def _handle_hello(self, server_hello: bytearray) -> None:
|
def _handle_hello(self, server_hello: bytearray) -> None:
|
||||||
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
"""Perform the handshake with the server."""
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
self._handle_error_and_close(HandshakeAPIError("ServerHello is empty"))
|
||||||
return
|
return
|
||||||
|
@ -301,6 +301,7 @@ class APIConnection:
|
|||||||
lambda: APIPlaintextFrameHelper(
|
lambda: APIPlaintextFrameHelper(
|
||||||
on_pkt=self._process_packet,
|
on_pkt=self._process_packet,
|
||||||
on_error=self._report_fatal_error,
|
on_error=self._report_fatal_error,
|
||||||
|
client_info=self._params.client_info,
|
||||||
),
|
),
|
||||||
sock=self._socket,
|
sock=self._socket,
|
||||||
)
|
)
|
||||||
@ -311,6 +312,7 @@ class APIConnection:
|
|||||||
expected_name=self._params.expected_name,
|
expected_name=self._params.expected_name,
|
||||||
on_pkt=self._process_packet,
|
on_pkt=self._process_packet,
|
||||||
on_error=self._report_fatal_error,
|
on_error=self._report_fatal_error,
|
||||||
|
client_info=self._params.client_info,
|
||||||
),
|
),
|
||||||
sock=self._socket,
|
sock=self._socket,
|
||||||
)
|
)
|
||||||
|
@ -68,7 +68,9 @@ async def test_plaintext_frame_helper(in_bytes, pkt_data, pkt_type):
|
|||||||
def _on_error(exc: Exception):
|
def _on_error(exc: Exception):
|
||||||
raise exc
|
raise exc
|
||||||
|
|
||||||
helper = APIPlaintextFrameHelper(on_pkt=_packet, on_error=_on_error)
|
helper = APIPlaintextFrameHelper(
|
||||||
|
on_pkt=_packet, on_error=_on_error, client_info="my client"
|
||||||
|
)
|
||||||
|
|
||||||
helper.data_received(in_bytes)
|
helper.data_received(in_bytes)
|
||||||
|
|
||||||
@ -113,6 +115,7 @@ async def test_noise_frame_helper_incorrect_key():
|
|||||||
on_error=_on_error,
|
on_error=_on_error,
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
|
client_info="my client",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
helper._transport = MagicMock()
|
||||||
|
|
||||||
@ -151,6 +154,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
|||||||
on_error=_on_error,
|
on_error=_on_error,
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
|
client_info="my client",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
helper._transport = MagicMock()
|
||||||
|
|
||||||
@ -191,6 +195,7 @@ async def test_noise_incorrect_name():
|
|||||||
on_error=_on_error,
|
on_error=_on_error,
|
||||||
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
|
||||||
expected_name="wrongname",
|
expected_name="wrongname",
|
||||||
|
client_info="my client",
|
||||||
)
|
)
|
||||||
helper._transport = MagicMock()
|
helper._transport = MagicMock()
|
||||||
|
|
||||||
|
@ -54,7 +54,9 @@ def socket_socket():
|
|||||||
|
|
||||||
def _get_mock_protocol(conn: APIConnection):
|
def _get_mock_protocol(conn: APIConnection):
|
||||||
protocol = APIPlaintextFrameHelper(
|
protocol = APIPlaintextFrameHelper(
|
||||||
on_pkt=conn._process_packet, on_error=conn._report_fatal_error
|
on_pkt=conn._process_packet,
|
||||||
|
on_error=conn._report_fatal_error,
|
||||||
|
client_info="mock",
|
||||||
)
|
)
|
||||||
protocol._connected_event.set()
|
protocol._connected_event.set()
|
||||||
protocol._transport = MagicMock()
|
protocol._transport = MagicMock()
|
||||||
|
Loading…
Reference in New Issue
Block a user