mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-28 13:05:12 +01:00
Update host_resolve for zeroconf 0.32.0 (#52)
This commit is contained in:
parent
9cfe8199f7
commit
2629e8d86c
@ -3,4 +3,3 @@ from .client import APIClient
|
||||
from .connection import APIConnection, ConnectionParams
|
||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||
from .model import *
|
||||
from .util import resolve_ip_address, resolve_ip_address_getaddrinfo
|
||||
|
@ -13,10 +13,9 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
import zeroconf
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi.api_pb2 import ( # type: ignore
|
||||
from .api_pb2 import ( # type: ignore
|
||||
BinarySensorStateResponse,
|
||||
CameraImageRequest,
|
||||
CameraImageResponse,
|
||||
@ -60,9 +59,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
|
||||
SwitchStateResponse,
|
||||
TextSensorStateResponse,
|
||||
)
|
||||
from aioesphomeapi.connection import APIConnection, ConnectionParams
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from aioesphomeapi.model import (
|
||||
from .connection import APIConnection, ConnectionParams
|
||||
from .core import APIConnectionError
|
||||
from .host_resolver import ZeroconfInstanceType
|
||||
from .model import (
|
||||
APIVersion,
|
||||
BinarySensorInfo,
|
||||
BinarySensorState,
|
||||
@ -107,6 +107,7 @@ ExecuteServiceDataType = Dict[
|
||||
]
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class APIClient:
|
||||
def __init__(
|
||||
self,
|
||||
@ -117,7 +118,7 @@ class APIClient:
|
||||
*,
|
||||
client_info: str = "aioesphomeapi",
|
||||
keepalive: float = 15.0,
|
||||
zeroconf_instance: Optional[zeroconf.Zeroconf] = None
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
):
|
||||
self._params = ConnectionParams(
|
||||
eventloop=eventloop,
|
||||
@ -128,7 +129,7 @@ class APIClient:
|
||||
keepalive=keepalive,
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
)
|
||||
self._connection = None # type: Optional[APIConnection]
|
||||
self._connection: Optional[APIConnection] = None
|
||||
|
||||
async def connect(
|
||||
self,
|
||||
|
@ -2,13 +2,12 @@ import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import astuple, dataclass
|
||||
from typing import Any, Awaitable, Callable, List, Optional, cast
|
||||
|
||||
import zeroconf
|
||||
from google.protobuf import message
|
||||
|
||||
from aioesphomeapi.api_pb2 import ( # type: ignore
|
||||
from .api_pb2 import ( # type: ignore
|
||||
ConnectRequest,
|
||||
ConnectResponse,
|
||||
DisconnectRequest,
|
||||
@ -20,9 +19,10 @@ from aioesphomeapi.api_pb2 import ( # type: ignore
|
||||
PingRequest,
|
||||
PingResponse,
|
||||
)
|
||||
from aioesphomeapi.core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||
from aioesphomeapi.model import APIVersion
|
||||
from aioesphomeapi.util import _bytes_to_varuint, _varuint_to_bytes, resolve_ip_address
|
||||
from .core import MESSAGE_TYPE_TO_PROTO, APIConnectionError
|
||||
from .host_resolver import ZeroconfInstanceType, async_resolve_host
|
||||
from .model import APIVersion
|
||||
from .util import bytes_to_varuint, varuint_to_bytes
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -35,7 +35,7 @@ class ConnectionParams:
|
||||
password: Optional[str]
|
||||
client_info: str
|
||||
keepalive: float
|
||||
zeroconf_instance: Optional[zeroconf.Zeroconf]
|
||||
zeroconf_instance: ZeroconfInstanceType
|
||||
|
||||
|
||||
class APIConnection:
|
||||
@ -111,13 +111,13 @@ class APIConnection:
|
||||
raise APIConnectionError("Already connected!")
|
||||
|
||||
try:
|
||||
coro = resolve_ip_address(
|
||||
coro = async_resolve_host(
|
||||
self._params.eventloop,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
self._params.zeroconf_instance,
|
||||
)
|
||||
sockaddr = await asyncio.wait_for(coro, 30.0)
|
||||
addr = await asyncio.wait_for(coro, 30.0)
|
||||
except APIConnectionError as err:
|
||||
await self._on_error()
|
||||
raise err
|
||||
@ -125,7 +125,9 @@ class APIConnection:
|
||||
await self._on_error()
|
||||
raise APIConnectionError("Timeout while resolving IP address")
|
||||
|
||||
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self._socket = socket.socket(
|
||||
family=addr.family, type=addr.type, proto=addr.proto
|
||||
)
|
||||
self._socket.setblocking(False)
|
||||
self._socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
|
||||
@ -134,17 +136,18 @@ class APIConnection:
|
||||
self._params.address,
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
sockaddr,
|
||||
addr,
|
||||
)
|
||||
sockaddr = astuple(addr.sockaddr)
|
||||
try:
|
||||
coro2 = self._params.eventloop.sock_connect(self._socket, sockaddr)
|
||||
await asyncio.wait_for(coro2, 30.0)
|
||||
except OSError as err:
|
||||
await self._on_error()
|
||||
raise APIConnectionError("Error connecting to {}: {}".format(sockaddr, err))
|
||||
raise APIConnectionError(f"Error connecting to {sockaddr}: {err}")
|
||||
except asyncio.TimeoutError:
|
||||
await self._on_error()
|
||||
raise APIConnectionError("Timeout while connecting to {}".format(sockaddr))
|
||||
raise APIConnectionError(f"Timeout while connecting to {sockaddr}")
|
||||
|
||||
_LOGGER.debug("%s: Opened socket for", self._params.address)
|
||||
self._socket_reader, self._socket_writer = await asyncio.open_connection(
|
||||
@ -230,9 +233,9 @@ class APIConnection:
|
||||
encoded = msg.SerializeToString()
|
||||
_LOGGER.debug("%s: Sending %s: %s", self._params.address, type(msg), str(msg))
|
||||
req = bytes([0])
|
||||
req += _varuint_to_bytes(len(encoded))
|
||||
req += varuint_to_bytes(len(encoded))
|
||||
# pylint: disable=undefined-loop-variable
|
||||
req += _varuint_to_bytes(message_type)
|
||||
req += varuint_to_bytes(message_type)
|
||||
req += encoded
|
||||
await self._write(req)
|
||||
|
||||
@ -307,7 +310,7 @@ class APIConnection:
|
||||
raw = bytes()
|
||||
while not raw or raw[-1] & 0x80:
|
||||
raw += await self._recv(1)
|
||||
return cast(int, _bytes_to_varuint(raw))
|
||||
return cast(int, bytes_to_varuint(raw))
|
||||
|
||||
async def _run_once(self) -> None:
|
||||
preamble = await self._recv(1)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from aioesphomeapi.api_pb2 import ( # type: ignore
|
||||
from .api_pb2 import ( # type: ignore
|
||||
BinarySensorStateResponse,
|
||||
CameraImageRequest,
|
||||
CameraImageResponse,
|
||||
|
@ -1,97 +1,181 @@
|
||||
import asyncio
|
||||
import socket
|
||||
import time
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Union, cast
|
||||
|
||||
import zeroconf
|
||||
from zeroconf import Zeroconf
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from .core import APIConnectionError
|
||||
|
||||
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
|
||||
|
||||
|
||||
class HostResolver(zeroconf.RecordUpdateListener):
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.address: Optional[bytes] = None
|
||||
|
||||
def update_record(
|
||||
self, zc: zeroconf.Zeroconf, now: float, record: zeroconf.DNSRecord
|
||||
) -> None:
|
||||
if record is None:
|
||||
return
|
||||
if record.type == zeroconf._TYPE_A:
|
||||
assert isinstance(record, zeroconf.DNSAddress)
|
||||
if record.name == self.name:
|
||||
self.address = record.address
|
||||
|
||||
def request(self, zc: zeroconf.Zeroconf, timeout: float) -> bool:
|
||||
now = time.time()
|
||||
delay = 0.2
|
||||
next_ = now + delay
|
||||
last = now + timeout
|
||||
|
||||
try:
|
||||
zc.add_listener(
|
||||
self,
|
||||
zeroconf.DNSQuestion(self.name, zeroconf._TYPE_ANY, zeroconf._CLASS_IN),
|
||||
)
|
||||
while self.address is None:
|
||||
if last <= now:
|
||||
# Timeout
|
||||
return False
|
||||
if next_ <= now:
|
||||
out = zeroconf.DNSOutgoing(zeroconf._FLAGS_QR_QUERY)
|
||||
out.add_question(
|
||||
zeroconf.DNSQuestion(
|
||||
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
|
||||
)
|
||||
)
|
||||
out.add_answer_at_time(
|
||||
zc.cache.get_by_details(
|
||||
self.name, zeroconf._TYPE_A, zeroconf._CLASS_IN
|
||||
),
|
||||
now,
|
||||
)
|
||||
zc.send(out)
|
||||
next_ = now + delay
|
||||
delay *= 2
|
||||
|
||||
zc.wait(min(next_, last) - now)
|
||||
now = time.time()
|
||||
finally:
|
||||
zc.remove_listener(self)
|
||||
|
||||
return True
|
||||
@dataclass(frozen=True)
|
||||
class Sockaddr:
|
||||
pass
|
||||
|
||||
|
||||
def resolve_host(
|
||||
@dataclass(frozen=True)
|
||||
class IPv4Sockaddr(Sockaddr):
|
||||
address: str
|
||||
port: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IPv6Sockaddr(Sockaddr):
|
||||
address: str
|
||||
port: int
|
||||
flowinfo: int
|
||||
scope_id: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddrInfo:
|
||||
family: int
|
||||
type: int
|
||||
proto: int
|
||||
sockaddr: Sockaddr
|
||||
|
||||
|
||||
async def _async_resolve_host_zeroconf( # pylint: disable=too-many-branches
|
||||
host: str,
|
||||
port: int,
|
||||
*,
|
||||
timeout: float = 3.0,
|
||||
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
|
||||
) -> str:
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
) -> List[AddrInfo]:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
if zeroconf_instance is None:
|
||||
try:
|
||||
zc = zeroconf_instance or zeroconf.Zeroconf()
|
||||
zc = AsyncZeroconf()
|
||||
except Exception:
|
||||
raise APIConnectionError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
do_close = True
|
||||
elif isinstance(zeroconf_instance, AsyncZeroconf):
|
||||
zc = zeroconf_instance
|
||||
do_close = False
|
||||
elif isinstance(zeroconf_instance, Zeroconf):
|
||||
zc = AsyncZeroconf(zc=zeroconf_instance)
|
||||
do_close = False
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
|
||||
)
|
||||
|
||||
service_type = "_esphomelib._tcp.local."
|
||||
service_name = f"{host}.{service_type}"
|
||||
|
||||
try:
|
||||
info = HostResolver(host + ".")
|
||||
assert info.address is not None
|
||||
address = None
|
||||
if info.request(zc, timeout):
|
||||
address = socket.inet_ntoa(info.address)
|
||||
except Exception as err:
|
||||
raise APIConnectionError(
|
||||
"Error resolving mDNS hostname: {}".format(err)
|
||||
) from err
|
||||
finally:
|
||||
if not zeroconf_instance:
|
||||
zc.close()
|
||||
|
||||
if address is None:
|
||||
raise APIConnectionError(
|
||||
"Error resolving address with mDNS: Did not respond. "
|
||||
"Maybe the device is offline."
|
||||
info = await zc.async_get_service_info(
|
||||
service_type, service_name, int(timeout * 1000)
|
||||
)
|
||||
return address
|
||||
except Exception as exc:
|
||||
raise APIConnectionError(
|
||||
f"Error resolving host {host} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
if do_close:
|
||||
await zc.async_close()
|
||||
|
||||
if info is None:
|
||||
return []
|
||||
|
||||
addrs: List[AddrInfo] = []
|
||||
for raw in info.addresses_by_version(zeroconf.IPVersion.All):
|
||||
is_ipv6 = len(raw) == 16
|
||||
sockaddr: Sockaddr
|
||||
if is_ipv6:
|
||||
sockaddr = IPv6Sockaddr(
|
||||
address=socket.inet_ntop(socket.AF_INET6, raw),
|
||||
port=port,
|
||||
flowinfo=0,
|
||||
scope_id=0,
|
||||
)
|
||||
else:
|
||||
sockaddr = IPv4Sockaddr(
|
||||
address=socket.inet_ntop(socket.AF_INET, raw),
|
||||
port=port,
|
||||
)
|
||||
|
||||
addrs.append(
|
||||
AddrInfo(
|
||||
family=socket.AF_INET6 if is_ipv6 else socket.AF_INET,
|
||||
type=socket.SOCK_STREAM,
|
||||
proto=socket.IPPROTO_TCP,
|
||||
sockaddr=sockaddr,
|
||||
)
|
||||
)
|
||||
return addrs
|
||||
|
||||
|
||||
async def _async_resolve_host_getaddrinfo(
|
||||
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
|
||||
) -> List[AddrInfo]:
|
||||
try:
|
||||
# Limit to TCP IP protocol
|
||||
res = await eventloop.getaddrinfo(host, port, proto=socket.IPPROTO_TCP)
|
||||
except OSError as err:
|
||||
raise APIConnectionError("Error resolving IP address: {}".format(err))
|
||||
|
||||
addrs: List[AddrInfo] = []
|
||||
for family, type_, proto, _, raw in res:
|
||||
sockaddr: Sockaddr
|
||||
if family == socket.AF_INET:
|
||||
raw = cast(Tuple[str, int], raw)
|
||||
address, port = raw
|
||||
sockaddr = IPv4Sockaddr(address=address, port=port)
|
||||
elif family == socket.AF_INET6:
|
||||
raw = cast(Tuple[str, int, int, int], raw)
|
||||
address, port, flowinfo, scope_id = raw
|
||||
sockaddr = IPv6Sockaddr(
|
||||
address=address, port=port, flowinfo=flowinfo, scope_id=scope_id
|
||||
)
|
||||
else:
|
||||
# Unknown family
|
||||
continue
|
||||
|
||||
addrs.append(
|
||||
AddrInfo(family=family, type=type_, proto=proto, sockaddr=sockaddr)
|
||||
)
|
||||
return addrs
|
||||
|
||||
|
||||
async def async_resolve_host(
|
||||
eventloop: asyncio.events.AbstractEventLoop,
|
||||
host: str,
|
||||
port: int,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
) -> AddrInfo:
|
||||
addrs: List[AddrInfo] = []
|
||||
|
||||
zc_error = None
|
||||
if host.endswith(".local"):
|
||||
name = host[: -len(".local")]
|
||||
try:
|
||||
addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
name, port, zeroconf_instance=zeroconf_instance
|
||||
)
|
||||
)
|
||||
except APIConnectionError as err:
|
||||
zc_error = err
|
||||
|
||||
if not addrs:
|
||||
addrs.extend(await _async_resolve_host_getaddrinfo(eventloop, host, port))
|
||||
|
||||
if not addrs:
|
||||
if zc_error:
|
||||
# Only show ZC error if getaddrinfo also didn't work
|
||||
raise zc_error
|
||||
raise APIConnectionError(
|
||||
f"Could not resolve host {host} - got no results from OS"
|
||||
)
|
||||
|
||||
# Use first matching result
|
||||
# Future: return all matches and use first working one
|
||||
return addrs[0]
|
||||
|
@ -1,15 +1,7 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import socket
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import zeroconf
|
||||
|
||||
# pylint: disable=cyclic-import
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _varuint_to_bytes(value: int) -> bytes:
|
||||
def varuint_to_bytes(value: int) -> bytes:
|
||||
if value <= 0x7F:
|
||||
return bytes([value])
|
||||
|
||||
@ -25,7 +17,7 @@ def _varuint_to_bytes(value: int) -> bytes:
|
||||
return ret
|
||||
|
||||
|
||||
def _bytes_to_varuint(value: bytes) -> Optional[int]:
|
||||
def bytes_to_varuint(value: bytes) -> Optional[int]:
|
||||
result = 0
|
||||
bitpos = 0
|
||||
for val in value:
|
||||
@ -34,53 +26,3 @@ def _bytes_to_varuint(value: bytes) -> Optional[int]:
|
||||
if (val & 0x80) == 0:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
async def resolve_ip_address_getaddrinfo(
|
||||
eventloop: asyncio.events.AbstractEventLoop, host: str, port: int
|
||||
) -> Tuple[Any, ...]:
|
||||
|
||||
try:
|
||||
socket.inet_aton(host)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
return (host, port)
|
||||
|
||||
try:
|
||||
res = await eventloop.getaddrinfo(
|
||||
host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP
|
||||
)
|
||||
except OSError as err:
|
||||
raise APIConnectionError("Error resolving IP address: {}".format(err))
|
||||
|
||||
if not res:
|
||||
raise APIConnectionError("Error resolving IP address: No matches!")
|
||||
|
||||
_, _, _, _, sockaddr = res[0]
|
||||
|
||||
return sockaddr
|
||||
|
||||
|
||||
async def resolve_ip_address(
|
||||
eventloop: asyncio.events.AbstractEventLoop,
|
||||
host: str,
|
||||
port: int,
|
||||
zeroconf_instance: Optional[zeroconf.Zeroconf] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
if host.endswith(".local"):
|
||||
from aioesphomeapi.host_resolver import resolve_host
|
||||
|
||||
try:
|
||||
return (
|
||||
await eventloop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
resolve_host, host, zeroconf_instance=zeroconf_instance
|
||||
),
|
||||
),
|
||||
port,
|
||||
)
|
||||
except APIConnectionError:
|
||||
pass
|
||||
return await resolve_ip_address_getaddrinfo(eventloop, host, port)
|
||||
|
@ -1,2 +1,2 @@
|
||||
protobuf>=3.12.2,<4.0
|
||||
zeroconf>=0.28.0,<1.0
|
||||
zeroconf>=0.32.0,<1.0
|
||||
|
Loading…
Reference in New Issue
Block a user