From 4e6d3729e178999daea09421e42ccf022c768509 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Mon, 27 Nov 2023 16:39:24 -0600 Subject: [PATCH] dashboard: Small cleanups to dashboard (#5841) --- esphome/dashboard/entries.py | 2 +- esphome/dashboard/settings.py | 24 +++--- esphome/dashboard/web_server.py | 135 +++++++++++++++++--------------- 3 files changed, 87 insertions(+), 74 deletions(-) diff --git a/esphome/dashboard/entries.py b/esphome/dashboard/entries.py index 8ccfa795d5..ad139b830b 100644 --- a/esphome/dashboard/entries.py +++ b/esphome/dashboard/entries.py @@ -262,7 +262,7 @@ class DashboardEntry: self.state = EntryState.UNKNOWN self._to_dict: dict[str, Any] | None = None - def __repr__(self): + def __repr__(self) -> str: """Return the representation of this entry.""" return ( f"DashboardEntry(path={self.path} " diff --git a/esphome/dashboard/settings.py b/esphome/dashboard/settings.py index 61718298d2..1a5b1620e8 100644 --- a/esphome/dashboard/settings.py +++ b/esphome/dashboard/settings.py @@ -23,45 +23,45 @@ class DashboardSettings: self.cookie_secret: str | None = None self.absolute_config_dir: Path | None = None - def parse_args(self, args): + def parse_args(self, args: Any) -> None: self.on_ha_addon: bool = args.ha_addon - password: str = args.password or os.getenv("PASSWORD", "") + password = args.password or os.getenv("PASSWORD") or "" if not self.on_ha_addon: - self.username: str = args.username or os.getenv("USERNAME", "") + self.username = args.username or os.getenv("USERNAME") or "" self.using_password = bool(password) if self.using_password: self.password_hash = password_hash(password) - self.config_dir: str = args.configuration - self.absolute_config_dir: Path = Path(self.config_dir).resolve() + self.config_dir = args.configuration + self.absolute_config_dir = Path(self.config_dir).resolve() CORE.config_path = os.path.join(self.config_dir, ".") @property - def relative_url(self): - return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL", "/") + def relative_url(self) -> str: + return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL") or "/" @property def status_use_ping(self): return get_bool_env("ESPHOME_DASHBOARD_USE_PING") @property - def status_use_mqtt(self): + def status_use_mqtt(self) -> bool: return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT") @property - def using_ha_addon_auth(self): + def using_ha_addon_auth(self) -> bool: if not self.on_ha_addon: return False return not get_bool_env("DISABLE_HA_AUTHENTICATION") @property - def using_auth(self): + def using_auth(self) -> bool: return self.using_password or self.using_ha_addon_auth @property - def streamer_mode(self): + def streamer_mode(self) -> bool: return get_bool_env("ESPHOME_STREAMER_MODE") - def check_password(self, username, password): + def check_password(self, username: str, password: str) -> bool: if not self.using_auth: return True if username != self.username: diff --git a/esphome/dashboard/web_server.py b/esphome/dashboard/web_server.py index 9a9ccb462b..9bbf0b28dc 100644 --- a/esphome/dashboard/web_server.py +++ b/esphome/dashboard/web_server.py @@ -14,12 +14,14 @@ import shutil import subprocess import threading from pathlib import Path -from typing import Any +from typing import Any, Callable, TypeVar +from collections.abc import Iterable import tornado import tornado.concurrent import tornado.gen import tornado.httpserver +import tornado.httputil import tornado.ioloop import tornado.iostream import tornado.netutil @@ -27,9 +29,9 @@ import tornado.process import tornado.queues import tornado.web import tornado.websocket -import tornado.httputil import yaml from tornado.log import access_log +from yaml.nodes import Node from esphome import const, platformio_api, yaml_util from esphome.helpers import get_bool_env, mkdir_p @@ -54,7 +56,7 @@ cookie_authenticated_yes = b"yes" settings = DASHBOARD.settings -def template_args(): +def template_args() -> dict[str, Any]: version = const.__version__ if "b" in version: docs_link = "https://beta.esphome.io/" @@ -73,9 +75,12 @@ def template_args(): } -def authenticated(func): +T = TypeVar("T", bound=Callable[..., Any]) + + +def authenticated(func: T) -> T: @functools.wraps(func) - def decorator(self, *args, **kwargs): + def decorator(self, *args: Any, **kwargs: Any): if not is_authenticated(self): self.redirect("./login") return None @@ -209,7 +214,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler): tornado.ioloop.IOLoop.current().spawn_callback(self._redirect_stdout) @property - def is_process_active(self): + def is_process_active(self) -> bool: return self._proc is not None and self._proc.returncode is None @websocket_method("stdin") @@ -398,7 +403,7 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket): class SerialPortRequestHandler(BaseHandler): @authenticated - async def get(self): + async def get(self) -> None: ports = await asyncio.get_running_loop().run_in_executor(None, get_serial_ports) data = [] for port in ports: @@ -418,7 +423,7 @@ class SerialPortRequestHandler(BaseHandler): class WizardRequestHandler(BaseHandler): @authenticated - def post(self): + def post(self) -> None: from esphome import wizard kwargs = { @@ -449,7 +454,7 @@ class WizardRequestHandler(BaseHandler): class ImportRequestHandler(BaseHandler): @authenticated - def post(self): + def post(self) -> None: from esphome.components.dashboard_import import import_config dashboard = DASHBOARD @@ -504,7 +509,7 @@ class ImportRequestHandler(BaseHandler): class DownloadListRequestHandler(BaseHandler): @authenticated @bind_config - def get(self, configuration=None): + def get(self, configuration: str | None = None) -> None: storage_path = ext_storage_path(configuration) storage_json = StorageJSON.load(storage_path) if storage_json is None: @@ -512,26 +517,29 @@ class DownloadListRequestHandler(BaseHandler): return from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS - from esphome.components.esp32 import get_download_types as esp32_types - from esphome.components.esp8266 import get_download_types as esp8266_types - from esphome.components.libretiny import get_download_types as libretiny_types - from esphome.components.rp2040 import get_download_types as rp2040_types downloads = [] - platform = storage_json.target_platform.lower() + platform: str = storage_json.target_platform.lower() if platform == const.PLATFORM_RP2040: + from esphome.components.rp2040 import get_download_types as rp2040_types + downloads = rp2040_types(storage_json) elif platform == const.PLATFORM_ESP8266: + from esphome.components.esp8266 import get_download_types as esp8266_types + downloads = esp8266_types(storage_json) elif platform.upper() in ESP32_VARIANTS: + from esphome.components.esp32 import get_download_types as esp32_types + downloads = esp32_types(storage_json) - elif platform == const.PLATFORM_BK72XX: - downloads = libretiny_types(storage_json) - elif platform == const.PLATFORM_RTL87XX: + elif platform in (const.PLATFORM_RTL87XX, const.PLATFORM_BK72XX): + from esphome.components.libretiny import ( + get_download_types as libretiny_types, + ) + downloads = libretiny_types(storage_json) else: - self.send_error(418) - return + raise ValueError(f"Unknown platform {platform}") self.set_status(200) self.set_header("content-type", "application/json") @@ -551,7 +559,7 @@ class DownloadBinaryRequestHandler(BaseHandler): @authenticated @bind_config - async def get(self, configuration: str | None = None): + async def get(self, configuration: str | None = None) -> None: """Download a binary file.""" loop = asyncio.get_running_loop() compressed = self.get_argument("compressed", "0") == "1" @@ -618,7 +626,7 @@ class DownloadBinaryRequestHandler(BaseHandler): class EsphomeVersionHandler(BaseHandler): @authenticated - def get(self): + def get(self) -> None: self.set_header("Content-Type", "application/json") self.write(json.dumps({"version": const.__version__})) self.finish() @@ -626,7 +634,7 @@ class EsphomeVersionHandler(BaseHandler): class ListDevicesHandler(BaseHandler): @authenticated - async def get(self): + async def get(self) -> None: dashboard = DASHBOARD await dashboard.entries.async_request_update_entries() entries = dashboard.entries.async_all() @@ -656,7 +664,7 @@ class ListDevicesHandler(BaseHandler): class MainRequestHandler(BaseHandler): @authenticated - def get(self): + def get(self) -> None: begin = bool(self.get_argument("begin", False)) self.render( @@ -669,7 +677,7 @@ class MainRequestHandler(BaseHandler): class PrometheusServiceDiscoveryHandler(BaseHandler): @authenticated - async def get(self): + async def get(self) -> None: dashboard = DASHBOARD await dashboard.entries.async_request_update_entries() entries = dashboard.entries.async_all() @@ -698,29 +706,34 @@ class PrometheusServiceDiscoveryHandler(BaseHandler): class BoardsRequestHandler(BaseHandler): @authenticated - def get(self, platform: str): - from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS - from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS - from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS - from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS - from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS - - platform_to_boards = { - const.PLATFORM_ESP32: ESP32_BOARDS, - const.PLATFORM_ESP8266: ESP8266_BOARDS, - const.PLATFORM_RP2040: RP2040_BOARDS, - const.PLATFORM_BK72XX: BK72XX_BOARDS, - const.PLATFORM_RTL87XX: RTL87XX_BOARDS, - } + def get(self, platform: str) -> None: # filter all ESP32 variants by requested platform if platform.startswith("esp32"): + from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS + boards = { k: v - for k, v in platform_to_boards[const.PLATFORM_ESP32].items() + for k, v in ESP32_BOARDS.items() if v[const.KEY_VARIANT] == platform.upper() } + elif platform == const.PLATFORM_ESP8266: + from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS + + boards = ESP8266_BOARDS + elif platform == const.PLATFORM_RP2040: + from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS + + boards = RP2040_BOARDS + elif platform == const.PLATFORM_BK72XX: + from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS + + boards = BK72XX_BOARDS + elif platform == const.PLATFORM_RTL87XX: + from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS + + boards = RTL87XX_BOARDS else: - boards = platform_to_boards[platform] + raise ValueError(f"Unknown platform {platform}") # map to a {board_name: board_title} dict platform_boards = {key: val[const.KEY_NAME] for key, val in boards.items()} @@ -734,7 +747,7 @@ class BoardsRequestHandler(BaseHandler): class PingRequestHandler(BaseHandler): @authenticated - def get(self): + def get(self) -> None: dashboard = DASHBOARD dashboard.ping_request.set() if settings.status_use_mqtt: @@ -754,7 +767,7 @@ class PingRequestHandler(BaseHandler): class InfoRequestHandler(BaseHandler): @authenticated @bind_config - async def get(self, configuration=None): + async def get(self, configuration: str | None = None) -> None: yaml_path = settings.rel_path(configuration) dashboard = DASHBOARD entry = dashboard.entries.get(yaml_path) @@ -770,7 +783,7 @@ class InfoRequestHandler(BaseHandler): class EditRequestHandler(BaseHandler): @authenticated @bind_config - async def get(self, configuration: str | None = None): + async def get(self, configuration: str | None = None) -> None: """Get the content of a file.""" loop = asyncio.get_running_loop() filename = settings.rel_path(configuration) @@ -788,7 +801,7 @@ class EditRequestHandler(BaseHandler): @authenticated @bind_config - async def post(self, configuration: str | None = None): + async def post(self, configuration: str | None = None) -> None: """Write the content of a file.""" loop = asyncio.get_running_loop() config_file = settings.rel_path(configuration) @@ -805,7 +818,7 @@ class EditRequestHandler(BaseHandler): class DeleteRequestHandler(BaseHandler): @authenticated @bind_config - def post(self, configuration=None): + def post(self, configuration: str | None = None) -> None: config_file = settings.rel_path(configuration) storage_path = ext_storage_path(configuration) @@ -825,20 +838,20 @@ class DeleteRequestHandler(BaseHandler): class UndoDeleteRequestHandler(BaseHandler): @authenticated @bind_config - def post(self, configuration=None): + def post(self, configuration: str | None = None) -> None: config_file = settings.rel_path(configuration) trash_path = trash_storage_path() shutil.move(os.path.join(trash_path, configuration), config_file) class LoginHandler(BaseHandler): - def get(self): + def get(self) -> None: if is_authenticated(self): self.redirect("./") else: self.render_login_page() - def render_login_page(self, error=None): + def render_login_page(self, error: str | None = None) -> None: self.render( "login.template.html", error=error, @@ -847,7 +860,7 @@ class LoginHandler(BaseHandler): **template_args(), ) - def post_ha_addon_login(self): + def post_ha_addon_login(self) -> None: import requests headers = { @@ -874,7 +887,7 @@ class LoginHandler(BaseHandler): self.set_status(401) self.render_login_page(error="Invalid username or password") - def post_native_login(self): + def post_native_login(self) -> None: username = self.get_argument("username", "") password = self.get_argument("password", "") if settings.check_password(username, password): @@ -887,7 +900,7 @@ class LoginHandler(BaseHandler): self.set_status(401) self.render_login_page(error=error_str) - def post(self): + def post(self) -> None: if settings.using_ha_addon_auth: self.post_ha_addon_login() else: @@ -896,14 +909,14 @@ class LoginHandler(BaseHandler): class LogoutHandler(BaseHandler): @authenticated - def get(self): + def get(self) -> None: self.clear_cookie("authenticated") self.redirect("./login") class SecretKeysRequestHandler(BaseHandler): @authenticated - def get(self): + def get(self) -> None: filename = None for secret_filename in const.SECRETS_FILES: @@ -923,10 +936,10 @@ class SecretKeysRequestHandler(BaseHandler): class SafeLoaderIgnoreUnknown(FastestAvailableSafeLoader): - def ignore_unknown(self, node): + def ignore_unknown(self, node: Node) -> str: return f"{node.tag} {node.value}" - def construct_yaml_binary(self, node) -> str: + def construct_yaml_binary(self, node: Node) -> str: return super().construct_yaml_binary(node).decode("ascii") @@ -939,7 +952,7 @@ SafeLoaderIgnoreUnknown.add_constructor( class JsonConfigRequestHandler(BaseHandler): @authenticated @bind_config - async def get(self, configuration=None): + async def get(self, configuration: str | None = None) -> None: filename = settings.rel_path(configuration) if not os.path.isfile(filename): self.send_error(404) @@ -959,7 +972,7 @@ class JsonConfigRequestHandler(BaseHandler): self.finish() -def get_base_frontend_path(): +def get_base_frontend_path() -> str: if ENV_DEV not in os.environ: import esphome_dashboard @@ -973,12 +986,12 @@ def get_base_frontend_path(): return os.path.abspath(os.path.join(os.getcwd(), static_path, "esphome_dashboard")) -def get_static_path(*args): +def get_static_path(*args: Iterable[str]) -> str: return os.path.join(get_base_frontend_path(), "static", *args) @functools.cache -def get_static_file_url(name): +def get_static_file_url(name: str) -> str: base = f"./static/{name}" if ENV_DEV in os.environ: @@ -997,7 +1010,7 @@ def get_static_file_url(name): def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application: - def log_function(handler): + def log_function(handler: tornado.web.RequestHandler) -> None: if handler.get_status() < 400: log_method = access_log.info