mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2025-01-10 19:57:40 +01:00
Refactor frame_helper into new module (#109)
This commit is contained in:
parent
48ea96f9da
commit
9ca228cd1e
254
aioesphomeapi/_frame_helper.py
Normal file
254
aioesphomeapi/_frame_helper.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from noise.connection import NoiseConnection # type: ignore
|
||||||
|
|
||||||
|
from .core import (
|
||||||
|
HandshakeAPIError,
|
||||||
|
InvalidEncryptionKeyAPIError,
|
||||||
|
ProtocolAPIError,
|
||||||
|
RequiresEncryptionAPIError,
|
||||||
|
SocketAPIError,
|
||||||
|
SocketClosedAPIError,
|
||||||
|
)
|
||||||
|
from .util import bytes_to_varuint, varuint_to_bytes
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Packet:
|
||||||
|
type: int
|
||||||
|
data: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class APIFrameHelper(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def close(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def write_packet(self, packet: Packet) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def read_packet(self) -> Packet:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
):
|
||||||
|
self._reader = reader
|
||||||
|
self._writer = writer
|
||||||
|
self._write_lock = asyncio.Lock()
|
||||||
|
self._read_lock = asyncio.Lock()
|
||||||
|
self._closed_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self._closed_event.set()
|
||||||
|
self._writer.close()
|
||||||
|
|
||||||
|
async def write_packet(self, packet: Packet) -> None:
|
||||||
|
data = b"\0"
|
||||||
|
data += varuint_to_bytes(len(packet.data))
|
||||||
|
data += varuint_to_bytes(packet.type)
|
||||||
|
data += packet.data
|
||||||
|
try:
|
||||||
|
async with self._write_lock:
|
||||||
|
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
||||||
|
self._writer.write(data)
|
||||||
|
await self._writer.drain()
|
||||||
|
except OSError as err:
|
||||||
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
|
async def read_packet(self) -> Packet:
|
||||||
|
async with self._read_lock:
|
||||||
|
try:
|
||||||
|
preamble = await self._reader.readexactly(1)
|
||||||
|
if preamble[0] != 0x00:
|
||||||
|
if preamble[0] == 0x01:
|
||||||
|
raise RequiresEncryptionAPIError(
|
||||||
|
"Connection requires encryption"
|
||||||
|
)
|
||||||
|
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
||||||
|
|
||||||
|
length = b""
|
||||||
|
while not length or (length[-1] & 0x80) == 0x80:
|
||||||
|
length += await self._reader.readexactly(1)
|
||||||
|
length_int = bytes_to_varuint(length)
|
||||||
|
assert length_int is not None
|
||||||
|
msg_type = b""
|
||||||
|
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
||||||
|
msg_type += await self._reader.readexactly(1)
|
||||||
|
msg_type_int = bytes_to_varuint(msg_type)
|
||||||
|
assert msg_type_int is not None
|
||||||
|
|
||||||
|
raw_msg = b""
|
||||||
|
if length_int != 0:
|
||||||
|
raw_msg = await self._reader.readexactly(length_int)
|
||||||
|
return Packet(type=msg_type_int, data=raw_msg)
|
||||||
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
|
if (
|
||||||
|
isinstance(err, asyncio.IncompleteReadError)
|
||||||
|
and self._closed_event.is_set()
|
||||||
|
):
|
||||||
|
raise SocketClosedAPIError(
|
||||||
|
f"Socket closed while reading data: {err}"
|
||||||
|
) from err
|
||||||
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_noise_psk(psk: str) -> bytes:
|
||||||
|
"""Decode the given noise psk from base64 format to raw bytes."""
|
||||||
|
try:
|
||||||
|
psk_bytes = base64.b64decode(psk)
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidEncryptionKeyAPIError(
|
||||||
|
f"Malformed PSK {psk}, expected base64-encoded value"
|
||||||
|
)
|
||||||
|
if len(psk_bytes) != 32:
|
||||||
|
raise InvalidEncryptionKeyAPIError(
|
||||||
|
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
|
||||||
|
)
|
||||||
|
return psk_bytes
|
||||||
|
|
||||||
|
|
||||||
|
class APINoiseFrameHelper(APIFrameHelper):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
noise_psk: str,
|
||||||
|
):
|
||||||
|
self._reader = reader
|
||||||
|
self._writer = writer
|
||||||
|
self._write_lock = asyncio.Lock()
|
||||||
|
self._read_lock = asyncio.Lock()
|
||||||
|
self._ready_event = asyncio.Event()
|
||||||
|
self._closed_event = asyncio.Event()
|
||||||
|
self._proto: Optional[NoiseConnection] = None
|
||||||
|
self._noise_psk = noise_psk
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
self._closed_event.set()
|
||||||
|
self._writer.close()
|
||||||
|
|
||||||
|
async def _write_frame(self, frame: bytes) -> None:
|
||||||
|
try:
|
||||||
|
async with self._write_lock:
|
||||||
|
_LOGGER.debug("Sending frame %s", frame.hex())
|
||||||
|
header = bytes(
|
||||||
|
[
|
||||||
|
0x01,
|
||||||
|
(len(frame) >> 8) & 0xFF,
|
||||||
|
len(frame) & 0xFF,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self._writer.write(header + frame)
|
||||||
|
await self._writer.drain()
|
||||||
|
except OSError as err:
|
||||||
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
|
async def _read_frame(self) -> bytes:
|
||||||
|
try:
|
||||||
|
async with self._read_lock:
|
||||||
|
header = await self._reader.readexactly(3)
|
||||||
|
if header[0] != 0x01:
|
||||||
|
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
|
msg_size = (header[1] << 8) | header[2]
|
||||||
|
frame = await self._reader.readexactly(msg_size)
|
||||||
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
|
if (
|
||||||
|
isinstance(err, asyncio.IncompleteReadError)
|
||||||
|
and self._closed_event.is_set()
|
||||||
|
):
|
||||||
|
raise SocketClosedAPIError(
|
||||||
|
f"Socket closed while reading data: {err}"
|
||||||
|
) from err
|
||||||
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||||
|
|
||||||
|
_LOGGER.debug("Received frame %s", frame.hex())
|
||||||
|
return frame
|
||||||
|
|
||||||
|
async def perform_handshake(self) -> None:
|
||||||
|
await self._write_frame(b"") # ClientHello
|
||||||
|
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||||
|
server_hello = await self._read_frame() # ServerHello
|
||||||
|
if not server_hello:
|
||||||
|
raise HandshakeAPIError("ServerHello is empty")
|
||||||
|
chosen_proto = server_hello[0]
|
||||||
|
if chosen_proto != 0x01:
|
||||||
|
raise HandshakeAPIError(
|
||||||
|
f"Unknown protocol selected by client {chosen_proto}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||||
|
self._proto.set_as_initiator()
|
||||||
|
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||||
|
self._proto.set_prologue(prologue)
|
||||||
|
self._proto.start_handshake()
|
||||||
|
|
||||||
|
_LOGGER.debug("Starting handshake...")
|
||||||
|
do_write = True
|
||||||
|
while not self._proto.handshake_finished:
|
||||||
|
if do_write:
|
||||||
|
msg = self._proto.write_message()
|
||||||
|
await self._write_frame(b"\x00" + msg)
|
||||||
|
else:
|
||||||
|
msg = await self._read_frame()
|
||||||
|
if not msg:
|
||||||
|
raise HandshakeAPIError("Handshake message too short")
|
||||||
|
if msg[0] != 0:
|
||||||
|
explanation = msg[1:].decode()
|
||||||
|
if explanation == "Handshake MAC failure":
|
||||||
|
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||||
|
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||||
|
self._proto.read_message(msg[1:])
|
||||||
|
|
||||||
|
do_write = not do_write
|
||||||
|
|
||||||
|
_LOGGER.debug("Handshake complete!")
|
||||||
|
self._ready_event.set()
|
||||||
|
|
||||||
|
async def write_packet(self, packet: Packet) -> None:
|
||||||
|
# Wait for handshake to complete
|
||||||
|
await self._ready_event.wait()
|
||||||
|
padding = 0
|
||||||
|
data = (
|
||||||
|
bytes(
|
||||||
|
[
|
||||||
|
(packet.type >> 8) & 0xFF,
|
||||||
|
(packet.type >> 0) & 0xFF,
|
||||||
|
(len(packet.data) >> 8) & 0xFF,
|
||||||
|
(len(packet.data) >> 0) & 0xFF,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
+ packet.data
|
||||||
|
+ b"\x00" * padding
|
||||||
|
)
|
||||||
|
assert self._proto is not None
|
||||||
|
frame = self._proto.encrypt(data)
|
||||||
|
await self._write_frame(frame)
|
||||||
|
|
||||||
|
async def read_packet(self) -> Packet:
|
||||||
|
# Wait for handshake to complete
|
||||||
|
await self._ready_event.wait()
|
||||||
|
frame = await self._read_frame()
|
||||||
|
assert self._proto is not None
|
||||||
|
msg = self._proto.decrypt(frame)
|
||||||
|
if len(msg) < 4:
|
||||||
|
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||||
|
pkt_type = (msg[0] << 8) | msg[1]
|
||||||
|
data_len = (msg[2] << 8) | msg[3]
|
||||||
|
if data_len + 4 > len(msg):
|
||||||
|
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||||
|
data = msg[4 : 4 + data_len]
|
||||||
|
return Packet(type=pkt_type, data=data)
|
@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
|
||||||
import enum
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
@ -9,10 +8,15 @@ from dataclasses import astuple, dataclass
|
|||||||
from typing import Any, Awaitable, Callable, List, Optional
|
from typing import Any, Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
from google.protobuf import message
|
from google.protobuf import message
|
||||||
from noise.connection import NoiseConnection # type: ignore
|
|
||||||
|
|
||||||
import aioesphomeapi.host_resolver as hr
|
import aioesphomeapi.host_resolver as hr
|
||||||
|
|
||||||
|
from ._frame_helper import (
|
||||||
|
APIFrameHelper,
|
||||||
|
APINoiseFrameHelper,
|
||||||
|
APIPlaintextFrameHelper,
|
||||||
|
Packet,
|
||||||
|
)
|
||||||
from .api_pb2 import ( # type: ignore
|
from .api_pb2 import ( # type: ignore
|
||||||
ConnectRequest,
|
ConnectRequest,
|
||||||
ConnectResponse,
|
ConnectResponse,
|
||||||
@ -28,20 +32,16 @@ from .api_pb2 import ( # type: ignore
|
|||||||
from .core import (
|
from .core import (
|
||||||
MESSAGE_TYPE_TO_PROTO,
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
HandshakeAPIError,
|
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
|
||||||
PingFailedAPIError,
|
PingFailedAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
ReadFailedAPIError,
|
ReadFailedAPIError,
|
||||||
RequiresEncryptionAPIError,
|
|
||||||
ResolveAPIError,
|
ResolveAPIError,
|
||||||
SocketAPIError,
|
SocketAPIError,
|
||||||
SocketClosedAPIError,
|
SocketClosedAPIError,
|
||||||
TimeoutAPIError,
|
TimeoutAPIError,
|
||||||
)
|
)
|
||||||
from .model import APIVersion
|
from .model import APIVersion
|
||||||
from .util import bytes_to_varuint, varuint_to_bytes
|
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -58,217 +58,6 @@ class ConnectionParams:
|
|||||||
noise_psk: Optional[str]
|
noise_psk: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Packet:
|
|
||||||
type: int
|
|
||||||
data: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class APIFrameHelper:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
reader: asyncio.StreamReader,
|
|
||||||
writer: asyncio.StreamWriter,
|
|
||||||
params: ConnectionParams,
|
|
||||||
):
|
|
||||||
self._reader = reader
|
|
||||||
self._writer = writer
|
|
||||||
self._params = params
|
|
||||||
self._write_lock = asyncio.Lock()
|
|
||||||
self._read_lock = asyncio.Lock()
|
|
||||||
self._ready_event = asyncio.Event()
|
|
||||||
self._proto: Optional[NoiseConnection] = None
|
|
||||||
self._closed_event = asyncio.Event()
|
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
self._closed_event.set()
|
|
||||||
self._writer.close()
|
|
||||||
|
|
||||||
async def _write_frame_noise(self, frame: bytes) -> None:
|
|
||||||
try:
|
|
||||||
async with self._write_lock:
|
|
||||||
_LOGGER.debug("Sending frame %s", frame.hex())
|
|
||||||
header = bytes(
|
|
||||||
[
|
|
||||||
0x01,
|
|
||||||
(len(frame) >> 8) & 0xFF,
|
|
||||||
len(frame) & 0xFF,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self._writer.write(header + frame)
|
|
||||||
await self._writer.drain()
|
|
||||||
except OSError as err:
|
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
|
||||||
|
|
||||||
async def _read_frame_noise(self) -> bytes:
|
|
||||||
try:
|
|
||||||
async with self._read_lock:
|
|
||||||
header = await self._reader.readexactly(3)
|
|
||||||
if header[0] != 0x01:
|
|
||||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
|
||||||
msg_size = (header[1] << 8) | header[2]
|
|
||||||
frame = await self._reader.readexactly(msg_size)
|
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
|
||||||
if (
|
|
||||||
isinstance(err, asyncio.IncompleteReadError)
|
|
||||||
and self._closed_event.is_set()
|
|
||||||
):
|
|
||||||
raise SocketClosedAPIError(
|
|
||||||
f"Socket closed while reading data: {err}"
|
|
||||||
) from err
|
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
|
||||||
|
|
||||||
_LOGGER.debug("Received frame %s", frame.hex())
|
|
||||||
return frame
|
|
||||||
|
|
||||||
async def perform_handshake(self) -> None:
|
|
||||||
if self._params.noise_psk is None:
|
|
||||||
return
|
|
||||||
await self._write_frame_noise(b"") # ClientHello
|
|
||||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
|
||||||
server_hello = await self._read_frame_noise() # ServerHello
|
|
||||||
if not server_hello:
|
|
||||||
raise HandshakeAPIError("ServerHello is empty")
|
|
||||||
chosen_proto = server_hello[0]
|
|
||||||
if chosen_proto != 0x01:
|
|
||||||
raise HandshakeAPIError(
|
|
||||||
f"Unknown protocol selected by client {chosen_proto}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
|
||||||
self._proto.set_as_initiator()
|
|
||||||
|
|
||||||
try:
|
|
||||||
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidEncryptionKeyAPIError(
|
|
||||||
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
|
|
||||||
)
|
|
||||||
if len(noise_psk_bytes) != 32:
|
|
||||||
raise InvalidEncryptionKeyAPIError(
|
|
||||||
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._proto.set_psks(noise_psk_bytes)
|
|
||||||
self._proto.set_prologue(prologue)
|
|
||||||
self._proto.start_handshake()
|
|
||||||
|
|
||||||
_LOGGER.debug("Starting handshake...")
|
|
||||||
do_write = True
|
|
||||||
while not self._proto.handshake_finished:
|
|
||||||
if do_write:
|
|
||||||
msg = self._proto.write_message()
|
|
||||||
await self._write_frame_noise(b"\x00" + msg)
|
|
||||||
else:
|
|
||||||
msg = await self._read_frame_noise()
|
|
||||||
if not msg:
|
|
||||||
raise HandshakeAPIError("Handshake message too short")
|
|
||||||
if msg[0] != 0:
|
|
||||||
explanation = msg[1:].decode()
|
|
||||||
if explanation == "Handshake MAC failure":
|
|
||||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
|
||||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
|
||||||
self._proto.read_message(msg[1:])
|
|
||||||
|
|
||||||
do_write = not do_write
|
|
||||||
|
|
||||||
_LOGGER.debug("Handshake complete!")
|
|
||||||
self._ready_event.set()
|
|
||||||
|
|
||||||
async def _write_packet_noise(self, packet: Packet) -> None:
|
|
||||||
await self._ready_event.wait()
|
|
||||||
padding = 0
|
|
||||||
data = (
|
|
||||||
bytes(
|
|
||||||
[
|
|
||||||
(packet.type >> 8) & 0xFF,
|
|
||||||
(packet.type >> 0) & 0xFF,
|
|
||||||
(len(packet.data) >> 8) & 0xFF,
|
|
||||||
(len(packet.data) >> 0) & 0xFF,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
+ packet.data
|
|
||||||
+ b"\x00" * padding
|
|
||||||
)
|
|
||||||
assert self._proto is not None
|
|
||||||
frame = self._proto.encrypt(data)
|
|
||||||
await self._write_frame_noise(frame)
|
|
||||||
|
|
||||||
async def _write_packet_plaintext(self, packet: Packet) -> None:
|
|
||||||
data = b"\0"
|
|
||||||
data += varuint_to_bytes(len(packet.data))
|
|
||||||
data += varuint_to_bytes(packet.type)
|
|
||||||
data += packet.data
|
|
||||||
try:
|
|
||||||
async with self._write_lock:
|
|
||||||
_LOGGER.debug("Sending frame %s", data.hex())
|
|
||||||
self._writer.write(data)
|
|
||||||
await self._writer.drain()
|
|
||||||
except OSError as err:
|
|
||||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
|
||||||
|
|
||||||
async def write_packet(self, packet: Packet) -> None:
|
|
||||||
if self._params.noise_psk is None:
|
|
||||||
await self._write_packet_plaintext(packet)
|
|
||||||
else:
|
|
||||||
await self._write_packet_noise(packet)
|
|
||||||
|
|
||||||
async def _read_packet_noise(self) -> Packet:
|
|
||||||
await self._ready_event.wait()
|
|
||||||
frame = await self._read_frame_noise()
|
|
||||||
assert self._proto is not None
|
|
||||||
msg = self._proto.decrypt(frame)
|
|
||||||
if len(msg) < 4:
|
|
||||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
|
||||||
pkt_type = (msg[0] << 8) | msg[1]
|
|
||||||
data_len = (msg[2] << 8) | msg[3]
|
|
||||||
if data_len + 4 > len(msg):
|
|
||||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
|
||||||
data = msg[4 : 4 + data_len]
|
|
||||||
return Packet(type=pkt_type, data=data)
|
|
||||||
|
|
||||||
async def _read_packet_plaintext(self) -> Packet:
|
|
||||||
async with self._read_lock:
|
|
||||||
try:
|
|
||||||
preamble = await self._reader.readexactly(1)
|
|
||||||
if preamble[0] != 0x00:
|
|
||||||
if preamble[0] == 0x01:
|
|
||||||
raise RequiresEncryptionAPIError(
|
|
||||||
"Connection requires encryption"
|
|
||||||
)
|
|
||||||
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
|
||||||
|
|
||||||
length = b""
|
|
||||||
while not length or (length[-1] & 0x80) == 0x80:
|
|
||||||
length += await self._reader.readexactly(1)
|
|
||||||
length_int = bytes_to_varuint(length)
|
|
||||||
assert length_int is not None
|
|
||||||
msg_type = b""
|
|
||||||
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
|
||||||
msg_type += await self._reader.readexactly(1)
|
|
||||||
msg_type_int = bytes_to_varuint(msg_type)
|
|
||||||
assert msg_type_int is not None
|
|
||||||
|
|
||||||
raw_msg = b""
|
|
||||||
if length_int != 0:
|
|
||||||
raw_msg = await self._reader.readexactly(length_int)
|
|
||||||
return Packet(type=msg_type_int, data=raw_msg)
|
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
|
||||||
if (
|
|
||||||
isinstance(err, asyncio.IncompleteReadError)
|
|
||||||
and self._closed_event.is_set()
|
|
||||||
):
|
|
||||||
raise SocketClosedAPIError(
|
|
||||||
f"Socket closed while reading data: {err}"
|
|
||||||
) from err
|
|
||||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
|
||||||
|
|
||||||
async def read_packet(self) -> Packet:
|
|
||||||
if self._params.noise_psk is None:
|
|
||||||
return await self._read_packet_plaintext()
|
|
||||||
return await self._read_packet_noise()
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionState(enum.Enum):
|
class ConnectionState(enum.Enum):
|
||||||
# The connection is initialized, but connect() wasn't called yet
|
# The connection is initialized, but connect() wasn't called yet
|
||||||
INITIALIZED = 0
|
INITIALIZED = 0
|
||||||
@ -381,8 +170,13 @@ class APIConnection:
|
|||||||
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
"""Step 3 in connect process: initialize the frame helper and init read loop."""
|
||||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||||
|
|
||||||
self._frame_helper = APIFrameHelper(reader, writer, self._params)
|
if self._params.noise_psk is None:
|
||||||
await self._frame_helper.perform_handshake()
|
self._frame_helper = APIPlaintextFrameHelper(reader, writer)
|
||||||
|
else:
|
||||||
|
fh = self._frame_helper = APINoiseFrameHelper(
|
||||||
|
reader, writer, self._params.noise_psk
|
||||||
|
)
|
||||||
|
await fh.perform_handshake()
|
||||||
|
|
||||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user