From 2629e8d86cb1c9c5d8c3eda8f70d0983bb8f2eda Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 30 Jun 2021 17:00:22 +0200 Subject: [PATCH] Update host_resolve for zeroconf 0.32.0 (#52) --- aioesphomeapi/__init__.py | 1 - aioesphomeapi/client.py | 15 +- aioesphomeapi/connection.py | 35 ++--- aioesphomeapi/core.py | 2 +- aioesphomeapi/host_resolver.py | 246 ++++++++++++++++++++++----------- aioesphomeapi/util.py | 64 +-------- requirements.txt | 2 +- 7 files changed, 197 insertions(+), 168 deletions(-) diff --git a/aioesphomeapi/__init__.py b/aioesphomeapi/__init__.py index 94e6e50..1f27050 100644 --- a/aioesphomeapi/__init__.py +++ b/aioesphomeapi/__init__.py @@ -3,4 +3,3 @@ from .client import APIClient from .connection import APIConnection, ConnectionParams from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .model import * -from .util import resolve_ip_address, resolve_ip_address_getaddrinfo diff --git a/aioesphomeapi/client.py b/aioesphomeapi/client.py index 95436f4..293fa29 100644 --- a/aioesphomeapi/client.py +++ b/aioesphomeapi/client.py @@ -13,10 +13,9 @@ from typing import ( cast, ) -import zeroconf from google.protobuf import message -from aioesphomeapi.api_pb2 import ( # type: ignore +from .api_pb2 import ( # type: ignore BinarySensorStateResponse, CameraImageRequest, CameraImageResponse, @@ -60,9 +59,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore SwitchStateResponse, TextSensorStateResponse, ) -from aioesphomeapi.connection import APIConnection, ConnectionParams -from aioesphomeapi.core import APIConnectionError -from aioesphomeapi.model import ( +from .connection import APIConnection, ConnectionParams +from .core import APIConnectionError +from .host_resolver import ZeroconfInstanceType +from .model import ( APIVersion, BinarySensorInfo, BinarySensorState, @@ -107,6 +107,7 @@ ExecuteServiceDataType = Dict[ ] +# pylint: disable=too-many-public-methods class APIClient: def __init__( self, @@ -117,7 +118,7 @@ class APIClient: *, client_info: str = "aioesphomeapi", keepalive: float = 15.0, - zeroconf_instance: Optional[zeroconf.Zeroconf] = None + zeroconf_instance: ZeroconfInstanceType = None, ): self._params = ConnectionParams( eventloop=eventloop, @@ -128,7 +129,7 @@ class APIClient: keepalive=keepalive, zeroconf_instance=zeroconf_instance, ) - self._connection = None # type: Optional[APIConnection] + self._connection: Optional[APIConnection] = None async def connect( self, diff --git a/aioesphomeapi/connection.py b/aioesphomeapi/connection.py index 7c303a5..5848d61 100644 --- a/aioesphomeapi/connection.py +++ b/aioesphomeapi/connection.py @@ -2,13 +2,12 @@ import asyncio import logging import socket import time -from dataclasses import dataclass +from dataclasses import astuple, dataclass from typing import Any, Awaitable, Callable, List, Optional, cast -import zeroconf from google.protobuf import message -from aioesphomeapi.api_pb2 import ( # type: ignore +from .api_pb2 import ( # type: ignore ConnectRequest, ConnectResponse, DisconnectRequest, @@ -20,9 +19,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore PingRequest, PingResponse, ) -from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError -from aioesphomeapi.model import APIVersion -from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address +from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError +from .host_resolver import ZeroconfInstanceType, async_resolve_host +from .model import APIVersion +from .util import bytes_to_varuint, varuint_to_bytes _LOGGER = logging.getLogger(__name__) @@ -35,7 +35,7 @@ class ConnectionParams: password: Optional[str] client_info: str keepalive: float - zeroconf_instance: Optional[zeroconf.Zeroconf] + zeroconf_instance: ZeroconfInstanceType class APIConnection: @@ -111,13 +111,13 @@ class APIConnection: raise APIConnectionError("Already connected!") try: - coro = resolve_ip_address( + coro = async_resolve_host( self._params.eventloop, self._params.address, self._params.port, self._params.zeroconf_instance, ) - sockaddr = await asyncio.wait_for(coro, 30.0) + addr = await asyncio.wait_for(coro, 30.0) except APIConnectionError as err: await self._on_error() raise err @@ -125,7 +125,9 @@ class APIConnection: await self._on_error() raise APIConnectionError("Timeout while resolving IP address") - self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket = socket.socket( + family=addr.family, type=addr.type, proto=addr.proto + ) self._socket.setblocking(False) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -134,17 +136,18 @@ class APIConnection: self._params.address, self._params.address, self._params.port, - sockaddr, + addr, ) + sockaddr = astuple(addr.sockaddr) try: coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr) await asyncio.wait_for(coro2, 30.0) except OSError as err: await self._on_error() - raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err)) + raise APIConnectionError(f"Error connecting to {sockaddr}: {err}") except asyncio.TimeoutError: await self._on_error() - raise APIConnectionError("Timeout while connecting to {}".format(sockaddr)) + raise APIConnectionError(f"Timeout while connecting to {sockaddr}") _LOGGER.debug("%s: Opened socket for", self._params.address) self._socket_reader, self._socket_writer = await asyncio.open_connection( @@ -230,9 +233,9 @@ class APIConnection: encoded = msg.SerializeToString() _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) req = bytes([0]) - req += _varuint_to_bytes(len(encoded)) + req += varuint_to_bytes(len(encoded)) # pylint: disable=undefined-loop-variable - req += _varuint_to_bytes(message_type) + req += varuint_to_bytes(message_type) req += encoded await self._write(req) @@ -307,7 +310,7 @@ class APIConnection: raw = bytes() while not raw or raw[-1] & 0x80: raw += await self._recv(1) - return cast(int, _bytes_to_varuint(raw)) + return cast(int, bytes_to_varuint(raw)) async def _run_once(self) -> None: preamble = await self._recv(1) diff --git a/aioesphomeapi/core.py b/aioesphomeapi/core.py index c4d5649..af2ab7d 100644 --- a/aioesphomeapi/core.py +++ b/aioesphomeapi/core.py @@ -1,4 +1,4 @@ -from aioesphomeapi.api_pb2 import ( # type: ignore +from .api_pb2 import ( # type: ignore BinarySensorStateResponse, CameraImageRequest, CameraImageResponse, diff --git a/aioesphomeapi/host_resolver.py b/aioesphomeapi/host_resolver.py index 28a9783..1897301 100644 --- a/aioesphomeapi/host_resolver.py +++ b/aioesphomeapi/host_resolver.py @@ -1,97 +1,181 @@ +import asyncio import socket -import time -from typing import Optional +from dataclasses import dataclass +from typing import List, Tuple, Union, cast import zeroconf +from zeroconf import Zeroconf +from zeroconf.asyncio import AsyncZeroconf + +from .core import APIConnectionError + +ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None] -class HostResolver(zeroconf.RecordUpdateListener): - def __init__(self, name: str): - self.name = name - self.address: Optional[bytes] = None - - def update_record( - self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord - ) -> None: - if record is None: - return - if record.type == zeroconf._TYPE_A: - assert isinstance(record, zeroconf.DNSAddress) - if record.name == self.name: - self.address = record.address - - def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool: - now = time.time() - delay = 0.2 - next_ = now + delay - last = now + timeout - - try: - zc.add_listener( - self, - zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN), - ) - while self.address is None: - if last <= now: - # Timeout - return False - if next_ <= now: - out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY) - out.add_question( - zeroconf.DNSQuestion( - self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN - ) - ) - out.add_answer_at_time( - zc.cache.get_by_details( - self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN - ), - now, - ) - zc.send(out) - next_ = now + delay - delay *= 2 - - zc.wait(min(next_, last) - now) - now = time.time() - finally: - zc.remove_listener(self) - - return True +@dataclass(frozen=True) +class Sockaddr: + pass -def resolve_host( +@dataclass(frozen=True) +class IPv4Sockaddr(Sockaddr): + address: str + port: int + + +@dataclass(frozen=True) +class IPv6Sockaddr(Sockaddr): + address: str + port: int + flowinfo: int + scope_id: int + + +@dataclass(frozen=True) +class AddrInfo: + family: int + type: int + proto: int + sockaddr: Sockaddr + + +async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches host: str, + port: int, + *, timeout: float = 3.0, - zeroconf_instance: Optional[zeroconf.Zeroconf] = None, -) -> str: - from aioesphomeapi.core import APIConnectionError - - try: - zc = zeroconf_instance or zeroconf.Zeroconf() - except Exception: - raise APIConnectionError( - "Cannot start mDNS sockets, is this a docker container without " - "host network mode?" + zeroconf_instance: ZeroconfInstanceType = None, +) -> List[AddrInfo]: + # Use or create zeroconf instance, ensure it's an AsyncZeroconf + if zeroconf_instance is None: + try: + zc = AsyncZeroconf() + except Exception: + raise APIConnectionError( + "Cannot start mDNS sockets, is this a docker container without " + "host network mode?" + ) + do_close = True + elif isinstance(zeroconf_instance, AsyncZeroconf): + zc = zeroconf_instance + do_close = False + elif isinstance(zeroconf_instance, Zeroconf): + zc = AsyncZeroconf(zc=zeroconf_instance) + do_close = False + else: + raise ValueError( + f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}" ) + service_type = "_esphomelib._tcp.local." + service_name = f"{host}.{service_type}" + try: - info = HostResolver(host + ".") - assert info.address is not None - address = None - if info.request(zc, timeout): - address = socket.inet_ntoa(info.address) - except Exception as err: + info = await zc.async_get_service_info( + service_type, service_name, int(timeout * 1000) + ) + except Exception as exc: raise APIConnectionError( - "Error resolving mDNS hostname: {}".format(err) - ) from err + f"Error resolving host {host} via mDNS: {exc}" + ) from exc finally: - if not zeroconf_instance: - zc.close() + if do_close: + await zc.async_close() - if address is None: - raise APIConnectionError( - "Error resolving address with mDNS: Did not respond. " - "Maybe the device is offline." + if info is None: + return [] + + addrs: List[AddrInfo] = [] + for raw in info.addresses_by_version(zeroconf.IPVersion.All): + is_ipv6 = len(raw) == 16 + sockaddr: Sockaddr + if is_ipv6: + sockaddr = IPv6Sockaddr( + address=socket.inet_ntop(socket.AF_INET6, raw), + port=port, + flowinfo=0, + scope_id=0, + ) + else: + sockaddr = IPv4Sockaddr( + address=socket.inet_ntop(socket.AF_INET, raw), + port=port, + ) + + addrs.append( + AddrInfo( + family=socket.AF_INET6 if is_ipv6 else socket.AF_INET, + type=socket.SOCK_STREAM, + proto=socket.IPPROTO_TCP, + sockaddr=sockaddr, + ) ) - return address + return addrs + + +async def _async_resolve_host_getaddrinfo( + eventloop: asyncio.events.AbstractEventLoop, host: str, port: int +) -> List[AddrInfo]: + try: + # Limit to TCP IP protocol + res = await eventloop.getaddrinfo(host, port, proto=socket.IPPROTO_TCP) + except OSError as err: + raise APIConnectionError("Error resolving IP address: {}".format(err)) + + addrs: List[AddrInfo] = [] + for family, type_, proto, _, raw in res: + sockaddr: Sockaddr + if family == socket.AF_INET: + raw = cast(Tuple[str, int], raw) + address, port = raw + sockaddr = IPv4Sockaddr(address=address, port=port) + elif family == socket.AF_INET6: + raw = cast(Tuple[str, int, int, int], raw) + address, port, flowinfo, scope_id = raw + sockaddr = IPv6Sockaddr( + address=address, port=port, flowinfo=flowinfo, scope_id=scope_id + ) + else: + # Unknown family + continue + + addrs.append( + AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr) + ) + return addrs + + +async def async_resolve_host( + eventloop: asyncio.events.AbstractEventLoop, + host: str, + port: int, + zeroconf_instance: ZeroconfInstanceType = None, +) -> AddrInfo: + addrs: List[AddrInfo] = [] + + zc_error = None + if host.endswith(".local"): + name = host[: -len(".local")] + try: + addrs.extend( + await _async_resolve_host_zeroconf( + name, port, zeroconf_instance=zeroconf_instance + ) + ) + except APIConnectionError as err: + zc_error = err + + if not addrs: + addrs.extend(await _async_resolve_host_getaddrinfo(eventloop, host, port)) + + if not addrs: + if zc_error: + # Only show ZC error if getaddrinfo also didn't work + raise zc_error + raise APIConnectionError( + f"Could not resolve host {host} - got no results from OS" + ) + + # Use first matching result + # Future: return all matches and use first working one + return addrs[0] diff --git a/aioesphomeapi/util.py b/aioesphomeapi/util.py index 93c2a57..8594c44 100644 --- a/aioesphomeapi/util.py +++ b/aioesphomeapi/util.py @@ -1,15 +1,7 @@ -import asyncio -import functools -import socket -from typing import Any, Optional, Tuple - -import zeroconf - -# pylint: disable=cyclic-import -from aioesphomeapi.core import APIConnectionError +from typing import Optional -def _varuint_to_bytes(value: int) -> bytes: +def varuint_to_bytes(value: int) -> bytes: if value <= 0x7F: return bytes([value]) @@ -25,7 +17,7 @@ def _varuint_to_bytes(value: int) -> bytes: return ret -def _bytes_to_varuint(value: bytes) -> Optional[int]: +def bytes_to_varuint(value: bytes) -> Optional[int]: result = 0 bitpos = 0 for val in value: @@ -34,53 +26,3 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]: if (val & 0x80) == 0: return result return None - - -async def resolve_ip_address_getaddrinfo( - eventloop: asyncio.events.AbstractEventLoop, host: str, port: int -) -> Tuple[Any, ...]: - - try: - socket.inet_aton(host) - except OSError: - pass - else: - return (host, port) - - try: - res = await eventloop.getaddrinfo( - host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP - ) - except OSError as err: - raise APIConnectionError("Error resolving IP address: {}".format(err)) - - if not res: - raise APIConnectionError("Error resolving IP address: No matches!") - - _, _, _, _, sockaddr = res[0] - - return sockaddr - - -async def resolve_ip_address( - eventloop: asyncio.events.AbstractEventLoop, - host: str, - port: int, - zeroconf_instance: Optional[zeroconf.Zeroconf] = None, -) -> Tuple[Any, ...]: - if host.endswith(".local"): - from aioesphomeapi.host_resolver import resolve_host - - try: - return ( - await eventloop.run_in_executor( - None, - functools.partial( - resolve_host, host, zeroconf_instance=zeroconf_instance - ), - ), - port, - ) - except APIConnectionError: - pass - return await resolve_ip_address_getaddrinfo(eventloop, host, port) diff --git a/requirements.txt b/requirements.txt index 1916b64..9cfded3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ protobuf>=3.12.2,<4.0 -zeroconf>=0.28.0,<1.0 +zeroconf>=0.32.0,<1.0