Ability to use the shared Home Assistant Zeroconf instance (#13)

This commit is contained in:
J. Nick Koston 2020-08-21 22:45:29 -05:00 committed by GitHub
parent d285c26b16
commit ff70932064
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 8 deletions

View File

@ -1,5 +1,6 @@
import logging import logging
from typing import Any, Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
import zeroconf
import aioesphomeapi.api_pb2 as pb import aioesphomeapi.api_pb2 as pb
from aioesphomeapi.connection import APIConnection, ConnectionParams from aioesphomeapi.connection import APIConnection, ConnectionParams
@ -11,7 +12,8 @@ _LOGGER = logging.getLogger(__name__)
class APIClient: class APIClient:
def __init__(self, eventloop, address: str, port: int, password: str, *, def __init__(self, eventloop, address: str, port: int, password: str, *,
client_info: str = 'aioesphomeapi', keepalive: float = 15.0): client_info: str = 'aioesphomeapi', keepalive: float = 15.0,
zeroconf_instance: zeroconf.Zeroconf = None):
self._params = ConnectionParams( self._params = ConnectionParams(
eventloop=eventloop, eventloop=eventloop,
address=address, address=address,
@ -19,6 +21,7 @@ class APIClient:
password=password, password=password,
client_info=client_info, client_info=client_info,
keepalive=keepalive, keepalive=keepalive,
zeroconf_instance=zeroconf_instance
) )
self._connection = None # type: Optional[APIConnection] self._connection = None # type: Optional[APIConnection]

View File

@ -5,6 +5,7 @@ import time
from typing import Any, Callable, List, Optional, cast from typing import Any, Callable, List, Optional, cast
import attr import attr
import zeroconf
from google.protobuf import message from google.protobuf import message
import aioesphomeapi.api_pb2 as pb import aioesphomeapi.api_pb2 as pb
@ -23,7 +24,7 @@ class ConnectionParams:
password = attr.ib(type=Optional[str]) password = attr.ib(type=Optional[str])
client_info = attr.ib(type=str) client_info = attr.ib(type=str)
keepalive = attr.ib(type=float) keepalive = attr.ib(type=float)
zeroconf_instance = attr.ib(type=zeroconf.Zeroconf)
class APIConnection: class APIConnection:
def __init__(self, params: ConnectionParams, on_stop): def __init__(self, params: ConnectionParams, on_stop):
@ -100,7 +101,7 @@ class APIConnection:
try: try:
coro = resolve_ip_address(self._params.eventloop, self._params.address, coro = resolve_ip_address(self._params.eventloop, self._params.address,
self._params.port) self._params.port, self._params.zeroconf_instance)
sockaddr = await asyncio.wait_for(coro, 30.0) sockaddr = await asyncio.wait_for(coro, 30.0)
except APIConnectionError as err: except APIConnectionError as err:
await self._on_error() await self._on_error()

View File

@ -49,11 +49,11 @@ class HostResolver(zeroconf.RecordUpdateListener):
return True return True
def resolve_host(host, timeout=3.0): def resolve_host(host, timeout=3.0, zeroconf_instance: zeroconf.Zeroconf = None):
from aioesphomeapi import APIConnectionError from aioesphomeapi import APIConnectionError
try: try:
zc = zeroconf.Zeroconf() zc = zeroconf_instance or zeroconf.Zeroconf()
except Exception: except Exception:
raise APIConnectionError("Cannot start mDNS sockets, is this a docker container without " raise APIConnectionError("Cannot start mDNS sockets, is this a docker container without "
"host network mode?") "host network mode?")
@ -66,7 +66,8 @@ def resolve_host(host, timeout=3.0):
except Exception as err: except Exception as err:
raise APIConnectionError("Error resolving mDNS hostname: {}".format(err)) raise APIConnectionError("Error resolving mDNS hostname: {}".format(err))
finally: finally:
zc.close() if not zeroconf_instance:
zc.close()
if address is None: if address is None:
raise APIConnectionError("Error resolving address with mDNS: Did not respond. " raise APIConnectionError("Error resolving address with mDNS: Did not respond. "

View File

@ -1,6 +1,8 @@
import asyncio import asyncio
import functools
import socket import socket
from typing import Optional, Tuple, Any from typing import Optional, Tuple, Any
import zeroconf
# pylint: disable=cyclic-import # pylint: disable=cyclic-import
from aioesphomeapi.core import APIConnectionError from aioesphomeapi.core import APIConnectionError
@ -50,12 +52,20 @@ async def resolve_ip_address_getaddrinfo(eventloop: asyncio.events.AbstractEvent
async def resolve_ip_address(eventloop: asyncio.events.AbstractEventLoop, async def resolve_ip_address(eventloop: asyncio.events.AbstractEventLoop,
host: str, port: int) -> Tuple[Any, ...]: host: str, port: int,
zeroconf_instance: zeroconf.Zeroconf = None) -> Tuple[Any, ...]:
if host.endswith('.local'): if host.endswith('.local'):
from aioesphomeapi.host_resolver import resolve_host from aioesphomeapi.host_resolver import resolve_host
try: try:
return await eventloop.run_in_executor(None, resolve_host, host), port return await eventloop.run_in_executor(
None,
functools.partial(
resolve_host,
host,
zeroconf_instance=zeroconf_instance
)
), port
except APIConnectionError: except APIConnectionError:
pass pass
return await resolve_ip_address_getaddrinfo(eventloop, host, port) return await resolve_ip_address_getaddrinfo(eventloop, host, port)