diff --git a/aioesphomeapi/_frame_helper.py b/aioesphomeapi/_frame_helper.py index 4eb4e88..223419a 100644 --- a/aioesphomeapi/_frame_helper.py +++ b/aioesphomeapi/_frame_helper.py @@ -5,9 +5,8 @@ from abc import abstractmethod from enum import Enum from functools import partial from struct import Struct -from typing import TYPE_CHECKING, Any, Callable, Optional, cast +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast -import async_timeout from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable from cryptography.exceptions import InvalidTag from noise.backends.default import DefaultNoiseBackend # type: ignore[import] @@ -65,6 +64,7 @@ class APIFrameHelper(asyncio.Protocol): "_on_pkt", "_on_error", "_transport", + "_writer", "_connected_event", "_buffer", "_buffer_len", @@ -84,6 +84,9 @@ class APIFrameHelper(asyncio.Protocol): self._on_pkt = on_pkt self._on_error = on_error self._transport: Optional[asyncio.Transport] = None + self._writer: Optional[ + Callable[[Union[bytes, bytearray, memoryview]], None] + ] = None self._connected_event = asyncio.Event() self._buffer = bytearray() self._buffer_len = 0 @@ -111,6 +114,7 @@ class APIFrameHelper(asyncio.Protocol): def connection_made(self, transport: asyncio.BaseTransport) -> None: """Handle a new connection.""" self._transport = cast(asyncio.Transport, transport) + self._writer = self._transport.write self._connected_event.set() def _handle_error_and_close(self, exc: Exception) -> None: @@ -134,6 +138,8 @@ class APIFrameHelper(asyncio.Protocol): """Close the connection.""" if self._transport: self._transport.close() + self._transport = None + self._writer = None class APIPlaintextFrameHelper(APIFrameHelper): @@ -144,13 +150,15 @@ class APIPlaintextFrameHelper(APIFrameHelper): The entire packet must be written in a single call. """ - assert self._transport is not None, "Transport should be set" + if TYPE_CHECKING: + assert self._writer is not None, "Writer should be set" + data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data if _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex()) try: - self._transport.write(data) + self._writer(data) except (RuntimeError, ConnectionResetError, OSError) as err: raise SocketAPIError( f"{self._log_name}: Error while writing data: {err}" @@ -250,21 +258,6 @@ class APIPlaintextFrameHelper(APIFrameHelper): # If we have more data, continue processing -def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes: - """Decode the given noise psk from base64 format to raw bytes.""" - try: - psk_bytes = base64.b64decode(psk) - except ValueError: - raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {psk}, expected base64-encoded value", server_name - ) - if len(psk_bytes) != 32: - raise InvalidEncryptionKeyAPIError( - f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name - ) - return psk_bytes - - class NoiseConnectionState(Enum): """Noise connection state.""" @@ -287,6 +280,8 @@ class APINoiseFrameHelper(APIFrameHelper): "_proto", "_decrypt", "_encrypt", + "_is_ready", + "_loop", ) def __init__( @@ -300,7 +295,9 @@ class APINoiseFrameHelper(APIFrameHelper): ) -> None: """Initialize the API frame helper.""" super().__init__(on_pkt, on_error, client_info, log_name) - self._ready_future = asyncio.get_event_loop().create_future() + loop = asyncio.get_event_loop() + self._loop = loop + self._ready_future = loop.create_future() self._noise_psk = noise_psk self._expected_name = expected_name self._set_state(NoiseConnectionState.HELLO) @@ -308,6 +305,7 @@ class APINoiseFrameHelper(APIFrameHelper): self._decrypt: Optional[Callable[[bytes], bytes]] = None self._encrypt: Optional[Callable[[bytes], bytes]] = None self._setup_proto() + self._is_ready = False def _set_ready_future_exception(self, exc: Exception) -> None: if not self._ready_future.done(): @@ -316,6 +314,7 @@ class APINoiseFrameHelper(APIFrameHelper): def _set_state(self, state: NoiseConnectionState) -> None: """Set the current state.""" self._state = state + self._is_ready = state == NoiseConnectionState.READY self._dispatch = self.STATE_TO_CALLABLE[state] def close(self) -> None: @@ -353,20 +352,24 @@ class APINoiseFrameHelper(APIFrameHelper): The entire packet must be written in a single call to write. """ - assert self._transport is not None, "Transport is not set" + if TYPE_CHECKING: + assert self._writer is not None, "Writer is not set" + if _LOGGER.isEnabledFor(logging.DEBUG): _LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex()) frame_len = len(frame) try: - header = bytes( - [ - 0x01, - (frame_len >> 8) & 0xFF, - frame_len & 0xFF, - ] + self._writer( + bytes( + ( + 0x01, + (frame_len >> 8) & 0xFF, + frame_len & 0xFF, + ) + ) + + frame ) - self._transport.write(header + frame) except (RuntimeError, ConnectionResetError, OSError) as err: raise SocketAPIError( f"{self._log_name}: Error while writing data: {err}" @@ -375,13 +378,17 @@ class APINoiseFrameHelper(APIFrameHelper): async def perform_handshake(self) -> None: """Perform the handshake with the server.""" self._send_hello() + handshake_handle = self._loop.call_later( + 60, self._set_ready_future_exception, asyncio.TimeoutError() + ) try: - async with async_timeout.timeout(60.0): - await self._ready_future + await self._ready_future except asyncio.TimeoutError as err: raise HandshakeAPIError( f"{self._log_name}: Timeout during handshake" ) from err + finally: + handshake_handle.cancel() def data_received(self, data: bytes) -> None: self._buffer += data @@ -462,15 +469,36 @@ class APINoiseFrameHelper(APIFrameHelper): self._set_state(NoiseConnectionState.HANDSHAKE) self._send_handshake() + def _decode_noise_psk(self) -> bytes: + """Decode the given noise psk from base64 format to raw bytes.""" + psk = self._noise_psk + server_name = self._server_name + try: + psk_bytes = base64.b64decode(psk) + except ValueError: + raise InvalidEncryptionKeyAPIError( + f"{self._log_name}: Malformed PSK {psk}, expected " + "base64-encoded value", + server_name, + ) + if len(psk_bytes) != 32: + raise InvalidEncryptionKeyAPIError( + f"{self._log_name}:Malformed PSK {psk}, expected" + f" 32-bytes of base64 data", + server_name, + ) + return psk_bytes + def _setup_proto(self) -> None: """Set up the noise protocol.""" - self._proto = NoiseConnection.from_name( + proto = NoiseConnection.from_name( b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND ) - self._proto.set_as_initiator() - self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name)) - self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00") - self._proto.start_handshake() + proto.set_as_initiator() + proto.set_psks(self._decode_noise_psk()) + proto.set_prologue(b"NoiseAPIInit\x00\x00") + proto.start_handshake() + self._proto = proto def _send_handshake(self) -> None: """Send the handshake message.""" @@ -515,10 +543,12 @@ class APINoiseFrameHelper(APIFrameHelper): def write_packet(self, type_: int, data: bytes) -> None: """Write a packet to the socket.""" - if self._state != NoiseConnectionState.READY: + if not self._is_ready: raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready") + if TYPE_CHECKING: assert self._encrypt is not None, "Handshake should be complete" + data_len = len(data) self._write_frame( self._encrypt( diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 2069601..05683a7 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -715,8 +715,7 @@ class APIConnection: def _process_packet_factory(self) -> Callable[[int, bytes], None]: """Factory to make a packet processor.""" message_type_to_proto = MESSAGE_TYPE_TO_PROTO - is_enabled_for = _LOGGER.isEnabledFor - logging_debug = logging.DEBUG + debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG) message_handlers = self._message_handlers internal_message_types = INTERNAL_MESSAGE_TYPES @@ -759,7 +758,7 @@ class APIConnection: msg_type = type(msg) - if is_enabled_for(logging_debug): + if debug_enabled(): _LOGGER.debug( "%s: Got message of type %s: %s", self.log_name, diff --git a/tests/test__frame_helper.py b/tests/test__frame_helper.py index 0efaaac..aa40d66 100644 --- a/tests/test__frame_helper.py +++ b/tests/test__frame_helper.py @@ -119,6 +119,7 @@ async def test_noise_frame_helper_incorrect_key(): log_name="test", ) helper._transport = MagicMock() + helper._writer = MagicMock() for pkt in outgoing_packets: helper._write_frame(bytes.fromhex(pkt)) @@ -159,6 +160,7 @@ async def test_noise_frame_helper_incorrect_key_fragments(): log_name="test", ) helper._transport = MagicMock() + helper._writer = MagicMock() for pkt in outgoing_packets: helper._write_frame(bytes.fromhex(pkt)) @@ -201,6 +203,7 @@ async def test_noise_incorrect_name(): log_name="test", ) helper._transport = MagicMock() + helper._writer = MagicMock() for pkt in outgoing_packets: helper._write_frame(bytes.fromhex(pkt)) diff --git a/tests/test_connection.py b/tests/test_connection.py index af4d248..e90efa4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -61,6 +61,7 @@ def _get_mock_protocol(conn: APIConnection): ) protocol._connected_event.set() protocol._transport = MagicMock() + protocol._writer = MagicMock() return protocol