From fce59819f5627dcf32267b3da20ead54cc8acd69 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 6 Nov 2023 16:07:59 -0600 Subject: [PATCH] Refactor dashboard zeroconf support (#5681) --- esphome/dashboard/dashboard.py | 109 ++++++++++++++++------ esphome/zeroconf.py | 162 ++++++++------------------------- 2 files changed, 123 insertions(+), 148 deletions(-) diff --git a/esphome/dashboard/dashboard.py b/esphome/dashboard/dashboard.py index ce8976cb0f..f6eb079430 100644 --- a/esphome/dashboard/dashboard.py +++ b/esphome/dashboard/dashboard.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import binascii import codecs @@ -15,7 +17,6 @@ import shutil import subprocess import threading from pathlib import Path -from typing import Optional import tornado import tornado.concurrent @@ -42,7 +43,13 @@ from esphome.storage_json import ( trash_storage_path, ) from esphome.util import get_serial_ports, shlex_quote -from esphome.zeroconf import DashboardImportDiscovery, DashboardStatus, EsphomeZeroconf +from esphome.zeroconf import ( + ESPHOME_SERVICE_TYPE, + DashboardBrowser, + DashboardImportDiscovery, + DashboardStatus, + EsphomeZeroconf, +) from .util import friendly_name_slugify, password_hash @@ -517,6 +524,8 @@ class ImportRequestHandler(BaseHandler): network, encryption, ) + # Make sure the device gets marked online right away + PING_REQUEST.set() except FileExistsError: self.set_status(500) self.write("File already exists") @@ -542,13 +551,11 @@ class DownloadListRequestHandler(BaseHandler): self.send_error(404) return - from esphome.components.esp32 import ( - get_download_types as esp32_types, - VARIANTS as ESP32_VARIANTS, - ) + from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS + from esphome.components.esp32 import get_download_types as esp32_types from esphome.components.esp8266 import get_download_types as esp8266_types - from esphome.components.rp2040 import get_download_types as rp2040_types from esphome.components.libretiny import get_download_types as libretiny_types + from esphome.components.rp2040 import get_download_types as rp2040_types downloads = [] platform = storage_json.target_platform.lower() @@ -661,12 +668,21 @@ class DashboardEntry: self._storage = None self._loaded_storage = False + def __repr__(self): + return ( + f"DashboardEntry({self.path} " + f"address={self.address} " + f"web_port={self.web_port} " + f"name={self.name} " + f"no_mdns={self.no_mdns})" + ) + @property def filename(self): return os.path.basename(self.path) @property - def storage(self) -> Optional[StorageJSON]: + def storage(self) -> StorageJSON | None: if not self._loaded_storage: self._storage = StorageJSON.load(ext_storage_path(self.filename)) self._loaded_storage = True @@ -831,10 +847,10 @@ class PrometheusServiceDiscoveryHandler(BaseHandler): class BoardsRequestHandler(BaseHandler): @authenticated def get(self, platform: str): + from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS - from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS platform_to_boards = { @@ -865,35 +881,76 @@ class BoardsRequestHandler(BaseHandler): class MDNSStatusThread(threading.Thread): + def __init__(self): + """Initialize the MDNSStatusThread.""" + super().__init__() + # This is the current mdns state for each host (True, False, None) + self.host_mdns_state: dict[str, bool | None] = {} + # This is the hostnames to filenames mapping + self.host_name_to_filename: dict[str, str] = {} + # This is a set of host names to track (i.e no_mdns = false) + self.host_name_with_mdns_enabled: set[set] = set() + self._refresh_hosts() + + def _refresh_hosts(self): + """Refresh the hosts to track.""" + entries = _list_dashboard_entries() + host_name_with_mdns_enabled = self.host_name_with_mdns_enabled + host_mdns_state = self.host_mdns_state + host_name_to_filename = self.host_name_to_filename + + for entry in entries: + name = entry.name + # If no_mdns is set, remove it from the set + if entry.no_mdns: + host_name_with_mdns_enabled.discard(name) + continue + + # We are tracking this host + host_name_with_mdns_enabled.add(name) + filename = entry.filename + + # If we just adopted/imported this host, we likely + # already have a state for it, so we should make sure + # to set it so the dashboard shows it as online + if name in host_mdns_state: + PING_RESULT[filename] = host_mdns_state[name] + + # Make sure the mapping is up to date + # so when we get an mdns update we can map it back + # to the filename + host_name_to_filename[name] = filename + def run(self): global IMPORT_RESULT zc = EsphomeZeroconf() + host_mdns_state = self.host_mdns_state + host_name_to_filename = self.host_name_to_filename + host_name_with_mdns_enabled = self.host_name_with_mdns_enabled - def on_update(dat): - for key, b in dat.items(): - PING_RESULT[key] = b + def on_update(dat: dict[str, bool | None]) -> None: + """Update the global PING_RESULT dict.""" + for name, result in dat.items(): + host_mdns_state[name] = result + if name in host_name_with_mdns_enabled: + filename = host_name_to_filename[name] + PING_RESULT[filename] = result - stat = DashboardStatus(zc, on_update) - imports = DashboardImportDiscovery(zc) + self._refresh_hosts() + stat = DashboardStatus(on_update) + imports = DashboardImportDiscovery() + browser = DashboardBrowser( + zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback] + ) - stat.start() while not STOP_EVENT.is_set(): - entries = _list_dashboard_entries() - hosts = {} - for entry in entries: - if entry.no_mdns is not True: - hosts[entry.filename] = f"{entry.name}.local." - - stat.request_query(hosts) + self._refresh_hosts() IMPORT_RESULT = imports.import_state - PING_REQUEST.wait() PING_REQUEST.clear() - stat.stop() - stat.join() - imports.cancel() + browser.cancel() zc.close() diff --git a/esphome/zeroconf.py b/esphome/zeroconf.py index 14dd740a96..d20111ce20 100644 --- a/esphome/zeroconf.py +++ b/esphome/zeroconf.py @@ -1,130 +1,49 @@ +from __future__ import annotations + import logging -import socket -import threading -import time from dataclasses import dataclass -from typing import Optional +from typing import Callable from zeroconf import ( - DNSAddress, - DNSOutgoing, - DNSQuestion, - RecordUpdate, - RecordUpdateListener, + IPVersion, ServiceBrowser, + ServiceInfo, ServiceStateChange, Zeroconf, - current_time_millis, ) from esphome.storage_json import StorageJSON, ext_storage_path -_CLASS_IN = 1 -_FLAGS_QR_QUERY = 0x0000 # query -_TYPE_A = 1 _LOGGER = logging.getLogger(__name__) -class HostResolver(RecordUpdateListener): +class HostResolver(ServiceInfo): """Resolve a host name to an IP address.""" - def __init__(self, name: str): - self.name = name - self.address: Optional[bytes] = None - - def async_update_records( - self, zc: Zeroconf, now: float, records: list[RecordUpdate] - ) -> None: - """Update multiple records in one shot. - - This will run in zeroconf's event loop thread so it - must be thread-safe. - """ - for record_update in records: - record, _ = record_update - if record is None: - continue - if record.type == _TYPE_A: - assert isinstance(record, DNSAddress) - if record.name == self.name: - self.address = record.address - - def request(self, zc: Zeroconf, timeout: float) -> bool: - now = time.time() - delay = 0.2 - next_ = now + delay - last = now + timeout - - try: - zc.add_listener(self, None) - while self.address is None: - if last <= now: - # Timeout - return False - if next_ <= now: - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question(DNSQuestion(self.name, _TYPE_A, _CLASS_IN)) - zc.send(out) - next_ = now + delay - delay *= 2 - - time.sleep(min(next_, last) - now) - now = time.time() - finally: - zc.remove_listener(self) - - return True + @property + def _is_complete(self) -> bool: + """The ServiceInfo has all expected properties.""" + return bool(self._ipv4_addresses) -class DashboardStatus(threading.Thread): - PING_AFTER = 15 * 1000 # Send new mDNS request after 15 seconds - OFFLINE_AFTER = PING_AFTER * 2 # Offline if no mDNS response after 30 seconds - - def __init__(self, zc: Zeroconf, on_update) -> None: - threading.Thread.__init__(self) - self.zc = zc - self.query_hosts: set[str] = set() - self.key_to_host: dict[str, str] = {} - self.stop_event = threading.Event() - self.query_event = threading.Event() +class DashboardStatus: + def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None: + """Initialize the dashboard status.""" self.on_update = on_update - def request_query(self, hosts: dict[str, str]) -> None: - self.query_hosts = set(hosts.values()) - self.key_to_host = hosts - self.query_event.set() - - def stop(self) -> None: - self.stop_event.set() - self.query_event.set() - - def host_status(self, key: str) -> bool: - entries = self.zc.cache.entries_with_name(key) - if not entries: - return False - now = current_time_millis() - - return any( - (entry.created + DashboardStatus.OFFLINE_AFTER) >= now for entry in entries - ) - - def run(self) -> None: - while not self.stop_event.is_set(): - self.on_update( - {key: self.host_status(host) for key, host in self.key_to_host.items()} - ) - now = current_time_millis() - for host in self.query_hosts: - entries = self.zc.cache.entries_with_name(host) - if not entries or all( - (entry.created + DashboardStatus.PING_AFTER) <= now - for entry in entries - ): - out = DNSOutgoing(_FLAGS_QR_QUERY) - out.add_question(DNSQuestion(host, _TYPE_A, _CLASS_IN)) - self.zc.send(out) - self.query_event.wait() - self.query_event.clear() + def browser_callback( + self, + zeroconf: Zeroconf, + service_type: str, + name: str, + state_change: ServiceStateChange, + ) -> None: + """Handle a service update.""" + short_name = name.partition(".")[0] + if state_change == ServiceStateChange.Removed: + self.on_update({short_name: False}) + elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added): + self.on_update({short_name: True}) ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local." @@ -138,7 +57,7 @@ TXT_RECORD_VERSION = b"version" @dataclass class DiscoveredImport: - friendly_name: Optional[str] + friendly_name: str | None device_name: str package_import_url: str project_name: str @@ -146,15 +65,15 @@ class DiscoveredImport: network: str +class DashboardBrowser(ServiceBrowser): + """A class to browse for ESPHome nodes.""" + + class DashboardImportDiscovery: - def __init__(self, zc: Zeroconf) -> None: - self.zc = zc - self.service_browser = ServiceBrowser( - self.zc, ESPHOME_SERVICE_TYPE, [self._on_update] - ) + def __init__(self) -> None: self.import_state: dict[str, DiscoveredImport] = {} - def _on_update( + def browser_callback( self, zeroconf: Zeroconf, service_type: str, @@ -167,8 +86,6 @@ class DashboardImportDiscovery: name, state_change, ) - if service_type != ESPHOME_SERVICE_TYPE: - return if state_change == ServiceStateChange.Removed: self.import_state.pop(name, None) return @@ -212,9 +129,6 @@ class DashboardImportDiscovery: network=network, ) - def cancel(self) -> None: - self.service_browser.cancel() - def update_device_mdns(self, node_name: str, version: str): storage_path = ext_storage_path(node_name + ".yaml") storage_json = StorageJSON.load(storage_path) @@ -234,7 +148,11 @@ class DashboardImportDiscovery: class EsphomeZeroconf(Zeroconf): def resolve_host(self, host: str, timeout=3.0): - info = HostResolver(host) - if info.request(self, timeout): - return socket.inet_ntoa(info.address) + """Resolve a host name to an IP address.""" + name = host.partition(".")[0] + info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE) + if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and ( + addresses := info.ip_addresses_by_version(IPVersion.V4Only) + ): + return str(addresses[0]) return None