mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-17 20:51:36 +01:00
Split _frame_helper into plain_text and noise (#491)
This commit is contained in:
parent
85cf377e14
commit
65e659e4a5
7
aioesphomeapi/_frame_helper/__init__.py
Normal file
7
aioesphomeapi/_frame_helper/__init__.py
Normal file
@ -0,0 +1,7 @@
|
||||
from .noise import APINoiseFrameHelper
|
||||
from .plain_text import APIPlaintextFrameHelper
|
||||
|
||||
__all__ = (
|
||||
"APINoiseFrameHelper",
|
||||
"APIPlaintextFrameHelper",
|
||||
)
|
105
aioesphomeapi/_frame_helper/base.py
Normal file
105
aioesphomeapi/_frame_helper/base.py
Normal file
@ -0,0 +1,105 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union, cast
|
||||
|
||||
from ..core import SocketClosedAPIError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SOCKET_ERRORS = (
|
||||
ConnectionResetError,
|
||||
asyncio.IncompleteReadError,
|
||||
OSError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
||||
|
||||
|
||||
class APIFrameHelper(asyncio.Protocol):
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
__slots__ = (
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_connected_event",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
"_client_info",
|
||||
"_log_name",
|
||||
"_debug_enabled",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
self._on_pkt = on_pkt
|
||||
self._on_error = on_error
|
||||
self._transport: Optional[asyncio.Transport] = None
|
||||
self._writer: Optional[
|
||||
Callable[[Union[bytes, bytearray, memoryview]], None]
|
||||
] = None
|
||||
self._connected_event = asyncio.Event()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
new_pos = original_pos + length
|
||||
if self._buffer_len < new_pos:
|
||||
return None
|
||||
self._pos = new_pos
|
||||
return self._buffer[original_pos:new_pos]
|
||||
|
||||
@abstractmethod
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._writer = self._transport.write
|
||||
self._connected_event.set()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
self._handle_error(exc)
|
||||
self.close()
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
self._on_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
self._handle_error(
|
||||
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
|
||||
)
|
||||
return super().connection_lost(exc)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
|
||||
return super().eof_received()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._writer = None
|
@ -1,11 +1,10 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Type
|
||||
|
||||
from chacha20poly1305_reuseable import ChaCha20Poly1305Reusable
|
||||
from cryptography.exceptions import InvalidTag
|
||||
@ -13,31 +12,21 @@ from noise.backends.default import DefaultNoiseBackend # type: ignore[import]
|
||||
from noise.backends.default.ciphers import ChaCha20Cipher # type: ignore[import]
|
||||
from noise.connection import NoiseConnection # type: ignore[import]
|
||||
|
||||
from .core import (
|
||||
from ..core import (
|
||||
APIConnectionError,
|
||||
BadNameAPIError,
|
||||
HandshakeAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
SocketAPIError,
|
||||
SocketClosedAPIError,
|
||||
)
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SOCKET_ERRORS = (
|
||||
ConnectionResetError,
|
||||
asyncio.IncompleteReadError,
|
||||
OSError,
|
||||
TimeoutError,
|
||||
)
|
||||
|
||||
PACK_NONCE = partial(Struct("<LQ").pack, 0)
|
||||
|
||||
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
|
||||
|
||||
|
||||
class ChaCha20CipherReuseable(ChaCha20Cipher): # type: ignore[misc]
|
||||
"""ChaCha20 cipher that can be reused."""
|
||||
@ -58,209 +47,6 @@ class ESPHomeNoiseBackend(DefaultNoiseBackend): # type: ignore[misc]
|
||||
ESPHOME_NOISE_BACKEND = ESPHomeNoiseBackend()
|
||||
|
||||
|
||||
class APIFrameHelper(asyncio.Protocol):
|
||||
"""Helper class to handle the API frame protocol."""
|
||||
|
||||
__slots__ = (
|
||||
"_on_pkt",
|
||||
"_on_error",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_connected_event",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
"_pos",
|
||||
"_client_info",
|
||||
"_log_name",
|
||||
"_debug_enabled",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_pkt: Callable[[int, bytes], None],
|
||||
on_error: Callable[[Exception], None],
|
||||
client_info: str,
|
||||
log_name: str,
|
||||
) -> None:
|
||||
"""Initialize the API frame helper."""
|
||||
self._on_pkt = on_pkt
|
||||
self._on_error = on_error
|
||||
self._transport: Optional[asyncio.Transport] = None
|
||||
self._writer: Optional[
|
||||
Callable[[Union[bytes, bytearray, memoryview]], None]
|
||||
] = None
|
||||
self._connected_event = asyncio.Event()
|
||||
self._buffer = bytearray()
|
||||
self._buffer_len = 0
|
||||
self._pos = 0
|
||||
self._client_info = client_info
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def _read_exactly(self, length: int) -> Optional[bytearray]:
|
||||
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
|
||||
original_pos = self._pos
|
||||
new_pos = original_pos + length
|
||||
if self._buffer_len < new_pos:
|
||||
return None
|
||||
self._pos = new_pos
|
||||
return self._buffer[original_pos:new_pos]
|
||||
|
||||
@abstractmethod
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
|
||||
@abstractmethod
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket."""
|
||||
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._writer = self._transport.write
|
||||
self._connected_event.set()
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
self._handle_error(exc)
|
||||
self.close()
|
||||
|
||||
def _handle_error(self, exc: Exception) -> None:
|
||||
self._on_error(exc)
|
||||
|
||||
def connection_lost(self, exc: Optional[Exception]) -> None:
|
||||
self._handle_error(
|
||||
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
|
||||
)
|
||||
return super().connection_lost(exc)
|
||||
|
||||
def eof_received(self) -> Optional[bool]:
|
||||
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
|
||||
return super().eof_received()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._writer = None
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
|
||||
try:
|
||||
self._writer(data)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
await self._connected_event.wait()
|
||||
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to _read_exactly
|
||||
self._pos = 0
|
||||
init_bytes = self._read_exactly(3)
|
||||
if init_bytes is None:
|
||||
return
|
||||
msg_type_int: Optional[int] = None
|
||||
length_int: Optional[int] = None
|
||||
preamble, length_high, maybe_msg_type = init_bytes
|
||||
if preamble != 0x00:
|
||||
if preamble == 0x01:
|
||||
self._handle_error_and_close(
|
||||
RequiresEncryptionAPIError(
|
||||
f"{self._log_name}: Connection requires encryption"
|
||||
)
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(
|
||||
f"{self._log_name}: Invalid preamble {preamble:02x}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if length_high & 0x80 != 0x80:
|
||||
# Length is only 1 byte
|
||||
#
|
||||
# This is the most common case needing a single byte for
|
||||
# length and type which means we avoid 2 calls to _read_exactly
|
||||
length_int = length_high
|
||||
if maybe_msg_type & 0x80 != 0x80:
|
||||
# Message type is also only 1 byte
|
||||
msg_type_int = maybe_msg_type
|
||||
else:
|
||||
# Message type is longer than 1 byte
|
||||
msg_type = bytes(init_bytes[2:3])
|
||||
else:
|
||||
# Length is longer than 1 byte
|
||||
length = bytes(init_bytes[1:3])
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
add_length = self._read_exactly(1)
|
||||
if add_length is None:
|
||||
return
|
||||
length += add_length
|
||||
length_int = bytes_to_varuint(length)
|
||||
# Since the length is longer than 1 byte we do not have the
|
||||
# message type yet.
|
||||
msg_type = b""
|
||||
|
||||
# If the we do not have the message type yet because the message
|
||||
# length was so long it did not fit into the first byte we need
|
||||
# to read the (rest) of the message type
|
||||
if msg_type_int is None:
|
||||
while not msg_type or msg_type[-1] & 0x80 == 0x80:
|
||||
add_msg_type = self._read_exactly(1)
|
||||
if add_msg_type is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert length_int is not None
|
||||
assert msg_type_int is not None
|
||||
|
||||
if length_int == 0:
|
||||
packet_data = b""
|
||||
else:
|
||||
packet_data_bytearray = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data_bytearray is None:
|
||||
return
|
||||
packet_data = bytes(packet_data_bytearray)
|
||||
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
self._on_pkt(msg_type_int, packet_data)
|
||||
# If we have more data, continue processing
|
||||
|
||||
|
||||
class NoiseConnectionState(Enum):
|
||||
"""Noise connection state."""
|
||||
|
124
aioesphomeapi/_frame_helper/plain_text.py
Normal file
124
aioesphomeapi/_frame_helper/plain_text.py
Normal file
@ -0,0 +1,124 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from ..core import ProtocolAPIError, RequiresEncryptionAPIError, SocketAPIError
|
||||
from ..util import bytes_to_varuint, varuint_to_bytes
|
||||
from .base import WRITE_EXCEPTIONS, APIFrameHelper
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
"""Frame helper for plaintext API connections."""
|
||||
|
||||
def write_packet(self, type_: int, data: bytes) -> None:
|
||||
"""Write a packet to the socket.
|
||||
|
||||
The entire packet must be written in a single call.
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer should be set"
|
||||
|
||||
data = b"\0" + varuint_to_bytes(len(data)) + varuint_to_bytes(type_) + data
|
||||
if self._debug_enabled():
|
||||
_LOGGER.debug("%s: Sending plaintext frame %s", self._log_name, data.hex())
|
||||
|
||||
try:
|
||||
self._writer(data)
|
||||
except WRITE_EXCEPTIONS as err:
|
||||
raise SocketAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
) from err
|
||||
|
||||
async def perform_handshake(self) -> None:
|
||||
"""Perform the handshake."""
|
||||
await self._connected_event.wait()
|
||||
|
||||
def data_received(self, data: bytes) -> None: # pylint: disable=too-many-branches
|
||||
self._buffer += data
|
||||
self._buffer_len += len(data)
|
||||
while self._buffer:
|
||||
# Read preamble, which should always 0x00
|
||||
# Also try to get the length and msg type
|
||||
# to avoid multiple calls to _read_exactly
|
||||
self._pos = 0
|
||||
init_bytes = self._read_exactly(3)
|
||||
if init_bytes is None:
|
||||
return
|
||||
msg_type_int: Optional[int] = None
|
||||
length_int: Optional[int] = None
|
||||
preamble, length_high, maybe_msg_type = init_bytes
|
||||
if preamble != 0x00:
|
||||
if preamble == 0x01:
|
||||
self._handle_error_and_close(
|
||||
RequiresEncryptionAPIError(
|
||||
f"{self._log_name}: Connection requires encryption"
|
||||
)
|
||||
)
|
||||
return
|
||||
self._handle_error_and_close(
|
||||
ProtocolAPIError(
|
||||
f"{self._log_name}: Invalid preamble {preamble:02x}"
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if length_high & 0x80 != 0x80:
|
||||
# Length is only 1 byte
|
||||
#
|
||||
# This is the most common case needing a single byte for
|
||||
# length and type which means we avoid 2 calls to _read_exactly
|
||||
length_int = length_high
|
||||
if maybe_msg_type & 0x80 != 0x80:
|
||||
# Message type is also only 1 byte
|
||||
msg_type_int = maybe_msg_type
|
||||
else:
|
||||
# Message type is longer than 1 byte
|
||||
msg_type = bytes(init_bytes[2:3])
|
||||
else:
|
||||
# Length is longer than 1 byte
|
||||
length = bytes(init_bytes[1:3])
|
||||
# If the message is long, we need to read the rest of the length
|
||||
while length[-1] & 0x80 == 0x80:
|
||||
add_length = self._read_exactly(1)
|
||||
if add_length is None:
|
||||
return
|
||||
length += add_length
|
||||
length_int = bytes_to_varuint(length)
|
||||
# Since the length is longer than 1 byte we do not have the
|
||||
# message type yet.
|
||||
msg_type = b""
|
||||
|
||||
# If the we do not have the message type yet because the message
|
||||
# length was so long it did not fit into the first byte we need
|
||||
# to read the (rest) of the message type
|
||||
if msg_type_int is None:
|
||||
while not msg_type or msg_type[-1] & 0x80 == 0x80:
|
||||
add_msg_type = self._read_exactly(1)
|
||||
if add_msg_type is None:
|
||||
return
|
||||
msg_type += add_msg_type
|
||||
msg_type_int = bytes_to_varuint(msg_type)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert length_int is not None
|
||||
assert msg_type_int is not None
|
||||
|
||||
if length_int == 0:
|
||||
packet_data = b""
|
||||
else:
|
||||
packet_data_bytearray = self._read_exactly(length_int)
|
||||
# The packet data is not yet available, wait for more data
|
||||
# to arrive before continuing, since callback_packet has not
|
||||
# been called yet the buffer will not be cleared and the next
|
||||
# call to data_received will continue processing the packet
|
||||
# at the start of the frame.
|
||||
if packet_data_bytearray is None:
|
||||
return
|
||||
packet_data = bytes(packet_data_bytearray)
|
||||
|
||||
end_of_frame_pos = self._pos
|
||||
del self._buffer[:end_of_frame_pos]
|
||||
self._buffer_len -= end_of_frame_pos
|
||||
self._on_pkt(msg_type_int, packet_data)
|
||||
# If we have more data, continue processing
|
@ -25,7 +25,7 @@ from google.protobuf import message
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
|
||||
from ._frame_helper import APIFrameHelper, APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from ._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from .api_pb2 import ( # type: ignore
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
@ -168,7 +168,9 @@ class APIConnection:
|
||||
self.on_stop: Optional[Callable[[bool], Coroutine[Any, Any, None]]] = on_stop
|
||||
self._on_stop_task: Optional[asyncio.Task[None]] = None
|
||||
self._socket: Optional[socket.socket] = None
|
||||
self._frame_helper: Optional[APIFrameHelper] = None
|
||||
self._frame_helper: Optional[
|
||||
Union[APINoiseFrameHelper, APIPlaintextFrameHelper]
|
||||
] = None
|
||||
self.api_version: Optional[APIVersion] = None
|
||||
|
||||
self._connection_state = ConnectionState.INITIALIZED
|
||||
|
@ -2,11 +2,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from aioesphomeapi._frame_helper import (
|
||||
WRITE_EXCEPTIONS,
|
||||
APINoiseFrameHelper,
|
||||
APIPlaintextFrameHelper,
|
||||
)
|
||||
from aioesphomeapi._frame_helper import APINoiseFrameHelper, APIPlaintextFrameHelper
|
||||
from aioesphomeapi._frame_helper.base import WRITE_EXCEPTIONS
|
||||
from aioesphomeapi.core import (
|
||||
BadNameAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
|
Loading…
Reference in New Issue
Block a user