Update host_resolve for zeroconf 0.32.0 (#52)

This commit is contained in:
Otto Winter 2021-06-30 17:00:22 +02:00 committed by GitHub
parent 9cfe8199f7
commit 2629e8d86c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 168 deletions

View File

@ -3,4 +3,3 @@ from .client import APIClient
from .connection import APIConnection, ConnectionParams from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import * from .model import *
from .util import resolve_ip_address, resolve_ip_address_getaddrinfo

View File

@ -13,10 +13,9 @@ from typing import (
cast, cast,
) )
import zeroconf
from google.protobuf import message from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore from .api_pb2 import ( # type: ignore
BinarySensorStateResponse, BinarySensorStateResponse,
CameraImageRequest, CameraImageRequest,
CameraImageResponse, CameraImageResponse,
@ -60,9 +59,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
SwitchStateResponse, SwitchStateResponse,
TextSensorStateResponse, TextSensorStateResponse,
) )
from aioesphomeapi.connection import APIConnection, ConnectionParams from .connection import APIConnection, ConnectionParams
from aioesphomeapi.core import APIConnectionError from .core import APIConnectionError
from aioesphomeapi.model import ( from .host_resolver import ZeroconfInstanceType
from .model import (
APIVersion, APIVersion,
BinarySensorInfo, BinarySensorInfo,
BinarySensorState, BinarySensorState,
@ -107,6 +107,7 @@ ExecuteServiceDataType = Dict[
] ]
# pylint: disable=too-many-public-methods
class APIClient: class APIClient:
def __init__( def __init__(
self, self,
@ -117,7 +118,7 @@ class APIClient:
*, *,
client_info: str = "aioesphomeapi", client_info: str = "aioesphomeapi",
keepalive: float = 15.0, keepalive: float = 15.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None zeroconf_instance: ZeroconfInstanceType = None,
): ):
self._params = ConnectionParams( self._params = ConnectionParams(
eventloop=eventloop, eventloop=eventloop,
@ -128,7 +129,7 @@ class APIClient:
keepalive=keepalive, keepalive=keepalive,
zeroconf_instance=zeroconf_instance, zeroconf_instance=zeroconf_instance,
) )
self._connection = None # type: Optional[APIConnection] self._connection: Optional[APIConnection] = None
async def connect( async def connect(
self, self,

View File

@ -2,13 +2,12 @@ import asyncio
import logging import logging
import socket import socket
import time import time
from dataclasses import dataclass from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional, cast from typing import Any, Awaitable, Callable, List, Optional, cast
import zeroconf
from google.protobuf import message from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore from .api_pb2 import ( # type: ignore
ConnectRequest, ConnectRequest,
ConnectResponse, ConnectResponse,
DisconnectRequest, DisconnectRequest,
@ -20,9 +19,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
PingRequest, PingRequest,
PingResponse, PingResponse,
) )
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from aioesphomeapi.model import APIVersion from .host_resolver import ZeroconfInstanceType, async_resolve_host
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -35,7 +35,7 @@ class ConnectionParams:
password: Optional[str] password: Optional[str]
client_info: str client_info: str
keepalive: float keepalive: float
zeroconf_instance: Optional[zeroconf.Zeroconf] zeroconf_instance: ZeroconfInstanceType
class APIConnection: class APIConnection:
@ -111,13 +111,13 @@ class APIConnection:
raise APIConnectionError("Already connected!") raise APIConnectionError("Already connected!")
try: try:
coro = resolve_ip_address( coro = async_resolve_host(
self._params.eventloop, self._params.eventloop,
self._params.address, self._params.address,
self._params.port, self._params.port,
self._params.zeroconf_instance, self._params.zeroconf_instance,
) )
sockaddr = await asyncio.wait_for(coro, 30.0) addr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err: except APIConnectionError as err:
await self._on_error() await self._on_error()
raise err raise err
@ -125,7 +125,9 @@ class APIConnection:
await self._on_error() await self._on_error()
raise APIConnectionError("Timeout while resolving IP address") raise APIConnectionError("Timeout while resolving IP address")
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto
)
self._socket.setblocking(False) self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@ -134,17 +136,18 @@ class APIConnection:
self._params.address, self._params.address,
self._params.address, self._params.address,
self._params.port, self._params.port,
sockaddr, addr,
) )
sockaddr = astuple(addr.sockaddr)
try: try:
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr) coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro2, 30.0) await asyncio.wait_for(coro2, 30.0)
except OSError as err: except OSError as err:
await self._on_error() await self._on_error()
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err)) raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
except asyncio.TimeoutError: except asyncio.TimeoutError:
await self._on_error() await self._on_error()
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr)) raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
_LOGGER.debug("%s: Opened socket for", self._params.address) _LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection( self._socket_reader, self._socket_writer = await asyncio.open_connection(
@ -230,9 +233,9 @@ class APIConnection:
encoded = msg.SerializeToString() encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg)) _LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
req = bytes([0]) req = bytes([0])
req += _varuint_to_bytes(len(encoded)) req += varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable # pylint: disable=undefined-loop-variable
req += _varuint_to_bytes(message_type) req += varuint_to_bytes(message_type)
req += encoded req += encoded
await self._write(req) await self._write(req)
@ -307,7 +310,7 @@ class APIConnection:
raw = bytes() raw = bytes()
while not raw or raw[-1] & 0x80: while not raw or raw[-1] & 0x80:
raw += await self._recv(1) raw += await self._recv(1)
return cast(int, _bytes_to_varuint(raw)) return cast(int, bytes_to_varuint(raw))
async def _run_once(self) -> None: async def _run_once(self) -> None:
preamble = await self._recv(1) preamble = await self._recv(1)

View File

@ -1,4 +1,4 @@
from aioesphomeapi.api_pb2 import ( # type: ignore from .api_pb2 import ( # type: ignore
BinarySensorStateResponse, BinarySensorStateResponse,
CameraImageRequest, CameraImageRequest,
CameraImageResponse, CameraImageResponse,

View File

@ -1,97 +1,181 @@
import asyncio
import socket import socket
import time from dataclasses import dataclass
from typing import Optional from typing import List, Tuple, Union, cast
import zeroconf import zeroconf
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from .core import APIConnectionError
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
class HostResolver(zeroconf.RecordUpdateListener): @dataclass(frozen=True)
def __init__(self, name: str): class Sockaddr:
self.name = name pass
self.address: Optional[bytes] = None
def update_record(
self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord
) -> None:
if record is None:
return
if record.type == zeroconf._TYPE_A:
assert isinstance(record, zeroconf.DNSAddress)
if record.name == self.name:
self.address = record.address
def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool:
now = time.time()
delay = 0.2
next_ = now + delay
last = now + timeout
try:
zc.add_listener(
self,
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN),
)
while self.address is None:
if last <= now:
# Timeout
return False
if next_ <= now:
out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
out.add_question(
zeroconf.DNSQuestion(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
)
)
out.add_answer_at_time(
zc.cache.get_by_details(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
),
now,
)
zc.send(out)
next_ = now + delay
delay *= 2
zc.wait(min(next_, last) - now)
now = time.time()
finally:
zc.remove_listener(self)
return True
def resolve_host( @dataclass(frozen=True)
class IPv4Sockaddr(Sockaddr):
address: str
port: int
@dataclass(frozen=True)
class IPv6Sockaddr(Sockaddr):
address: str
port: int
flowinfo: int
scope_id: int
@dataclass(frozen=True)
class AddrInfo:
family: int
type: int
proto: int
sockaddr: Sockaddr
async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
host: str, host: str,
port: int,
*,
timeout: float = 3.0, timeout: float = 3.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None, zeroconf_instance: ZeroconfInstanceType = None,
) -> str: ) -> List[AddrInfo]:
from aioesphomeapi.core import APIConnectionError # Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None:
try: try:
zc = zeroconf_instance or zeroconf.Zeroconf() zc = AsyncZeroconf()
except Exception: except Exception:
raise APIConnectionError( raise APIConnectionError(
"Cannot start mDNS sockets, is this a docker container without " "Cannot start mDNS sockets, is this a docker container without "
"host network mode?" "host network mode?"
)
do_close = True
elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance
do_close = False
elif isinstance(zeroconf_instance, Zeroconf):
zc = AsyncZeroconf(zc=zeroconf_instance)
do_close = False
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
) )
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
try: try:
info = HostResolver(host + ".") info = await zc.async_get_service_info(
assert info.address is not None service_type, service_name, int(timeout * 1000)
address = None )
if info.request(zc, timeout): except Exception as exc:
address = socket.inet_ntoa(info.address)
except Exception as err:
raise APIConnectionError( raise APIConnectionError(
"Error resolving mDNS hostname: {}".format(err) f"Error resolving host {host} via mDNS: {exc}"
) from err ) from exc
finally: finally:
if not zeroconf_instance: if do_close:
zc.close() await zc.async_close()
if address is None: if info is None:
raise APIConnectionError( return []
"Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline." addrs: List[AddrInfo] = []
for raw in info.addresses_by_version(zeroconf.IPVersion.All):
is_ipv6 = len(raw) == 16
sockaddr: Sockaddr
if is_ipv6:
sockaddr = IPv6Sockaddr(
address=socket.inet_ntop(socket.AF_INET6, raw),
port=port,
flowinfo=0,
scope_id=0,
)
else:
sockaddr = IPv4Sockaddr(
address=socket.inet_ntop(socket.AF_INET, raw),
port=port,
)
addrs.append(
AddrInfo(
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=sockaddr,
)
) )
return address return addrs
async def _async_resolve_host_getaddrinfo(
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> List[AddrInfo]:
try:
# Limit to TCP IP protocol
res = await eventloop.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err))
addrs: List[AddrInfo] = []
for family, type_, proto, _, raw in res:
sockaddr: Sockaddr
if family == socket.AF_INET:
raw = cast(Tuple[str, int], raw)
address, port = raw
sockaddr = IPv4Sockaddr(address=address, port=port)
elif family == socket.AF_INET6:
raw = cast(Tuple[str, int, int, int], raw)
address, port, flowinfo, scope_id = raw
sockaddr = IPv6Sockaddr(
address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
)
else:
# Unknown family
continue
addrs.append(
AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr)
)
return addrs
async def async_resolve_host(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
zeroconf_instance: ZeroconfInstanceType = None,
) -> AddrInfo:
addrs: List[AddrInfo] = []
zc_error = None
if host.endswith(".local"):
name = host[: -len(".local")]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_instance=zeroconf_instance
)
)
except APIConnectionError as err:
zc_error = err
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(eventloop, host, port))
if not addrs:
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise APIConnectionError(
f"Could not resolve host {host} - got no results from OS"
)
# Use first matching result
# Future: return all matches and use first working one
return addrs[0]

View File

@ -1,15 +1,7 @@
import asyncio from typing import Optional
import functools
import socket
from typing import Any, Optional, Tuple
import zeroconf
# pylint: disable=cyclic-import
from aioesphomeapi.core import APIConnectionError
def _varuint_to_bytes(value: int) -> bytes: def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F: if value <= 0x7F:
return bytes([value]) return bytes([value])
@ -25,7 +17,7 @@ def _varuint_to_bytes(value: int) -> bytes:
return ret return ret
def _bytes_to_varuint(value: bytes) -> Optional[int]: def bytes_to_varuint(value: bytes) -> Optional[int]:
result = 0 result = 0
bitpos = 0 bitpos = 0
for val in value: for val in value:
@ -34,53 +26,3 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]:
if (val & 0x80) == 0: if (val & 0x80) == 0:
return result return result
return None return None
async def resolve_ip_address_getaddrinfo(
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> Tuple[Any, ...]:
try:
socket.inet_aton(host)
except OSError:
pass
else:
return (host, port)
try:
res = await eventloop.getaddrinfo(
host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP
)
except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err))
if not res:
raise APIConnectionError("Error resolving IP address: No matches!")
_, _, _, _, sockaddr = res[0]
return sockaddr
async def resolve_ip_address(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> Tuple[Any, ...]:
if host.endswith(".local"):
from aioesphomeapi.host_resolver import resolve_host
try:
return (
await eventloop.run_in_executor(
None,
functools.partial(
resolve_host, host, zeroconf_instance=zeroconf_instance
),
),
port,
)
except APIConnectionError:
pass
return await resolve_ip_address_getaddrinfo(eventloop, host, port)

View File

@ -1,2 +1,2 @@
protobuf>=3.12.2,<4.0 protobuf>=3.12.2,<4.0
zeroconf>=0.28.0,<1.0 zeroconf>=0.32.0,<1.0