mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-13 10:43:59 +01:00
Fix host resolution when local dns does not resolve mdns (#636)
This commit is contained in:
parent
c1a0500ecb
commit
634c739048
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import IPv4Address, IPv6Address
|
||||
@ -12,8 +13,12 @@ from zeroconf.asyncio import AsyncServiceInfo, AsyncZeroconf
|
||||
|
||||
from .core import APIConnectionError, ResolveAPIError
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
ZeroconfInstanceType = Union[Zeroconf, AsyncZeroconf, None]
|
||||
|
||||
SERVICE_TYPE = "_esphomelib._tcp.local."
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Sockaddr:
|
||||
@ -89,11 +94,11 @@ async def _async_resolve_host_zeroconf(
|
||||
timeout: float = 3.0,
|
||||
zeroconf_instance: ZeroconfInstanceType = None,
|
||||
) -> list[AddrInfo]:
|
||||
service_type = "_esphomelib._tcp.local."
|
||||
service_name = f"{host}.{service_type}"
|
||||
service_name = f"{host}.{SERVICE_TYPE}"
|
||||
|
||||
_LOGGER.debug("Resolving host %s via mDNS", service_name)
|
||||
info = await _async_zeroconf_get_service_info(
|
||||
zeroconf_instance, service_type, service_name, timeout
|
||||
zeroconf_instance, SERVICE_TYPE, service_name, timeout
|
||||
)
|
||||
|
||||
if info is None:
|
||||
@ -197,8 +202,8 @@ async def async_resolve_host(
|
||||
addrs: list[AddrInfo] = []
|
||||
|
||||
zc_error = None
|
||||
if host.endswith(".local"):
|
||||
name = host[: -len(".local")]
|
||||
if "." not in host or host.endswith(".local"):
|
||||
name = host.partition(".")[0]
|
||||
try:
|
||||
addrs.extend(
|
||||
await _async_resolve_host_zeroconf(
|
||||
|
@ -7,6 +7,8 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from zeroconf.asyncio import AsyncZeroconf
|
||||
|
||||
from .api_pb2 import SubscribeLogsResponse # type: ignore
|
||||
from .client import APIClient
|
||||
from .log_runner import async_run
|
||||
@ -27,11 +29,14 @@ async def main(argv: list[str]) -> None:
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
aiozc = AsyncZeroconf()
|
||||
|
||||
cli = APIClient(
|
||||
args.address,
|
||||
args.port,
|
||||
args.password or "",
|
||||
noise_psk=args.noise_psk,
|
||||
zeroconf_instance=aiozc.zeroconf,
|
||||
keepalive=10,
|
||||
)
|
||||
|
||||
@ -41,11 +46,12 @@ async def main(argv: list[str]) -> None:
|
||||
text = message.decode("utf8", "backslashreplace")
|
||||
print(f"[{time_.hour:02}:{time_.minute:02}:{time_.second:02}]{text}")
|
||||
|
||||
stop = await async_run(cli, on_log)
|
||||
stop = await async_run(cli, on_log, aio_zeroconf_instance=aiozc)
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(60)
|
||||
finally:
|
||||
await aiozc.async_close()
|
||||
await stop()
|
||||
|
||||
|
||||
|
@ -34,12 +34,15 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
||||
class PatchableAPIClient(APIClient):
|
||||
pass
|
||||
|
||||
async_zeroconf = get_mock_async_zeroconf()
|
||||
|
||||
cli = PatchableAPIClient(
|
||||
address=Estr("1.2.3.4"),
|
||||
port=6052,
|
||||
password=None,
|
||||
noise_psk=None,
|
||||
expected_name=Estr("fake"),
|
||||
zeroconf_instance=async_zeroconf.zeroconf,
|
||||
)
|
||||
messages = []
|
||||
|
||||
@ -60,8 +63,6 @@ async def test_log_runner(event_loop: asyncio.AbstractEventLoop, conn: APIConnec
|
||||
await original_subscribe_logs(*args, **kwargs)
|
||||
subscribed.set()
|
||||
|
||||
async_zeroconf = get_mock_async_zeroconf()
|
||||
|
||||
with patch.object(event_loop, "sock_connect"), patch.object(
|
||||
loop, "create_connection", side_effect=_create_mock_transport_protocol
|
||||
), patch.object(cli, "subscribe_logs", _wait_subscribe_cli):
|
||||
|
Loading…
Reference in New Issue
Block a user