From 7c243dafb35decec2869f6eed38b7f9ef64611c7 Mon Sep 17 00:00:00 2001 From: Jesse Hills <3060199+jesserockz@users.noreply.github.com> Date: Thu, 16 May 2024 14:11:54 +1200 Subject: [PATCH] [core] Fix some extends cases (#6748) --- esphome/components/substitutions/__init__.py | 4 +-- esphome/config_helpers.py | 27 +++++++++++++++++--- 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/esphome/components/substitutions/__init__.py b/esphome/components/substitutions/__init__.py index 2d3a79ccae..fa52200d46 100644 --- a/esphome/components/substitutions/__init__.py +++ b/esphome/components/substitutions/__init__.py @@ -4,7 +4,7 @@ import esphome.config_validation as cv from esphome import core from esphome.const import CONF_SUBSTITUTIONS, VALID_SUBSTITUTIONS_CHARACTERS from esphome.yaml_util import ESPHomeDataBase, make_data_base -from esphome.config_helpers import merge_config +from esphome.config_helpers import merge_config, Extend, Remove CODEOWNERS = ["@esphome/core"] _LOGGER = logging.getLogger(__name__) @@ -105,7 +105,7 @@ def _substitute_item(substitutions, item, path, ignore_missing): sub = _expand_substitutions(substitutions, item, path, ignore_missing) if sub != item: return sub - elif isinstance(item, core.Lambda): + elif isinstance(item, (core.Lambda, Extend, Remove)): sub = _expand_substitutions(substitutions, item.value, path, ignore_missing) if sub != item: item.value = sub diff --git a/esphome/config_helpers.py b/esphome/config_helpers.py index 7b47e097c8..b5e0b26143 100644 --- a/esphome/config_helpers.py +++ b/esphome/config_helpers.py @@ -8,6 +8,9 @@ class Extend: def __str__(self): return f"!extend {self.value}" + def __repr__(self): + return f"Extend({self.value})" + def __eq__(self, b): """ Check if two Extend objects contain the same ID. @@ -24,6 +27,9 @@ class Remove: def __str__(self): return f"!remove {self.value}" + def __repr__(self): + return f"Remove({self.value})" + def __eq__(self, b): """ Check if two Remove objects contain the same ID. @@ -50,14 +56,19 @@ def merge_config(full_old, full_new): return new res = old.copy() ids = { - v[CONF_ID]: i + v_id: i for i, v in enumerate(res) - if CONF_ID in v and isinstance(v[CONF_ID], str) + if (v_id := v.get(CONF_ID)) and isinstance(v_id, str) } + extend_ids = { + v_id.value: i + for i, v in enumerate(res) + if (v_id := v.get(CONF_ID)) and isinstance(v_id, Extend) + } + ids_to_delete = [] for v in new: - if CONF_ID in v: - new_id = v[CONF_ID] + if new_id := v.get(CONF_ID): if isinstance(new_id, Extend): new_id = new_id.value if new_id in ids: @@ -69,6 +80,14 @@ def merge_config(full_old, full_new): if new_id in ids: ids_to_delete.append(ids[new_id]) continue + elif ( + new_id in extend_ids + ): # When a package is extending a non-packaged item + extend_res = res[extend_ids[new_id]] + extend_res[CONF_ID] = new_id + new_v = merge(v, extend_res) + res[extend_ids[new_id]] = new_v + continue else: ids[new_id] = len(res) res.append(v)