Fix host resolution when local dns does not resolve mdns (#636)

This commit is contained in:
J. Nick Koston 2023-11-11 14:48:12 -06:00 committed by GitHub
parent c1a0500ecb
commit 634c739048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 8 deletions

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
import socket
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
@ -12,8 +13,12 @@ from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
from .core import APIConnectionError, ResolveAPIError
_LOGGER = logging.getLogger(__name__)
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
SERVICE_TYPE = "_esphomelib._tcp.local."
@dataclass(frozen=True)
class Sockaddr:
@ -89,11 +94,11 @@ async def _async_resolve_host_zeroconf(
timeout: float = 3.0,
zeroconf_instance: ZeroconfInstanceType = None,
) -> list[AddrInfo]:
service_type = "_esphomelib._tcp.local."
service_name = f"{host}.{service_type}"
service_name = f"{host}.{SERVICE_TYPE}"
_LOGGER.debug("Resolving host %s via mDNS", service_name)
info = await _async_zeroconf_get_service_info(
zeroconf_instance, service_type, service_name, timeout
zeroconf_instance, SERVICE_TYPE, service_name, timeout
)
if info is None:
@ -197,8 +202,8 @@ async def async_resolve_host(
addrs: list[AddrInfo] = []
zc_error = None
if host.endswith(".local"):
name = host[: -len(".local")]
if "." not in host or host.endswith(".local"):
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(

View File

@ -7,6 +7,8 @@ import logging
import sys
from datetime import datetime
from zeroconf.asyncio import AsyncZeroconf
from .api_pb2 import SubscribeLogsResponse # type: ignore
from .client import APIClient
from .log_runner import async_run
@ -27,11 +29,14 @@ async def main(argv: list[str]) -> None:
datefmt="%Y-%m-%d %H:%M:%S",
)
aiozc = AsyncZeroconf()
cli = APIClient(
args.address,
args.port,
args.password or "",
noise_psk=args.noise_psk,
zeroconf_instance=aiozc.zeroconf,
keepalive=10,
)
@ -41,11 +46,12 @@ async def main(argv: list[str]) -> None:
text = message.decode("utf8", "backslashreplace")
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
stop = await async_run(cli, on_log)
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
try:
while True:
await asyncio.sleep(60)
finally:
await aiozc.async_close()
await stop()

View File

@ -34,12 +34,15 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
class PatchableAPIClient(APIClient):
pass
async_zeroconf = get_mock_async_zeroconf()
cli = PatchableAPIClient(
address=Estr("1.2.3.4"),
port=6052,
password=None,
noise_psk=None,
expected_name=Estr("fake"),
zeroconf_instance=async_zeroconf.zeroconf,
)
messages = []
@ -60,8 +63,6 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
await original_subscribe_logs(*args, **kwargs)
subscribed.set()
async_zeroconf = get_mock_async_zeroconf()
with patch.object(event_loop, "sock_connect"), patch.object(
loop, "create_connection", side_effect=_create_mock_transport_protocol
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):