dashboard: Small cleanups to dashboard (#5841)

This commit is contained in:
J. Nick Koston 2023-11-27 16:39:24 -06:00 committed by GitHub
parent 460362b11f
commit 4e6d3729e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 87 additions and 74 deletions

View File

@ -262,7 +262,7 @@ class DashboardEntry:
self.state = EntryState.UNKNOWN
self._to_dict: dict[str, Any] | None = None
def __repr__(self):
def __repr__(self) -> str:
"""Return the representation of this entry."""
return (
f"DashboardEntry(path={self.path} "

View File

@ -23,45 +23,45 @@ class DashboardSettings:
self.cookie_secret: str | None = None
self.absolute_config_dir: Path | None = None
def parse_args(self, args):
def parse_args(self, args: Any) -> None:
self.on_ha_addon: bool = args.ha_addon
password: str = args.password or os.getenv("PASSWORD", "")
password = args.password or os.getenv("PASSWORD") or ""
if not self.on_ha_addon:
self.username: str = args.username or os.getenv("USERNAME", "")
self.username = args.username or os.getenv("USERNAME") or ""
self.using_password = bool(password)
if self.using_password:
self.password_hash = password_hash(password)
self.config_dir: str = args.configuration
self.absolute_config_dir: Path = Path(self.config_dir).resolve()
self.config_dir = args.configuration
self.absolute_config_dir = Path(self.config_dir).resolve()
CORE.config_path = os.path.join(self.config_dir, ".")
@property
def relative_url(self):
return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL", "/")
def relative_url(self) -> str:
return os.getenv("ESPHOME_DASHBOARD_RELATIVE_URL") or "/"
@property
def status_use_ping(self):
return get_bool_env("ESPHOME_DASHBOARD_USE_PING")
@property
def status_use_mqtt(self):
def status_use_mqtt(self) -> bool:
return get_bool_env("ESPHOME_DASHBOARD_USE_MQTT")
@property
def using_ha_addon_auth(self):
def using_ha_addon_auth(self) -> bool:
if not self.on_ha_addon:
return False
return not get_bool_env("DISABLE_HA_AUTHENTICATION")
@property
def using_auth(self):
def using_auth(self) -> bool:
return self.using_password or self.using_ha_addon_auth
@property
def streamer_mode(self):
def streamer_mode(self) -> bool:
return get_bool_env("ESPHOME_STREAMER_MODE")
def check_password(self, username, password):
def check_password(self, username: str, password: str) -> bool:
if not self.using_auth:
return True
if username != self.username:

View File

@ -14,12 +14,14 @@ import shutil
import subprocess
import threading
from pathlib import Path
from typing import Any
from typing import Any, Callable, TypeVar
from collections.abc import Iterable
import tornado
import tornado.concurrent
import tornado.gen
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.netutil
@ -27,9 +29,9 @@ import tornado.process
import tornado.queues
import tornado.web
import tornado.websocket
import tornado.httputil
import yaml
from tornado.log import access_log
from yaml.nodes import Node
from esphome import const, platformio_api, yaml_util
from esphome.helpers import get_bool_env, mkdir_p
@ -54,7 +56,7 @@ cookie_authenticated_yes = b"yes"
settings = DASHBOARD.settings
def template_args():
def template_args() -> dict[str, Any]:
version = const.__version__
if "b" in version:
docs_link = "https://beta.esphome.io/"
@ -73,9 +75,12 @@ def template_args():
}
def authenticated(func):
T = TypeVar("T", bound=Callable[..., Any])
def authenticated(func: T) -> T:
@functools.wraps(func)
def decorator(self, *args, **kwargs):
def decorator(self, *args: Any, **kwargs: Any):
if not is_authenticated(self):
self.redirect("./login")
return None
@ -209,7 +214,7 @@ class EsphomeCommandWebSocket(tornado.websocket.WebSocketHandler):
tornado.ioloop.IOLoop.current().spawn_callback(self._redirect_stdout)
@property
def is_process_active(self):
def is_process_active(self) -> bool:
return self._proc is not None and self._proc.returncode is None
@websocket_method("stdin")
@ -398,7 +403,7 @@ class EsphomeUpdateAllHandler(EsphomeCommandWebSocket):
class SerialPortRequestHandler(BaseHandler):
@authenticated
async def get(self):
async def get(self) -> None:
ports = await asyncio.get_running_loop().run_in_executor(None, get_serial_ports)
data = []
for port in ports:
@ -418,7 +423,7 @@ class SerialPortRequestHandler(BaseHandler):
class WizardRequestHandler(BaseHandler):
@authenticated
def post(self):
def post(self) -> None:
from esphome import wizard
kwargs = {
@ -449,7 +454,7 @@ class WizardRequestHandler(BaseHandler):
class ImportRequestHandler(BaseHandler):
@authenticated
def post(self):
def post(self) -> None:
from esphome.components.dashboard_import import import_config
dashboard = DASHBOARD
@ -504,7 +509,7 @@ class ImportRequestHandler(BaseHandler):
class DownloadListRequestHandler(BaseHandler):
@authenticated
@bind_config
def get(self, configuration=None):
def get(self, configuration: str | None = None) -> None:
storage_path = ext_storage_path(configuration)
storage_json = StorageJSON.load(storage_path)
if storage_json is None:
@ -512,26 +517,29 @@ class DownloadListRequestHandler(BaseHandler):
return
from esphome.components.esp32 import VARIANTS as ESP32_VARIANTS
from esphome.components.esp32 import get_download_types as esp32_types
from esphome.components.esp8266 import get_download_types as esp8266_types
from esphome.components.libretiny import get_download_types as libretiny_types
from esphome.components.rp2040 import get_download_types as rp2040_types
downloads = []
platform = storage_json.target_platform.lower()
platform: str = storage_json.target_platform.lower()
if platform == const.PLATFORM_RP2040:
from esphome.components.rp2040 import get_download_types as rp2040_types
downloads = rp2040_types(storage_json)
elif platform == const.PLATFORM_ESP8266:
from esphome.components.esp8266 import get_download_types as esp8266_types
downloads = esp8266_types(storage_json)
elif platform.upper() in ESP32_VARIANTS:
from esphome.components.esp32 import get_download_types as esp32_types
downloads = esp32_types(storage_json)
elif platform == const.PLATFORM_BK72XX:
downloads = libretiny_types(storage_json)
elif platform == const.PLATFORM_RTL87XX:
elif platform in (const.PLATFORM_RTL87XX, const.PLATFORM_BK72XX):
from esphome.components.libretiny import (
get_download_types as libretiny_types,
)
downloads = libretiny_types(storage_json)
else:
self.send_error(418)
return
raise ValueError(f"Unknown platform {platform}")
self.set_status(200)
self.set_header("content-type", "application/json")
@ -551,7 +559,7 @@ class DownloadBinaryRequestHandler(BaseHandler):
@authenticated
@bind_config
async def get(self, configuration: str | None = None):
async def get(self, configuration: str | None = None) -> None:
"""Download a binary file."""
loop = asyncio.get_running_loop()
compressed = self.get_argument("compressed", "0") == "1"
@ -618,7 +626,7 @@ class DownloadBinaryRequestHandler(BaseHandler):
class EsphomeVersionHandler(BaseHandler):
@authenticated
def get(self):
def get(self) -> None:
self.set_header("Content-Type", "application/json")
self.write(json.dumps({"version": const.__version__}))
self.finish()
@ -626,7 +634,7 @@ class EsphomeVersionHandler(BaseHandler):
class ListDevicesHandler(BaseHandler):
@authenticated
async def get(self):
async def get(self) -> None:
dashboard = DASHBOARD
await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all()
@ -656,7 +664,7 @@ class ListDevicesHandler(BaseHandler):
class MainRequestHandler(BaseHandler):
@authenticated
def get(self):
def get(self) -> None:
begin = bool(self.get_argument("begin", False))
self.render(
@ -669,7 +677,7 @@ class MainRequestHandler(BaseHandler):
class PrometheusServiceDiscoveryHandler(BaseHandler):
@authenticated
async def get(self):
async def get(self) -> None:
dashboard = DASHBOARD
await dashboard.entries.async_request_update_entries()
entries = dashboard.entries.async_all()
@ -698,29 +706,34 @@ class PrometheusServiceDiscoveryHandler(BaseHandler):
class BoardsRequestHandler(BaseHandler):
@authenticated
def get(self, platform: str):
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
platform_to_boards = {
const.PLATFORM_ESP32: ESP32_BOARDS,
const.PLATFORM_ESP8266: ESP8266_BOARDS,
const.PLATFORM_RP2040: RP2040_BOARDS,
const.PLATFORM_BK72XX: BK72XX_BOARDS,
const.PLATFORM_RTL87XX: RTL87XX_BOARDS,
}
def get(self, platform: str) -> None:
# filter all ESP32 variants by requested platform
if platform.startswith("esp32"):
from esphome.components.esp32.boards import BOARDS as ESP32_BOARDS
boards = {
k: v
for k, v in platform_to_boards[const.PLATFORM_ESP32].items()
for k, v in ESP32_BOARDS.items()
if v[const.KEY_VARIANT] == platform.upper()
}
elif platform == const.PLATFORM_ESP8266:
from esphome.components.esp8266.boards import BOARDS as ESP8266_BOARDS
boards = ESP8266_BOARDS
elif platform == const.PLATFORM_RP2040:
from esphome.components.rp2040.boards import BOARDS as RP2040_BOARDS
boards = RP2040_BOARDS
elif platform == const.PLATFORM_BK72XX:
from esphome.components.bk72xx.boards import BOARDS as BK72XX_BOARDS
boards = BK72XX_BOARDS
elif platform == const.PLATFORM_RTL87XX:
from esphome.components.rtl87xx.boards import BOARDS as RTL87XX_BOARDS
boards = RTL87XX_BOARDS
else:
boards = platform_to_boards[platform]
raise ValueError(f"Unknown platform {platform}")
# map to a {board_name: board_title} dict
platform_boards = {key: val[const.KEY_NAME] for key, val in boards.items()}
@ -734,7 +747,7 @@ class BoardsRequestHandler(BaseHandler):
class PingRequestHandler(BaseHandler):
@authenticated
def get(self):
def get(self) -> None:
dashboard = DASHBOARD
dashboard.ping_request.set()
if settings.status_use_mqtt:
@ -754,7 +767,7 @@ class PingRequestHandler(BaseHandler):
class InfoRequestHandler(BaseHandler):
@authenticated
@bind_config
async def get(self, configuration=None):
async def get(self, configuration: str | None = None) -> None:
yaml_path = settings.rel_path(configuration)
dashboard = DASHBOARD
entry = dashboard.entries.get(yaml_path)
@ -770,7 +783,7 @@ class InfoRequestHandler(BaseHandler):
class EditRequestHandler(BaseHandler):
@authenticated
@bind_config
async def get(self, configuration: str | None = None):
async def get(self, configuration: str | None = None) -> None:
"""Get the content of a file."""
loop = asyncio.get_running_loop()
filename = settings.rel_path(configuration)
@ -788,7 +801,7 @@ class EditRequestHandler(BaseHandler):
@authenticated
@bind_config
async def post(self, configuration: str | None = None):
async def post(self, configuration: str | None = None) -> None:
"""Write the content of a file."""
loop = asyncio.get_running_loop()
config_file = settings.rel_path(configuration)
@ -805,7 +818,7 @@ class EditRequestHandler(BaseHandler):
class DeleteRequestHandler(BaseHandler):
@authenticated
@bind_config
def post(self, configuration=None):
def post(self, configuration: str | None = None) -> None:
config_file = settings.rel_path(configuration)
storage_path = ext_storage_path(configuration)
@ -825,20 +838,20 @@ class DeleteRequestHandler(BaseHandler):
class UndoDeleteRequestHandler(BaseHandler):
@authenticated
@bind_config
def post(self, configuration=None):
def post(self, configuration: str | None = None) -> None:
config_file = settings.rel_path(configuration)
trash_path = trash_storage_path()
shutil.move(os.path.join(trash_path, configuration), config_file)
class LoginHandler(BaseHandler):
def get(self):
def get(self) -> None:
if is_authenticated(self):
self.redirect("./")
else:
self.render_login_page()
def render_login_page(self, error=None):
def render_login_page(self, error: str | None = None) -> None:
self.render(
"login.template.html",
error=error,
@ -847,7 +860,7 @@ class LoginHandler(BaseHandler):
**template_args(),
)
def post_ha_addon_login(self):
def post_ha_addon_login(self) -> None:
import requests
headers = {
@ -874,7 +887,7 @@ class LoginHandler(BaseHandler):
self.set_status(401)
self.render_login_page(error="Invalid username or password")
def post_native_login(self):
def post_native_login(self) -> None:
username = self.get_argument("username", "")
password = self.get_argument("password", "")
if settings.check_password(username, password):
@ -887,7 +900,7 @@ class LoginHandler(BaseHandler):
self.set_status(401)
self.render_login_page(error=error_str)
def post(self):
def post(self) -> None:
if settings.using_ha_addon_auth:
self.post_ha_addon_login()
else:
@ -896,14 +909,14 @@ class LoginHandler(BaseHandler):
class LogoutHandler(BaseHandler):
@authenticated
def get(self):
def get(self) -> None:
self.clear_cookie("authenticated")
self.redirect("./login")
class SecretKeysRequestHandler(BaseHandler):
@authenticated
def get(self):
def get(self) -> None:
filename = None
for secret_filename in const.SECRETS_FILES:
@ -923,10 +936,10 @@ class SecretKeysRequestHandler(BaseHandler):
class SafeLoaderIgnoreUnknown(FastestAvailableSafeLoader):
def ignore_unknown(self, node):
def ignore_unknown(self, node: Node) -> str:
return f"{node.tag} {node.value}"
def construct_yaml_binary(self, node) -> str:
def construct_yaml_binary(self, node: Node) -> str:
return super().construct_yaml_binary(node).decode("ascii")
@ -939,7 +952,7 @@ SafeLoaderIgnoreUnknown.add_constructor(
class JsonConfigRequestHandler(BaseHandler):
@authenticated
@bind_config
async def get(self, configuration=None):
async def get(self, configuration: str | None = None) -> None:
filename = settings.rel_path(configuration)
if not os.path.isfile(filename):
self.send_error(404)
@ -959,7 +972,7 @@ class JsonConfigRequestHandler(BaseHandler):
self.finish()
def get_base_frontend_path():
def get_base_frontend_path() -> str:
if ENV_DEV not in os.environ:
import esphome_dashboard
@ -973,12 +986,12 @@ def get_base_frontend_path():
return os.path.abspath(os.path.join(os.getcwd(), static_path, "esphome_dashboard"))
def get_static_path(*args):
def get_static_path(*args: Iterable[str]) -> str:
return os.path.join(get_base_frontend_path(), "static", *args)
@functools.cache
def get_static_file_url(name):
def get_static_file_url(name: str) -> str:
base = f"./static/{name}"
if ENV_DEV in os.environ:
@ -997,7 +1010,7 @@ def get_static_file_url(name):
def make_app(debug=get_bool_env(ENV_DEV)) -> tornado.web.Application:
def log_function(handler):
def log_function(handler: tornado.web.RequestHandler) -> None:
if handler.get_status() < 400:
log_method = access_log.info