dashboard: convert ping thread to use asyncio (#5749)

This commit is contained in:
J. Nick Koston 2023-11-14 22:55:33 -06:00 committed by GitHub
parent 642db6d92b
commit 20ea8bf06e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 73 additions and 88 deletions

View File

@ -1,18 +1,13 @@
from __future__ import annotations
import asyncio
import threading
class ThreadedAsyncEvent:
"""This is a shim to allow the asyncio event to be used in a threaded context.
When more of the code is moved to asyncio, this can be removed.
"""
class AsyncEvent:
"""This is a shim around asyncio.Event."""
def __init__(self) -> None:
"""Initialize the ThreadedAsyncEvent."""
self.event = threading.Event()
self.async_event: asyncio.Event | None = None
self.loop: asyncio.AbstractEventLoop | None = None
@ -26,31 +21,11 @@ class ThreadedAsyncEvent:
def async_set(self) -> None:
"""Set the asyncio.Event instance."""
self.async_event.set()
self.event.set()
def set(self) -> None:
"""Set the event."""
self.loop.call_soon_threadsafe(self.async_event.set)
self.event.set()
def wait(self) -> None:
"""Wait for the event."""
self.event.wait()
async def async_wait(self) -> None:
"""Wait the event async."""
await self.async_event.wait()
def clear(self) -> None:
"""Clear the event."""
self.loop.call_soon_threadsafe(self.async_event.clear)
self.event.clear()
def async_clear(self) -> None:
"""Clear the event async."""
self.async_event.clear()
self.event.clear()
def is_set(self) -> bool:
"""Return if the event is set."""
return self.event.is_set()

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import base64
import binascii
import collections
import datetime
import functools
import gzip
@ -11,14 +10,13 @@ import hashlib
import hmac
import json
import logging
import multiprocessing
import os
import secrets
import shutil
import subprocess
import threading
from pathlib import Path
from typing import Any
from typing import Any, cast
import tornado
import tornado.concurrent
@ -52,9 +50,9 @@ from esphome.zeroconf import (
DashboardImportDiscovery,
DashboardStatus,
)
from .async_adapter import ThreadedAsyncEvent
from .util import friendly_name_slugify, password_hash
from .async_adapter import AsyncEvent
from .util import chunked, friendly_name_slugify, password_hash
_LOGGER = logging.getLogger(__name__)
@ -603,7 +601,7 @@ class ImportRequestHandler(BaseHandler):
encryption,
)
# Make sure the device gets marked online right away
PING_REQUEST.set()
PING_REQUEST.async_set()
except FileExistsError:
self.set_status(500)
self.write("File already exists")
@ -905,15 +903,6 @@ class MainRequestHandler(BaseHandler):
)
def _ping_func(filename, address):
if os.name == "nt":
command = ["ping", "-n", "1", address]
else:
command = ["ping", "-c", "1", address]
rc, _, _ = run_system_command(*command)
return filename, rc == 0
class PrometheusServiceDiscoveryHandler(BaseHandler):
@authenticated
def get(self):
@ -1070,47 +1059,48 @@ class MDNSStatus:
self.aiozc = None
class PingStatusThread(threading.Thread):
def run(self):
with multiprocessing.Pool(processes=8) as pool:
while not STOP_EVENT.wait(2):
# Only do pings if somebody has the dashboard open
async def _async_ping_host(host: str) -> bool:
"""Ping a host."""
ping_command = ["ping", "-n" if os.name == "nt" else "-c", "1"]
process = await asyncio.create_subprocess_exec(
*ping_command,
host,
stdin=asyncio.subprocess.DEVNULL,
stdout=asyncio.subprocess.DEVNULL,
stderr=asyncio.subprocess.DEVNULL,
)
await process.wait()
return process.returncode == 0
def callback(ret):
PING_RESULT[ret[0]] = ret[1]
entries = _list_dashboard_entries()
queue = collections.deque()
for entry in entries:
if entry.address is None:
PING_RESULT[entry.filename] = None
continue
class PingStatus:
def __init__(self) -> None:
"""Initialize the PingStatus class."""
super().__init__()
self._loop = asyncio.get_running_loop()
result = pool.apply_async(
_ping_func, (entry.filename, entry.address), callback=callback
)
queue.append(result)
while queue:
item = queue[0]
if item.ready():
queue.popleft()
continue
try:
item.get(0.1)
except OSError:
# ping not installed
pass
except multiprocessing.TimeoutError:
pass
if STOP_EVENT.is_set():
pool.terminate()
return
PING_REQUEST.wait()
PING_REQUEST.clear()
async def async_run(self) -> None:
"""Run the ping status."""
while not STOP_EVENT.is_set():
# Only ping if the dashboard is open
await PING_REQUEST.async_wait()
PING_REQUEST.async_clear()
entries = await self._loop.run_in_executor(None, _list_dashboard_entries)
to_ping: list[DashboardEntry] = [
entry for entry in entries if entry.address is not None
]
for ping_group in chunked(to_ping, 16):
ping_group = cast(list[DashboardEntry], ping_group)
results = await asyncio.gather(
*(_async_ping_host(entry.address) for entry in ping_group),
return_exceptions=True,
)
for entry, result in zip(ping_group, results):
if isinstance(result, Exception):
result = False
elif isinstance(result, BaseException):
raise result
PING_RESULT[entry.filename] = result
class MqttStatusThread(threading.Thread):
@ -1171,7 +1161,7 @@ class MqttStatusThread(threading.Thread):
class PingRequestHandler(BaseHandler):
@authenticated
def get(self):
PING_REQUEST.set()
PING_REQUEST.async_set()
if settings.status_use_mqtt:
MQTT_PING_REQUEST.set()
self.set_header("content-type", "application/json")
@ -1261,7 +1251,7 @@ class MDNSContainer:
PING_RESULT: dict = {}
IMPORT_RESULT = {}
STOP_EVENT = threading.Event()
PING_REQUEST = ThreadedAsyncEvent()
PING_REQUEST = AsyncEvent()
MQTT_PING_REQUEST = threading.Event()
MDNS_CONTAINER = MDNSContainer()
@ -1561,10 +1551,10 @@ async def async_start_web_server(args):
webbrowser.open(f"http://{args.address}:{args.port}")
mdns_task: asyncio.Task | None = None
ping_status_thread: PingStatusThread | None = None
ping_status_task: asyncio.Task | None = None
if settings.status_use_ping:
ping_status_thread = PingStatusThread()
ping_status_thread.start()
ping_status = PingStatus()
ping_status_task = asyncio.create_task(ping_status.async_run())
else:
mdns_status = MDNSStatus()
await mdns_status.async_refresh_hosts()
@ -1581,9 +1571,9 @@ async def async_start_web_server(args):
finally:
_LOGGER.info("Shutting down...")
STOP_EVENT.set()
PING_REQUEST.set()
if ping_status_thread:
ping_status_thread.join()
PING_REQUEST.async_set()
if ping_status_task:
ping_status_task.cancel()
MDNS_CONTAINER.set_mdns(None)
if mdns_task:
mdns_task.cancel()

View File

@ -1,5 +1,9 @@
import hashlib
import unicodedata
from collections.abc import Iterable
from functools import partial
from itertools import islice
from typing import Any
from esphome.const import ALLOWED_NAME_CHARS
@ -30,3 +34,19 @@ def friendly_name_slugify(value):
.strip("-")
)
return "".join(c for c in value if c in ALLOWED_NAME_CHARS)
def take(take_num: int, iterable: Iterable) -> list[Any]:
"""Return first n items of the iterable as a list.
From itertools recipes
"""
return list(islice(iterable, take_num))
def chunked(iterable: Iterable, chunked_num: int) -> Iterable[Any]:
"""Break *iterable* into lists of length *n*.
From more-itertools
"""
return iter(partial(take, chunked_num, iter(iterable)), [])