mirror of
https://github.com/esphome/esphome.git
synced 2024-11-22 11:47:30 +01:00
Refactor dashboard zeroconf support (#5681)
This commit is contained in:
parent
b978985aa1
commit
fce59819f5
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
import codecs
|
import codecs
|
||||||
@ -15,7 +17,6 @@ import shutil
|
|||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import tornado
|
import tornado
|
||||||
import tornado.concurrent
|
import tornado.concurrent
|
||||||
@ -42,7 +43,13 @@ from esphome.storage_json import (
|
|||||||
trash_storage_path,
|
trash_storage_path,
|
||||||
)
|
)
|
||||||
from esphome.util import get_serial_ports, shlex_quote
|
from esphome.util import get_serial_ports, shlex_quote
|
||||||
from esphome.zeroconf import DashboardImportDiscovery, DashboardStatus, EsphomeZeroconf
|
from esphome.zeroconf import (
|
||||||
|
ESPHOME_SERVICE_TYPE,
|
||||||
|
DashboardBrowser,
|
||||||
|
DashboardImportDiscovery,
|
||||||
|
DashboardStatus,
|
||||||
|
EsphomeZeroconf,
|
||||||
|
)
|
||||||
|
|
||||||
from .util import friendly_name_slugify, password_hash
|
from .util import friendly_name_slugify, password_hash
|
||||||
|
|
||||||
@ -517,6 +524,8 @@ class ImportRequestHandler(BaseHandler):
|
|||||||
network,
|
network,
|
||||||
encryption,
|
encryption,
|
||||||
)
|
)
|
||||||
|
# Make sure the device gets marked online right away
|
||||||
|
PING_REQUEST.set()
|
||||||
except FileExistsError:
|
except FileExistsError:
|
||||||
self.set_status(500)
|
self.set_status(500)
|
||||||
self.write("File already exists")
|
self.write("File already exists")
|
||||||
@ -542,13 +551,11 @@ class DownloadListRequestHandler(BaseHandler):
|
|||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
return
|
return
|
||||||
|
|
||||||
from esphome.components.esp32 import (
|
from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS
|
||||||
get_download_types as esp32_types,
|
from esphome.components.esp32 import get_download_types as esp32_types
|
||||||
VARIANTS as ESP32_VARIANTS,
|
|
||||||
)
|
|
||||||
from esphome.components.esp8266 import get_download_types as esp8266_types
|
from esphome.components.esp8266 import get_download_types as esp8266_types
|
||||||
from esphome.components.rp2040 import get_download_types as rp2040_types
|
|
||||||
from esphome.components.libretiny import get_download_types as libretiny_types
|
from esphome.components.libretiny import get_download_types as libretiny_types
|
||||||
|
from esphome.components.rp2040 import get_download_types as rp2040_types
|
||||||
|
|
||||||
downloads = []
|
downloads = []
|
||||||
platform = storage_json.target_platform.lower()
|
platform = storage_json.target_platform.lower()
|
||||||
@ -661,12 +668,21 @@ class DashboardEntry:
|
|||||||
self._storage = None
|
self._storage = None
|
||||||
self._loaded_storage = False
|
self._loaded_storage = False
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"DashboardEntry({self.path} "
|
||||||
|
f"address={self.address} "
|
||||||
|
f"web_port={self.web_port} "
|
||||||
|
f"name={self.name} "
|
||||||
|
f"no_mdns={self.no_mdns})"
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def filename(self):
|
def filename(self):
|
||||||
return os.path.basename(self.path)
|
return os.path.basename(self.path)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def storage(self) -> Optional[StorageJSON]:
|
def storage(self) -> StorageJSON | None:
|
||||||
if not self._loaded_storage:
|
if not self._loaded_storage:
|
||||||
self._storage = StorageJSON.load(ext_storage_path(self.filename))
|
self._storage = StorageJSON.load(ext_storage_path(self.filename))
|
||||||
self._loaded_storage = True
|
self._loaded_storage = True
|
||||||
@ -831,10 +847,10 @@ class PrometheusServiceDiscoveryHandler(BaseHandler):
|
|||||||
class BoardsRequestHandler(BaseHandler):
|
class BoardsRequestHandler(BaseHandler):
|
||||||
@authenticated
|
@authenticated
|
||||||
def get(self, platform: str):
|
def get(self, platform: str):
|
||||||
|
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
|
||||||
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
|
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
|
||||||
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
|
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
|
||||||
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
|
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
|
||||||
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
|
|
||||||
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
|
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
|
||||||
|
|
||||||
platform_to_boards = {
|
platform_to_boards = {
|
||||||
@ -865,35 +881,76 @@ class BoardsRequestHandler(BaseHandler):
|
|||||||
|
|
||||||
|
|
||||||
class MDNSStatusThread(threading.Thread):
|
class MDNSStatusThread(threading.Thread):
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the MDNSStatusThread."""
|
||||||
|
super().__init__()
|
||||||
|
# This is the current mdns state for each host (True, False, None)
|
||||||
|
self.host_mdns_state: dict[str, bool | None] = {}
|
||||||
|
# This is the hostnames to filenames mapping
|
||||||
|
self.host_name_to_filename: dict[str, str] = {}
|
||||||
|
# This is a set of host names to track (i.e no_mdns = false)
|
||||||
|
self.host_name_with_mdns_enabled: set[set] = set()
|
||||||
|
self._refresh_hosts()
|
||||||
|
|
||||||
|
def _refresh_hosts(self):
|
||||||
|
"""Refresh the hosts to track."""
|
||||||
|
entries = _list_dashboard_entries()
|
||||||
|
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
|
||||||
|
host_mdns_state = self.host_mdns_state
|
||||||
|
host_name_to_filename = self.host_name_to_filename
|
||||||
|
|
||||||
|
for entry in entries:
|
||||||
|
name = entry.name
|
||||||
|
# If no_mdns is set, remove it from the set
|
||||||
|
if entry.no_mdns:
|
||||||
|
host_name_with_mdns_enabled.discard(name)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We are tracking this host
|
||||||
|
host_name_with_mdns_enabled.add(name)
|
||||||
|
filename = entry.filename
|
||||||
|
|
||||||
|
# If we just adopted/imported this host, we likely
|
||||||
|
# already have a state for it, so we should make sure
|
||||||
|
# to set it so the dashboard shows it as online
|
||||||
|
if name in host_mdns_state:
|
||||||
|
PING_RESULT[filename] = host_mdns_state[name]
|
||||||
|
|
||||||
|
# Make sure the mapping is up to date
|
||||||
|
# so when we get an mdns update we can map it back
|
||||||
|
# to the filename
|
||||||
|
host_name_to_filename[name] = filename
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
global IMPORT_RESULT
|
global IMPORT_RESULT
|
||||||
|
|
||||||
zc = EsphomeZeroconf()
|
zc = EsphomeZeroconf()
|
||||||
|
host_mdns_state = self.host_mdns_state
|
||||||
|
host_name_to_filename = self.host_name_to_filename
|
||||||
|
host_name_with_mdns_enabled = self.host_name_with_mdns_enabled
|
||||||
|
|
||||||
def on_update(dat):
|
def on_update(dat: dict[str, bool | None]) -> None:
|
||||||
for key, b in dat.items():
|
"""Update the global PING_RESULT dict."""
|
||||||
PING_RESULT[key] = b
|
for name, result in dat.items():
|
||||||
|
host_mdns_state[name] = result
|
||||||
|
if name in host_name_with_mdns_enabled:
|
||||||
|
filename = host_name_to_filename[name]
|
||||||
|
PING_RESULT[filename] = result
|
||||||
|
|
||||||
stat = DashboardStatus(zc, on_update)
|
self._refresh_hosts()
|
||||||
imports = DashboardImportDiscovery(zc)
|
stat = DashboardStatus(on_update)
|
||||||
|
imports = DashboardImportDiscovery()
|
||||||
|
browser = DashboardBrowser(
|
||||||
|
zc, ESPHOME_SERVICE_TYPE, [stat.browser_callback, imports.browser_callback]
|
||||||
|
)
|
||||||
|
|
||||||
stat.start()
|
|
||||||
while not STOP_EVENT.is_set():
|
while not STOP_EVENT.is_set():
|
||||||
entries = _list_dashboard_entries()
|
self._refresh_hosts()
|
||||||
hosts = {}
|
|
||||||
for entry in entries:
|
|
||||||
if entry.no_mdns is not True:
|
|
||||||
hosts[entry.filename] = f"{entry.name}.local."
|
|
||||||
|
|
||||||
stat.request_query(hosts)
|
|
||||||
IMPORT_RESULT = imports.import_state
|
IMPORT_RESULT = imports.import_state
|
||||||
|
|
||||||
PING_REQUEST.wait()
|
PING_REQUEST.wait()
|
||||||
PING_REQUEST.clear()
|
PING_REQUEST.clear()
|
||||||
|
|
||||||
stat.stop()
|
browser.cancel()
|
||||||
stat.join()
|
|
||||||
imports.cancel()
|
|
||||||
zc.close()
|
zc.close()
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,130 +1,49 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Callable
|
||||||
|
|
||||||
from zeroconf import (
|
from zeroconf import (
|
||||||
DNSAddress,
|
IPVersion,
|
||||||
DNSOutgoing,
|
|
||||||
DNSQuestion,
|
|
||||||
RecordUpdate,
|
|
||||||
RecordUpdateListener,
|
|
||||||
ServiceBrowser,
|
ServiceBrowser,
|
||||||
|
ServiceInfo,
|
||||||
ServiceStateChange,
|
ServiceStateChange,
|
||||||
Zeroconf,
|
Zeroconf,
|
||||||
current_time_millis,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from esphome.storage_json import StorageJSON, ext_storage_path
|
from esphome.storage_json import StorageJSON, ext_storage_path
|
||||||
|
|
||||||
_CLASS_IN = 1
|
|
||||||
_FLAGS_QR_QUERY = 0x0000 # query
|
|
||||||
_TYPE_A = 1
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HostResolver(RecordUpdateListener):
|
class HostResolver(ServiceInfo):
|
||||||
"""Resolve a host name to an IP address."""
|
"""Resolve a host name to an IP address."""
|
||||||
|
|
||||||
def __init__(self, name: str):
|
@property
|
||||||
self.name = name
|
def _is_complete(self) -> bool:
|
||||||
self.address: Optional[bytes] = None
|
"""The ServiceInfo has all expected properties."""
|
||||||
|
return bool(self._ipv4_addresses)
|
||||||
def async_update_records(
|
|
||||||
self, zc: Zeroconf, now: float, records: list[RecordUpdate]
|
|
||||||
) -> None:
|
|
||||||
"""Update multiple records in one shot.
|
|
||||||
|
|
||||||
This will run in zeroconf's event loop thread so it
|
|
||||||
must be thread-safe.
|
|
||||||
"""
|
|
||||||
for record_update in records:
|
|
||||||
record, _ = record_update
|
|
||||||
if record is None:
|
|
||||||
continue
|
|
||||||
if record.type == _TYPE_A:
|
|
||||||
assert isinstance(record, DNSAddress)
|
|
||||||
if record.name == self.name:
|
|
||||||
self.address = record.address
|
|
||||||
|
|
||||||
def request(self, zc: Zeroconf, timeout: float) -> bool:
|
|
||||||
now = time.time()
|
|
||||||
delay = 0.2
|
|
||||||
next_ = now + delay
|
|
||||||
last = now + timeout
|
|
||||||
|
|
||||||
try:
|
|
||||||
zc.add_listener(self, None)
|
|
||||||
while self.address is None:
|
|
||||||
if last <= now:
|
|
||||||
# Timeout
|
|
||||||
return False
|
|
||||||
if next_ <= now:
|
|
||||||
out = DNSOutgoing(_FLAGS_QR_QUERY)
|
|
||||||
out.add_question(DNSQuestion(self.name, _TYPE_A, _CLASS_IN))
|
|
||||||
zc.send(out)
|
|
||||||
next_ = now + delay
|
|
||||||
delay *= 2
|
|
||||||
|
|
||||||
time.sleep(min(next_, last) - now)
|
|
||||||
now = time.time()
|
|
||||||
finally:
|
|
||||||
zc.remove_listener(self)
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
class DashboardStatus(threading.Thread):
|
class DashboardStatus:
|
||||||
PING_AFTER = 15 * 1000 # Send new mDNS request after 15 seconds
|
def __init__(self, on_update: Callable[[dict[str, bool | None], []]]) -> None:
|
||||||
OFFLINE_AFTER = PING_AFTER * 2 # Offline if no mDNS response after 30 seconds
|
"""Initialize the dashboard status."""
|
||||||
|
|
||||||
def __init__(self, zc: Zeroconf, on_update) -> None:
|
|
||||||
threading.Thread.__init__(self)
|
|
||||||
self.zc = zc
|
|
||||||
self.query_hosts: set[str] = set()
|
|
||||||
self.key_to_host: dict[str, str] = {}
|
|
||||||
self.stop_event = threading.Event()
|
|
||||||
self.query_event = threading.Event()
|
|
||||||
self.on_update = on_update
|
self.on_update = on_update
|
||||||
|
|
||||||
def request_query(self, hosts: dict[str, str]) -> None:
|
def browser_callback(
|
||||||
self.query_hosts = set(hosts.values())
|
self,
|
||||||
self.key_to_host = hosts
|
zeroconf: Zeroconf,
|
||||||
self.query_event.set()
|
service_type: str,
|
||||||
|
name: str,
|
||||||
def stop(self) -> None:
|
state_change: ServiceStateChange,
|
||||||
self.stop_event.set()
|
) -> None:
|
||||||
self.query_event.set()
|
"""Handle a service update."""
|
||||||
|
short_name = name.partition(".")[0]
|
||||||
def host_status(self, key: str) -> bool:
|
if state_change == ServiceStateChange.Removed:
|
||||||
entries = self.zc.cache.entries_with_name(key)
|
self.on_update({short_name: False})
|
||||||
if not entries:
|
elif state_change in (ServiceStateChange.Updated, ServiceStateChange.Added):
|
||||||
return False
|
self.on_update({short_name: True})
|
||||||
now = current_time_millis()
|
|
||||||
|
|
||||||
return any(
|
|
||||||
(entry.created + DashboardStatus.OFFLINE_AFTER) >= now for entry in entries
|
|
||||||
)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
while not self.stop_event.is_set():
|
|
||||||
self.on_update(
|
|
||||||
{key: self.host_status(host) for key, host in self.key_to_host.items()}
|
|
||||||
)
|
|
||||||
now = current_time_millis()
|
|
||||||
for host in self.query_hosts:
|
|
||||||
entries = self.zc.cache.entries_with_name(host)
|
|
||||||
if not entries or all(
|
|
||||||
(entry.created + DashboardStatus.PING_AFTER) <= now
|
|
||||||
for entry in entries
|
|
||||||
):
|
|
||||||
out = DNSOutgoing(_FLAGS_QR_QUERY)
|
|
||||||
out.add_question(DNSQuestion(host, _TYPE_A, _CLASS_IN))
|
|
||||||
self.zc.send(out)
|
|
||||||
self.query_event.wait()
|
|
||||||
self.query_event.clear()
|
|
||||||
|
|
||||||
|
|
||||||
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
|
ESPHOME_SERVICE_TYPE = "_esphomelib._tcp.local."
|
||||||
@ -138,7 +57,7 @@ TXT_RECORD_VERSION = b"version"
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DiscoveredImport:
|
class DiscoveredImport:
|
||||||
friendly_name: Optional[str]
|
friendly_name: str | None
|
||||||
device_name: str
|
device_name: str
|
||||||
package_import_url: str
|
package_import_url: str
|
||||||
project_name: str
|
project_name: str
|
||||||
@ -146,15 +65,15 @@ class DiscoveredImport:
|
|||||||
network: str
|
network: str
|
||||||
|
|
||||||
|
|
||||||
|
class DashboardBrowser(ServiceBrowser):
|
||||||
|
"""A class to browse for ESPHome nodes."""
|
||||||
|
|
||||||
|
|
||||||
class DashboardImportDiscovery:
|
class DashboardImportDiscovery:
|
||||||
def __init__(self, zc: Zeroconf) -> None:
|
def __init__(self) -> None:
|
||||||
self.zc = zc
|
|
||||||
self.service_browser = ServiceBrowser(
|
|
||||||
self.zc, ESPHOME_SERVICE_TYPE, [self._on_update]
|
|
||||||
)
|
|
||||||
self.import_state: dict[str, DiscoveredImport] = {}
|
self.import_state: dict[str, DiscoveredImport] = {}
|
||||||
|
|
||||||
def _on_update(
|
def browser_callback(
|
||||||
self,
|
self,
|
||||||
zeroconf: Zeroconf,
|
zeroconf: Zeroconf,
|
||||||
service_type: str,
|
service_type: str,
|
||||||
@ -167,8 +86,6 @@ class DashboardImportDiscovery:
|
|||||||
name,
|
name,
|
||||||
state_change,
|
state_change,
|
||||||
)
|
)
|
||||||
if service_type != ESPHOME_SERVICE_TYPE:
|
|
||||||
return
|
|
||||||
if state_change == ServiceStateChange.Removed:
|
if state_change == ServiceStateChange.Removed:
|
||||||
self.import_state.pop(name, None)
|
self.import_state.pop(name, None)
|
||||||
return
|
return
|
||||||
@ -212,9 +129,6 @@ class DashboardImportDiscovery:
|
|||||||
network=network,
|
network=network,
|
||||||
)
|
)
|
||||||
|
|
||||||
def cancel(self) -> None:
|
|
||||||
self.service_browser.cancel()
|
|
||||||
|
|
||||||
def update_device_mdns(self, node_name: str, version: str):
|
def update_device_mdns(self, node_name: str, version: str):
|
||||||
storage_path = ext_storage_path(node_name + ".yaml")
|
storage_path = ext_storage_path(node_name + ".yaml")
|
||||||
storage_json = StorageJSON.load(storage_path)
|
storage_json = StorageJSON.load(storage_path)
|
||||||
@ -234,7 +148,11 @@ class DashboardImportDiscovery:
|
|||||||
|
|
||||||
class EsphomeZeroconf(Zeroconf):
|
class EsphomeZeroconf(Zeroconf):
|
||||||
def resolve_host(self, host: str, timeout=3.0):
|
def resolve_host(self, host: str, timeout=3.0):
|
||||||
info = HostResolver(host)
|
"""Resolve a host name to an IP address."""
|
||||||
if info.request(self, timeout):
|
name = host.partition(".")[0]
|
||||||
return socket.inet_ntoa(info.address)
|
info = HostResolver(f"{name}.{ESPHOME_SERVICE_TYPE}", ESPHOME_SERVICE_TYPE)
|
||||||
|
if (info.load_from_cache(self) or info.request(self, timeout * 1000)) and (
|
||||||
|
addresses := info.ip_addresses_by_version(IPVersion.V4Only)
|
||||||
|
):
|
||||||
|
return str(addresses[0])
|
||||||
return None
|
return None
|
||||||
|
Loading…
Reference in New Issue
Block a user