Fix host resolution when local dns does not resolve mdns (#636)

This commit is contained in:
J. Nick Koston 2023-11-11 14:48:12 -06:00 committed by GitHub
parent c1a0500ecb
commit 634c739048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import logging
import socket import socket
from dataclasses import dataclass from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
@ -12,8 +13,12 @@ from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
from .core import APIConnectionError, ResolveAPIError from .core import APIConnectionError, ResolveAPIError
_LOGGER = logging.getLogger(__name__)
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None] ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
SERVICE_TYPE = "_esphomelib._tcp.local."
@dataclass(frozen=True) @dataclass(frozen=True)
class Sockaddr: class Sockaddr:
@ -89,11 +94,11 @@ async def _async_resolve_host_zeroconf(
timeout: float = 3.0, timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None, zeroconf_instance: ZeroconfInstanceType = None,
) -> list[AddrInfo]: ) -> list[AddrInfo]:
service_type = "_esphomelib._tcp.local." service_name = f"{host}.{SERVICE_TYPE}"
service_name = f"{host}.{service_type}"
_LOGGER.debug("Resolving host %s via mDNS", service_name)
info = await _async_zeroconf_get_service_info( info = await _async_zeroconf_get_service_info(
zeroconf_instance, service_type, service_name, timeout zeroconf_instance, SERVICE_TYPE, service_name, timeout
) )
if info is None: if info is None:
@ -197,8 +202,8 @@ async def async_resolve_host(
addrs: list[AddrInfo] = [] addrs: list[AddrInfo] = []
zc_error = None zc_error = None
if host.endswith(".local"): if "." not in host or host.endswith(".local"):
name = host[: -len(".local")] name = host.partition(".")[0]
try: try:
addrs.extend( addrs.extend(
await _async_resolve_host_zeroconf( await _async_resolve_host_zeroconf(

View File

@ -7,6 +7,8 @@ import logging
import sys import sys
from datetime import datetime from datetime import datetime
from zeroconf.asyncio import AsyncZeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient from .client import APIClient
from .log_runner import async_run from .log_runner import async_run
@ -27,11 +29,14 @@ async def main(argv: list[str]) -> None:
datefmt="%Y-%m-%d %H:%M:%S", datefmt="%Y-%m-%d %H:%M:%S",
) )
aiozc = AsyncZeroconf()
cli = APIClient( cli = APIClient(
args.address, args.address,
args.port, args.port,
args.password or "", args.password or "",
noise_psk=args.noise_psk, noise_psk=args.noise_psk,
zeroconf_instance=aiozc.zeroconf,
keepalive=10, keepalive=10,
) )
@ -41,11 +46,12 @@ async def main(argv: list[str]) -> None:
text = message.decode("utf8", "backslashreplace") text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}") print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
stop = await async_run(cli, on_log) stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
try: try:
while True: while True:
await asyncio.sleep(60) await asyncio.sleep(60)
finally: finally:
await aiozc.async_close()
await stop() await stop()

View File

@ -34,12 +34,15 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
class PatchableAPIClient(APIClient): class PatchableAPIClient(APIClient):
pass pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient( cli = PatchableAPIClient(
address=Estr("1.2.3.4"), address=Estr("1.2.3.4"),
port=6052, port=6052,
password=None, password=None,
noise_psk=None, noise_psk=None,
expected_name=Estr("fake"), expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
) )
messages = [] messages = []
@ -60,8 +63,6 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
await original_subscribe_logs(*args, **kwargs) await original_subscribe_logs(*args, **kwargs)
subscribed.set() subscribed.set()
async_zeroconf = get_mock_async_zeroconf()
with patch.object(event_loop, "sock_connect"), patch.object( with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli): ), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):