mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-26 12:45:26 +01:00
Emit different Exception types to differentiate between connection errors (#102)
* Emit different Exception types to differentiate between connection errors * Import in init
This commit is contained in:
parent
0660f1cd05
commit
5c9e7acbce
@ -1,6 +1,16 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from .client import APIClient
|
from .client import APIClient
|
||||||
from .connection import APIConnection, ConnectionParams
|
from .connection import APIConnection, ConnectionParams
|
||||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
from .core import (
|
||||||
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
|
APIConnectionError,
|
||||||
|
HandshakeAPIError,
|
||||||
|
InvalidAuthAPIError,
|
||||||
|
InvalidEncryptionKeyAPIError,
|
||||||
|
ProtocolAPIError,
|
||||||
|
RequiresEncryptionAPIError,
|
||||||
|
ResolveAPIError,
|
||||||
|
SocketAPIError,
|
||||||
|
)
|
||||||
from .model import *
|
from .model import *
|
||||||
from .reconnect_logic import ReconnectLogic
|
from .reconnect_logic import ReconnectLogic
|
||||||
|
@ -139,7 +139,8 @@ class APIClient:
|
|||||||
client_info=client_info,
|
client_info=client_info,
|
||||||
keepalive=keepalive,
|
keepalive=keepalive,
|
||||||
zeroconf_instance=zeroconf_instance,
|
zeroconf_instance=zeroconf_instance,
|
||||||
noise_psk=noise_psk,
|
# treat empty psk string as missing (like password)
|
||||||
|
noise_psk=noise_psk or None,
|
||||||
)
|
)
|
||||||
self._connection: Optional[APIConnection] = None
|
self._connection: Optional[APIConnection] = None
|
||||||
self._cached_name: Optional[str] = None
|
self._cached_name: Optional[str] = None
|
||||||
|
@ -3,6 +3,7 @@ import base64
|
|||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
|
from contextlib import suppress
|
||||||
from dataclasses import astuple, dataclass
|
from dataclasses import astuple, dataclass
|
||||||
from typing import Any, Awaitable, Callable, List, Optional
|
from typing import Any, Awaitable, Callable, List, Optional
|
||||||
|
|
||||||
@ -23,7 +24,17 @@ from .api_pb2 import ( # type: ignore
|
|||||||
PingRequest,
|
PingRequest,
|
||||||
PingResponse,
|
PingResponse,
|
||||||
)
|
)
|
||||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
from .core import (
|
||||||
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
|
APIConnectionError,
|
||||||
|
HandshakeAPIError,
|
||||||
|
InvalidAuthAPIError,
|
||||||
|
InvalidEncryptionKeyAPIError,
|
||||||
|
ProtocolAPIError,
|
||||||
|
RequiresEncryptionAPIError,
|
||||||
|
ResolveAPIError,
|
||||||
|
SocketAPIError,
|
||||||
|
)
|
||||||
from .model import APIVersion
|
from .model import APIVersion
|
||||||
from .util import bytes_to_varuint, varuint_to_bytes
|
from .util import bytes_to_varuint, varuint_to_bytes
|
||||||
|
|
||||||
@ -41,12 +52,6 @@ class ConnectionParams:
|
|||||||
zeroconf_instance: hr.ZeroconfInstanceType
|
zeroconf_instance: hr.ZeroconfInstanceType
|
||||||
noise_psk: Optional[str]
|
noise_psk: Optional[str]
|
||||||
|
|
||||||
@property
|
|
||||||
def noise_psk_bytes(self) -> Optional[bytes]:
|
|
||||||
if self.noise_psk is None:
|
|
||||||
return None
|
|
||||||
return base64.b64decode(self.noise_psk)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Packet:
|
class Packet:
|
||||||
@ -87,18 +92,18 @@ class APIFrameHelper:
|
|||||||
self._writer.write(header + frame)
|
self._writer.write(header + frame)
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def _read_frame_noise(self) -> bytes:
|
async def _read_frame_noise(self) -> bytes:
|
||||||
try:
|
try:
|
||||||
async with self._read_lock:
|
async with self._read_lock:
|
||||||
header = await self._reader.readexactly(3)
|
header = await self._reader.readexactly(3)
|
||||||
if header[0] != 0x01:
|
if header[0] != 0x01:
|
||||||
raise APIConnectionError(f"Marker byte invalid: {header[0]}")
|
raise ProtocolAPIError(f"Marker byte invalid: {header[0]}")
|
||||||
msg_size = (header[1] << 8) | header[2]
|
msg_size = (header[1] << 8) | header[2]
|
||||||
frame = await self._reader.readexactly(msg_size)
|
frame = await self._reader.readexactly(msg_size)
|
||||||
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
except (asyncio.IncompleteReadError, OSError, TimeoutError) as err:
|
||||||
raise APIConnectionError(f"Error while reading data: {err}") from err
|
raise SocketAPIError(f"Error while reading data: {err}") from err
|
||||||
|
|
||||||
_LOGGER.debug("Received frame %s", frame.hex())
|
_LOGGER.debug("Received frame %s", frame.hex())
|
||||||
return frame
|
return frame
|
||||||
@ -110,16 +115,28 @@ class APIFrameHelper:
|
|||||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||||
server_hello = await self._read_frame_noise() # ServerHello
|
server_hello = await self._read_frame_noise() # ServerHello
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
raise APIConnectionError("ServerHello is empty")
|
raise HandshakeAPIError("ServerHello is empty")
|
||||||
chosen_proto = server_hello[0]
|
chosen_proto = server_hello[0]
|
||||||
if chosen_proto != 0x01:
|
if chosen_proto != 0x01:
|
||||||
raise APIConnectionError(
|
raise HandshakeAPIError(
|
||||||
f"Unknown protocol selected by client {chosen_proto}"
|
f"Unknown protocol selected by client {chosen_proto}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||||
self._proto.set_as_initiator()
|
self._proto.set_as_initiator()
|
||||||
self._proto.set_psks(self._params.noise_psk_bytes)
|
|
||||||
|
try:
|
||||||
|
noise_psk_bytes = base64.b64decode(self._params.noise_psk)
|
||||||
|
except ValueError:
|
||||||
|
raise InvalidEncryptionKeyAPIError(
|
||||||
|
f"Malformed PSK {self._params.noise_psk}, expected base64-encoded value"
|
||||||
|
)
|
||||||
|
if len(noise_psk_bytes) != 32:
|
||||||
|
raise InvalidEncryptionKeyAPIError(
|
||||||
|
f"Malformed PSK {self._params.noise_psk}, expected 32-bytes of base64 data"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._proto.set_psks(noise_psk_bytes)
|
||||||
self._proto.set_prologue(prologue)
|
self._proto.set_prologue(prologue)
|
||||||
self._proto.start_handshake()
|
self._proto.start_handshake()
|
||||||
|
|
||||||
@ -131,8 +148,13 @@ class APIFrameHelper:
|
|||||||
await self._write_frame_noise(b"\x00" + msg)
|
await self._write_frame_noise(b"\x00" + msg)
|
||||||
else:
|
else:
|
||||||
msg = await self._read_frame_noise()
|
msg = await self._read_frame_noise()
|
||||||
if not msg or msg[0] != 0:
|
if not msg:
|
||||||
raise APIConnectionError(f"Handshake failure: {msg[1:].decode()}")
|
raise HandshakeAPIError("Handshake message too short")
|
||||||
|
if msg[0] != 0:
|
||||||
|
explanation = msg[1:].decode()
|
||||||
|
if explanation == "Handshake MAC failure":
|
||||||
|
raise InvalidEncryptionKeyAPIError("Invalid encryption key")
|
||||||
|
raise HandshakeAPIError(f"Handshake failure: {explanation}")
|
||||||
self._proto.read_message(msg[1:])
|
self._proto.read_message(msg[1:])
|
||||||
|
|
||||||
do_write = not do_write
|
do_write = not do_write
|
||||||
@ -170,7 +192,7 @@ class APIFrameHelper:
|
|||||||
self._writer.write(data)
|
self._writer.write(data)
|
||||||
await self._writer.drain()
|
await self._writer.drain()
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
raise APIConnectionError(f"Error while writing data: {err}") from err
|
raise SocketAPIError(f"Error while writing data: {err}") from err
|
||||||
|
|
||||||
async def write_packet(self, packet: Packet) -> None:
|
async def write_packet(self, packet: Packet) -> None:
|
||||||
if self._params.noise_psk is None:
|
if self._params.noise_psk is None:
|
||||||
@ -184,11 +206,11 @@ class APIFrameHelper:
|
|||||||
assert self._proto is not None
|
assert self._proto is not None
|
||||||
msg = self._proto.decrypt(frame)
|
msg = self._proto.decrypt(frame)
|
||||||
if len(msg) < 4:
|
if len(msg) < 4:
|
||||||
raise APIConnectionError(f"Bad packet frame: {msg}")
|
raise ProtocolAPIError(f"Bad packet frame: {msg}")
|
||||||
pkt_type = (msg[0] << 8) | msg[1]
|
pkt_type = (msg[0] << 8) | msg[1]
|
||||||
data_len = (msg[2] << 8) | msg[3]
|
data_len = (msg[2] << 8) | msg[3]
|
||||||
if data_len + 4 > len(msg):
|
if data_len + 4 > len(msg):
|
||||||
raise APIConnectionError(f"Bad data len: {data_len} vs {len(msg)}")
|
raise ProtocolAPIError(f"Bad data len: {data_len} vs {len(msg)}")
|
||||||
data = msg[4 : 4 + data_len]
|
data = msg[4 : 4 + data_len]
|
||||||
return Packet(type=pkt_type, data=data)
|
return Packet(type=pkt_type, data=data)
|
||||||
|
|
||||||
@ -196,7 +218,9 @@ class APIFrameHelper:
|
|||||||
async with self._read_lock:
|
async with self._read_lock:
|
||||||
preamble = await self._reader.readexactly(1)
|
preamble = await self._reader.readexactly(1)
|
||||||
if preamble[0] != 0x00:
|
if preamble[0] != 0x00:
|
||||||
raise APIConnectionError("Invalid preamble")
|
if preamble[0] == 0x01:
|
||||||
|
raise RequiresEncryptionAPIError("Connection requires encryption")
|
||||||
|
raise ProtocolAPIError(f"Invalid preamble {preamble[0]:02x}")
|
||||||
|
|
||||||
length = b""
|
length = b""
|
||||||
while not length or (length[-1] & 0x80) == 0x80:
|
while not length or (length[-1] & 0x80) == 0x80:
|
||||||
@ -238,6 +262,7 @@ class APIConnection:
|
|||||||
self._message_handlers: List[Callable[[message.Message], None]] = []
|
self._message_handlers: List[Callable[[message.Message], None]] = []
|
||||||
self.log_name = params.address
|
self.log_name = params.address
|
||||||
self._ping_task: Optional[asyncio.Task[None]] = None
|
self._ping_task: Optional[asyncio.Task[None]] = None
|
||||||
|
self._read_exception_handlers: List[Callable[[Exception], None]] = []
|
||||||
|
|
||||||
def _start_ping(self) -> None:
|
def _start_ping(self) -> None:
|
||||||
async def func() -> None:
|
async def func() -> None:
|
||||||
@ -305,7 +330,7 @@ class APIConnection:
|
|||||||
raise err
|
raise err
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(
|
||||||
f"Timeout while resolving IP address for {self.log_name}"
|
f"Timeout while resolving IP address for {self.log_name}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -328,10 +353,10 @@ class APIConnection:
|
|||||||
await asyncio.wait_for(coro2, 30.0)
|
await asyncio.wait_for(coro2, 30.0)
|
||||||
except OSError as err:
|
except OSError as err:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
|
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
|
raise SocketAPIError(f"Timeout while connecting to {sockaddr}")
|
||||||
|
|
||||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||||
reader, writer = await asyncio.open_connection(sock=self._socket)
|
reader, writer = await asyncio.open_connection(sock=self._socket)
|
||||||
@ -383,7 +408,7 @@ class APIConnection:
|
|||||||
connect.password = self._params.password
|
connect.password = self._params.password
|
||||||
resp = await self.send_message_await_response(connect, ConnectResponse)
|
resp = await self.send_message_await_response(connect, ConnectResponse)
|
||||||
if resp.invalid_password:
|
if resp.invalid_password:
|
||||||
raise APIConnectionError("Invalid password!")
|
raise InvalidAuthAPIError("Invalid password!")
|
||||||
|
|
||||||
self._authenticated = True
|
self._authenticated = True
|
||||||
|
|
||||||
@ -444,20 +469,25 @@ class APIConnection:
|
|||||||
if do_stop(resp):
|
if do_stop(resp):
|
||||||
fut.set_result(responses)
|
fut.set_result(responses)
|
||||||
|
|
||||||
|
def on_read_exception(exc: Exception) -> None:
|
||||||
|
if not fut.done():
|
||||||
|
fut.set_exception(exc)
|
||||||
|
|
||||||
self._message_handlers.append(on_message)
|
self._message_handlers.append(on_message)
|
||||||
|
self._read_exception_handlers.append(on_read_exception)
|
||||||
await self.send_message(send_msg)
|
await self.send_message(send_msg)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(fut, timeout)
|
await asyncio.wait_for(fut, timeout)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
if self._stopped:
|
if self._stopped:
|
||||||
raise APIConnectionError("Disconnected while waiting for API response!")
|
raise SocketAPIError("Disconnected while waiting for API response!")
|
||||||
raise APIConnectionError("Timeout while waiting for API response!")
|
raise SocketAPIError("Timeout while waiting for API response!")
|
||||||
|
finally:
|
||||||
try:
|
with suppress(ValueError):
|
||||||
self._message_handlers.remove(on_message)
|
self._message_handlers.remove(on_message)
|
||||||
except ValueError:
|
with suppress(ValueError):
|
||||||
pass
|
self._read_exception_handlers.remove(on_read_exception)
|
||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
@ -491,7 +521,7 @@ class APIConnection:
|
|||||||
try:
|
try:
|
||||||
msg.ParseFromString(raw_msg)
|
msg.ParseFromString(raw_msg)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise APIConnectionError("Invalid protobuf message: {}".format(e))
|
raise ProtocolAPIError(f"Invalid protobuf message: {e}") from e
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
"%s: Got message of type %s: %s", self._params.address, type(msg), msg
|
||||||
)
|
)
|
||||||
@ -509,15 +539,19 @@ class APIConnection:
|
|||||||
self.log_name,
|
self.log_name,
|
||||||
err,
|
err,
|
||||||
)
|
)
|
||||||
|
for handler in self._read_exception_handlers[:]:
|
||||||
|
handler(err)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
break
|
break
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
_LOGGER.info(
|
_LOGGER.warning(
|
||||||
"%s: Unexpected error while reading incoming messages: %s",
|
"%s: Unexpected error while reading incoming messages: %s",
|
||||||
self.log_name,
|
self.log_name,
|
||||||
err,
|
err,
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
)
|
)
|
||||||
|
for handler in self._read_exception_handlers[:]:
|
||||||
|
handler(err)
|
||||||
await self._on_error()
|
await self._on_error()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -63,6 +63,34 @@ class APIConnectionError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAuthAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ResolveAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProtocolAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class RequiresEncryptionAPIError(ProtocolAPIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SocketAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class HandshakeAPIError(APIConnectionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
MESSAGE_TYPE_TO_PROTO = {
|
MESSAGE_TYPE_TO_PROTO = {
|
||||||
1: HelloRequest,
|
1: HelloRequest,
|
||||||
2: HelloResponse,
|
2: HelloResponse,
|
||||||
|
@ -13,7 +13,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
ZC_ASYNCIO = False
|
ZC_ASYNCIO = False
|
||||||
|
|
||||||
from .core import APIConnectionError
|
from .core import APIConnectionError, ResolveAPIError
|
||||||
|
|
||||||
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
|
ZeroconfInstanceType = Union[zeroconf.Zeroconf, "zeroconf.asyncio.AsyncZeroconf", None]
|
||||||
|
|
||||||
@ -56,7 +56,7 @@ def _sync_zeroconf_get_service_info(
|
|||||||
try:
|
try:
|
||||||
zc = zeroconf.Zeroconf()
|
zc = zeroconf.Zeroconf()
|
||||||
except Exception:
|
except Exception:
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(
|
||||||
"Cannot start mDNS sockets, is this a docker container without "
|
"Cannot start mDNS sockets, is this a docker container without "
|
||||||
"host network mode?"
|
"host network mode?"
|
||||||
)
|
)
|
||||||
@ -72,7 +72,7 @@ def _sync_zeroconf_get_service_info(
|
|||||||
try:
|
try:
|
||||||
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
|
info = zc.get_service_info(service_type, service_name, int(timeout * 1000))
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(
|
||||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||||
) from exc
|
) from exc
|
||||||
finally:
|
finally:
|
||||||
@ -105,7 +105,7 @@ async def _async_zeroconf_get_service_info(
|
|||||||
try:
|
try:
|
||||||
zc = zeroconf.asyncio.AsyncZeroconf()
|
zc = zeroconf.asyncio.AsyncZeroconf()
|
||||||
except Exception:
|
except Exception:
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(
|
||||||
"Cannot start mDNS sockets, is this a docker container without "
|
"Cannot start mDNS sockets, is this a docker container without "
|
||||||
"host network mode?"
|
"host network mode?"
|
||||||
)
|
)
|
||||||
@ -126,7 +126,7 @@ async def _async_zeroconf_get_service_info(
|
|||||||
service_type, service_name, int(timeout * 1000)
|
service_type, service_name, int(timeout * 1000)
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(
|
||||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||||
) from exc
|
) from exc
|
||||||
finally:
|
finally:
|
||||||
@ -240,9 +240,7 @@ async def async_resolve_host(
|
|||||||
if zc_error:
|
if zc_error:
|
||||||
# Only show ZC error if getaddrinfo also didn't work
|
# Only show ZC error if getaddrinfo also didn't work
|
||||||
raise zc_error
|
raise zc_error
|
||||||
raise APIConnectionError(
|
raise ResolveAPIError(f"Could not resolve host {host} - got no results from OS")
|
||||||
f"Could not resolve host {host} - got no results from OS"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use first matching result
|
# Use first matching result
|
||||||
# Future: return all matches and use first working one
|
# Future: return all matches and use first working one
|
||||||
|
Loading…
Reference in New Issue
Block a user