esphome/esphome/zeroconf.py
J. Nick Koston 214b419db2
dashboard: Use mdns cache when available if device connection is OTA (#5724)
* Use mdns or freshen cache when device connection is OTA

Since we already have a service browser running, we likely
already know the IP of the deivce we want to connect to so
we can replace OTA with the address to avoid the esphome
app having to look it up again

* isort

* Fix zeroconf name resolution refactoring error

HostResolver should get the type as the first arg instead
of the name

* no i/o

* tornado support native coros

* lint

* use new tornado start methods

* use new tornado start methods

* use new tornado start methods

* break

* lint

* lint

* typing, missing awaits

* io in executor

* missed one

* fix: missing if

* stale comment

* rename run_command to build_device_command since it does not actually run anything
2023-11-14 21:21:44 -05:00

198 lines
6.6 KiB
Python

from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass
from typing import Callable
from zeroconf import IPVersion, ServiceInfo, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf
from esphome.storage_json import StorageJSON, ext_storage_path
_LOGGER = logging.getLogger(__name__)
_BACKGROUND_TASKS: set[asyncio.Task] = set()
class HostResolver(ServiceInfo):
"""Resolve a host name to an IP address."""
@property
def _is_complete(self) -> bool:
"""The ServiceInfo has all expected properties."""
return bool(self._ipv4_addresses)
class DashboardStatus:
def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None:
"""Initialize the dashboard status."""
self.on_update = on_update
def browser_callback(
self,
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
"""Handle a service update."""
short_name = name.partition(".")[0]
if state_change == ServiceStateChange.Removed:
self.on_update({short_name: False})
elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added):
self.on_update({short_name: True})
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
TXT_RECORD_PACKAGE_IMPORT_URL = b"package_import_url"
TXT_RECORD_PROJECT_NAME = b"project_name"
TXT_RECORD_PROJECT_VERSION = b"project_version"
TXT_RECORD_NETWORK = b"network"
TXT_RECORD_FRIENDLY_NAME = b"friendly_name"
TXT_RECORD_VERSION = b"version"
@dataclass
class DiscoveredImport:
friendly_name: str | None
device_name: str
package_import_url: str
project_name: str
project_version: str
network: str
class DashboardBrowser(AsyncServiceBrowser):
"""A class to browse for ESPHome nodes."""
class DashboardImportDiscovery:
def __init__(self) -> None:
self.import_state: dict[str, DiscoveredImport] = {}
def browser_callback(
self,
zeroconf: Zeroconf,
service_type: str,
name: str,
state_change: ServiceStateChange,
) -> None:
_LOGGER.debug(
"service_update: type=%s name=%s state_change=%s",
service_type,
name,
state_change,
)
if state_change == ServiceStateChange.Removed:
self.import_state.pop(name, None)
return
if state_change == ServiceStateChange.Updated and name not in self.import_state:
# Ignore updates for devices that are not in the import state
return
info = AsyncServiceInfo(
service_type,
name,
)
if info.load_from_cache(zeroconf):
self._process_service_info(name, info)
return
task = asyncio.create_task(
self._async_process_service_info(zeroconf, info, service_type, name)
)
_BACKGROUND_TASKS.add(task)
task.add_done_callback(_BACKGROUND_TASKS.discard)
async def _async_process_service_info(
self, zeroconf: Zeroconf, info: AsyncServiceInfo, service_type: str, name: str
) -> None:
"""Process a service info."""
if await info.async_request(zeroconf):
self._process_service_info(name, info)
def _process_service_info(self, name: str, info: ServiceInfo) -> None:
"""Process a service info."""
_LOGGER.debug("-> resolved info: %s", info)
if info is None:
return
node_name = name[: -len(ESPHOME_SERVICE_TYPE) - 1]
required_keys = [
TXT_RECORD_PACKAGE_IMPORT_URL,
TXT_RECORD_PROJECT_NAME,
TXT_RECORD_PROJECT_VERSION,
]
if any(key not in info.properties for key in required_keys):
# Not a dashboard import device
version = info.properties.get(TXT_RECORD_VERSION)
if version is not None:
version = version.decode()
self.update_device_mdns(node_name, version)
return
import_url = info.properties[TXT_RECORD_PACKAGE_IMPORT_URL].decode()
project_name = info.properties[TXT_RECORD_PROJECT_NAME].decode()
project_version = info.properties[TXT_RECORD_PROJECT_VERSION].decode()
network = info.properties.get(TXT_RECORD_NETWORK, b"wifi").decode()
friendly_name = info.properties.get(TXT_RECORD_FRIENDLY_NAME)
if friendly_name is not None:
friendly_name = friendly_name.decode()
self.import_state[name] = DiscoveredImport(
friendly_name=friendly_name,
device_name=node_name,
package_import_url=import_url,
project_name=project_name,
project_version=project_version,
network=network,
)
def update_device_mdns(self, node_name: str, version: str):
storage_path = ext_storage_path(node_name + ".yaml")
storage_json = StorageJSON.load(storage_path)
if storage_json is not None:
storage_version = storage_json.esphome_version
if version != storage_version:
storage_json.esphome_version = version
storage_json.save(storage_path)
_LOGGER.info(
"Updated %s with mdns version %s (was %s)",
node_name,
version,
storage_version,
)
def _make_host_resolver(host: str) -> HostResolver:
"""Create a new HostResolver for the given host name."""
name = host.partition(".")[0]
info = HostResolver(ESPHOME_SERVICE_TYPE, f"{name}.{ESPHOME_SERVICE_TYPE}")
return info
class EsphomeZeroconf(Zeroconf):
def resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address."""
info = _make_host_resolver(host)
if (
info.load_from_cache(self)
or (timeout and info.request(self, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0])
return None
class AsyncEsphomeZeroconf(AsyncZeroconf):
async def async_resolve_host(self, host: str, timeout: float = 3.0) -> str | None:
"""Resolve a host name to an IP address."""
info = _make_host_resolver(host)
if (
info.load_from_cache(self.zeroconf)
or (timeout and await info.async_request(self.zeroconf, timeout * 1000))
) and (addresses := info.ip_addresses_by_version(IPVersion.V4Only)):
return str(addresses[0])
return None