mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Emit different Exception types to differentiate between connection errors (#102)
* Emit different Exception types to differentiate between connection errors * Import in init
This commit is contained in:
parent
0660f1cd05
commit
5c9e7acbce
@ -1,6 +1,16 @@
|
||||
# flake8: noqa
|
||||
from .client import APIClient
|
||||
from .connection import APIConnection, ConnectionParams
|
||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||
from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
from .model import *
|
||||
from .reconnect_logic import ReconnectLogic
|
||||
|
@ -139,7 +139,8 @@ class APIClient:
|
||||
client_info=client_info,
|
||||
keepalive=keepalive,
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
noise_psk=noise_psk,
|
||||
# treat empty psk string as missing (like password)
|
||||
noise_psk=noise_psk or None,
|
||||
)
|
||||
self._connection: Optional[APIConnection] = None
|
||||
self._cached_name: Optional[str] = None
|
||||
|
@ -3,6 +3,7 @@ import base64
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from contextlib import suppress
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional
|
||||
|
||||
@ -23,7 +24,17 @@ from .api_pb2 import ( # type: ignore
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
)
|
||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||
from .core import (
|
||||
MESSAGE_TYPE_TO_PROTO,
|
||||
APIConnectionError,
|
||||
HandshakeAPIError,
|
||||
InvalidAuthAPIError,
|
||||
InvalidEncryptionKeyAPIError,
|
||||
ProtocolAPIError,
|
||||
RequiresEncryptionAPIError,
|
||||
ResolveAPIError,
|
||||
SocketAPIError,
|
||||
)
|
||||
from .model import APIVersion
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
@ -41,12 +52,6 @@ class ConnectionParams:
|
||||
zeroconf_instance: hr.ZeroconfInstanceType
|
||||
noise_psk: Optional[str]
|
||||
|
||||
@property
|
||||
def noise_psk_bytes(self) -> Optional[bytes]:
|
||||
if self.noise_psk is None:
|
||||
return None
|
||||
return base64.b64decode(self.noise_psk)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Packet:
|
||||
@ -87,18 +92,18 @@ class APIFrameHelper:
|
||||
self._writer.write(header + frame)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def _read_frame_noise(self) -> bytes:
|
||||
try:
|
||||
async with self._read_lock:
|
||||
header = await self._reader.readexactly(3)
|
||||
if header[0] != 0x01:
|
||||
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
|
||||
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||
msg_size = (header[1] << 8) | header[2]
|
||||
frame = await self._reader.readexactly(msg_size)
|
||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||
raise APIConnectionError(f"Error while reading data: {err}") from err
|
||||
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||
|
||||
_LOGGER.debug("Received frame %s", frame.hex())
|
||||
return frame
|
||||
@ -110,16 +115,28 @@ class APIFrameHelper:
|
||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||
server_hello = await self._read_frame_noise() # ServerHello
|
||||
if not server_hello:
|
||||
raise APIConnectionError("ServerHello is empty")
|
||||
raise HandshakeAPIError("ServerHello is empty")
|
||||
chosen_proto = server_hello[0]
|
||||
if chosen_proto != 0x01:
|
||||
raise APIConnectionError(
|
||||
raise HandshakeAPIError(
|
||||
f"Unknown protocol selected by client {chosen_proto}"
|
||||
)
|
||||
|
||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||
self._proto.set_as_initiator()
|
||||
self._proto.set_psks(self._params.noise_psk_bytes)
|
||||
|
||||
try:
|
||||
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
|
||||
except ValueError:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
|
||||
)
|
||||
if len(noise_psk_bytes) != 32:
|
||||
raise InvalidEncryptionKeyAPIError(
|
||||
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
|
||||
)
|
||||
|
||||
self._proto.set_psks(noise_psk_bytes)
|
||||
self._proto.set_prologue(prologue)
|
||||
self._proto.start_handshake()
|
||||
|
||||
@ -131,8 +148,13 @@ class APIFrameHelper:
|
||||
await self._write_frame_noise(b"\x00" + msg)
|
||||
else:
|
||||
msg = await self._read_frame_noise()
|
||||
if not msg or msg[0] != 0:
|
||||
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
|
||||
if not msg:
|
||||
raise HandshakeAPIError("Handshake message too short")
|
||||
if msg[0] != 0:
|
||||
explanation = msg[1:].decode()
|
||||
if explanation == "Handshake MAC failure":
|
||||
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||
self._proto.read_message(msg[1:])
|
||||
|
||||
do_write = not do_write
|
||||
@ -170,7 +192,7 @@ class APIFrameHelper:
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
except OSError as err:
|
||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
||||
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||
|
||||
async def write_packet(self, packet: Packet) -> None:
|
||||
if self._params.noise_psk is None:
|
||||
@ -184,11 +206,11 @@ class APIFrameHelper:
|
||||
assert self._proto is not None
|
||||
msg = self._proto.decrypt(frame)
|
||||
if len(msg) < 4:
|
||||
raise APIConnectionError(f"Bad packet frame: {msg}")
|
||||
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||
pkt_type = (msg[0] << 8) | msg[1]
|
||||
data_len = (msg[2] << 8) | msg[3]
|
||||
if data_len + 4 > len(msg):
|
||||
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||
data = msg[4 : 4 + data_len]
|
||||
return Packet(type=pkt_type, data=data)
|
||||
|
||||
@ -196,7 +218,9 @@ class APIFrameHelper:
|
||||
async with self._read_lock:
|
||||
preamble = await self._reader.readexactly(1)
|
||||
if preamble[0] != 0x00:
|
||||
raise APIConnectionError("Invalid preamble")
|
||||
if preamble[0] == 0x01:
|
||||
raise RequiresEncryptionAPIError("Connection requires encryption")
|
||||
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
||||
|
||||
length = b""
|
||||
while not length or (length[-1] & 0x80) == 0x80:
|
||||
@ -238,6 +262,7 @@ class APIConnection:
|
||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||
self.log_name = params.address
|
||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
||||
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
||||
|
||||
def _start_ping(self) -> None:
|
||||
async def func() -> None:
|
||||
@ -305,7 +330,7 @@ class APIConnection:
|
||||
raise err
|
||||
except asyncio.TimeoutError:
|
||||
await self._on_error()
|
||||
raise APIConnectionError(
|
||||
raise ResolveAPIError(
|
||||
f"Timeout while resolving IP address for {self.log_name}"
|
||||
)
|
||||
|
||||
@ -328,10 +353,10 @@ class APIConnection:
|
||||
await asyncio.wait_for(coro2, 30.0)
|
||||
except OSError as err:
|
||||
await self._on_error()
|
||||
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
|
||||
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
|
||||
except asyncio.TimeoutError:
|
||||
await self._on_error()
|
||||
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
|
||||
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
|
||||
|
||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||
@ -383,7 +408,7 @@ class APIConnection:
|
||||
connect.password = self._params.password
|
||||
resp = await self.send_message_await_response(connect, ConnectResponse)
|
||||
if resp.invalid_password:
|
||||
raise APIConnectionError("Invalid password!")
|
||||
raise InvalidAuthAPIError("Invalid password!")
|
||||
|
||||
self._authenticated = True
|
||||
|
||||
@ -444,20 +469,25 @@ class APIConnection:
|
||||
if do_stop(resp):
|
||||
fut.set_result(responses)
|
||||
|
||||
def on_read_exception(exc: Exception) -> None:
|
||||
if not fut.done():
|
||||
fut.set_exception(exc)
|
||||
|
||||
self._message_handlers.append(on_message)
|
||||
self._read_exception_handlers.append(on_read_exception)
|
||||
await self.send_message(send_msg)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(fut, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
if self._stopped:
|
||||
raise APIConnectionError("Disconnected while waiting for API response!")
|
||||
raise APIConnectionError("Timeout while waiting for API response!")
|
||||
|
||||
try:
|
||||
self._message_handlers.remove(on_message)
|
||||
except ValueError:
|
||||
pass
|
||||
raise SocketAPIError("Disconnected while waiting for API response!")
|
||||
raise SocketAPIError("Timeout while waiting for API response!")
|
||||
finally:
|
||||
with suppress(ValueError):
|
||||
self._message_handlers.remove(on_message)
|
||||
with suppress(ValueError):
|
||||
self._read_exception_handlers.remove(on_read_exception)
|
||||
|
||||
return responses
|
||||
|
||||
@ -491,7 +521,7 @@ class APIConnection:
|
||||
try:
|
||||
msg.ParseFromString(raw_msg)
|
||||
except Exception as e:
|
||||
raise APIConnectionError("Invalid protobuf message: {}".format(e))
|
||||
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
||||
_LOGGER.debug(
|
||||
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
||||
)
|
||||
@ -509,15 +539,19 @@ class APIConnection:
|
||||
self.log_name,
|
||||
err,
|
||||
)
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
await self._on_error()
|
||||
break
|
||||
except Exception as err: # pylint: disable=broad-except
|
||||
_LOGGER.info(
|
||||
_LOGGER.warning(
|
||||
"%s: Unexpected error while reading incoming messages: %s",
|
||||
self.log_name,
|
||||
err,
|
||||
exc_info=True,
|
||||
)
|
||||
for handler in self._read_exception_handlers[:]:
|
||||
handler(err)
|
||||
await self._on_error()
|
||||
break
|
||||
|
||||
|
@ -63,6 +63,34 @@ class APIConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidAuthAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class ResolveAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class ProtocolAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class RequiresEncryptionAPIError(ProtocolAPIError):
|
||||
pass
|
||||
|
||||
|
||||
class SocketAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class HandshakeAPIError(APIConnectionError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||
pass
|
||||
|
||||
|
||||
MESSAGE_TYPE_TO_PROTO = {
|
||||
1: HelloRequest,
|
||||
2: HelloResponse,
|
||||
|
@ -13,7 +13,7 @@ try:
|
||||
except ImportError:
|
||||
ZC_ASYNCIO = False
|
||||
|
||||
from .core import APIConnectionError
|
||||
from .core import APIConnectionError, ResolveAPIError
|
||||
|
||||
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
|
||||
|
||||
@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info(
|
||||
try:
|
||||
zc = zeroconf.Zeroconf()
|
||||
except Exception:
|
||||
raise APIConnectionError(
|
||||
raise ResolveAPIError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info(
|
||||
try:
|
||||
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
|
||||
except Exception as exc:
|
||||
raise APIConnectionError(
|
||||
raise ResolveAPIError(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info(
|
||||
try:
|
||||
zc = zeroconf.asyncio.AsyncZeroconf()
|
||||
except Exception:
|
||||
raise APIConnectionError(
|
||||
raise ResolveAPIError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info(
|
||||
service_type, service_name, int(timeout * 1000)
|
||||
)
|
||||
except Exception as exc:
|
||||
raise APIConnectionError(
|
||||
raise ResolveAPIError(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
@ -240,9 +240,7 @@ async def async_resolve_host(
|
||||
if zc_error:
|
||||
# Only show ZC error if getaddrinfo also didn't work
|
||||
raise zc_error
|
||||
raise APIConnectionError(
|
||||
f"Could not resolve host {host} - got no results from OS"
|
||||
)
|
||||
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
|
||||
|
||||
# Use first matching result
|
||||
# Future: return all matches and use first working one
|
||||
|
Loading…
Reference in New Issue
Block a user