From 2419bc367820b985d153d7ffdc304949cc242adb Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Thu, 17 Jun 2021 21:54:14 +0200 Subject: [PATCH] Improve config final validation (#1917) --- esphome/components/cse7766/sensor.py | 9 +- esphome/components/dfplayer/__init__.py | 9 +- esphome/components/gps/__init__.py | 5 +- esphome/components/http_request/__init__.py | 68 +++++++-------- esphome/components/rc522_spi/__init__.py | 9 +- esphome/components/rtttl/__init__.py | 24 +++-- esphome/components/sim800l/__init__.py | 7 +- esphome/components/spi/__init__.py | 28 ++++-- esphome/components/uart/__init__.py | 97 +++++++++++++-------- esphome/components/wifi/__init__.py | 22 ++--- esphome/config.py | 76 ++++++++++------ esphome/config_validation.py | 8 +- esphome/core/__init__.py | 10 +-- esphome/cpp_helpers.py | 3 +- esphome/final_validate.py | 57 ++++++++++++ esphome/loader.py | 10 ++- esphome/storage_json.py | 5 +- esphome/types.py | 18 ++++ 18 files changed, 303 insertions(+), 162 deletions(-) create mode 100644 esphome/final_validate.py create mode 100644 esphome/types.py diff --git a/esphome/components/cse7766/sensor.py b/esphome/components/cse7766/sensor.py index 98cf4da96d..4ccb346efd 100644 --- a/esphome/components/cse7766/sensor.py +++ b/esphome/components/cse7766/sensor.py @@ -45,6 +45,9 @@ CONFIG_SCHEMA = ( .extend(cv.polling_component_schema("60s")) .extend(uart.UART_DEVICE_SCHEMA) ) +FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( + "cse7766", baud_rate=4800, require_rx=True +) async def to_code(config): @@ -64,9 +67,3 @@ async def to_code(config): conf = config[CONF_POWER] sens = await sensor.new_sensor(conf) cg.add(var.set_power_sensor(sens)) - - -def validate(config, item_config): - uart.validate_device( - "cse7766", config, item_config, baud_rate=4800, require_tx=False - ) diff --git a/esphome/components/dfplayer/__init__.py b/esphome/components/dfplayer/__init__.py index 6af83888ab..3cdfc8ab85 100644 --- a/esphome/components/dfplayer/__init__.py +++ b/esphome/components/dfplayer/__init__.py @@ -68,6 +68,9 @@ CONFIG_SCHEMA = cv.All( } ).extend(uart.UART_DEVICE_SCHEMA) ) +FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( + "dfplayer", baud_rate=9600, require_tx=True +) async def to_code(config): @@ -80,12 +83,6 @@ async def to_code(config): await automation.build_automation(trigger, [], conf) -def validate(config, item_config): - uart.validate_device( - "dfplayer", config, item_config, baud_rate=9600, require_rx=False - ) - - @automation.register_action( "dfplayer.play_next", NextAction, diff --git a/esphome/components/gps/__init__.py b/esphome/components/gps/__init__.py index de3eae1115..2867aa7325 100644 --- a/esphome/components/gps/__init__.py +++ b/esphome/components/gps/__init__.py @@ -62,6 +62,7 @@ CONFIG_SCHEMA = ( .extend(cv.polling_component_schema("20s")) .extend(uart.UART_DEVICE_SCHEMA) ) +FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema("gps", require_rx=True) async def to_code(config): @@ -95,7 +96,3 @@ async def to_code(config): # https://platformio.org/lib/show/1655/TinyGPSPlus cg.add_library("1655", "1.0.2") # TinyGPSPlus, has name conflict - - -def validate(config, item_config): - uart.validate_device("gps", config, item_config, require_tx=False) diff --git a/esphome/components/http_request/__init__.py b/esphome/components/http_request/__init__.py index 650602150a..f4475060df 100644 --- a/esphome/components/http_request/__init__.py +++ b/esphome/components/http_request/__init__.py @@ -2,6 +2,7 @@ import urllib.parse as urlparse import esphome.codegen as cg import esphome.config_validation as cv +import esphome.final_validate as fv from esphome import automation from esphome.const import ( CONF_ID, @@ -14,7 +15,6 @@ from esphome.const import ( CONF_URL, ) from esphome.core import CORE, Lambda -from esphome.core.config import PLATFORMIO_ESP8266_LUT DEPENDENCIES = ["network"] AUTO_LOAD = ["json"] @@ -36,29 +36,6 @@ CONF_VERIFY_SSL = "verify_ssl" CONF_ON_RESPONSE = "on_response" -def validate_framework(config): - if CORE.is_esp32: - return config - - version = "RECOMMENDED" - if CONF_ARDUINO_VERSION in CORE.raw_config[CONF_ESPHOME]: - version = CORE.raw_config[CONF_ESPHOME][CONF_ARDUINO_VERSION] - - if version in ["LATEST", "DEV"]: - return config - - framework = ( - PLATFORMIO_ESP8266_LUT[version] - if version in PLATFORMIO_ESP8266_LUT - else version - ) - if framework < ARDUINO_VERSION_ESP8266["2.5.1"]: - raise cv.Invalid( - "This component is not supported on arduino framework version below 2.5.1" - ) - return config - - def validate_url(value): value = cv.string(value) try: @@ -92,19 +69,36 @@ def validate_secure_url(config): return config -CONFIG_SCHEMA = ( - cv.Schema( - { - cv.GenerateID(): cv.declare_id(HttpRequestComponent), - cv.Optional(CONF_USERAGENT, "ESPHome"): cv.string, - cv.Optional( - CONF_TIMEOUT, default="5s" - ): cv.positive_time_period_milliseconds, - } - ) - .add_extra(validate_framework) - .extend(cv.COMPONENT_SCHEMA) -) +CONFIG_SCHEMA = cv.Schema( + { + cv.GenerateID(): cv.declare_id(HttpRequestComponent), + cv.Optional(CONF_USERAGENT, "ESPHome"): cv.string, + cv.Optional(CONF_TIMEOUT, default="5s"): cv.positive_time_period_milliseconds, + } +).extend(cv.COMPONENT_SCHEMA) + + +def validate_framework(config): + if CORE.is_esp32: + return + + # only for ESP8266 + path = [CONF_ESPHOME, CONF_ARDUINO_VERSION] + version: str = fv.full_config.get().get_config_for_path(path) + + reverse_map = {v: k for k, v in ARDUINO_VERSION_ESP8266.items()} + framework_version = reverse_map.get(version) + if framework_version is None or framework_version == "dev": + return + + if framework_version < "2.5.1": + raise cv.Invalid( + "This component is not supported on arduino framework version below 2.5.1", + path=[cv.ROOT_CONFIG_PATH] + path, + ) + + +FINAL_VALIDATE_SCHEMA = cv.Schema(validate_framework) async def to_code(config): diff --git a/esphome/components/rc522_spi/__init__.py b/esphome/components/rc522_spi/__init__.py index 68b1e64145..77b0a99662 100644 --- a/esphome/components/rc522_spi/__init__.py +++ b/esphome/components/rc522_spi/__init__.py @@ -19,13 +19,12 @@ CONFIG_SCHEMA = cv.All( ).extend(spi.spi_device_schema(cs_pin_required=True)) ) +FINAL_VALIDATE_SCHEMA = spi.final_validate_device_schema( + "rc522_spi", require_miso=True, require_mosi=True +) + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await rc522.setup_rc522(var, config) await spi.register_spi_device(var, config) - - -def validate(config, item_config): - # validate given SPI hub is suitable for rc522_spi, it needs both miso and mosi - spi.validate_device("rc522_spi", config, item_config, True, True) diff --git a/esphome/components/rtttl/__init__.py b/esphome/components/rtttl/__init__.py index 7f860fe3d7..e9453896ac 100644 --- a/esphome/components/rtttl/__init__.py +++ b/esphome/components/rtttl/__init__.py @@ -1,6 +1,7 @@ import logging import esphome.codegen as cg import esphome.config_validation as cv +import esphome.final_validate as fv from esphome import automation from esphome.components.output import FloatOutput from esphome.const import CONF_ID, CONF_OUTPUT, CONF_PLATFORM, CONF_TRIGGER_ID @@ -36,12 +37,8 @@ CONFIG_SCHEMA = cv.Schema( ).extend(cv.COMPONENT_SCHEMA) -def validate(config, item_config): - # Not adding this to FloatOutput as this is the only component which needs `update_frequency` - - parent_config = config.get_config_by_id(item_config[CONF_OUTPUT]) - platform = parent_config[CONF_PLATFORM] - +def validate_parent_output_config(value): + platform = value.get(CONF_PLATFORM) PWM_GOOD = ["esp8266_pwm", "ledc"] PWM_BAD = [ "ac_dimmer ", @@ -55,14 +52,25 @@ def validate(config, item_config): ] if platform in PWM_BAD: - raise ValueError(f"Component rtttl cannot use {platform} as output component") + raise cv.Invalid(f"Component rtttl cannot use {platform} as output component") if platform not in PWM_GOOD: _LOGGER.warning( - "Component rtttl is not known to work with the selected output type. Make sure this output supports custom frequency output method." + "Component rtttl is not known to work with the selected output type. " + "Make sure this output supports custom frequency output method." ) +FINAL_VALIDATE_SCHEMA = cv.Schema( + { + cv.Required(CONF_OUTPUT): fv.id_declaration_match_schema( + validate_parent_output_config + ) + }, + extra=cv.ALLOW_EXTRA, +) + + async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) await cg.register_component(var, config) diff --git a/esphome/components/sim800l/__init__.py b/esphome/components/sim800l/__init__.py index 40c011a769..0887b8640f 100644 --- a/esphome/components/sim800l/__init__.py +++ b/esphome/components/sim800l/__init__.py @@ -40,6 +40,9 @@ CONFIG_SCHEMA = cv.All( .extend(cv.polling_component_schema("5s")) .extend(uart.UART_DEVICE_SCHEMA) ) +FINAL_VALIDATE_SCHEMA = uart.final_validate_device_schema( + "sim800l", baud_rate=9600, require_tx=True, require_rx=True +) async def to_code(config): @@ -54,10 +57,6 @@ async def to_code(config): ) -def validate(config, item_config): - uart.validate_device("sim800l", config, item_config, baud_rate=9600) - - SIM800L_SEND_SMS_SCHEMA = cv.Schema( { cv.GenerateID(): cv.use_id(Sim800LComponent), diff --git a/esphome/components/spi/__init__.py b/esphome/components/spi/__init__.py index e6e073c4a4..803a45814c 100644 --- a/esphome/components/spi/__init__.py +++ b/esphome/components/spi/__init__.py @@ -1,5 +1,6 @@ import esphome.codegen as cg import esphome.config_validation as cv +import esphome.final_validate as fv from esphome import pins from esphome.const import ( CONF_CLK_PIN, @@ -69,9 +70,24 @@ async def register_spi_device(var, config): cg.add(var.set_cs_pin(pin)) -def validate_device(name, config, item_config, require_mosi, require_miso): - spi_config = config.get_config_by_id(item_config[CONF_SPI_ID]) - if require_mosi and CONF_MISO_PIN not in spi_config: - raise ValueError(f"Component {name} requires parent spi to declare miso_pin") - if require_miso and CONF_MOSI_PIN not in spi_config: - raise ValueError(f"Component {name} requires parent spi to declare mosi_pin") +def final_validate_device_schema(name: str, *, require_mosi: bool, require_miso: bool): + hub_schema = {} + if require_miso: + hub_schema[ + cv.Required( + CONF_MISO_PIN, + msg=f"Component {name} requires this spi bus to declare a miso_pin", + ) + ] = cv.valid + if require_mosi: + hub_schema[ + cv.Required( + CONF_MOSI_PIN, + msg=f"Component {name} requires this spi bus to declare a mosi_pin", + ) + ] = cv.valid + + return cv.Schema( + {cv.Required(CONF_SPI_ID): fv.id_declaration_match_schema(hub_schema)}, + extra=cv.ALLOW_EXTRA, + ) diff --git a/esphome/components/uart/__init__.py b/esphome/components/uart/__init__.py index aaed333e34..d2fcac2cb6 100644 --- a/esphome/components/uart/__init__.py +++ b/esphome/components/uart/__init__.py @@ -1,5 +1,8 @@ +from typing import Optional + import esphome.codegen as cg import esphome.config_validation as cv +import esphome.final_validate as fv from esphome import pins, automation from esphome.const import ( CONF_BAUD_RATE, @@ -92,42 +95,6 @@ async def to_code(config): cg.add(var.set_parity(config[CONF_PARITY])) -def validate_device( - name, config, item_config, baud_rate=None, require_tx=True, require_rx=True -): - if not hasattr(config, "uart_devices"): - config.uart_devices = {} - devices = config.uart_devices - - uart_config = config.get_config_by_id(item_config[CONF_UART_ID]) - - uart_id = uart_config[CONF_ID] - device = devices.setdefault(uart_id, {}) - - if require_tx: - if CONF_TX_PIN not in uart_config: - raise ValueError(f"Component {name} requires parent uart to declare tx_pin") - if CONF_TX_PIN in device: - raise ValueError( - f"Component {name} cannot use the same uart.{CONF_TX_PIN} as component {device[CONF_TX_PIN]} is already using it" - ) - device[CONF_TX_PIN] = name - - if require_rx: - if CONF_RX_PIN not in uart_config: - raise ValueError(f"Component {name} requires parent uart to declare rx_pin") - if CONF_RX_PIN in device: - raise ValueError( - f"Component {name} cannot use the same uart.{CONF_RX_PIN} as component {device[CONF_RX_PIN]} is already using it" - ) - device[CONF_RX_PIN] = name - - if baud_rate and uart_config[CONF_BAUD_RATE] != baud_rate: - raise ValueError( - f"Component {name} requires parent uart baud rate be {baud_rate}" - ) - - # A schema to use for all UART devices, all UART integrations must extend this! UART_DEVICE_SCHEMA = cv.Schema( { @@ -135,6 +102,64 @@ UART_DEVICE_SCHEMA = cv.Schema( } ) +KEY_UART_DEVICES = "uart_devices" + + +def final_validate_device_schema( + name: str, + *, + baud_rate: Optional[int] = None, + require_tx: bool = False, + require_rx: bool = False, +): + def validate_baud_rate(value): + if value != baud_rate: + raise cv.Invalid( + f"Component {name} required baud rate {baud_rate} for the uart bus" + ) + return value + + def validate_pin(opt, device): + def validator(value): + if opt in device: + raise cv.Invalid( + f"The uart {opt} is used both by {name} and {device[opt]}, " + f"but can only be used by one. Please create a new uart bus for {name}." + ) + device[opt] = name + return value + + return validator + + def validate_hub(hub_config): + hub_schema = {} + uart_id = hub_config[CONF_ID] + devices = fv.full_config.get().data.setdefault(KEY_UART_DEVICES, {}) + device = devices.setdefault(uart_id, {}) + + if require_tx: + hub_schema[ + cv.Required( + CONF_TX_PIN, + msg=f"Component {name} requires this uart bus to declare a tx_pin", + ) + ] = validate_pin(CONF_TX_PIN, device) + if require_rx: + hub_schema[ + cv.Required( + CONF_RX_PIN, + msg=f"Component {name} requires this uart bus to declare a rx_pin", + ) + ] = validate_pin(CONF_RX_PIN, device) + if baud_rate is not None: + hub_schema[cv.Required(CONF_BAUD_RATE)] = validate_baud_rate + return cv.Schema(hub_schema, extra=cv.ALLOW_EXTRA)(hub_config) + + return cv.Schema( + {cv.Required(CONF_UART_ID): fv.id_declaration_match_schema(validate_hub)}, + extra=cv.ALLOW_EXTRA, + ) + async def register_uart_device(var, config): """Register a UART device, setting up all the internal values. diff --git a/esphome/components/wifi/__init__.py b/esphome/components/wifi/__init__.py index 13d698d245..54307388d6 100644 --- a/esphome/components/wifi/__init__.py +++ b/esphome/components/wifi/__init__.py @@ -1,5 +1,6 @@ import esphome.codegen as cg import esphome.config_validation as cv +import esphome.final_validate as fv from esphome import automation from esphome.automation import Condition from esphome.components.network import add_mdns_library @@ -137,16 +138,17 @@ WIFI_NETWORK_STA = WIFI_NETWORK_BASE.extend( ) -def validate(config, item_config): - if ( - (CONF_NETWORKS in item_config) - and (item_config[CONF_NETWORKS] == []) - and (CONF_AP not in item_config) - ): - if "esp32_improv" not in config: - raise ValueError( - "Please specify at least an SSID or an Access Point to create." - ) +def final_validate(config): + has_sta = bool(config.get(CONF_NETWORKS, True)) + has_ap = CONF_AP in config + has_improv = "esp32_improv" in fv.full_config.get() + if (not has_sta) and (not has_ap) and (not has_improv): + raise cv.Invalid( + "Please specify at least an SSID or an Access Point to create." + ) + + +FINAL_VALIDATE_SCHEMA = cv.Schema(final_validate) def _validate(config): diff --git a/esphome/config.py b/esphome/config.py index fcd2fac90f..93413a009c 100644 --- a/esphome/config.py +++ b/esphome/config.py @@ -21,11 +21,13 @@ from esphome.helpers import indent from esphome.util import safe_print, OrderedDict from typing import List, Optional, Tuple, Union -from esphome.core import ConfigType from esphome.loader import get_component, get_platform, ComponentManifest from esphome.yaml_util import is_secret, ESPHomeDataBase, ESPForceValue from esphome.voluptuous_schema import ExtraKeysInvalid from esphome.log import color, Fore +import esphome.final_validate as fv +import esphome.config_validation as cv +from esphome.types import ConfigType, ConfigPathType, ConfigFragmentType _LOGGER = logging.getLogger(__name__) @@ -54,7 +56,7 @@ def _path_begins_with(path, other): # type: (ConfigPath, ConfigPath) -> bool return path[: len(other)] == other -class Config(OrderedDict): +class Config(OrderedDict, fv.FinalValidateConfig): def __init__(self): super().__init__() # A list of voluptuous errors @@ -65,6 +67,7 @@ class Config(OrderedDict): self.output_paths = [] # type: List[Tuple[ConfigPath, str]] # A list of components ids with the config path self.declare_ids = [] # type: List[Tuple[core.ID, ConfigPath]] + self._data = {} def add_error(self, error): # type: (vol.Invalid) -> None @@ -72,6 +75,12 @@ class Config(OrderedDict): for err in error.errors: self.add_error(err) return + if cv.ROOT_CONFIG_PATH in error.path: + # Root value means that the path before the root should be ignored + last_root = max( + i for i, v in enumerate(error.path) if v is cv.ROOT_CONFIG_PATH + ) + error.path = error.path[last_root + 1 :] self.errors.append(error) @contextmanager @@ -140,13 +149,16 @@ class Config(OrderedDict): return doc_range - def get_nested_item(self, path): - # type: (ConfigPath) -> ConfigType + def get_nested_item( + self, path: ConfigPathType, raise_error: bool = False + ) -> ConfigFragmentType: data = self for item_index in path: try: data = data[item_index] except (KeyError, IndexError, TypeError): + if raise_error: + raise return {} return data @@ -163,11 +175,20 @@ class Config(OrderedDict): part.append(item_index) return part - def get_config_by_id(self, id): + def get_path_for_id(self, id: core.ID): + """Return the config fragment where the given ID is declared.""" for declared_id, path in self.declare_ids: if declared_id.id == str(id): - return self.get_nested_item(path[:-1]) - return None + return path + raise KeyError(f"ID {id} not found in configuration") + + def get_config_for_path(self, path: ConfigPathType) -> ConfigFragmentType: + return self.get_nested_item(path, raise_error=True) + + @property + def data(self): + """Return temporary data used by final validation functions.""" + return self._data def iter_ids(config, path=None): @@ -189,23 +210,22 @@ def do_id_pass(result): # type: (Config) -> None from esphome.cpp_generator import MockObjClass from esphome.cpp_types import Component - declare_ids = result.declare_ids # type: List[Tuple[core.ID, ConfigPath]] searching_ids = [] # type: List[Tuple[core.ID, ConfigPath]] for id, path in iter_ids(result): if id.is_declaration: if id.id is not None: # Look for duplicate definitions - match = next((v for v in declare_ids if v[0].id == id.id), None) + match = next((v for v in result.declare_ids if v[0].id == id.id), None) if match is not None: opath = "->".join(str(v) for v in match[1]) result.add_str_error(f"ID {id.id} redefined! Check {opath}", path) continue - declare_ids.append((id, path)) + result.declare_ids.append((id, path)) else: searching_ids.append((id, path)) # Resolve default ids after manual IDs - for id, _ in declare_ids: - id.resolve([v[0].id for v in declare_ids]) + for id, _ in result.declare_ids: + id.resolve([v[0].id for v in result.declare_ids]) if isinstance(id.type, MockObjClass) and id.type.inherits_from(Component): CORE.component_ids.add(id.id) @@ -213,7 +233,7 @@ def do_id_pass(result): # type: (Config) -> None for id, path in searching_ids: if id.id is not None: # manually declared - match = next((v[0] for v in declare_ids if v[0].id == id.id), None) + match = next((v[0] for v in result.declare_ids if v[0].id == id.id), None) if match is None or not match.is_manual: # No declared ID with this name import difflib @@ -224,7 +244,7 @@ def do_id_pass(result): # type: (Config) -> None ) # Find candidates matches = difflib.get_close_matches( - id.id, [v[0].id for v in declare_ids if v[0].is_manual] + id.id, [v[0].id for v in result.declare_ids if v[0].is_manual] ) if matches: matches_s = ", ".join(f'"{x}"' for x in matches) @@ -245,7 +265,7 @@ def do_id_pass(result): # type: (Config) -> None if id.id is None and id.type is not None: matches = [] - for v in declare_ids: + for v in result.declare_ids: if v[0] is None or not isinstance(v[0].type, MockObjClass): continue inherits = v[0].type.inherits_from(id.type) @@ -278,8 +298,6 @@ def do_id_pass(result): # type: (Config) -> None def recursive_check_replaceme(value): - import esphome.config_validation as cv - if isinstance(value, list): return cv.Schema([recursive_check_replaceme])(value) if isinstance(value, dict): @@ -558,14 +576,16 @@ def validate_config(config, command_line_substitutions): # 7. Final validation if not result.errors: # Inter - components validation - for path, conf, comp in validate_queue: - if comp.config_schema is None: + token = fv.full_config.set(result) + + for path, _, comp in validate_queue: + if comp.final_validate_schema is None: continue - if callable(comp.validate): - try: - comp.validate(result, result.get_nested_item(path)) - except ValueError as err: - result.add_str_error(err, path) + conf = result.get_nested_item(path) + with result.catch_error(path): + comp.final_validate_schema(conf) + + fv.full_config.reset(token) return result @@ -621,8 +641,12 @@ def _format_vol_invalid(ex, config): ) elif "extra keys not allowed" in str(ex): message += "[{}] is an invalid option for [{}].".format(ex.path[-1], paren) - elif "required key not provided" in str(ex): - message += "'{}' is a required option for [{}].".format(ex.path[-1], paren) + elif isinstance(ex, vol.RequiredFieldInvalid): + if ex.msg == "required key not provided": + message += "'{}' is a required option for [{}].".format(ex.path[-1], paren) + else: + # Required has set a custom error message + message += ex.msg else: message += humanize_error(config, ex) diff --git a/esphome/config_validation.py b/esphome/config_validation.py index 2cdb6b0b76..7292cc3af5 100644 --- a/esphome/config_validation.py +++ b/esphome/config_validation.py @@ -75,6 +75,9 @@ Inclusive = vol.Inclusive ALLOW_EXTRA = vol.ALLOW_EXTRA UNDEFINED = vol.UNDEFINED RequiredFieldInvalid = vol.RequiredFieldInvalid +# this sentinel object can be placed in an 'Invalid' path to say +# the rest of the error path is relative to the root config path +ROOT_CONFIG_PATH = object() RESERVED_IDS = [ # C++ keywords http://en.cppreference.com/w/cpp/keyword @@ -218,8 +221,8 @@ class Required(vol.Required): - *not* the `config.get(CONF_)` syntax. """ - def __init__(self, key): - super().__init__(key) + def __init__(self, key, msg=None): + super().__init__(key, msg=msg) def check_not_templatable(value): @@ -1073,6 +1076,7 @@ def invalid(message): def valid(value): + """A validator that is always valid and returns the value as-is.""" return value diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 1841dfd8be..df98e1b150 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -2,7 +2,7 @@ import logging import math import os import re -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from esphome.const import ( CONF_ARDUINO_VERSION, @@ -23,6 +23,7 @@ from esphome.util import OrderedDict if TYPE_CHECKING: from ..cpp_generator import MockObj, MockObjClass, Statement + from ..types import ConfigType _LOGGER = logging.getLogger(__name__) @@ -462,9 +463,9 @@ class EsphomeCore: # The board that's used (for example nodemcuv2) self.board: Optional[str] = None # The full raw configuration - self.raw_config: Optional[ConfigType] = None + self.raw_config: Optional["ConfigType"] = None # The validated configuration, this is None until the config has been validated - self.config: Optional[ConfigType] = None + self.config: Optional["ConfigType"] = None # The pending tasks in the task queue (mostly for C++ generation) # This is a priority queue (with heapq) # Each item is a tuple of form: (-priority, unique number, task) @@ -752,6 +753,3 @@ class EnumValue: CORE = EsphomeCore() - -ConfigType = Dict[str, Any] -CoreType = EsphomeCore diff --git a/esphome/cpp_helpers.py b/esphome/cpp_helpers.py index 1c52f38e50..1d66eabf6c 100644 --- a/esphome/cpp_helpers.py +++ b/esphome/cpp_helpers.py @@ -8,7 +8,8 @@ from esphome.const import ( ) # pylint: disable=unused-import -from esphome.core import coroutine, ID, CORE, ConfigType +from esphome.core import coroutine, ID, CORE +from esphome.types import ConfigType from esphome.cpp_generator import RawExpression, add, get_variable from esphome.cpp_types import App, GPIOPin from esphome.util import Registry, RegistryEntry diff --git a/esphome/final_validate.py b/esphome/final_validate.py new file mode 100644 index 0000000000..50fdbaf3f4 --- /dev/null +++ b/esphome/final_validate.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod, abstractproperty +from typing import Dict, Any +import contextvars + +from esphome.types import ConfigFragmentType, ID, ConfigPathType +import esphome.config_validation as cv + + +class FinalValidateConfig(ABC): + @abstractproperty + def data(self) -> Dict[str, Any]: + """A dictionary that can be used by post validation functions to store + global data during the validation phase. Each component should store its + data under a unique key + """ + + @abstractmethod + def get_path_for_id(self, id: ID) -> ConfigPathType: + """Get the config path a given ID has been declared in. + + This is the location under the _validated_ config (for example, with cv.ensure_list applied) + Raises KeyError if the id was not declared in the configuration. + """ + + @abstractmethod + def get_config_for_path(self, path: ConfigPathType) -> ConfigFragmentType: + """Get the config fragment for the given global path. + + Raises KeyError if a key in the path does not exist. + """ + + +FinalValidateConfig.register(dict) + +# Context variable tracking the full config for some final validation functions. +full_config: contextvars.ContextVar[FinalValidateConfig] = contextvars.ContextVar( + "full_config" +) + + +def id_declaration_match_schema(schema): + """A final-validation schema function that applies a schema to the outer config fragment of an + ID declaration. + + This validator must be applied to ID values. + """ + if not isinstance(schema, cv.Schema): + schema = cv.Schema(schema, extra=cv.ALLOW_EXTRA) + + def validator(value): + fconf = full_config.get() + path = fconf.get_path_for_id(value)[:-1] + declaration_config = fconf.get_config_for_path(path) + with cv.prepend_path([cv.ROOT_CONFIG_PATH] + path): + return schema(declaration_config) + + return validator diff --git a/esphome/loader.py b/esphome/loader.py index d9d407d787..f74fc6367d 100644 --- a/esphome/loader.py +++ b/esphome/loader.py @@ -12,6 +12,7 @@ from pathlib import Path from esphome.const import ESP_PLATFORMS, SOURCE_FILE_EXTENSIONS import esphome.core.config from esphome.core import CORE +from esphome.types import ConfigType _LOGGER = logging.getLogger(__name__) @@ -81,8 +82,13 @@ class ComponentManifest: return getattr(self.module, "CODEOWNERS", []) @property - def validate(self): - return getattr(self.module, "validate", None) + def final_validate_schema(self) -> Optional[Callable[[ConfigType], None]]: + """Components can declare a `FINAL_VALIDATE_SCHEMA` cv.Schema that gets called + after the main validation. In that function checks across components can be made. + + Note that the function can't mutate the configuration - no changes are saved + """ + return getattr(self.module, "FINAL_VALIDATE_SCHEMA", None) @property def source_files(self) -> Dict[Path, SourceFile]: diff --git a/esphome/storage_json.py b/esphome/storage_json.py index 6f81e0d96a..c0fbc6edf7 100644 --- a/esphome/storage_json.py +++ b/esphome/storage_json.py @@ -4,14 +4,13 @@ from datetime import datetime import json import logging import os +from typing import Any, Optional, List from esphome import const from esphome.core import CORE from esphome.helpers import write_file_if_changed -# pylint: disable=unused-import, wrong-import-order -from esphome.core import CoreType -from typing import Any, Optional, List +from esphome.types import CoreType _LOGGER = logging.getLogger(__name__) diff --git a/esphome/types.py b/esphome/types.py new file mode 100644 index 0000000000..6bbfb00ce6 --- /dev/null +++ b/esphome/types.py @@ -0,0 +1,18 @@ +"""This helper module tracks commonly used types in the esphome python codebase.""" +from typing import Dict, Union, List + +from esphome.core import ID, Lambda, EsphomeCore + +ConfigFragmentType = Union[ + str, + int, + float, + None, + Dict[Union[str, int], "ConfigFragmentType"], + List["ConfigFragmentType"], + ID, + Lambda, +] +ConfigType = Dict[str, ConfigFragmentType] +CoreType = EsphomeCore +ConfigPathType = Union[str, int]