Emit different Exception types to differentiate between connection errors (#102)

* Emit different Exception types to differentiate between connection errors

* Import in init
This commit is contained in:
Otto Winter 2021-09-14 12:44:52 +02:00 committed by GitHub
parent 0660f1cd05
commit 5c9e7acbce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 42 deletions

View File

@ -1,6 +1,16 @@
# flake8: noqa # flake8: noqa
from .client import APIClient from .client import APIClient
from .connection import APIConnection, ConnectionParams from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import * from .model import *
from .reconnect_logic import ReconnectLogic from .reconnect_logic import ReconnectLogic

View File

@ -139,7 +139,8 @@ class APIClient:
client_info=client_info, client_info=client_info,
keepalive=keepalive, keepalive=keepalive,
zeroconf_instance=zeroconf_instance, zeroconf_instance=zeroconf_instance,
noise_psk=noise_psk, # treat empty psk string as missing (like password)
noise_psk=noise_psk or None,
) )
self._connection: Optional[APIConnection] = None self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None self._cached_name: Optional[str] = None

View File

@ -3,6 +3,7 @@ import base64
import logging import logging
import socket import socket
import time import time
from contextlib import suppress
from dataclasses import astuple, dataclass from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional from typing import Any, Awaitable, Callable, List, Optional
@ -23,7 +24,17 @@ from .api_pb2 import ( # type: ignore
PingRequest, PingRequest,
PingResponse, PingResponse,
) )
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import APIVersion from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes from .util import bytes_to_varuint, varuint_to_bytes
@ -41,12 +52,6 @@ class ConnectionParams:
zeroconf_instance: hr.ZeroconfInstanceType zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str] noise_psk: Optional[str]
@property
def noise_psk_bytes(self) -> Optional[bytes]:
if self.noise_psk is None:
return None
return base64.b64decode(self.noise_psk)
@dataclass @dataclass
class Packet: class Packet:
@ -87,18 +92,18 @@ class APIFrameHelper:
self._writer.write(header + frame) self._writer.write(header + frame)
await self._writer.drain() await self._writer.drain()
except OSError as err: except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame_noise(self) -> bytes: async def _read_frame_noise(self) -> bytes:
try: try:
async with self._read_lock: async with self._read_lock:
header = await self._reader.readexactly(3) header = await self._reader.readexactly(3)
if header[0] != 0x01: if header[0] != 0x01:
raise APIConnectionError(f"Marker byte invalid: {header[0]}") raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2] msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size) frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err: except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(f"Error while reading data: {err}") from err raise SocketAPIError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex()) _LOGGER.debug("Received frame %s", frame.hex())
return frame return frame
@ -110,16 +115,28 @@ class APIFrameHelper:
prologue = b"NoiseAPIInit" + b"\x00\x00" prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello server_hello = await self._read_frame_noise() # ServerHello
if not server_hello: if not server_hello:
raise APIConnectionError("ServerHello is empty") raise HandshakeAPIError("ServerHello is empty")
chosen_proto = server_hello[0] chosen_proto = server_hello[0]
if chosen_proto != 0x01: if chosen_proto != 0x01:
raise APIConnectionError( raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}" f"Unknown protocol selected by client {chosen_proto}"
) )
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256") self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator() self._proto.set_as_initiator()
self._proto.set_psks(self._params.noise_psk_bytes)
try:
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
)
if len(noise_psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
)
self._proto.set_psks(noise_psk_bytes)
self._proto.set_prologue(prologue) self._proto.set_prologue(prologue)
self._proto.start_handshake() self._proto.start_handshake()
@ -131,8 +148,13 @@ class APIFrameHelper:
await self._write_frame_noise(b"\x00" + msg) await self._write_frame_noise(b"\x00" + msg)
else: else:
msg = await self._read_frame_noise() msg = await self._read_frame_noise()
if not msg or msg[0] != 0: if not msg:
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}") raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:]) self._proto.read_message(msg[1:])
do_write = not do_write do_write = not do_write
@ -170,7 +192,7 @@ class APIFrameHelper:
self._writer.write(data) self._writer.write(data)
await self._writer.drain() await self._writer.drain()
except OSError as err: except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err raise SocketAPIError(f"Error while writing data: {err}") from err
async def write_packet(self, packet: Packet) -> None: async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None: if self._params.noise_psk is None:
@ -184,11 +206,11 @@ class APIFrameHelper:
assert self._proto is not None assert self._proto is not None
msg = self._proto.decrypt(frame) msg = self._proto.decrypt(frame)
if len(msg) < 4: if len(msg) < 4:
raise APIConnectionError(f"Bad packet frame: {msg}") raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1] pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3] data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg): if data_len + 4 > len(msg):
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}") raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len] data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data) return Packet(type=pkt_type, data=data)
@ -196,7 +218,9 @@ class APIFrameHelper:
async with self._read_lock: async with self._read_lock:
preamble = await self._reader.readexactly(1) preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00: if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble") if preamble[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption")
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
length = b"" length = b""
while not length or (length[-1] & 0x80) == 0x80: while not length or (length[-1] & 0x80) == 0x80:
@ -238,6 +262,7 @@ class APIConnection:
self._message_handlers: List[Callable[[message.Message], None]] = [] self._message_handlers: List[Callable[[message.Message], None]] = []
self.log_name = params.address self.log_name = params.address
self._ping_task: Optional[asyncio.Task[None]] = None self._ping_task: Optional[asyncio.Task[None]] = None
self._read_exception_handlers: List[Callable[[Exception], None]] = []
def _start_ping(self) -> None: def _start_ping(self) -> None:
async def func() -> None: async def func() -> None:
@ -305,7 +330,7 @@ class APIConnection:
raise err raise err
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self._on_error() await self._on_error()
raise APIConnectionError( raise ResolveAPIError(
f"Timeout while resolving IP address for {self.log_name}" f"Timeout while resolving IP address for {self.log_name}"
) )
@ -328,10 +353,10 @@ class APIConnection:
await asyncio.wait_for(coro2, 30.0) await asyncio.wait_for(coro2, 30.0)
except OSError as err: except OSError as err:
await self._on_error() await self._on_error()
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}") raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self._on_error() await self._on_error()
raise APIConnectionError(f"Timeout while connecting to {sockaddr}") raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
_LOGGER.debug("%s: Opened socket for", self._params.address) _LOGGER.debug("%s: Opened socket for", self._params.address)
reader, writer = await asyncio.open_connection(sock=self._socket) reader, writer = await asyncio.open_connection(sock=self._socket)
@ -383,7 +408,7 @@ class APIConnection:
connect.password = self._params.password connect.password = self._params.password
resp = await self.send_message_await_response(connect, ConnectResponse) resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password: if resp.invalid_password:
raise APIConnectionError("Invalid password!") raise InvalidAuthAPIError("Invalid password!")
self._authenticated = True self._authenticated = True
@ -444,20 +469,25 @@ class APIConnection:
if do_stop(resp): if do_stop(resp):
fut.set_result(responses) fut.set_result(responses)
def on_read_exception(exc: Exception) -> None:
if not fut.done():
fut.set_exception(exc)
self._message_handlers.append(on_message) self._message_handlers.append(on_message)
self._read_exception_handlers.append(on_read_exception)
await self.send_message(send_msg) await self.send_message(send_msg)
try: try:
await asyncio.wait_for(fut, timeout) await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
if self._stopped: if self._stopped:
raise APIConnectionError("Disconnected while waiting for API response!") raise SocketAPIError("Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!") raise SocketAPIError("Timeout while waiting for API response!")
finally:
try: with suppress(ValueError):
self._message_handlers.remove(on_message) self._message_handlers.remove(on_message)
except ValueError: with suppress(ValueError):
pass self._read_exception_handlers.remove(on_read_exception)
return responses return responses
@ -491,7 +521,7 @@ class APIConnection:
try: try:
msg.ParseFromString(raw_msg) msg.ParseFromString(raw_msg)
except Exception as e: except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e)) raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
_LOGGER.debug( _LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg "%s: Got message of type %s: %s", self._params.address, type(msg), msg
) )
@ -509,15 +539,19 @@ class APIConnection:
self.log_name, self.log_name,
err, err,
) )
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error() await self._on_error()
break break
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except
_LOGGER.info( _LOGGER.warning(
"%s: Unexpected error while reading incoming messages: %s", "%s: Unexpected error while reading incoming messages: %s",
self.log_name, self.log_name,
err, err,
exc_info=True, exc_info=True,
) )
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error() await self._on_error()
break break

View File

@ -63,6 +63,34 @@ class APIConnectionError(Exception):
pass pass
class InvalidAuthAPIError(APIConnectionError):
pass
class ResolveAPIError(APIConnectionError):
pass
class ProtocolAPIError(APIConnectionError):
pass
class RequiresEncryptionAPIError(ProtocolAPIError):
pass
class SocketAPIError(APIConnectionError):
pass
class HandshakeAPIError(APIConnectionError):
pass
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
pass
MESSAGE_TYPE_TO_PROTO = { MESSAGE_TYPE_TO_PROTO = {
1: HelloRequest, 1: HelloRequest,
2: HelloResponse, 2: HelloResponse,

View File

@ -13,7 +13,7 @@ try:
except ImportError: except ImportError:
ZC_ASYNCIO = False ZC_ASYNCIO = False
from .core import APIConnectionError from .core import APIConnectionError, ResolveAPIError
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None] ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info(
try: try:
zc = zeroconf.Zeroconf() zc = zeroconf.Zeroconf()
except Exception: except Exception:
raise APIConnectionError( raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without " "Cannot start mDNS sockets, is this a docker container without "
"host network mode?" "host network mode?"
) )
@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info(
try: try:
info = zc.get_service_info(service_type, service_name, int(timeout * 1000)) info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
except Exception as exc: except Exception as exc:
raise APIConnectionError( raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}" f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc ) from exc
finally: finally:
@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info(
try: try:
zc = zeroconf.asyncio.AsyncZeroconf() zc = zeroconf.asyncio.AsyncZeroconf()
except Exception: except Exception:
raise APIConnectionError( raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without " "Cannot start mDNS sockets, is this a docker container without "
"host network mode?" "host network mode?"
) )
@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info(
service_type, service_name, int(timeout * 1000) service_type, service_name, int(timeout * 1000)
) )
except Exception as exc: except Exception as exc:
raise APIConnectionError( raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}" f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc ) from exc
finally: finally:
@ -240,9 +240,7 @@ async def async_resolve_host(
if zc_error: if zc_error:
# Only show ZC error if getaddrinfo also didn't work # Only show ZC error if getaddrinfo also didn't work
raise zc_error raise zc_error
raise APIConnectionError( raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
f"Could not resolve host {host} - got no results from OS"
)
# Use first matching result # Use first matching result
# Future: return all matches and use first working one # Future: return all matches and use first working one