esphome/esphome/yaml_util.py
2024-07-29 14:07:44 +12:00

585 lines
21 KiB
Python

from __future__ import annotations
import fnmatch
import functools
import inspect
from io import TextIOWrapper
import logging
import math
import os
from typing import Any
import uuid
import yaml
from yaml import SafeLoader as PurePythonLoader
import yaml.constructor
try:
from yaml import CSafeLoader as FastestAvailableSafeLoader
except ImportError:
FastestAvailableSafeLoader = PurePythonLoader
from esphome import core
from esphome.config_helpers import Extend, Remove
from esphome.core import (
CORE,
DocumentRange,
EsphomeError,
IPAddress,
Lambda,
MACAddress,
TimePeriod,
)
from esphome.helpers import add_class_to_obj
from esphome.util import OrderedDict, filter_yaml_files
_LOGGER = logging.getLogger(__name__)
# Mostly copied from Home Assistant because that code works fine and
# let's not reinvent the wheel here
SECRET_YAML = "secrets.yaml"
_SECRET_CACHE = {}
_SECRET_VALUES = {}
class ESPHomeDataBase:
@property
def esp_range(self):
return getattr(self, "_esp_range", None)
@property
def content_offset(self):
return getattr(self, "_content_offset", 0)
def from_node(self, node):
# pylint: disable=attribute-defined-outside-init
self._esp_range = DocumentRange.from_marks(node.start_mark, node.end_mark)
if isinstance(node, yaml.ScalarNode):
if node.style is not None and node.style in "|>":
self._content_offset = 1
def from_database(self, database):
# pylint: disable=attribute-defined-outside-init
self._esp_range = database.esp_range
self._content_offset = database.content_offset
class ESPForceValue:
pass
def make_data_base(value, from_database: ESPHomeDataBase = None):
try:
value = add_class_to_obj(value, ESPHomeDataBase)
if from_database is not None:
value.from_database(from_database)
return value
except TypeError:
# Adding class failed, ignore error
return value
def _add_data_ref(fn):
@functools.wraps(fn)
def wrapped(loader, node):
res = fn(loader, node)
# newer PyYAML versions use generators, resolve them
if inspect.isgenerator(res):
generator = res
res = next(generator)
# Let generator finish
for _ in generator:
pass
res = make_data_base(res)
if isinstance(res, ESPHomeDataBase):
res.from_node(node)
return res
return wrapped
class ESPHomeLoaderMixin:
"""Loader class that keeps track of line numbers."""
@_add_data_ref
def construct_yaml_int(self, node):
return super().construct_yaml_int(node)
@_add_data_ref
def construct_yaml_float(self, node):
return super().construct_yaml_float(node)
@_add_data_ref
def construct_yaml_binary(self, node):
return super().construct_yaml_binary(node)
@_add_data_ref
def construct_yaml_omap(self, node):
return super().construct_yaml_omap(node)
@_add_data_ref
def construct_yaml_str(self, node):
return super().construct_yaml_str(node)
@_add_data_ref
def construct_yaml_seq(self, node):
return super().construct_yaml_seq(node)
@_add_data_ref
def construct_yaml_map(self, node):
"""Traverses the given mapping node and returns a list of constructed key-value pairs."""
assert isinstance(node, yaml.MappingNode)
# A list of key-value pairs we find in the current mapping
pairs = []
# A list of key-value pairs we find while resolving merges ('<<' key), will be
# added to pairs in a second pass
merge_pairs = []
# A dict of seen keys so far, used to alert the user of duplicate keys and checking
# which keys to merge.
# Value of dict items is the start mark of the previous declaration.
seen_keys = {}
for key_node, value_node in node.value:
# merge key is '<<'
is_merge_key = key_node.tag == "tag:yaml.org,2002:merge"
# key has no explicit tag set
is_default_tag = key_node.tag == "tag:yaml.org,2002:value"
if is_default_tag:
# Default tag for mapping keys is string
key_node.tag = "tag:yaml.org,2002:str"
if not is_merge_key:
# base case, this is a simple key-value pair
key = self.construct_object(key_node)
value = self.construct_object(value_node)
# Check if key is hashable
try:
hash(key)
except TypeError:
# pylint: disable=raise-missing-from
raise yaml.constructor.ConstructorError(
f'Invalid key "{key}" (not hashable)', key_node.start_mark
)
key = make_data_base(str(key))
key.from_node(key_node)
# Check if it is a duplicate key
if key in seen_keys:
raise yaml.constructor.ConstructorError(
f'Duplicate key "{key}"',
key_node.start_mark,
"NOTE: Previous declaration here:",
seen_keys[key],
)
seen_keys[key] = key_node.start_mark
# Add to pairs
pairs.append((key, value))
continue
# This is a merge key, resolve value and add to merge_pairs
value = self.construct_object(value_node)
if isinstance(value, dict):
# base case, copy directly to merge_pairs
# direct merge, like "<<: {some_key: some_value}"
merge_pairs.extend(value.items())
elif isinstance(value, list):
# sequence merge, like "<<: [{some_key: some_value}, {other_key: some_value}]"
for item in value:
if not isinstance(item, dict):
raise yaml.constructor.ConstructorError(
"While constructing a mapping",
node.start_mark,
f"Expected a mapping for merging, but found {type(item)}",
value_node.start_mark,
)
merge_pairs.extend(item.items())
else:
raise yaml.constructor.ConstructorError(
"While constructing a mapping",
node.start_mark,
f"Expected a mapping or list of mappings for merging, but found {type(value)}",
value_node.start_mark,
)
if merge_pairs:
# We found some merge keys along the way, merge them into base pairs
# https://yaml.org/type/merge.html
# Construct a new merge set with values overridden by current mapping or earlier
# sequence entries removed
for key, value in merge_pairs:
if key in seen_keys:
# key already in the current map or from an earlier merge sequence entry,
# do not override
#
# "... each of its key/value pairs is inserted into the current mapping,
# unless the key already exists in it."
#
# "If the value associated with the merge key is a sequence, then this sequence
# is expected to contain mapping nodes and each of these nodes is merged in
# turn according to its order in the sequence. Keys in mapping nodes earlier
# in the sequence override keys specified in later mapping nodes."
continue
pairs.append((key, value))
# Add key node to seen keys, for sequence merge values.
seen_keys[key] = None
return OrderedDict(pairs)
@_add_data_ref
def construct_env_var(self, node):
args = node.value.split()
# Check for a default value
if len(args) > 1:
return os.getenv(args[0], " ".join(args[1:]))
if args[0] in os.environ:
return os.environ[args[0]]
raise yaml.MarkedYAMLError(
f"Environment variable '{node.value}' not defined", node.start_mark
)
@property
def _directory(self):
return os.path.dirname(self.name)
def _rel_path(self, *args):
return os.path.join(self._directory, *args)
@_add_data_ref
def construct_secret(self, node):
try:
secrets = _load_yaml_internal(self._rel_path(SECRET_YAML))
except EsphomeError as e:
if self.name == CORE.config_path:
raise e
try:
main_config_dir = os.path.dirname(CORE.config_path)
main_secret_yml = os.path.join(main_config_dir, SECRET_YAML)
secrets = _load_yaml_internal(main_secret_yml)
except EsphomeError as er:
raise EsphomeError(f"{e}\n{er}") from er
if node.value not in secrets:
raise yaml.MarkedYAMLError(
f"Secret '{node.value}' not defined", node.start_mark
)
val = secrets[node.value]
_SECRET_VALUES[str(val)] = node.value
return val
@_add_data_ref
def construct_include(self, node):
def extract_file_vars(node):
fields = self.construct_yaml_map(node)
file = fields.get("file")
if file is None:
raise yaml.MarkedYAMLError("Must include 'file'", node.start_mark)
vars = fields.get("vars")
if vars:
vars = {k: str(v) for k, v in vars.items()}
return file, vars
def substitute_vars(config, vars):
from esphome.components import substitutions
from esphome.const import CONF_DEFAULTS, CONF_SUBSTITUTIONS
org_subs = None
result = config
if not isinstance(config, dict):
# when the included yaml contains a list or a scalar
# wrap it into an OrderedDict because do_substitution_pass expects it
result = OrderedDict([("yaml", config)])
elif CONF_SUBSTITUTIONS in result:
org_subs = result.pop(CONF_SUBSTITUTIONS)
defaults = {}
if CONF_DEFAULTS in result:
defaults = result.pop(CONF_DEFAULTS)
result[CONF_SUBSTITUTIONS] = vars
for k, v in defaults.items():
if k not in result[CONF_SUBSTITUTIONS]:
result[CONF_SUBSTITUTIONS][k] = v
# Ignore missing vars that refer to the top level substitutions
substitutions.do_substitution_pass(result, None, ignore_missing=True)
result.pop(CONF_SUBSTITUTIONS)
if not isinstance(config, dict):
result = result["yaml"] # unwrap the result
elif org_subs:
result[CONF_SUBSTITUTIONS] = org_subs
return result
if isinstance(node, yaml.nodes.MappingNode):
file, vars = extract_file_vars(node)
else:
file, vars = node.value, None
result = _load_yaml_internal(self._rel_path(file))
if not vars:
vars = {}
result = substitute_vars(result, vars)
return result
@_add_data_ref
def construct_include_dir_list(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
return [_load_yaml_internal(f) for f in files]
@_add_data_ref
def construct_include_dir_merge_list(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
merged_list = []
for fname in files:
loaded_yaml = _load_yaml_internal(fname)
if isinstance(loaded_yaml, list):
merged_list.extend(loaded_yaml)
return merged_list
@_add_data_ref
def construct_include_dir_named(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict()
for fname in files:
filename = os.path.splitext(os.path.basename(fname))[0]
mapping[filename] = _load_yaml_internal(fname)
return mapping
@_add_data_ref
def construct_include_dir_merge_named(self, node):
files = filter_yaml_files(_find_files(self._rel_path(node.value), "*.yaml"))
mapping = OrderedDict()
for fname in files:
loaded_yaml = _load_yaml_internal(fname)
if isinstance(loaded_yaml, dict):
mapping.update(loaded_yaml)
return mapping
@_add_data_ref
def construct_lambda(self, node):
return Lambda(str(node.value))
@_add_data_ref
def construct_force(self, node):
obj = self.construct_scalar(node)
return add_class_to_obj(obj, ESPForceValue)
@_add_data_ref
def construct_extend(self, node):
return Extend(str(node.value))
@_add_data_ref
def construct_remove(self, node):
return Remove(str(node.value))
class ESPHomeLoader(ESPHomeLoaderMixin, FastestAvailableSafeLoader):
"""Loader class that keeps track of line numbers."""
class ESPHomePurePythonLoader(ESPHomeLoaderMixin, PurePythonLoader):
"""Loader class that keeps track of line numbers."""
for _loader in (ESPHomeLoader, ESPHomePurePythonLoader):
_loader.add_constructor("tag:yaml.org,2002:int", _loader.construct_yaml_int)
_loader.add_constructor("tag:yaml.org,2002:float", _loader.construct_yaml_float)
_loader.add_constructor("tag:yaml.org,2002:binary", _loader.construct_yaml_binary)
_loader.add_constructor("tag:yaml.org,2002:omap", _loader.construct_yaml_omap)
_loader.add_constructor("tag:yaml.org,2002:str", _loader.construct_yaml_str)
_loader.add_constructor("tag:yaml.org,2002:seq", _loader.construct_yaml_seq)
_loader.add_constructor("tag:yaml.org,2002:map", _loader.construct_yaml_map)
_loader.add_constructor("!env_var", _loader.construct_env_var)
_loader.add_constructor("!secret", _loader.construct_secret)
_loader.add_constructor("!include", _loader.construct_include)
_loader.add_constructor("!include_dir_list", _loader.construct_include_dir_list)
_loader.add_constructor(
"!include_dir_merge_list", _loader.construct_include_dir_merge_list
)
_loader.add_constructor("!include_dir_named", _loader.construct_include_dir_named)
_loader.add_constructor(
"!include_dir_merge_named", _loader.construct_include_dir_merge_named
)
_loader.add_constructor("!lambda", _loader.construct_lambda)
_loader.add_constructor("!force", _loader.construct_force)
_loader.add_constructor("!extend", _loader.construct_extend)
_loader.add_constructor("!remove", _loader.construct_remove)
def load_yaml(fname: str, clear_secrets: bool = True) -> Any:
if clear_secrets:
_SECRET_VALUES.clear()
_SECRET_CACHE.clear()
return _load_yaml_internal(fname)
def parse_yaml(file_name: str, file_handle: TextIOWrapper) -> Any:
"""Parse a YAML file."""
try:
return _load_yaml_internal_with_type(ESPHomeLoader, file_name, file_handle)
except EsphomeError:
# Loading failed, so we now load with the Python loader which has more
# readable exceptions
# Rewind the stream so we can try again
file_handle.seek(0, 0)
return _load_yaml_internal_with_type(
ESPHomePurePythonLoader, file_name, file_handle
)
def _load_yaml_internal(fname: str) -> Any:
"""Load a YAML file."""
try:
with open(fname, encoding="utf-8") as f_handle:
return parse_yaml(fname, f_handle)
except (UnicodeDecodeError, OSError) as err:
raise EsphomeError(f"Error reading file {fname}: {err}") from err
def _load_yaml_internal_with_type(
loader_type: type[ESPHomeLoader] | type[ESPHomePurePythonLoader],
fname: str,
content: TextIOWrapper,
) -> Any:
"""Load a YAML file."""
loader = loader_type(content)
loader.name = fname
try:
return loader.get_single_data() or OrderedDict()
except yaml.YAMLError as exc:
raise EsphomeError(exc) from exc
finally:
loader.dispose()
def dump(dict_, show_secrets=False):
"""Dump YAML to a string and remove null."""
if show_secrets:
_SECRET_VALUES.clear()
_SECRET_CACHE.clear()
return yaml.dump(
dict_, default_flow_style=False, allow_unicode=True, Dumper=ESPHomeDumper
)
def _is_file_valid(name):
"""Decide if a file is valid."""
return not name.startswith(".")
def _find_files(directory, pattern):
"""Recursively load files in a directory."""
for root, dirs, files in os.walk(directory, topdown=True):
dirs[:] = [d for d in dirs if _is_file_valid(d)]
for basename in files:
if _is_file_valid(basename) and fnmatch.fnmatch(basename, pattern):
filename = os.path.join(root, basename)
yield filename
def is_secret(value):
try:
return _SECRET_VALUES[str(value)]
except (KeyError, ValueError):
return None
class ESPHomeDumper(yaml.SafeDumper):
def represent_mapping(self, tag, mapping, flow_style=None):
value = []
node = yaml.MappingNode(tag, value, flow_style=flow_style)
if self.alias_key is not None:
self.represented_objects[self.alias_key] = node
best_style = True
if hasattr(mapping, "items"):
mapping = list(mapping.items())
for item_key, item_value in mapping:
node_key = self.represent_data(item_key)
node_value = self.represent_data(item_value)
if not (isinstance(node_key, yaml.ScalarNode) and not node_key.style):
best_style = False
if not (isinstance(node_value, yaml.ScalarNode) and not node_value.style):
best_style = False
value.append((node_key, node_value))
if flow_style is None:
if self.default_flow_style is not None:
node.flow_style = self.default_flow_style
else:
node.flow_style = best_style
return node
def represent_secret(self, value):
return self.represent_scalar(tag="!secret", value=_SECRET_VALUES[str(value)])
def represent_stringify(self, value):
if is_secret(value):
return self.represent_secret(value)
return self.represent_scalar(tag="tag:yaml.org,2002:str", value=str(value))
# pylint: disable=arguments-renamed
def represent_bool(self, value):
return self.represent_scalar(
"tag:yaml.org,2002:bool", "true" if value else "false"
)
# pylint: disable=arguments-renamed
def represent_int(self, value):
if is_secret(value):
return self.represent_secret(value)
return self.represent_scalar(tag="tag:yaml.org,2002:int", value=str(value))
# pylint: disable=arguments-renamed
def represent_float(self, value):
if is_secret(value):
return self.represent_secret(value)
if math.isnan(value):
value = ".nan"
elif math.isinf(value):
value = ".inf" if value > 0 else "-.inf"
else:
value = str(repr(value)).lower()
# Note that in some cases `repr(data)` represents a float number
# without the decimal parts. For instance:
# >>> repr(1e17)
# '1e17'
# Unfortunately, this is not a valid float representation according
# to the definition of the `!!float` tag. We fix this by adding
# '.0' before the 'e' symbol.
if "." not in value and "e" in value:
value = value.replace("e", ".0e", 1)
return self.represent_scalar(tag="tag:yaml.org,2002:float", value=value)
def represent_lambda(self, value):
if is_secret(value.value):
return self.represent_secret(value.value)
return self.represent_scalar(tag="!lambda", value=value.value, style="|")
def represent_id(self, value):
if is_secret(value.id):
return self.represent_secret(value.id)
return self.represent_stringify(value.id)
ESPHomeDumper.add_multi_representer(
dict, lambda dumper, value: dumper.represent_mapping("tag:yaml.org,2002:map", value)
)
ESPHomeDumper.add_multi_representer(
list,
lambda dumper, value: dumper.represent_sequence("tag:yaml.org,2002:seq", value),
)
ESPHomeDumper.add_multi_representer(bool, ESPHomeDumper.represent_bool)
ESPHomeDumper.add_multi_representer(str, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(int, ESPHomeDumper.represent_int)
ESPHomeDumper.add_multi_representer(float, ESPHomeDumper.represent_float)
ESPHomeDumper.add_multi_representer(IPAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(MACAddress, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(TimePeriod, ESPHomeDumper.represent_stringify)
ESPHomeDumper.add_multi_representer(Lambda, ESPHomeDumper.represent_lambda)
ESPHomeDumper.add_multi_representer(core.ID, ESPHomeDumper.represent_id)
ESPHomeDumper.add_multi_representer(uuid.UUID, ESPHomeDumper.represent_stringify)