mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-22 16:48:04 +01:00
Speed up noise handshake (#486)
This commit is contained in:
parent
0c1f710869
commit
ab3c096c9b
@ -5,9 +5,8 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
|
||||
|
||||
import async_timeout
|
||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||
from cryptography.exceptions import InvalidTag
|
||||
from noise.backends.default import DefaultNoiseBackend # type: ignore[import]
|
||||
@ -65,6 +64,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_connected_event",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
@ -84,6 +84,9 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
self._on_pkt = on_pkt
|
||||
self._on_error = on_error
|
||||
self._transport: Optional[asyncio.Transport] = None
|
||||
self._writer: Optional[
|
||||
Callable[[Union[bytes, bytearray, memoryview]], None]
|
||||
] = None
|
||||
self._connected_event = asyncio.Event()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
@ -111,6 +114,7 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._writer = self._transport.write
|
||||
self._connected_event.set()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
@ -134,6 +138,8 @@ class APIFrameHelper(asyncio.Protocol):
|
||||
"""Close the connection."""
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._writer = None
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
@ -144,13 +150,15 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
assert self._transport is not None, "Transport should be set"
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
|
||||
try:
|
||||
self._transport.write(data)
|
||||
self._writer(data)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
@ -250,21 +258,6 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
# If we have more data, continue processing
|
||||
|
||||
|
||||
def _decode_noise_psk(psk: str, server_name: Optional[str]) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
try:
|
||||
psk_bytes = base64.b64decode(psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected base64-encoded value", server_name
|
||||
)
|
||||
if len(psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {psk}, expected 32-bytes of base64 data", server_name
|
||||
)
|
||||
return psk_bytes
|
||||
|
||||
|
||||
class NoiseConnectionState(Enum):
|
||||
"""Noise connection state."""
|
||||
|
||||
@ -287,6 +280,8 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
"_proto",
|
||||
"_decrypt",
|
||||
"_encrypt",
|
||||
"_is_ready",
|
||||
"_loop",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -300,7 +295,9 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
super().__init__(on_pkt, on_error, client_info, log_name)
|
||||
self._ready_future = asyncio.get_event_loop().create_future()
|
||||
loop = asyncio.get_event_loop()
|
||||
self._loop = loop
|
||||
self._ready_future = loop.create_future()
|
||||
self._noise_psk = noise_psk
|
||||
self._expected_name = expected_name
|
||||
self._set_state(NoiseConnectionState.HELLO)
|
||||
@ -308,6 +305,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._decrypt: Optional[Callable[[bytes], bytes]] = None
|
||||
self._encrypt: Optional[Callable[[bytes], bytes]] = None
|
||||
self._setup_proto()
|
||||
self._is_ready = False
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception) -> None:
|
||||
if not self._ready_future.done():
|
||||
@ -316,6 +314,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
def _set_state(self, state: NoiseConnectionState) -> None:
|
||||
"""Set the current state."""
|
||||
self._state = state
|
||||
self._is_ready = state == NoiseConnectionState.READY
|
||||
self._dispatch = self.STATE_TO_CALLABLE[state]
|
||||
|
||||
def close(self) -> None:
|
||||
@ -353,20 +352,24 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
The entire packet must be written in a single call to write.
|
||||
"""
|
||||
assert self._transport is not None, "Transport is not set"
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, frame.hex())
|
||||
|
||||
frame_len = len(frame)
|
||||
try:
|
||||
header = bytes(
|
||||
[
|
||||
0x01,
|
||||
(frame_len >> 8) & 0xFF,
|
||||
frame_len & 0xFF,
|
||||
]
|
||||
self._writer(
|
||||
bytes(
|
||||
(
|
||||
0x01,
|
||||
(frame_len >> 8) & 0xFF,
|
||||
frame_len & 0xFF,
|
||||
)
|
||||
)
|
||||
+ frame
|
||||
)
|
||||
self._transport.write(header + frame)
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
@ -375,13 +378,17 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake with the server."""
|
||||
self._send_hello()
|
||||
handshake_handle = self._loop.call_later(
|
||||
60, self._set_ready_future_exception, asyncio.TimeoutError()
|
||||
)
|
||||
try:
|
||||
async with async_timeout.timeout(60.0):
|
||||
await self._ready_future
|
||||
await self._ready_future
|
||||
except asyncio.TimeoutError as err:
|
||||
raise HandshakeAPIError(
|
||||
f"{self._log_name}: Timeout during handshake"
|
||||
) from err
|
||||
finally:
|
||||
handshake_handle.cancel()
|
||||
|
||||
def data_received(self, data: bytes) -> None:
|
||||
self._buffer += data
|
||||
@ -462,15 +469,36 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
self._set_state(NoiseConnectionState.HANDSHAKE)
|
||||
self._send_handshake()
|
||||
|
||||
def _decode_noise_psk(self) -> bytes:
|
||||
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||
psk = self._noise_psk
|
||||
server_name = self._server_name
|
||||
try:
|
||||
psk_bytes = base64.b64decode(psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}: Malformed PSK {psk}, expected "
|
||||
"base64-encoded value",
|
||||
server_name,
|
||||
)
|
||||
if len(psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"{self._log_name}:Malformed PSK {psk}, expected"
|
||||
f" 32-bytes of base64 data",
|
||||
server_name,
|
||||
)
|
||||
return psk_bytes
|
||||
|
||||
def _setup_proto(self) -> None:
|
||||
"""Set up the noise protocol."""
|
||||
self._proto = NoiseConnection.from_name(
|
||||
proto = NoiseConnection.from_name(
|
||||
b"Noise_NNpsk0_25519_ChaChaPoly_SHA256", backend=ESPHOME_NOISE_BACKEND
|
||||
)
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk, self._server_name))
|
||||
self._proto.set_prologue(b"NoiseAPIInit" + b"\x00\x00")
|
||||
self._proto.start_handshake()
|
||||
proto.set_as_initiator()
|
||||
proto.set_psks(self._decode_noise_psk())
|
||||
proto.set_prologue(b"NoiseAPIInit\x00\x00")
|
||||
proto.start_handshake()
|
||||
self._proto = proto
|
||||
|
||||
def _send_handshake(self) -> None:
|
||||
"""Send the handshake message."""
|
||||
@ -515,10 +543,12 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
if self._state != NoiseConnectionState.READY:
|
||||
if not self._is_ready:
|
||||
raise HandshakeAPIError(f"{self._log_name}: Noise connection is not ready")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._encrypt is not None, "Handshake should be complete"
|
||||
|
||||
data_len = len(data)
|
||||
self._write_frame(
|
||||
self._encrypt(
|
||||
|
@ -715,8 +715,7 @@ class APIConnection:
|
||||
def _process_packet_factory(self) -> Callable[[int, bytes], None]:
|
||||
"""Factory to make a packet processor."""
|
||||
message_type_to_proto = MESSAGE_TYPE_TO_PROTO
|
||||
is_enabled_for = _LOGGER.isEnabledFor
|
||||
logging_debug = logging.DEBUG
|
||||
debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
message_handlers = self._message_handlers
|
||||
internal_message_types = INTERNAL_MESSAGE_TYPES
|
||||
|
||||
@ -759,7 +758,7 @@ class APIConnection:
|
||||
|
||||
msg_type = type(msg)
|
||||
|
||||
if is_enabled_for(logging_debug):
|
||||
if debug_enabled():
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s",
|
||||
self.log_name,
|
||||
|
@ -119,6 +119,7 @@ async def test_noise_frame_helper_incorrect_key():
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
@ -159,6 +160,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
@ -201,6 +203,7 @@ async def test_noise_incorrect_name():
|
||||
log_name="test",
|
||||
)
|
||||
helper._transport = MagicMock()
|
||||
helper._writer = MagicMock()
|
||||
|
||||
for pkt in outgoing_packets:
|
||||
helper._write_frame(bytes.fromhex(pkt))
|
||||
|
@ -61,6 +61,7 @@ def _get_mock_protocol(conn: APIConnection):
|
||||
)
|
||||
protocol._connected_event.set()
|
||||
protocol._transport = MagicMock()
|
||||
protocol._writer = MagicMock()
|
||||
return protocol
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user