Update host_resolve for zeroconf 0.32.0 (#52)

This commit is contained in:
Otto Winter 2021-06-30 17:00:22 +02:00 committed by GitHub
parent 9cfe8199f7
commit 2629e8d86c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 168 deletions

View File

@ -3,4 +3,3 @@ from .client import APIClient
from .connection import APIConnection, ConnectionParams
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .model import *
from .util import resolve_ip_address, resolve_ip_address_getaddrinfo

View File

@ -13,10 +13,9 @@ from typing import (
cast,
)
import zeroconf
from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore
from .api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,
@ -60,9 +59,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
SwitchStateResponse,
TextSensorStateResponse,
)
from aioesphomeapi.connection import APIConnection, ConnectionParams
from aioesphomeapi.core import APIConnectionError
from aioesphomeapi.model import (
from .connection import APIConnection, ConnectionParams
from .core import APIConnectionError
from .host_resolver import ZeroconfInstanceType
from .model import (
APIVersion,
BinarySensorInfo,
BinarySensorState,
@ -107,6 +107,7 @@ ExecuteServiceDataType = Dict[
]
# pylint: disable=too-many-public-methods
class APIClient:
def __init__(
self,
@ -117,7 +118,7 @@ class APIClient:
*,
client_info: str = "aioesphomeapi",
keepalive: float = 15.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None
zeroconf_instance: ZeroconfInstanceType = None,
):
self._params = ConnectionParams(
eventloop=eventloop,
@ -128,7 +129,7 @@ class APIClient:
keepalive=keepalive,
zeroconf_instance=zeroconf_instance,
)
self._connection = None # type: Optional[APIConnection]
self._connection: Optional[APIConnection] = None
async def connect(
self,

View File

@ -2,13 +2,12 @@ import asyncio
import logging
import socket
import time
from dataclasses import dataclass
from dataclasses import astuple, dataclass
from typing import Any, Awaitable, Callable, List, Optional, cast
import zeroconf
from google.protobuf import message
from aioesphomeapi.api_pb2 import ( # type: ignore
from .api_pb2 import ( # type: ignore
ConnectRequest,
ConnectResponse,
DisconnectRequest,
@ -20,9 +19,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
PingRequest,
PingResponse,
)
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from aioesphomeapi.model import APIVersion
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
from .host_resolver import ZeroconfInstanceType, async_resolve_host
from .model import APIVersion
from .util import bytes_to_varuint, varuint_to_bytes
_LOGGER = logging.getLogger(__name__)
@ -35,7 +35,7 @@ class ConnectionParams:
password: Optional[str]
client_info: str
keepalive: float
zeroconf_instance: Optional[zeroconf.Zeroconf]
zeroconf_instance: ZeroconfInstanceType
class APIConnection:
@ -111,13 +111,13 @@ class APIConnection:
raise APIConnectionError("Already connected!")
try:
coro = resolve_ip_address(
coro = async_resolve_host(
self._params.eventloop,
self._params.address,
self._params.port,
self._params.zeroconf_instance,
)
sockaddr = await asyncio.wait_for(coro, 30.0)
addr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err:
await self._on_error()
raise err
@ -125,7 +125,9 @@ class APIConnection:
await self._on_error()
raise APIConnectionError("Timeout while resolving IP address")
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket = socket.socket(
family=addr.family, type=addr.type, proto=addr.proto
)
self._socket.setblocking(False)
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
@ -134,17 +136,18 @@ class APIConnection:
self._params.address,
self._params.address,
self._params.port,
sockaddr,
addr,
)
sockaddr = astuple(addr.sockaddr)
try:
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
await asyncio.wait_for(coro2, 30.0)
except OSError as err:
await self._on_error()
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
except asyncio.TimeoutError:
await self._on_error()
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
_LOGGER.debug("%s: Opened socket for", self._params.address)
self._socket_reader, self._socket_writer = await asyncio.open_connection(
@ -230,9 +233,9 @@ class APIConnection:
encoded = msg.SerializeToString()
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
req = bytes([0])
req += _varuint_to_bytes(len(encoded))
req += varuint_to_bytes(len(encoded))
# pylint: disable=undefined-loop-variable
req += _varuint_to_bytes(message_type)
req += varuint_to_bytes(message_type)
req += encoded
await self._write(req)
@ -307,7 +310,7 @@ class APIConnection:
raw = bytes()
while not raw or raw[-1] & 0x80:
raw += await self._recv(1)
return cast(int, _bytes_to_varuint(raw))
return cast(int, bytes_to_varuint(raw))
async def _run_once(self) -> None:
preamble = await self._recv(1)

View File

@ -1,4 +1,4 @@
from aioesphomeapi.api_pb2 import ( # type: ignore
from .api_pb2 import ( # type: ignore
BinarySensorStateResponse,
CameraImageRequest,
CameraImageResponse,

View File

@ -1,97 +1,181 @@
import asyncio
import socket
import time
from typing import Optional
from dataclasses import dataclass
from typing import List, Tuple, Union, cast
import zeroconf
from zeroconf import Zeroconf
from zeroconf.asyncio import AsyncZeroconf
from .core import APIConnectionError
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
class HostResolver(zeroconf.RecordUpdateListener):
def __init__(self, name: str):
self.name = name
self.address: Optional[bytes] = None
def update_record(
self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord
) -> None:
if record is None:
return
if record.type == zeroconf._TYPE_A:
assert isinstance(record, zeroconf.DNSAddress)
if record.name == self.name:
self.address = record.address
def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool:
now = time.time()
delay = 0.2
next_ = now + delay
last = now + timeout
try:
zc.add_listener(
self,
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN),
)
while self.address is None:
if last <= now:
# Timeout
return False
if next_ <= now:
out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
out.add_question(
zeroconf.DNSQuestion(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
)
)
out.add_answer_at_time(
zc.cache.get_by_details(
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
),
now,
)
zc.send(out)
next_ = now + delay
delay *= 2
zc.wait(min(next_, last) - now)
now = time.time()
finally:
zc.remove_listener(self)
return True
@dataclass(frozen=True)
class Sockaddr:
pass
def resolve_host(
@dataclass(frozen=True)
class IPv4Sockaddr(Sockaddr):
address: str
port: int
@dataclass(frozen=True)
class IPv6Sockaddr(Sockaddr):
address: str
port: int
flowinfo: int
scope_id: int
@dataclass(frozen=True)
class AddrInfo:
family: int
type: int
proto: int
sockaddr: Sockaddr
async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
host: str,
port: int,
*,
timeout: float = 3.0,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> str:
from aioesphomeapi.core import APIConnectionError
try:
zc = zeroconf_instance or zeroconf.Zeroconf()
except Exception:
raise APIConnectionError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
zeroconf_instance: ZeroconfInstanceType = None,
) -> List[AddrInfo]:
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
if zeroconf_instance is None:
try:
zc = AsyncZeroconf()
except Exception:
raise APIConnectionError(
"Cannot start mDNS sockets, is this a docker container without "
"host network mode?"
)
do_close = True
elif isinstance(zeroconf_instance, AsyncZeroconf):
zc = zeroconf_instance
do_close = False
elif isinstance(zeroconf_instance, Zeroconf):
zc = AsyncZeroconf(zc=zeroconf_instance)
do_close = False
else:
raise ValueError(
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
)
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
try:
info = HostResolver(host + ".")
assert info.address is not None
address = None
if info.request(zc, timeout):
address = socket.inet_ntoa(info.address)
except Exception as err:
info = await zc.async_get_service_info(
service_type, service_name, int(timeout * 1000)
)
except Exception as exc:
raise APIConnectionError(
"Error resolving mDNS hostname: {}".format(err)
) from err
f"Error resolving host {host} via mDNS: {exc}"
) from exc
finally:
if not zeroconf_instance:
zc.close()
if do_close:
await zc.async_close()
if address is None:
raise APIConnectionError(
"Error resolving address with mDNS: Did not respond. "
"Maybe the device is offline."
if info is None:
return []
addrs: List[AddrInfo] = []
for raw in info.addresses_by_version(zeroconf.IPVersion.All):
is_ipv6 = len(raw) == 16
sockaddr: Sockaddr
if is_ipv6:
sockaddr = IPv6Sockaddr(
address=socket.inet_ntop(socket.AF_INET6, raw),
port=port,
flowinfo=0,
scope_id=0,
)
else:
sockaddr = IPv4Sockaddr(
address=socket.inet_ntop(socket.AF_INET, raw),
port=port,
)
addrs.append(
AddrInfo(
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
type=socket.SOCK_STREAM,
proto=socket.IPPROTO_TCP,
sockaddr=sockaddr,
)
)
return address
return addrs
async def _async_resolve_host_getaddrinfo(
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> List[AddrInfo]:
try:
# Limit to TCP IP protocol
res = await eventloop.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err))
addrs: List[AddrInfo] = []
for family, type_, proto, _, raw in res:
sockaddr: Sockaddr
if family == socket.AF_INET:
raw = cast(Tuple[str, int], raw)
address, port = raw
sockaddr = IPv4Sockaddr(address=address, port=port)
elif family == socket.AF_INET6:
raw = cast(Tuple[str, int, int, int], raw)
address, port, flowinfo, scope_id = raw
sockaddr = IPv6Sockaddr(
address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
)
else:
# Unknown family
continue
addrs.append(
AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr)
)
return addrs
async def async_resolve_host(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
zeroconf_instance: ZeroconfInstanceType = None,
) -> AddrInfo:
addrs: List[AddrInfo] = []
zc_error = None
if host.endswith(".local"):
name = host[: -len(".local")]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_instance=zeroconf_instance
)
)
except APIConnectionError as err:
zc_error = err
if not addrs:
addrs.extend(await _async_resolve_host_getaddrinfo(eventloop, host, port))
if not addrs:
if zc_error:
# Only show ZC error if getaddrinfo also didn't work
raise zc_error
raise APIConnectionError(
f"Could not resolve host {host} - got no results from OS"
)
# Use first matching result
# Future: return all matches and use first working one
return addrs[0]

View File

@ -1,15 +1,7 @@
import asyncio
import functools
import socket
from typing import Any, Optional, Tuple
import zeroconf
# pylint: disable=cyclic-import
from aioesphomeapi.core import APIConnectionError
from typing import Optional
def _varuint_to_bytes(value: int) -> bytes:
def varuint_to_bytes(value: int) -> bytes:
if value <= 0x7F:
return bytes([value])
@ -25,7 +17,7 @@ def _varuint_to_bytes(value: int) -> bytes:
return ret
def _bytes_to_varuint(value: bytes) -> Optional[int]:
def bytes_to_varuint(value: bytes) -> Optional[int]:
result = 0
bitpos = 0
for val in value:
@ -34,53 +26,3 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]:
if (val & 0x80) == 0:
return result
return None
async def resolve_ip_address_getaddrinfo(
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
) -> Tuple[Any, ...]:
try:
socket.inet_aton(host)
except OSError:
pass
else:
return (host, port)
try:
res = await eventloop.getaddrinfo(
host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP
)
except OSError as err:
raise APIConnectionError("Error resolving IP address: {}".format(err))
if not res:
raise APIConnectionError("Error resolving IP address: No matches!")
_, _, _, _, sockaddr = res[0]
return sockaddr
async def resolve_ip_address(
eventloop: asyncio.events.AbstractEventLoop,
host: str,
port: int,
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
) -> Tuple[Any, ...]:
if host.endswith(".local"):
from aioesphomeapi.host_resolver import resolve_host
try:
return (
await eventloop.run_in_executor(
None,
functools.partial(
resolve_host, host, zeroconf_instance=zeroconf_instance
),
),
port,
)
except APIConnectionError:
pass
return await resolve_ip_address_getaddrinfo(eventloop, host, port)

View File

@ -1,2 +1,2 @@
protobuf>=3.12.2,<4.0
zeroconf>=0.28.0,<1.0
zeroconf>=0.32.0,<1.0