Emit different Exception types to differentiate between connection errors (#102)

* Emit different Exception types to differentiate between connection errors

* Import in init
This commit is contained in:
Otto Winter 2021-09-14 12:44:52 +02:00 committed by GitHub
parent 0660f1cd05
commit 5c9e7acbce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 113 additions and 42 deletions

View File

@ -1,6 +1,16 @@
# flake8: noqa
from .client import APIClient
from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import *
from .reconnect_logic import ReconnectLogic

View File

@ -139,7 +139,8 @@ class APIClient:
client_info=client_info,
keepalive=keepalive,
zeroconf_instance=zeroconf_instance,
noise_psk=noise_psk,
# treat empty psk string as missing (like password)
noise_psk=noise_psk or None,
)
self._connection: Optional[APIConnection] = None
self._cached_name: Optional[str] = None

View File

@ -3,6 +3,7 @@ import base64
import logging
import socket
import time
from contextlib import suppress
from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional
@ -23,7 +24,17 @@ from .api_pb2 import ( # type: ignore
PingRequest,
PingResponse,
)
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .core import (
MESSAGE_TYPE_TO_PROTO,
APIConnectionError,
HandshakeAPIError,
InvalidAuthAPIError,
InvalidEncryptionKeyAPIError,
ProtocolAPIError,
RequiresEncryptionAPIError,
ResolveAPIError,
SocketAPIError,
)
from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes
@ -41,12 +52,6 @@ class ConnectionParams:
zeroconf_instance: hr.ZeroconfInstanceType
noise_psk: Optional[str]
@property
def noise_psk_bytes(self) -> Optional[bytes]:
if self.noise_psk is None:
return None
return base64.b64decode(self.noise_psk)
@dataclass
class Packet:
@ -87,18 +92,18 @@ class APIFrameHelper:
self._writer.write(header + frame)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
raise SocketAPIError(f"Error while writing data: {err}") from err
async def _read_frame_noise(self) -> bytes:
try:
async with self._read_lock:
header = await self._reader.readexactly(3)
if header[0] != 0x01:
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
msg_size = (header[1] << 8) | header[2]
frame = await self._reader.readexactly(msg_size)
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
raise APIConnectionError(f"Error while reading data: {err}") from err
raise SocketAPIError(f"Error while reading data: {err}") from err
_LOGGER.debug("Received frame %s", frame.hex())
return frame
@ -110,16 +115,28 @@ class APIFrameHelper:
prologue = b"NoiseAPIInit" + b"\x00\x00"
server_hello = await self._read_frame_noise() # ServerHello
if not server_hello:
raise APIConnectionError("ServerHello is empty")
raise HandshakeAPIError("ServerHello is empty")
chosen_proto = server_hello[0]
if chosen_proto != 0x01:
raise APIConnectionError(
raise HandshakeAPIError(
f"Unknown protocol selected by client {chosen_proto}"
)
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
self._proto.set_as_initiator()
self._proto.set_psks(self._params.noise_psk_bytes)
try:
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
except ValueError:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
)
if len(noise_psk_bytes) != 32:
raise InvalidEncryptionKeyAPIError(
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
)
self._proto.set_psks(noise_psk_bytes)
self._proto.set_prologue(prologue)
self._proto.start_handshake()
@ -131,8 +148,13 @@ class APIFrameHelper:
await self._write_frame_noise(b"\x00" + msg)
else:
msg = await self._read_frame_noise()
if not msg or msg[0] != 0:
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
if not msg:
raise HandshakeAPIError("Handshake message too short")
if msg[0] != 0:
explanation = msg[1:].decode()
if explanation == "Handshake MAC failure":
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
raise HandshakeAPIError(f"Handshake failure: {explanation}")
self._proto.read_message(msg[1:])
do_write = not do_write
@ -170,7 +192,7 @@ class APIFrameHelper:
self._writer.write(data)
await self._writer.drain()
except OSError as err:
raise APIConnectionError(f"Error while writing data: {err}") from err
raise SocketAPIError(f"Error while writing data: {err}") from err
async def write_packet(self, packet: Packet) -> None:
if self._params.noise_psk is None:
@ -184,11 +206,11 @@ class APIFrameHelper:
assert self._proto is not None
msg = self._proto.decrypt(frame)
if len(msg) < 4:
raise APIConnectionError(f"Bad packet frame: {msg}")
raise ProtocolAPIError(f"Bad packet frame: {msg}")
pkt_type = (msg[0] << 8) | msg[1]
data_len = (msg[2] << 8) | msg[3]
if data_len + 4 > len(msg):
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
data = msg[4 : 4 + data_len]
return Packet(type=pkt_type, data=data)
@ -196,7 +218,9 @@ class APIFrameHelper:
async with self._read_lock:
preamble = await self._reader.readexactly(1)
if preamble[0] != 0x00:
raise APIConnectionError("Invalid preamble")
if preamble[0] == 0x01:
raise RequiresEncryptionAPIError("Connection requires encryption")
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
length = b""
while not length or (length[-1] & 0x80) == 0x80:
@ -238,6 +262,7 @@ class APIConnection:
self._message_handlers: List[Callable[[message.Message], None]] = []
self.log_name = params.address
self._ping_task: Optional[asyncio.Task[None]] = None
self._read_exception_handlers: List[Callable[[Exception], None]] = []
def _start_ping(self) -> None:
async def func() -> None:
@ -305,7 +330,7 @@ class APIConnection:
raise err
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(
raise ResolveAPIError(
f"Timeout while resolving IP address for {self.log_name}"
)
@ -328,10 +353,10 @@ class APIConnection:
await asyncio.wait_for(coro2, 30.0)
except OSError as err:
await self._on_error()
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
_LOGGER.debug("%s: Opened socket for", self._params.address)
reader, writer = await asyncio.open_connection(sock=self._socket)
@ -383,7 +408,7 @@ class APIConnection:
connect.password = self._params.password
resp = await self.send_message_await_response(connect, ConnectResponse)
if resp.invalid_password:
raise APIConnectionError("Invalid password!")
raise InvalidAuthAPIError("Invalid password!")
self._authenticated = True
@ -444,20 +469,25 @@ class APIConnection:
if do_stop(resp):
fut.set_result(responses)
def on_read_exception(exc: Exception) -> None:
if not fut.done():
fut.set_exception(exc)
self._message_handlers.append(on_message)
self._read_exception_handlers.append(on_read_exception)
await self.send_message(send_msg)
try:
await asyncio.wait_for(fut, timeout)
except asyncio.TimeoutError:
if self._stopped:
raise APIConnectionError("Disconnected while waiting for API response!")
raise APIConnectionError("Timeout while waiting for API response!")
try:
self._message_handlers.remove(on_message)
except ValueError:
pass
raise SocketAPIError("Disconnected while waiting for API response!")
raise SocketAPIError("Timeout while waiting for API response!")
finally:
with suppress(ValueError):
self._message_handlers.remove(on_message)
with suppress(ValueError):
self._read_exception_handlers.remove(on_read_exception)
return responses
@ -491,7 +521,7 @@ class APIConnection:
try:
msg.ParseFromString(raw_msg)
except Exception as e:
raise APIConnectionError("Invalid protobuf message: {}".format(e))
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
_LOGGER.debug(
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
)
@ -509,15 +539,19 @@ class APIConnection:
self.log_name,
err,
)
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error()
break
except Exception as err: # pylint: disable=broad-except
_LOGGER.info(
_LOGGER.warning(
"%s: Unexpected error while reading incoming messages: %s",
self.log_name,
err,
exc_info=True,
)
for handler in self._read_exception_handlers[:]:
handler(err)
await self._on_error()
break

View File

@ -63,6 +63,34 @@ class APIConnectionError(Exception):
pass
class InvalidAuthAPIError(APIConnectionError):
pass
class ResolveAPIError(APIConnectionError):
pass
class ProtocolAPIError(APIConnectionError):
pass
class RequiresEncryptionAPIError(ProtocolAPIError):
pass
class SocketAPIError(APIConnectionError):
pass
class HandshakeAPIError(APIConnectionError):
pass
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
pass
MESSAGE_TYPE_TO_PROTO = {
1: HelloRequest,
2: HelloResponse,

View File

@ -13,7 +13,7 @@ try:
except ImportError:
ZC_ASYNCIO = False
from .core import APIConnectionError
from .core import APIConnectionError, ResolveAPIError
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info(
try:
zc = zeroconf.Zeroconf()
except Exception:
raise APIConnectionError(
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info(
try:
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
except Exception as exc:
raise APIConnectionError(
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info(
try:
zc = zeroconf.asyncio.AsyncZeroconf()
except Exception:
raise APIConnectionError(
raise ResolveAPIError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info(
service_type, service_name, int(timeout * 1000)
)
except Exception as exc:
raise APIConnectionError(
raise ResolveAPIError(
f"Error resolving mDNS {service_name} via mDNS: {exc}"
) from exc
finally:
@ -240,9 +240,7 @@ async def async_resolve_host(
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise APIConnectionError(
f"Could not resolve host {host} - got no results from OS"
)
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
# Use first matching result
# Future: return all matches and use first working one