mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-12-31 18:17:46 +01:00
Refactor zeroconf code to avoid creating instances when one is unneeded (#643)
This commit is contained in:
parent
9a86f449a6
commit
b12903e2e7
@ -20,6 +20,8 @@ cdef class APIFrameHelper:
|
||||
cdef str _log_name
|
||||
cdef object _debug_enabled
|
||||
|
||||
cpdef set_log_name(self, str log_name)
|
||||
|
||||
@cython.locals(original_pos="unsigned int", new_pos="unsigned int")
|
||||
cdef bytes _read_exactly(self, int length)
|
||||
|
||||
|
@ -62,6 +62,10 @@ class APIFrameHelper:
|
||||
self._log_name = log_name
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
|
||||
def set_log_name(self, log_name: str) -> None:
|
||||
"""Set the log name."""
|
||||
self._log_name = log_name
|
||||
|
||||
def _set_ready_future_exception(self, exc: Exception | type[Exception]) -> None:
|
||||
if not self._ready_future.done():
|
||||
self._ready_future.set_exception(exc)
|
||||
|
@ -111,7 +111,6 @@ from .core import (
|
||||
UnhandledAPIConnectionError,
|
||||
to_human_readable_address,
|
||||
)
|
||||
from .host_resolver import ZeroconfInstanceType
|
||||
from .model import (
|
||||
AlarmControlPanelCommand,
|
||||
AlarmControlPanelEntityState,
|
||||
@ -177,6 +176,8 @@ from .model import (
|
||||
VoiceAssistantCommand,
|
||||
VoiceAssistantEventType,
|
||||
)
|
||||
from .util import build_log_name
|
||||
from .zeroconf import ZeroconfInstanceType, ZeroconfManager
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -258,10 +259,10 @@ class APIClient:
|
||||
__slots__ = (
|
||||
"_params",
|
||||
"_connection",
|
||||
"_cached_name",
|
||||
"cached_name",
|
||||
"_background_tasks",
|
||||
"_loop",
|
||||
"_log_name",
|
||||
"log_name",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -272,7 +273,7 @@ class APIClient:
|
||||
*,
|
||||
client_info: str = "aioesphomeapi",
|
||||
keepalive: float = KEEP_ALIVE_FREQUENCY,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
zeroconf_instance: ZeroconfInstanceType | None = None,
|
||||
noise_psk: str | None = None,
|
||||
expected_name: str | None = None,
|
||||
) -> None:
|
||||
@ -297,17 +298,21 @@ class APIClient:
|
||||
password=password,
|
||||
client_info=client_info,
|
||||
keepalive=keepalive,
|
||||
zeroconf_instance=zeroconf_instance,
|
||||
zeroconf_manager=ZeroconfManager(zeroconf_instance),
|
||||
# treat empty '' psk string as missing (like password)
|
||||
noise_psk=_stringify_or_none(noise_psk) or None,
|
||||
expected_name=_stringify_or_none(expected_name) or None,
|
||||
)
|
||||
self._connection: APIConnection | None = None
|
||||
self._cached_name: str | None = None
|
||||
self.cached_name: str | None = None
|
||||
self._background_tasks: set[asyncio.Task[Any]] = set()
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._set_log_name()
|
||||
|
||||
@property
|
||||
def zeroconf_manager(self) -> ZeroconfManager:
|
||||
return self._params.zeroconf_manager
|
||||
|
||||
@property
|
||||
def expected_name(self) -> str | None:
|
||||
return self._params.expected_name
|
||||
@ -320,26 +325,23 @@ class APIClient:
|
||||
def address(self) -> str:
|
||||
return self._params.address
|
||||
|
||||
def _get_log_name(self) -> str:
|
||||
"""Get the log name of the device."""
|
||||
address = self.address
|
||||
address_is_host = address.endswith(".local")
|
||||
if self._cached_name is not None:
|
||||
if address_is_host:
|
||||
return self._cached_name
|
||||
return f"{self._cached_name} @ {address}"
|
||||
if address_is_host:
|
||||
return address[:-6]
|
||||
return address
|
||||
|
||||
def _set_log_name(self) -> None:
|
||||
"""Set the log name of the device."""
|
||||
self._log_name = self._get_log_name()
|
||||
resolved_address: str | None = None
|
||||
if self._connection and self._connection.resolved_addr_info:
|
||||
resolved_address = self._connection.resolved_addr_info.sockaddr.address
|
||||
self.log_name = build_log_name(
|
||||
self.cached_name,
|
||||
self.address,
|
||||
resolved_address,
|
||||
)
|
||||
if self._connection:
|
||||
self._connection.set_log_name(self.log_name)
|
||||
|
||||
def set_cached_name_if_unset(self, name: str) -> None:
|
||||
"""Set the cached name of the device if not set."""
|
||||
if not self._cached_name:
|
||||
self._cached_name = name
|
||||
if not self.cached_name:
|
||||
self.cached_name = name
|
||||
self._set_log_name()
|
||||
|
||||
async def connect(
|
||||
@ -357,7 +359,7 @@ class APIClient:
|
||||
) -> None:
|
||||
"""Start connecting to the device."""
|
||||
if self._connection is not None:
|
||||
raise APIConnectionError(f"Already connected to {self._log_name}!")
|
||||
raise APIConnectionError(f"Already connected to {self.log_name}!")
|
||||
|
||||
async def _on_stop(expected_disconnect: bool) -> None:
|
||||
# Hook into on_stop handler to clear connection when stopped
|
||||
@ -365,9 +367,7 @@ class APIClient:
|
||||
if on_stop is not None:
|
||||
await on_stop(expected_disconnect)
|
||||
|
||||
self._connection = APIConnection(
|
||||
self._params, _on_stop, log_name=self._log_name
|
||||
)
|
||||
self._connection = APIConnection(self._params, _on_stop, log_name=self.log_name)
|
||||
|
||||
try:
|
||||
await self._connection.start_connection()
|
||||
@ -377,8 +377,11 @@ class APIClient:
|
||||
except Exception as e:
|
||||
self._connection = None
|
||||
raise UnhandledAPIConnectionError(
|
||||
f"Unexpected error while connecting to {self._log_name}: {e}"
|
||||
f"Unexpected error while connecting to {self.log_name}: {e}"
|
||||
) from e
|
||||
# If we resolved the address, we should set the log name now
|
||||
if self._connection.resolved_addr_info:
|
||||
self._set_log_name()
|
||||
|
||||
async def finish_connection(
|
||||
self,
|
||||
@ -394,8 +397,10 @@ class APIClient:
|
||||
except Exception as e:
|
||||
self._connection = None
|
||||
raise UnhandledAPIConnectionError(
|
||||
f"Unexpected error while connecting to {self._log_name}: {e}"
|
||||
f"Unexpected error while connecting to {self.log_name}: {e}"
|
||||
) from e
|
||||
if received_name := self._connection.received_name:
|
||||
self._set_name_from_device(received_name)
|
||||
|
||||
async def disconnect(self, force: bool = False) -> None:
|
||||
if self._connection is None:
|
||||
@ -408,10 +413,10 @@ class APIClient:
|
||||
def _check_authenticated(self) -> None:
|
||||
connection = self._connection
|
||||
if not connection:
|
||||
raise APIConnectionError(f"Not connected to {self._log_name}!")
|
||||
raise APIConnectionError(f"Not connected to {self.log_name}!")
|
||||
if not connection.is_connected:
|
||||
raise APIConnectionError(
|
||||
f"Authenticated connection not ready yet for {self._log_name}; "
|
||||
f"Authenticated connection not ready yet for {self.log_name}; "
|
||||
f"current state is {connection.connection_state}!"
|
||||
)
|
||||
|
||||
@ -423,11 +428,14 @@ class APIClient:
|
||||
DeviceInfoRequest(), DeviceInfoResponse
|
||||
)
|
||||
info = DeviceInfo.from_pb(resp)
|
||||
self._cached_name = info.name
|
||||
connection.set_log_name(self._log_name)
|
||||
self._set_log_name()
|
||||
self._set_name_from_device(info.name)
|
||||
return info
|
||||
|
||||
def _set_name_from_device(self, name: str) -> None:
|
||||
"""Set the name from a DeviceInfo message."""
|
||||
self.cached_name = name
|
||||
self._set_log_name()
|
||||
|
||||
async def list_entities_services(
|
||||
self,
|
||||
) -> tuple[list[EntityInfo], list[UserService]]:
|
||||
|
@ -68,6 +68,8 @@ cdef class APIConnection:
|
||||
cdef public bint is_connected
|
||||
cdef bint _handshake_complete
|
||||
cdef object _debug_enabled
|
||||
cdef public str received_name
|
||||
cdef public object resolved_addr_info
|
||||
|
||||
cpdef send_message(self, object msg)
|
||||
|
||||
|
@ -49,6 +49,7 @@ from .core import (
|
||||
TimeoutAPIError,
|
||||
)
|
||||
from .model import APIVersion
|
||||
from .zeroconf import ZeroconfManager
|
||||
|
||||
if sys.version_info[:2] < (3, 11):
|
||||
from async_timeout import timeout as asyncio_timeout
|
||||
@ -111,7 +112,7 @@ class ConnectionParams:
|
||||
password: str | None
|
||||
client_info: str
|
||||
keepalive: float
|
||||
zeroconf_instance: hr.ZeroconfInstanceType
|
||||
zeroconf_manager: ZeroconfManager
|
||||
noise_psk: str | None
|
||||
expected_name: str | None
|
||||
|
||||
@ -159,6 +160,8 @@ class APIConnection:
|
||||
"is_connected",
|
||||
"_handshake_complete",
|
||||
"_debug_enabled",
|
||||
"received_name",
|
||||
"resolved_addr_info",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@ -201,10 +204,14 @@ class APIConnection:
|
||||
self.is_connected = False
|
||||
self._handshake_complete = False
|
||||
self._debug_enabled = partial(_LOGGER.isEnabledFor, logging.DEBUG)
|
||||
self.received_name: str = ""
|
||||
self.resolved_addr_info: hr.AddrInfo | None = None
|
||||
|
||||
def set_log_name(self, name: str) -> None:
|
||||
"""Set the friendly log name for this connection."""
|
||||
self.log_name = name
|
||||
if self._frame_helper is not None:
|
||||
self._frame_helper.set_log_name(name)
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up all resources that have been allocated.
|
||||
@ -276,7 +283,7 @@ class APIConnection:
|
||||
return await hr.async_resolve_host(
|
||||
self._params.address,
|
||||
self._params.port,
|
||||
self._params.zeroconf_instance,
|
||||
self._params.zeroconf_manager,
|
||||
)
|
||||
except asyncio_TimeoutError as err:
|
||||
raise ResolveAPIError(
|
||||
@ -427,17 +434,16 @@ class APIConnection:
|
||||
|
||||
self.api_version = api_version
|
||||
expected_name = self._params.expected_name
|
||||
received_name = resp.name
|
||||
if (
|
||||
expected_name is not None
|
||||
and received_name != ""
|
||||
and received_name != expected_name
|
||||
):
|
||||
raise BadNameAPIError(
|
||||
f"Expected '{expected_name}' but server sent "
|
||||
f"a different name: '{received_name}'",
|
||||
received_name,
|
||||
)
|
||||
if received_name := resp.name:
|
||||
if expected_name is not None and received_name != expected_name:
|
||||
raise BadNameAPIError(
|
||||
f"Expected '{expected_name}' but server sent "
|
||||
f"a different name: '{received_name}'",
|
||||
received_name,
|
||||
)
|
||||
|
||||
self.received_name = received_name
|
||||
self.set_log_name(received_name)
|
||||
|
||||
def _async_schedule_keep_alive(self, now: _float) -> None:
|
||||
"""Start the keep alive task."""
|
||||
@ -506,8 +512,8 @@ class APIConnection:
|
||||
async def _do_connect(self) -> None:
|
||||
"""Do the actual connect process."""
|
||||
in_do_connect.set(True)
|
||||
addr = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(addr)
|
||||
self.resolved_addr_info = await self._connect_resolve_host()
|
||||
await self._connect_socket_connect(self.resolved_addr_info)
|
||||
|
||||
async def start_connection(self) -> None:
|
||||
"""Start the connection process.
|
||||
|
@ -6,35 +6,38 @@ import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
from typing import Union, cast
|
||||
from typing import cast
|
||||
|
||||
from zeroconf import IPVersion, Zeroconf
|
||||
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
|
||||
from zeroconf import IPVersion
|
||||
from zeroconf.asyncio import AsyncServiceInfo
|
||||
|
||||
from .core import APIConnectionError, ResolveAPIError
|
||||
from .util import address_is_local, host_is_name_part
|
||||
from .zeroconf import ZeroconfManager
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
|
||||
|
||||
SERVICE_TYPE = "_esphomelib._tcp.local."
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Sockaddr:
|
||||
pass
|
||||
"""Base socket address."""
|
||||
|
||||
address: str
|
||||
port: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IPv4Sockaddr(Sockaddr):
|
||||
address: str
|
||||
port: int
|
||||
"""IPv4 socket address."""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IPv6Sockaddr(Sockaddr):
|
||||
address: str
|
||||
port: int
|
||||
"""IPv6 socket address."""
|
||||
|
||||
flowinfo: int
|
||||
scope_id: int
|
||||
|
||||
@ -44,35 +47,23 @@ class AddrInfo:
|
||||
family: int
|
||||
type: int
|
||||
proto: int
|
||||
sockaddr: Sockaddr
|
||||
sockaddr: IPv4Sockaddr | IPv6Sockaddr
|
||||
|
||||
|
||||
async def _async_zeroconf_get_service_info(
|
||||
zeroconf_instance: ZeroconfInstanceType,
|
||||
zeroconf_manager: ZeroconfManager,
|
||||
service_type: str,
|
||||
service_name: str,
|
||||
timeout: float,
|
||||
) -> AsyncServiceInfo | None:
|
||||
# Use or create zeroconf instance, ensure it's an AsyncZeroconf
|
||||
async_zc_instance: AsyncZeroconf | None = None
|
||||
if zeroconf_instance is None:
|
||||
try:
|
||||
async_zc_instance = AsyncZeroconf()
|
||||
except Exception:
|
||||
raise ResolveAPIError(
|
||||
"Cannot start mDNS sockets, is this a docker container without "
|
||||
"host network mode?"
|
||||
)
|
||||
zc = async_zc_instance.zeroconf
|
||||
elif isinstance(zeroconf_instance, AsyncZeroconf):
|
||||
zc = zeroconf_instance.zeroconf
|
||||
elif isinstance(zeroconf_instance, Zeroconf):
|
||||
zc = zeroconf_instance
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid type passed for zeroconf_instance: {type(zeroconf_instance)}"
|
||||
)
|
||||
|
||||
try:
|
||||
zc = zeroconf_manager.get_async_zeroconf().zeroconf
|
||||
except Exception as exc:
|
||||
raise ResolveAPIError(
|
||||
f"Cannot start mDNS sockets: {exc}, is this a docker container without "
|
||||
"host network mode?"
|
||||
) from exc
|
||||
try:
|
||||
info = AsyncServiceInfo(service_type, service_name)
|
||||
if await info.async_request(zc, int(timeout * 1000)):
|
||||
@ -82,8 +73,7 @@ async def _async_zeroconf_get_service_info(
|
||||
f"Error resolving mDNS {service_name} via mDNS: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
if async_zc_instance:
|
||||
await async_zc_instance.async_close()
|
||||
await zeroconf_manager.async_close()
|
||||
return info
|
||||
|
||||
|
||||
@ -92,13 +82,13 @@ async def _async_resolve_host_zeroconf(
|
||||
port: int,
|
||||
*,
|
||||
timeout: float = 3.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
zeroconf_manager: ZeroconfManager | None = None,
|
||||
) -> list[AddrInfo]:
|
||||
service_name = f"{host}.{SERVICE_TYPE}"
|
||||
|
||||
_LOGGER.debug("Resolving host %s via mDNS", service_name)
|
||||
info = await _async_zeroconf_get_service_info(
|
||||
zeroconf_instance, SERVICE_TYPE, service_name, timeout
|
||||
zeroconf_manager or ZeroconfManager(), SERVICE_TYPE, service_name, timeout
|
||||
)
|
||||
|
||||
if info is None:
|
||||
@ -107,7 +97,7 @@ async def _async_resolve_host_zeroconf(
|
||||
addrs: list[AddrInfo] = []
|
||||
for ip_address in info.ip_addresses_by_version(IPVersion.All):
|
||||
is_ipv6 = ip_address.version == 6
|
||||
sockaddr: Sockaddr
|
||||
sockaddr: IPv6Sockaddr | IPv4Sockaddr
|
||||
if is_ipv6:
|
||||
sockaddr = IPv6Sockaddr(
|
||||
address=str(ip_address),
|
||||
@ -143,7 +133,7 @@ async def _async_resolve_host_getaddrinfo(host: str, port: int) -> list[AddrInfo
|
||||
|
||||
addrs: list[AddrInfo] = []
|
||||
for family, type_, proto, _, raw in res:
|
||||
sockaddr: Sockaddr
|
||||
sockaddr: IPv4Sockaddr | IPv6Sockaddr
|
||||
if family == socket.AF_INET:
|
||||
raw = cast(tuple[str, int], raw)
|
||||
address, port = raw
|
||||
@ -197,17 +187,17 @@ def _async_ip_address_to_addrs(host: str, port: int) -> list[AddrInfo]:
|
||||
async def async_resolve_host(
|
||||
host: str,
|
||||
port: int,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
zeroconf_manager: ZeroconfManager | None = None,
|
||||
) -> AddrInfo:
|
||||
addrs: list[AddrInfo] = []
|
||||
|
||||
zc_error = None
|
||||
if "." not in host or host.endswith(".local"):
|
||||
if host_is_name_part(host) or address_is_local(host):
|
||||
name = host.partition(".")[0]
|
||||
try:
|
||||
addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
name, port, zeroconf_instance=zeroconf_instance
|
||||
name, port, zeroconf_manager=zeroconf_manager
|
||||
)
|
||||
)
|
||||
except APIConnectionError as err:
|
||||
|
@ -7,8 +7,6 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from .client import APIClient
|
||||
from .log_runner import async_run
|
||||
@ -29,14 +27,11 @@ async def main(argv: list[str]) -> None:
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
aiozc = AsyncZeroconf()
|
||||
|
||||
cli = APIClient(
|
||||
args.address,
|
||||
args.port,
|
||||
args.password or "",
|
||||
noise_psk=args.noise_psk,
|
||||
zeroconf_instance=aiozc.zeroconf,
|
||||
keepalive=10,
|
||||
)
|
||||
|
||||
@ -46,12 +41,10 @@ async def main(argv: list[str]) -> None:
|
||||
text = message.decode("utf8", "backslashreplace")
|
||||
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
|
||||
|
||||
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
|
||||
stop = await async_run(cli, on_log)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
await asyncio.Event().wait()
|
||||
finally:
|
||||
await aiozc.async_close()
|
||||
await stop()
|
||||
|
||||
|
||||
|
@ -47,21 +47,16 @@ async def async_run(
|
||||
) -> None:
|
||||
_LOGGER.warning("Disconnected from API")
|
||||
|
||||
passed_in_zeroconf = aio_zeroconf_instance is not None
|
||||
aiozc = aio_zeroconf_instance or AsyncZeroconf()
|
||||
|
||||
logic = ReconnectLogic(
|
||||
client=cli,
|
||||
on_connect=on_connect,
|
||||
on_disconnect=on_disconnect,
|
||||
zeroconf_instance=aiozc.zeroconf,
|
||||
zeroconf_instance=aio_zeroconf_instance,
|
||||
name=name,
|
||||
)
|
||||
await logic.start()
|
||||
|
||||
async def _stop() -> None:
|
||||
if not passed_in_zeroconf:
|
||||
await aiozc.async_close()
|
||||
await logic.stop()
|
||||
await cli.disconnect()
|
||||
|
||||
|
@ -19,6 +19,8 @@ from .core import (
|
||||
RequiresEncryptionAPIError,
|
||||
UnhandledAPIConnectionError,
|
||||
)
|
||||
from .util import address_is_local
|
||||
from .zeroconf import ZeroconfInstanceType
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -62,7 +64,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
client: APIClient,
|
||||
on_connect: Callable[[], Awaitable[None]],
|
||||
on_disconnect: Callable[[bool], Awaitable[None]],
|
||||
zeroconf_instance: zeroconf.Zeroconf,
|
||||
zeroconf_instance: ZeroconfInstanceType | None = None,
|
||||
name: str | None = None,
|
||||
on_connect_error: Callable[[Exception], Awaitable[None]] | None = None,
|
||||
) -> None:
|
||||
@ -74,21 +76,19 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._cli = client
|
||||
self.name: str | None
|
||||
if client.address.endswith(".local"):
|
||||
self.name = client.address[:-6]
|
||||
self._log_name = self.name
|
||||
elif name:
|
||||
self.name: str | None = None
|
||||
if name:
|
||||
self.name = name
|
||||
self._log_name = f"{name} @ {self._cli.address}"
|
||||
self._cli.set_cached_name_if_unset(name)
|
||||
else:
|
||||
self.name = None
|
||||
self._log_name = client.address
|
||||
elif address_is_local(client.address):
|
||||
self.name = client.address.partition(".")[0]
|
||||
if self.name:
|
||||
self._cli.set_cached_name_if_unset(self.name)
|
||||
self._on_connect_cb = on_connect
|
||||
self._on_disconnect_cb = on_disconnect
|
||||
self._on_connect_error_cb = on_connect_error
|
||||
self._zc = zeroconf_instance
|
||||
self._zeroconf_manager = client.zeroconf_manager
|
||||
if zeroconf_instance is not None:
|
||||
self._zeroconf_manager.set_instance(zeroconf_instance)
|
||||
self._ptr_alias: str | None = None
|
||||
self._a_name: str | None = None
|
||||
# Flag to check if the device is connected
|
||||
@ -116,7 +116,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
_LOGGER.info(
|
||||
"Processing %s disconnect from ESPHome API for %s",
|
||||
disconnect_type,
|
||||
self._log_name,
|
||||
self._cli.log_name,
|
||||
)
|
||||
|
||||
# Run disconnect hook
|
||||
@ -172,7 +172,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
_LOGGER.log(
|
||||
level,
|
||||
"Can't connect to ESPHome API for %s: %s (%s)",
|
||||
self._log_name,
|
||||
self._cli.log_name,
|
||||
err,
|
||||
type(err).__name__,
|
||||
# Print stacktrace if unhandled
|
||||
@ -197,7 +197,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
finish_connect_time = time.perf_counter()
|
||||
connect_time = finish_connect_time - start_connect_time
|
||||
_LOGGER.info(
|
||||
"Successfully connected to %s in %0.3fs", self._log_name, connect_time
|
||||
"Successfully connected to %s in %0.3fs", self._cli.log_name, connect_time
|
||||
)
|
||||
self._stop_zc_listen()
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.HANDSHAKING)
|
||||
@ -221,7 +221,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
finish_handshake_time = time.perf_counter()
|
||||
handshake_time = finish_handshake_time - finish_connect_time
|
||||
_LOGGER.info(
|
||||
"Successful handshake with %s in %0.3fs", self._log_name, handshake_time
|
||||
"Successful handshake with %s in %0.3fs", self._cli.log_name, handshake_time
|
||||
)
|
||||
self._async_set_connection_state_while_locked(ReconnectLogicState.READY)
|
||||
await self._on_connect_cb()
|
||||
@ -250,7 +250,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
return
|
||||
_LOGGER.debug(
|
||||
"%s: Cancelling existing connect task, to try again now!",
|
||||
self._log_name,
|
||||
self._cli.log_name,
|
||||
)
|
||||
self._connect_task.cancel("Scheduling new connect attempt")
|
||||
self._connect_task = None
|
||||
@ -260,7 +260,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
self._connect_task = asyncio.create_task(
|
||||
self._connect_once_or_reschedule(),
|
||||
name=f"{self._log_name}: aioesphomeapi connect",
|
||||
name=f"{self._cli.log_name}: aioesphomeapi connect",
|
||||
)
|
||||
|
||||
def _cancel_connect(self, msg: str) -> None:
|
||||
@ -277,9 +277,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
|
||||
Must only be called from _call_connect_once
|
||||
"""
|
||||
_LOGGER.debug("Trying to connect to %s", self._log_name)
|
||||
_LOGGER.debug("Trying to connect to %s", self._cli.log_name)
|
||||
async with self._connected_lock:
|
||||
_LOGGER.debug("Connected lock acquired for %s", self._log_name)
|
||||
_LOGGER.debug("Connected lock acquired for %s", self._cli.log_name)
|
||||
if (
|
||||
self._connection_state != ReconnectLogicState.DISCONNECTED
|
||||
or self._is_stopped
|
||||
@ -291,9 +291,9 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
wait_time = int(round(min(1.8**tries, 60.0)))
|
||||
if tries == 1:
|
||||
_LOGGER.info(
|
||||
"Trying to connect to %s in the background", self._log_name
|
||||
"Trying to connect to %s in the background", self._cli.log_name
|
||||
)
|
||||
_LOGGER.debug("Retrying %s in %d seconds", self._log_name, wait_time)
|
||||
_LOGGER.debug("Retrying %s in %d seconds", self._cli.log_name, wait_time)
|
||||
if wait_time:
|
||||
# If we are waiting, start listening for mDNS records
|
||||
self._start_zc_listen()
|
||||
@ -311,7 +311,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
"""Stop the connect logic."""
|
||||
self._stop_task = asyncio.create_task(
|
||||
self.stop(),
|
||||
name=f"{self._log_name}: aioesphomeapi reconnect_logic stop_callback",
|
||||
name=f"{self._cli.log_name}: aioesphomeapi reconnect_logic stop_callback",
|
||||
)
|
||||
self._stop_task.add_done_callback(self._remove_stop_task)
|
||||
|
||||
@ -342,6 +342,8 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
ReconnectLogicState.DISCONNECTED
|
||||
)
|
||||
|
||||
await self._zeroconf_manager.async_close()
|
||||
|
||||
def _start_zc_listen(self) -> None:
|
||||
"""Listen for mDNS records.
|
||||
|
||||
@ -352,14 +354,18 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
_LOGGER.debug("Starting zeroconf listener for %s", self.name)
|
||||
self._ptr_alias = f"{self.name}._esphomelib._tcp.local."
|
||||
self._a_name = f"{self.name}.local."
|
||||
self._zc.async_add_listener(self, None)
|
||||
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_add_listener(
|
||||
self, None
|
||||
)
|
||||
self._zc_listening = True
|
||||
|
||||
def _stop_zc_listen(self) -> None:
|
||||
"""Stop listening for zeroconf updates."""
|
||||
if self._zc_listening:
|
||||
_LOGGER.debug("Removing zeroconf listener for %s", self.name)
|
||||
self._zc.async_remove_listener(self)
|
||||
self._zeroconf_manager.get_async_zeroconf().zeroconf.async_remove_listener(
|
||||
self
|
||||
)
|
||||
self._zc_listening = False
|
||||
|
||||
def async_update_records(
|
||||
@ -389,7 +395,7 @@ class ReconnectLogic(zeroconf.RecordUpdateListener):
|
||||
# Tell connection logic to retry connection attempt now (even before connect timer finishes)
|
||||
_LOGGER.debug(
|
||||
"%s: Triggering connect because of received mDNS record %s",
|
||||
self._log_name,
|
||||
self._cli.log_name,
|
||||
record_update.new,
|
||||
)
|
||||
# We can't stop the zeroconf listener here because we are in the middle of
|
||||
|
@ -24,3 +24,27 @@ def fix_float_single_double_conversion(value: float) -> float:
|
||||
l10 = math.ceil(math.log10(abs_val))
|
||||
prec = 7 - l10
|
||||
return round(value, prec)
|
||||
|
||||
|
||||
def host_is_name_part(address: str) -> bool:
|
||||
"""Return True if a host is the name part."""
|
||||
return "." not in address and ":" not in address
|
||||
|
||||
|
||||
def address_is_local(address: str) -> bool:
|
||||
"""Return True if the address is a local address."""
|
||||
return address.removesuffix(".").endswith(".local")
|
||||
|
||||
|
||||
def build_log_name(name: str | None, address: str, resolved_address: str | None) -> str:
|
||||
"""Return a log name for a connection."""
|
||||
if not name and address_is_local(address) or host_is_name_part(address):
|
||||
name = address.partition(".")[0]
|
||||
preferred_address = resolved_address or address
|
||||
if (
|
||||
name
|
||||
and name != preferred_address
|
||||
and not preferred_address.startswith(f"{name}.")
|
||||
):
|
||||
return f"{name} @ {preferred_address}"
|
||||
return preferred_address
|
||||
|
60
aioesphomeapi/zeroconf.py
Normal file
60
aioesphomeapi/zeroconf.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
from zeroconf import Zeroconf
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ZeroconfManager:
|
||||
"""Manage the Zeroconf objects.
|
||||
|
||||
This class is used to manage the Zeroconf objects. It is used to create
|
||||
the Zeroconf objects and to close them. It attempts to avoid creating
|
||||
a Zeroconf object unless one is actually needed.
|
||||
"""
|
||||
|
||||
def __init__(self, zeroconf: ZeroconfInstanceType | None = None) -> None:
|
||||
"""Initialize the ZeroconfManager."""
|
||||
self._created = False
|
||||
self._aiozc: AsyncZeroconf | None = None
|
||||
if zeroconf is not None:
|
||||
self.set_instance(zeroconf)
|
||||
|
||||
def set_instance(self, zc: AsyncZeroconf | Zeroconf) -> None:
|
||||
"""Set the AsyncZeroconf instance."""
|
||||
if self._aiozc:
|
||||
if isinstance(zc, AsyncZeroconf) and self._aiozc.zeroconf is zc.zeroconf:
|
||||
return
|
||||
if isinstance(zc, Zeroconf) and self._aiozc.zeroconf is zc:
|
||||
self._aiozc = AsyncZeroconf(zc=zc)
|
||||
return
|
||||
raise RuntimeError("Zeroconf instance already set to a different instance")
|
||||
self._aiozc = zc if isinstance(zc, AsyncZeroconf) else AsyncZeroconf(zc=zc)
|
||||
|
||||
def _create_async_zeroconf(self) -> None:
|
||||
"""Create an AsyncZeroconf instance."""
|
||||
_LOGGER.debug("Creating new AsyncZeroconf instance")
|
||||
self._aiozc = AsyncZeroconf()
|
||||
self._created = True
|
||||
|
||||
def get_async_zeroconf(self) -> AsyncZeroconf:
|
||||
"""Get the AsyncZeroconf instance."""
|
||||
if not self._aiozc:
|
||||
self._create_async_zeroconf()
|
||||
if TYPE_CHECKING:
|
||||
assert self._aiozc is not None
|
||||
return self._aiozc
|
||||
|
||||
async def async_close(self) -> None:
|
||||
"""Close the Zeroconf connection."""
|
||||
if not self._created or not self._aiozc:
|
||||
return
|
||||
await self._aiozc.async_close()
|
||||
self._aiozc = None
|
||||
self._created = False
|
@ -4,7 +4,7 @@ import asyncio
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from functools import partial
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from google.protobuf import message
|
||||
from zeroconf import Zeroconf
|
||||
@ -27,26 +27,30 @@ PROTO_TO_MESSAGE_TYPE = {v: k for k, v in MESSAGE_TYPE_TO_PROTO.items()}
|
||||
|
||||
|
||||
def get_mock_zeroconf() -> MagicMock:
|
||||
return MagicMock(spec=Zeroconf)
|
||||
with patch("zeroconf.Zeroconf.start"):
|
||||
zc = Zeroconf()
|
||||
zc.close = MagicMock()
|
||||
return zc
|
||||
|
||||
|
||||
def get_mock_async_zeroconf() -> MagicMock:
|
||||
mock = MagicMock(spec=AsyncZeroconf)
|
||||
mock.zeroconf = get_mock_zeroconf()
|
||||
mock.async_close = AsyncMock()
|
||||
return mock
|
||||
def get_mock_async_zeroconf() -> AsyncZeroconf:
|
||||
aiozc = AsyncZeroconf(zc=get_mock_zeroconf())
|
||||
aiozc.async_close = AsyncMock()
|
||||
return aiozc
|
||||
|
||||
|
||||
class Estr(str):
|
||||
"""A subclassed string."""
|
||||
|
||||
|
||||
def generate_plaintext_packet(msg: bytes, type_: int) -> bytes:
|
||||
def generate_plaintext_packet(msg: message.Message) -> bytes:
|
||||
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
|
||||
bytes_ = msg.SerializeToString()
|
||||
return (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(msg))
|
||||
+ _cached_varuint_to_bytes(len(bytes_))
|
||||
+ _cached_varuint_to_bytes(type_)
|
||||
+ msg
|
||||
+ bytes_
|
||||
)
|
||||
|
||||
|
||||
@ -99,10 +103,7 @@ def send_plaintext_hello(protocol: APIPlaintextFrameHelper) -> None:
|
||||
hello_response.api_version_major = 1
|
||||
hello_response.api_version_minor = 9
|
||||
hello_response.name = "fake"
|
||||
hello_msg = hello_response.SerializeToString()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(hello_msg, PROTO_TO_MESSAGE_TYPE[HelloResponse])
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(hello_response))
|
||||
|
||||
|
||||
def send_plaintext_connect_response(
|
||||
@ -110,8 +111,4 @@ def send_plaintext_connect_response(
|
||||
) -> None:
|
||||
connect_response: message.Message = ConnectResponse()
|
||||
connect_response.invalid_password = invalid_password
|
||||
connect_msg = connect_response.SerializeToString()
|
||||
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(connect_msg, PROTO_TO_MESSAGE_TYPE[ConnectResponse])
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(connect_response))
|
||||
|
@ -12,8 +12,14 @@ from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
|
||||
from aioesphomeapi.client import APIClient, ConnectionParams
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.host_resolver import AddrInfo, IPv4Sockaddr
|
||||
from aioesphomeapi.zeroconf import ZeroconfManager
|
||||
|
||||
from .common import connect, send_plaintext_hello
|
||||
from .common import connect, get_mock_async_zeroconf, send_plaintext_hello
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_zeroconf():
|
||||
return get_mock_async_zeroconf()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -42,7 +48,7 @@ def connection_params() -> ConnectionParams:
|
||||
password=None,
|
||||
client_info="Tests client",
|
||||
keepalive=15.0,
|
||||
zeroconf_instance=None,
|
||||
zeroconf_manager=ZeroconfManager(),
|
||||
noise_psk=None,
|
||||
expected_name=None,
|
||||
)
|
||||
|
@ -141,6 +141,7 @@ def test_plaintext_frame_helper(
|
||||
|
||||
assert type_ == pkt_type
|
||||
assert data == pkt_data
|
||||
helper.close()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -18,6 +18,8 @@ from aioesphomeapi.api_pb2 import (
|
||||
CameraImageResponse,
|
||||
ClimateCommandRequest,
|
||||
CoverCommandRequest,
|
||||
DeviceInfoResponse,
|
||||
DisconnectResponse,
|
||||
ExecuteServiceArgument,
|
||||
ExecuteServiceRequest,
|
||||
FanCommandRequest,
|
||||
@ -34,6 +36,7 @@ from aioesphomeapi.api_pb2 import (
|
||||
)
|
||||
from aioesphomeapi.client import APIClient
|
||||
from aioesphomeapi.connection import APIConnection
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
from aioesphomeapi.model import (
|
||||
AlarmControlPanelCommand,
|
||||
APIVersion,
|
||||
@ -680,12 +683,7 @@ async def test_bluetooth_disconnect(
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False
|
||||
)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceConnectionResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
await disconnect_task
|
||||
|
||||
|
||||
@ -700,12 +698,7 @@ async def test_bluetooth_pair(
|
||||
pair_task = asyncio.create_task(client.bluetooth_device_pair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDevicePairingResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDevicePairingResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
await pair_task
|
||||
|
||||
|
||||
@ -720,12 +713,7 @@ async def test_bluetooth_unpair(
|
||||
unpair_task = asyncio.create_task(client.bluetooth_device_unpair(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceUnpairingResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceUnpairingResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
await unpair_task
|
||||
|
||||
|
||||
@ -740,10 +728,36 @@ async def test_bluetooth_clear_cache(
|
||||
clear_task = asyncio.create_task(client.bluetooth_device_clear_cache(1234))
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = BluetoothDeviceClearCacheResponse(address=1234)
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[BluetoothDeviceClearCacheResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
await clear_task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_device_info(
|
||||
api_client: tuple[
|
||||
APIClient, APIConnection, asyncio.Transport, APIPlaintextFrameHelper
|
||||
],
|
||||
) -> None:
|
||||
"""Test fetching device info."""
|
||||
client, connection, transport, protocol = api_client
|
||||
assert client.log_name == "mydevice.local"
|
||||
device_info_task = asyncio.create_task(client.device_info())
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = DeviceInfoResponse(
|
||||
name="realname",
|
||||
friendly_name="My Device",
|
||||
has_deep_sleep=True,
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
device_info = await device_info_task
|
||||
assert device_info.name == "realname"
|
||||
assert device_info.friendly_name == "My Device"
|
||||
assert device_info.has_deep_sleep
|
||||
assert client.log_name == "realname @ 10.0.0.512"
|
||||
disconnect_task = asyncio.create_task(client.disconnect())
|
||||
await asyncio.sleep(0)
|
||||
response: message.Message = DisconnectResponse()
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
await disconnect_task
|
||||
with pytest.raises(APIConnectionError, match="CLOSED"):
|
||||
await client.device_info()
|
||||
|
@ -5,20 +5,11 @@ from ipaddress import ip_address
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from zeroconf import DNSCache
|
||||
from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
import aioesphomeapi.host_resolver as hr
|
||||
from aioesphomeapi.core import APIConnectionError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_zeroconf():
|
||||
with patch("aioesphomeapi.host_resolver.AsyncZeroconf") as klass:
|
||||
async_zeroconf = klass.return_value
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
async_zeroconf.zeroconf.cache = DNSCache()
|
||||
yield async_zeroconf
|
||||
from aioesphomeapi.core import APIConnectionError, ResolveAPIError
|
||||
from aioesphomeapi.zeroconf import ZeroconfManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,7 +43,9 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info
|
||||
), patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
|
||||
ret = await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
||||
info.async_request.assert_called_once()
|
||||
@ -61,10 +54,8 @@ async def test_resolve_host_zeroconf(async_zeroconf: AsyncZeroconf, addr_infos):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
|
||||
async_zeroconf = AsyncZeroconf(zc=MagicMock())
|
||||
async_zeroconf.async_close = AsyncMock()
|
||||
async_zeroconf.zeroconf.cache = DNSCache()
|
||||
async def test_resolve_host_passed_zeroconf(addr_infos, async_zeroconf):
|
||||
zeroconf_manager = ZeroconfManager()
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
@ -73,11 +64,10 @@ async def test_resolve_host_passed_zeroconf_does_not_close(addr_infos):
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch("aioesphomeapi.host_resolver.AsyncServiceInfo", return_value=info):
|
||||
ret = await hr._async_resolve_host_zeroconf(
|
||||
"asdf", 6052, zeroconf_instance=async_zeroconf
|
||||
"asdf", 6052, zeroconf_manager=zeroconf_manager
|
||||
)
|
||||
|
||||
info.async_request.assert_called_once()
|
||||
async_zeroconf.async_close.assert_not_called()
|
||||
assert ret == addr_infos
|
||||
|
||||
|
||||
@ -131,7 +121,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_zc.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("example.local", 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_not_called()
|
||||
assert ret == addr_infos[0]
|
||||
|
||||
@ -144,7 +134,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
|
||||
resolve_addr.return_value = addr_infos
|
||||
ret = await hr.async_resolve_host("example.local", 6052)
|
||||
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_instance=None)
|
||||
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
|
||||
resolve_addr.assert_called_once_with("example.local", 6052)
|
||||
assert ret == addr_infos[0]
|
||||
|
||||
@ -189,3 +179,40 @@ async def test_resolve_host_with_address(resolve_addr, resolve_zc):
|
||||
proto=6,
|
||||
sockaddr=hr.IPv4Sockaddr(address="127.0.0.1", port=6052),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_zeroconf_service_info_oserror(
|
||||
async_zeroconf: AsyncZeroconf, addr_infos
|
||||
):
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
|
||||
side_effect=OSError("out of buffers"),
|
||||
), patch(
|
||||
"aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf
|
||||
), pytest.raises(
|
||||
ResolveAPIError, match="out of buffers"
|
||||
):
|
||||
await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_host_create_zeroconf_oserror(
|
||||
async_zeroconf: AsyncZeroconf, addr_infos
|
||||
):
|
||||
info = MagicMock(auto_spec=AsyncServiceInfo)
|
||||
info.ip_addresses_by_version.return_value = [
|
||||
ip_address(b"\n\x00\x00*"),
|
||||
ip_address(b" \x01\r\xb8\x85\xa3\x00\x00\x00\x00\x8a.\x03ps4"),
|
||||
]
|
||||
info.async_request = AsyncMock(return_value=True)
|
||||
with patch(
|
||||
"aioesphomeapi.zeroconf.AsyncZeroconf", side_effect=OSError("out of buffers")
|
||||
), pytest.raises(ResolveAPIError, match="out of buffers"):
|
||||
await hr._async_resolve_host_zeroconf("asdf", 6052)
|
||||
|
@ -75,21 +75,11 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
||||
|
||||
response: message.Message = SubscribeLogsResponse()
|
||||
response.message = b"Hello world"
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[SubscribeLogsResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(response))
|
||||
assert len(messages) == 1
|
||||
assert messages[0].message == b"Hello world"
|
||||
stop_task = asyncio.create_task(stop())
|
||||
await asyncio.sleep(0)
|
||||
disconnect_response = DisconnectResponse()
|
||||
protocol.data_received(
|
||||
generate_plaintext_packet(
|
||||
disconnect_response.SerializeToString(),
|
||||
PROTO_TO_MESSAGE_TYPE[DisconnectResponse],
|
||||
)
|
||||
)
|
||||
protocol.data_received(generate_plaintext_packet(disconnect_response))
|
||||
await stop_task
|
||||
|
@ -41,14 +41,13 @@ async def test_reconnect_logic_name_from_host():
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
rl = ReconnectLogic(
|
||||
ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=MagicMock(spec=AsyncZeroconf),
|
||||
)
|
||||
assert rl._log_name == "mydevice"
|
||||
assert cli._log_name == "mydevice"
|
||||
assert cli.log_name == "mydevice.local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -66,15 +65,14 @@ async def test_reconnect_logic_name_from_host_and_set():
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
rl = ReconnectLogic(
|
||||
ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
)
|
||||
assert rl._log_name == "mydevice"
|
||||
assert cli._log_name == "mydevice"
|
||||
assert cli.log_name == "mydevice.local"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -92,14 +90,13 @@ async def test_reconnect_logic_name_from_address():
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
rl = ReconnectLogic(
|
||||
ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=get_mock_zeroconf(),
|
||||
)
|
||||
assert rl._log_name == "1.2.3.4"
|
||||
assert cli._log_name == "1.2.3.4"
|
||||
assert cli.log_name == "1.2.3.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -117,15 +114,14 @@ async def test_reconnect_logic_name_from_name():
|
||||
async def on_connect() -> None:
|
||||
pass
|
||||
|
||||
rl = ReconnectLogic(
|
||||
ReconnectLogic(
|
||||
client=cli,
|
||||
on_disconnect=on_disconnect,
|
||||
on_connect=on_connect,
|
||||
zeroconf_instance=get_mock_zeroconf(),
|
||||
name="mydevice",
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -164,8 +160,7 @@ async def test_reconnect_logic_state():
|
||||
name="mydevice",
|
||||
on_connect_error=on_connect_fail,
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
await rl.start()
|
||||
@ -241,8 +236,7 @@ async def test_reconnect_retry():
|
||||
name="mydevice",
|
||||
on_connect_error=on_connect_fail,
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
with patch.object(cli, "start_connection", side_effect=APIConnectionError):
|
||||
await rl.start()
|
||||
@ -338,8 +332,7 @@ async def test_reconnect_zeroconf(
|
||||
name="mydevice",
|
||||
on_connect_error=AsyncMock(),
|
||||
)
|
||||
assert rl._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli._log_name == "mydevice @ 1.2.3.4"
|
||||
assert cli.log_name == "mydevice @ 1.2.3.4"
|
||||
|
||||
async def slow_connect_fail(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
52
tests/test_zeroconf.py
Normal file
52
tests/test_zeroconf.py
Normal file
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from aioesphomeapi.zeroconf import ZeroconfManager
|
||||
|
||||
from .common import get_mock_async_zeroconf
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_closed_passed_in_async_instance(async_zeroconf: AsyncZeroconf):
|
||||
"""Test that the passed in instance is not closed."""
|
||||
manager = ZeroconfManager()
|
||||
manager.set_instance(async_zeroconf)
|
||||
await manager.async_close()
|
||||
assert async_zeroconf.async_close.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_does_not_closed_passed_in_sync_instance(async_zeroconf: AsyncZeroconf):
|
||||
"""Test that the passed in instance is not closed."""
|
||||
manager = ZeroconfManager()
|
||||
manager.set_instance(async_zeroconf.zeroconf)
|
||||
await manager.async_close()
|
||||
assert async_zeroconf.async_close.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_closes_created_instance(async_zeroconf: AsyncZeroconf):
|
||||
"""Test that the created instance is closed."""
|
||||
with patch("aioesphomeapi.zeroconf.AsyncZeroconf", return_value=async_zeroconf):
|
||||
manager = ZeroconfManager()
|
||||
assert manager.get_async_zeroconf() is async_zeroconf
|
||||
await manager.async_close()
|
||||
assert async_zeroconf.async_close.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_runtime_error_multiple_instances(async_zeroconf: AsyncZeroconf):
|
||||
"""Test runtime error is raised on multiple instances."""
|
||||
manager = ZeroconfManager(async_zeroconf)
|
||||
new_instance = get_mock_async_zeroconf()
|
||||
with pytest.raises(RuntimeError):
|
||||
manager.set_instance(new_instance)
|
||||
manager.set_instance(async_zeroconf)
|
||||
manager.set_instance(async_zeroconf.zeroconf)
|
||||
manager.set_instance(async_zeroconf)
|
||||
await manager.async_close()
|
||||
assert async_zeroconf.async_close.call_count == 0
|
Loading…
Reference in New Issue
Block a user