2021-10-13 10:05:08 +02:00
|
|
|
import asyncio
|
|
|
|
import base64
|
|
|
|
import logging
|
2022-12-02 20:12:19 +01:00
|
|
|
from abc import ABC, abstractmethod, abstractproperty
|
2021-10-13 10:05:08 +02:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Optional
|
|
|
|
|
2022-08-22 05:26:53 +02:00
|
|
|
import async_timeout
|
2021-10-13 10:05:08 +02:00
|
|
|
from noise.connection import NoiseConnection # type: ignore
|
|
|
|
|
|
|
|
from .core import (
|
2022-01-20 12:03:36 +01:00
|
|
|
BadNameAPIError,
|
2021-10-13 10:05:08 +02:00
|
|
|
HandshakeAPIError,
|
|
|
|
InvalidEncryptionKeyAPIError,
|
|
|
|
ProtocolAPIError,
|
|
|
|
RequiresEncryptionAPIError,
|
|
|
|
SocketAPIError,
|
|
|
|
SocketClosedAPIError,
|
|
|
|
)
|
|
|
|
from .util import bytes_to_varuint, varuint_to_bytes
|
|
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class Packet:
|
|
|
|
type: int
|
|
|
|
data: bytes
|
|
|
|
|
|
|
|
|
|
|
|
class APIFrameHelper(ABC):
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Helper class to handle the API frame protocol."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
reader: asyncio.StreamReader,
|
|
|
|
writer: asyncio.StreamWriter,
|
|
|
|
) -> None:
|
|
|
|
"""Initialize the API frame helper."""
|
|
|
|
self._reader = reader
|
|
|
|
self._writer = writer
|
|
|
|
self.read_lock = asyncio.Lock()
|
|
|
|
self._closed_event = asyncio.Event()
|
|
|
|
|
|
|
|
@abstractproperty # pylint: disable=deprecated-decorator
|
|
|
|
def ready(self) -> bool:
|
|
|
|
"""Return if the connection is ready."""
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
@abstractmethod
|
|
|
|
async def close(self) -> None:
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Close the connection."""
|
2021-10-13 10:05:08 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
2022-12-02 20:12:19 +01:00
|
|
|
def write_packet(self, packet: Packet) -> None:
|
|
|
|
"""Write a packet to the socket."""
|
2021-10-13 10:05:08 +02:00
|
|
|
|
|
|
|
@abstractmethod
|
2022-12-02 20:12:19 +01:00
|
|
|
async def read_packet_with_lock(self) -> Packet:
|
|
|
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def wait_for_ready(self) -> None:
|
|
|
|
"""Wait for the connection to be ready."""
|
2021-10-13 10:05:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
class APIPlaintextFrameHelper(APIFrameHelper):
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Frame helper for plaintext API connections."""
|
2021-10-13 10:05:08 +02:00
|
|
|
|
|
|
|
async def close(self) -> None:
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Close the connection."""
|
2021-10-13 10:05:08 +02:00
|
|
|
self._closed_event.set()
|
|
|
|
self._writer.close()
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
@property
|
|
|
|
def ready(self) -> bool:
|
|
|
|
"""Return if the connection is ready."""
|
|
|
|
# Plaintext is always ready
|
|
|
|
return True
|
|
|
|
|
|
|
|
def write_packet(self, packet: Packet) -> None:
|
|
|
|
"""Write a packet to the socket, the caller should not have the lock.
|
|
|
|
|
|
|
|
The entire packet must be written in a single call to write
|
|
|
|
to avoid locking.
|
|
|
|
"""
|
|
|
|
data = (
|
|
|
|
b"\0"
|
|
|
|
+ varuint_to_bytes(len(packet.data))
|
|
|
|
+ varuint_to_bytes(packet.type)
|
|
|
|
+ packet.data
|
|
|
|
)
|
|
|
|
_LOGGER.debug("Sending plaintext frame %s", data.hex())
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
try:
|
2022-12-02 20:12:19 +01:00
|
|
|
self._writer.write(data)
|
2022-01-04 20:30:22 +01:00
|
|
|
except (ConnectionResetError, OSError) as err:
|
2021-10-13 10:05:08 +02:00
|
|
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
async def wait_for_ready(self) -> None:
|
|
|
|
"""Wait for the connection to be ready."""
|
|
|
|
# No handshake for plaintext
|
|
|
|
|
|
|
|
async def read_packet_with_lock(self) -> Packet:
|
|
|
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
|
|
|
assert self.read_lock.locked(), "read_packet_with_lock called without lock"
|
|
|
|
try:
|
|
|
|
# Read preamble, which should always 0x00
|
|
|
|
# Also try to get the length and msg type
|
|
|
|
# to avoid multiple calls to readexactly
|
|
|
|
init_bytes = await self._reader.readexactly(3)
|
|
|
|
if init_bytes[0] != 0x00:
|
|
|
|
if init_bytes[0] == 0x01:
|
|
|
|
raise RequiresEncryptionAPIError("Connection requires encryption")
|
|
|
|
raise ProtocolAPIError(f"Invalid preamble {init_bytes[0]:02x}")
|
|
|
|
|
|
|
|
if init_bytes[1] & 0x80 == 0x80:
|
|
|
|
# Length is longer than 1 byte
|
|
|
|
length = init_bytes[1:3]
|
|
|
|
msg_type = b""
|
|
|
|
else:
|
|
|
|
# This is the most common case with 99% of messages
|
|
|
|
# needing a single byte for length and type which means
|
|
|
|
# we avoid 2 calls to readexactly
|
|
|
|
length = init_bytes[1:2]
|
|
|
|
msg_type = init_bytes[2:3]
|
|
|
|
|
|
|
|
# If the message is long, we need to read the rest of the length
|
|
|
|
while length[-1] & 0x80 == 0x80:
|
|
|
|
length += await self._reader.readexactly(1)
|
|
|
|
|
|
|
|
# If the message length was longer than 1 byte, we need to read the
|
|
|
|
# message type
|
|
|
|
while not msg_type or (msg_type[-1] & 0x80) == 0x80:
|
|
|
|
msg_type += await self._reader.readexactly(1)
|
|
|
|
|
|
|
|
length_int = bytes_to_varuint(length)
|
|
|
|
assert length_int is not None
|
|
|
|
msg_type_int = bytes_to_varuint(msg_type)
|
|
|
|
assert msg_type_int is not None
|
|
|
|
|
|
|
|
if length_int == 0:
|
|
|
|
return Packet(type=msg_type_int, data=b"")
|
|
|
|
|
|
|
|
data = await self._reader.readexactly(length_int)
|
|
|
|
return Packet(type=msg_type_int, data=data)
|
|
|
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
|
|
|
if (
|
|
|
|
isinstance(err, asyncio.IncompleteReadError)
|
|
|
|
and self._closed_event.is_set()
|
|
|
|
):
|
|
|
|
raise SocketClosedAPIError(
|
|
|
|
f"Socket closed while reading data: {err}"
|
|
|
|
) from err
|
|
|
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
2021-10-13 10:05:08 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _decode_noise_psk(psk: str) -> bytes:
|
|
|
|
"""Decode the given noise psk from base64 format to raw bytes."""
|
|
|
|
try:
|
|
|
|
psk_bytes = base64.b64decode(psk)
|
|
|
|
except ValueError:
|
|
|
|
raise InvalidEncryptionKeyAPIError(
|
|
|
|
f"Malformed PSK {psk}, expected base64-encoded value"
|
|
|
|
)
|
|
|
|
if len(psk_bytes) != 32:
|
|
|
|
raise InvalidEncryptionKeyAPIError(
|
|
|
|
f"Malformed PSK {psk}, expected 32-bytes of base64 data"
|
|
|
|
)
|
|
|
|
return psk_bytes
|
|
|
|
|
|
|
|
|
|
|
|
class APINoiseFrameHelper(APIFrameHelper):
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Frame helper for noise encrypted connections."""
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
reader: asyncio.StreamReader,
|
|
|
|
writer: asyncio.StreamWriter,
|
|
|
|
noise_psk: str,
|
2022-12-02 20:12:19 +01:00
|
|
|
) -> None:
|
|
|
|
"""Initialize the API frame helper."""
|
|
|
|
super().__init__(reader, writer)
|
2021-10-13 10:05:08 +02:00
|
|
|
self._ready_event = asyncio.Event()
|
|
|
|
self._proto: Optional[NoiseConnection] = None
|
|
|
|
self._noise_psk = noise_psk
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
@property
|
|
|
|
def ready(self) -> bool:
|
|
|
|
"""Return if the connection is ready."""
|
|
|
|
return self._ready_event.is_set()
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
async def close(self) -> None:
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Close the connection."""
|
|
|
|
# Make sure we set the ready event if its not already set
|
|
|
|
# so that we don't block forever on the ready event if we
|
|
|
|
# are waiting for the handshake to complete.
|
|
|
|
self._ready_event.set()
|
2021-10-13 10:05:08 +02:00
|
|
|
self._closed_event.set()
|
|
|
|
self._writer.close()
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
def _write_frame(self, frame: bytes) -> None:
|
|
|
|
"""Write a packet to the socket, the caller should not have the lock.
|
|
|
|
|
|
|
|
The entire packet must be written in a single call to write
|
|
|
|
to avoid locking.
|
|
|
|
"""
|
|
|
|
_LOGGER.debug("Sending frame %s", frame.hex())
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
try:
|
2022-12-02 20:12:19 +01:00
|
|
|
header = bytes(
|
|
|
|
[
|
|
|
|
0x01,
|
|
|
|
(len(frame) >> 8) & 0xFF,
|
|
|
|
len(frame) & 0xFF,
|
|
|
|
]
|
|
|
|
)
|
|
|
|
self._writer.write(header + frame)
|
2021-10-13 10:05:08 +02:00
|
|
|
except OSError as err:
|
|
|
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
async def _read_frame_with_lock(self) -> bytes:
|
|
|
|
"""Read a frame from the socket, the caller is responsible for having the lock."""
|
|
|
|
assert self.read_lock.locked(), "_read_frame_with_lock called without lock"
|
2021-10-13 10:05:08 +02:00
|
|
|
try:
|
2022-12-02 20:12:19 +01:00
|
|
|
header = await self._reader.readexactly(3)
|
|
|
|
if header[0] != 0x01:
|
|
|
|
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
|
|
|
msg_size = (header[1] << 8) | header[2]
|
|
|
|
frame = await self._reader.readexactly(msg_size)
|
2021-10-13 10:05:08 +02:00
|
|
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
|
|
|
if (
|
|
|
|
isinstance(err, asyncio.IncompleteReadError)
|
|
|
|
and self._closed_event.is_set()
|
|
|
|
):
|
|
|
|
raise SocketClosedAPIError(
|
|
|
|
f"Socket closed while reading data: {err}"
|
|
|
|
) from err
|
|
|
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
|
|
|
|
|
|
|
_LOGGER.debug("Received frame %s", frame.hex())
|
|
|
|
return frame
|
|
|
|
|
2022-02-09 16:29:50 +01:00
|
|
|
async def _perform_handshake(self, expected_name: Optional[str]) -> None:
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Perform the handshake with the server, the caller is responsible for having the lock."""
|
|
|
|
assert self.read_lock.locked(), "_perform_handshake called without lock"
|
|
|
|
self._write_frame(b"") # ClientHello
|
2021-10-13 10:05:08 +02:00
|
|
|
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
2022-01-20 12:03:36 +01:00
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
server_hello = await self._read_frame_with_lock() # ServerHello
|
2021-10-13 10:05:08 +02:00
|
|
|
if not server_hello:
|
|
|
|
raise HandshakeAPIError("ServerHello is empty")
|
2022-01-20 12:03:36 +01:00
|
|
|
|
|
|
|
# First byte of server hello is the protocol the server chose
|
|
|
|
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
|
|
|
# exists.
|
2021-10-13 10:05:08 +02:00
|
|
|
chosen_proto = server_hello[0]
|
|
|
|
if chosen_proto != 0x01:
|
|
|
|
raise HandshakeAPIError(
|
|
|
|
f"Unknown protocol selected by client {chosen_proto}"
|
|
|
|
)
|
|
|
|
|
2022-01-20 12:03:36 +01:00
|
|
|
# Check name matches expected name (for noise sessions, this is done
|
|
|
|
# during hello phase before a connection is set up)
|
|
|
|
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
|
|
|
|
server_name_i = server_hello.find(b"\0", 1)
|
|
|
|
if server_name_i != -1:
|
|
|
|
# server name found, this extension was added in 2022.2
|
|
|
|
server_name = server_hello[1:server_name_i].decode()
|
|
|
|
if expected_name is not None and expected_name != server_name:
|
|
|
|
raise BadNameAPIError(
|
|
|
|
f"Server sent a different name '{server_name}'", server_name
|
|
|
|
)
|
|
|
|
|
2021-10-13 10:05:08 +02:00
|
|
|
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
|
|
|
self._proto.set_as_initiator()
|
|
|
|
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
|
|
|
self._proto.set_prologue(prologue)
|
|
|
|
self._proto.start_handshake()
|
|
|
|
|
|
|
|
_LOGGER.debug("Starting handshake...")
|
|
|
|
do_write = True
|
|
|
|
while not self._proto.handshake_finished:
|
|
|
|
if do_write:
|
|
|
|
msg = self._proto.write_message()
|
2022-12-02 20:12:19 +01:00
|
|
|
self._write_frame(b"\x00" + msg)
|
2021-10-13 10:05:08 +02:00
|
|
|
else:
|
2022-12-02 20:12:19 +01:00
|
|
|
msg = await self._read_frame_with_lock()
|
2021-10-13 10:05:08 +02:00
|
|
|
if not msg:
|
|
|
|
raise HandshakeAPIError("Handshake message too short")
|
|
|
|
if msg[0] != 0:
|
|
|
|
explanation = msg[1:].decode()
|
|
|
|
if explanation == "Handshake MAC failure":
|
|
|
|
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
|
|
|
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
|
|
|
self._proto.read_message(msg[1:])
|
|
|
|
|
|
|
|
do_write = not do_write
|
|
|
|
|
|
|
|
_LOGGER.debug("Handshake complete!")
|
|
|
|
self._ready_event.set()
|
|
|
|
|
2022-02-09 16:29:50 +01:00
|
|
|
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
2022-12-02 20:12:19 +01:00
|
|
|
"""Perform the handshake with the server."""
|
2022-02-09 16:29:50 +01:00
|
|
|
# Allow up to 60 seconds for handhsake
|
|
|
|
try:
|
2022-12-02 20:12:19 +01:00
|
|
|
async with self.read_lock, async_timeout.timeout(60.0):
|
2022-08-22 05:26:53 +02:00
|
|
|
await self._perform_handshake(expected_name)
|
2022-02-09 16:29:50 +01:00
|
|
|
except asyncio.TimeoutError as err:
|
|
|
|
raise HandshakeAPIError("Timeout during handshake") from err
|
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
def write_packet(self, packet: Packet) -> None:
|
|
|
|
"""Write a packet to the socket."""
|
2021-10-13 10:05:08 +02:00
|
|
|
padding = 0
|
|
|
|
data = (
|
|
|
|
bytes(
|
|
|
|
[
|
|
|
|
(packet.type >> 8) & 0xFF,
|
|
|
|
(packet.type >> 0) & 0xFF,
|
|
|
|
(len(packet.data) >> 8) & 0xFF,
|
|
|
|
(len(packet.data) >> 0) & 0xFF,
|
|
|
|
]
|
|
|
|
)
|
|
|
|
+ packet.data
|
|
|
|
+ b"\x00" * padding
|
|
|
|
)
|
|
|
|
assert self._proto is not None
|
|
|
|
frame = self._proto.encrypt(data)
|
2022-12-02 20:12:19 +01:00
|
|
|
self._write_frame(frame)
|
2021-10-13 10:05:08 +02:00
|
|
|
|
2022-12-02 20:12:19 +01:00
|
|
|
async def wait_for_ready(self) -> None:
|
|
|
|
"""Wait for the connection to be ready."""
|
2021-10-13 10:05:08 +02:00
|
|
|
await self._ready_event.wait()
|
2022-12-02 20:12:19 +01:00
|
|
|
|
|
|
|
async def read_packet_with_lock(self) -> Packet:
|
|
|
|
"""Read a packet from the socket, the caller is responsible for having the lock."""
|
|
|
|
frame = await self._read_frame_with_lock()
|
2021-10-13 10:05:08 +02:00
|
|
|
assert self._proto is not None
|
|
|
|
msg = self._proto.decrypt(frame)
|
|
|
|
if len(msg) < 4:
|
|
|
|
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
|
|
|
pkt_type = (msg[0] << 8) | msg[1]
|
|
|
|
data_len = (msg[2] << 8) | msg[3]
|
|
|
|
if data_len + 4 > len(msg):
|
|
|
|
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
|
|
|
data = msg[4 : 4 + data_len]
|
|
|
|
return Packet(type=pkt_type, data=data)
|