diff --git a/esphome/__main__.py b/esphome/__main__.py index 1ec72d9255..b78962c2c0 100644 --- a/esphome/__main__.py +++ b/esphome/__main__.py @@ -18,7 +18,7 @@ from esphome.const import ( CONF_ESPHOME, CONF_PLATFORMIO_OPTIONS, ) -from esphome.core import CORE, EsphomeError, coroutine, coroutine_with_priority +from esphome.core import CORE, EsphomeError, coroutine from esphome.helpers import indent from esphome.util import ( run_external_command, @@ -127,15 +127,16 @@ def wrap_to_code(name, comp): coro = coroutine(comp.to_code) @functools.wraps(comp.to_code) - @coroutine_with_priority(coro.priority) - def wrapped(conf): + async def wrapped(conf): cg.add(cg.LineComment(f"{name}:")) if comp.config_schema is not None: conf_str = yaml_util.dump(conf) conf_str = conf_str.replace("//", "") cg.add(cg.LineComment(indent(conf_str))) - yield coro(conf) + await coro(conf) + if hasattr(coro, "priority"): + wrapped.priority = coro.priority return wrapped @@ -610,7 +611,7 @@ def run_esphome(argv): try: return PRE_CONFIG_ACTIONS[args.command](args) except EsphomeError as e: - _LOGGER.error(e) + _LOGGER.error(e, exc_info=args.verbose) return 1 for conf_path in args.configuration: @@ -628,7 +629,7 @@ def run_esphome(argv): try: rc = POST_CONFIG_ACTIONS[args.command](args, config) except EsphomeError as e: - _LOGGER.error(e) + _LOGGER.error(e, exc_info=args.verbose) return 1 if rc != 0: return rc diff --git a/esphome/automation.py b/esphome/automation.py index eb6cb02532..1cf6bbf542 100644 --- a/esphome/automation.py +++ b/esphome/automation.py @@ -10,7 +10,6 @@ from esphome.const import ( CONF_TYPE_ID, CONF_TIME, ) -from esphome.core import coroutine from esphome.jsonschema import jschema_extractor from esphome.util import Registry @@ -142,27 +141,27 @@ NotCondition = cg.esphome_ns.class_("NotCondition", Condition) @register_condition("and", AndCondition, validate_condition_list) -def and_condition_to_code(config, condition_id, template_arg, args): - conditions = yield build_condition_list(config, template_arg, args) - yield cg.new_Pvariable(condition_id, template_arg, conditions) +async def and_condition_to_code(config, condition_id, template_arg, args): + conditions = await build_condition_list(config, template_arg, args) + return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("or", OrCondition, validate_condition_list) -def or_condition_to_code(config, condition_id, template_arg, args): - conditions = yield build_condition_list(config, template_arg, args) - yield cg.new_Pvariable(condition_id, template_arg, conditions) +async def or_condition_to_code(config, condition_id, template_arg, args): + conditions = await build_condition_list(config, template_arg, args) + return cg.new_Pvariable(condition_id, template_arg, conditions) @register_condition("not", NotCondition, validate_potentially_and_condition) -def not_condition_to_code(config, condition_id, template_arg, args): - condition = yield build_condition(config, template_arg, args) - yield cg.new_Pvariable(condition_id, template_arg, condition) +async def not_condition_to_code(config, condition_id, template_arg, args): + condition = await build_condition(config, template_arg, args) + return cg.new_Pvariable(condition_id, template_arg, condition) @register_condition("lambda", LambdaCondition, cv.lambda_) -def lambda_condition_to_code(config, condition_id, template_arg, args): - lambda_ = yield cg.process_lambda(config, args, return_type=bool) - yield cg.new_Pvariable(condition_id, template_arg, lambda_) +async def lambda_condition_to_code(config, condition_id, template_arg, args): + lambda_ = await cg.process_lambda(config, args, return_type=bool) + return cg.new_Pvariable(condition_id, template_arg, lambda_) @register_condition( @@ -177,26 +176,26 @@ def lambda_condition_to_code(config, condition_id, template_arg, args): } ).extend(cv.COMPONENT_SCHEMA), ) -def for_condition_to_code(config, condition_id, template_arg, args): - condition = yield build_condition( +async def for_condition_to_code(config, condition_id, template_arg, args): + condition = await build_condition( config[CONF_CONDITION], cg.TemplateArguments(), [] ) var = cg.new_Pvariable(condition_id, template_arg, condition) - yield cg.register_component(var, config) - templ = yield cg.templatable(config[CONF_TIME], args, cg.uint32) + await cg.register_component(var, config) + templ = await cg.templatable(config[CONF_TIME], args, cg.uint32) cg.add(var.set_time(templ)) - yield var + return var @register_action( "delay", DelayAction, cv.templatable(cv.positive_time_period_milliseconds) ) -def delay_action_to_code(config, action_id, template_arg, args): +async def delay_action_to_code(config, action_id, template_arg, args): var = cg.new_Pvariable(action_id, template_arg) - yield cg.register_component(var, {}) - template_ = yield cg.templatable(config, args, cg.uint32) + await cg.register_component(var, {}) + template_ = await cg.templatable(config, args, cg.uint32) cg.add(var.set_delay(template_)) - yield var + return var @register_action( @@ -211,16 +210,16 @@ def delay_action_to_code(config, action_id, template_arg, args): cv.has_at_least_one_key(CONF_THEN, CONF_ELSE), ), ) -def if_action_to_code(config, action_id, template_arg, args): - conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) +async def if_action_to_code(config, action_id, template_arg, args): + conditions = await build_condition(config[CONF_CONDITION], template_arg, args) var = cg.new_Pvariable(action_id, template_arg, conditions) if CONF_THEN in config: - actions = yield build_action_list(config[CONF_THEN], template_arg, args) + actions = await build_action_list(config[CONF_THEN], template_arg, args) cg.add(var.add_then(actions)) if CONF_ELSE in config: - actions = yield build_action_list(config[CONF_ELSE], template_arg, args) + actions = await build_action_list(config[CONF_ELSE], template_arg, args) cg.add(var.add_else(actions)) - yield var + return var @register_action( @@ -233,12 +232,12 @@ def if_action_to_code(config, action_id, template_arg, args): } ), ) -def while_action_to_code(config, action_id, template_arg, args): - conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) +async def while_action_to_code(config, action_id, template_arg, args): + conditions = await build_condition(config[CONF_CONDITION], template_arg, args) var = cg.new_Pvariable(action_id, template_arg, conditions) - actions = yield build_action_list(config[CONF_THEN], template_arg, args) + actions = await build_action_list(config[CONF_THEN], template_arg, args) cg.add(var.add_then(actions)) - yield var + return var def validate_wait_until(value): @@ -253,17 +252,17 @@ def validate_wait_until(value): @register_action("wait_until", WaitUntilAction, validate_wait_until) -def wait_until_action_to_code(config, action_id, template_arg, args): - conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) +async def wait_until_action_to_code(config, action_id, template_arg, args): + conditions = await build_condition(config[CONF_CONDITION], template_arg, args) var = cg.new_Pvariable(action_id, template_arg, conditions) - yield cg.register_component(var, {}) - yield var + await cg.register_component(var, {}) + return var @register_action("lambda", LambdaAction, cv.lambda_) -def lambda_action_to_code(config, action_id, template_arg, args): - lambda_ = yield cg.process_lambda(config, args, return_type=cg.void) - yield cg.new_Pvariable(action_id, template_arg, lambda_) +async def lambda_action_to_code(config, action_id, template_arg, args): + lambda_ = await cg.process_lambda(config, args, return_type=cg.void) + return cg.new_Pvariable(action_id, template_arg, lambda_) @register_action( @@ -275,54 +274,51 @@ def lambda_action_to_code(config, action_id, template_arg, args): } ), ) -def component_update_action_to_code(config, action_id, template_arg, args): - comp = yield cg.get_variable(config[CONF_ID]) - yield cg.new_Pvariable(action_id, template_arg, comp) +async def component_update_action_to_code(config, action_id, template_arg, args): + comp = await cg.get_variable(config[CONF_ID]) + return cg.new_Pvariable(action_id, template_arg, comp) -@coroutine -def build_action(full_config, template_arg, args): +async def build_action(full_config, template_arg, args): registry_entry, config = cg.extract_registry_entry_config( ACTION_REGISTRY, full_config ) action_id = full_config[CONF_TYPE_ID] builder = registry_entry.coroutine_fun - yield builder(config, action_id, template_arg, args) + ret = await builder(config, action_id, template_arg, args) + return ret -@coroutine -def build_action_list(config, templ, arg_type): +async def build_action_list(config, templ, arg_type): actions = [] for conf in config: - action = yield build_action(conf, templ, arg_type) + action = await build_action(conf, templ, arg_type) actions.append(action) - yield actions + return actions -@coroutine -def build_condition(full_config, template_arg, args): +async def build_condition(full_config, template_arg, args): registry_entry, config = cg.extract_registry_entry_config( CONDITION_REGISTRY, full_config ) action_id = full_config[CONF_TYPE_ID] builder = registry_entry.coroutine_fun - yield builder(config, action_id, template_arg, args) + ret = await builder(config, action_id, template_arg, args) + return ret -@coroutine -def build_condition_list(config, templ, args): +async def build_condition_list(config, templ, args): conditions = [] for conf in config: - condition = yield build_condition(conf, templ, args) + condition = await build_condition(conf, templ, args) conditions.append(condition) - yield conditions + return conditions -@coroutine -def build_automation(trigger, args, config): +async def build_automation(trigger, args, config): arg_types = [arg[0] for arg in args] templ = cg.TemplateArguments(*arg_types) obj = cg.new_Pvariable(config[CONF_AUTOMATION_ID], templ, trigger) - actions = yield build_action_list(config[CONF_THEN], templ, args) + actions = await build_action_list(config[CONF_THEN], templ, args) cg.add(obj.add_actions(actions)) - yield obj + return obj diff --git a/esphome/components/ota/__init__.py b/esphome/components/ota/__init__.py index 25a278f5bf..7ee7ef47ca 100644 --- a/esphome/components/ota/__init__.py +++ b/esphome/components/ota/__init__.py @@ -32,12 +32,12 @@ CONFIG_SCHEMA = cv.Schema( @coroutine_with_priority(50.0) -def to_code(config): +async def to_code(config): var = cg.new_Pvariable(config[CONF_ID]) cg.add(var.set_port(config[CONF_PORT])) cg.add(var.set_auth_password(config[CONF_PASSWORD])) - yield cg.register_component(var, config) + await cg.register_component(var, config) if config[CONF_SAFE_MODE]: condition = var.should_enter_safe_mode( diff --git a/esphome/core/__init__.py b/esphome/core/__init__.py index 47048478ef..1841dfd8be 100644 --- a/esphome/core/__init__.py +++ b/esphome/core/__init__.py @@ -1,24 +1,23 @@ -import functools -import heapq -import inspect import logging - import math import os import re - -# pylint: disable=unused-import, wrong-import-order -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING # noqa +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple from esphome.const import ( CONF_ARDUINO_VERSION, - SOURCE_FILE_EXTENSIONS, CONF_COMMENT, CONF_ESPHOME, CONF_USE_ADDRESS, CONF_ETHERNET, CONF_WIFI, + SOURCE_FILE_EXTENSIONS, ) +from esphome.coroutine import FakeAwaitable as _FakeAwaitable +from esphome.coroutine import FakeEventLoop as _FakeEventLoop + +# pylint: disable=unused-import +from esphome.coroutine import coroutine, coroutine_with_priority # noqa from esphome.helpers import ensure_unique_string, is_hassio from esphome.util import OrderedDict @@ -431,64 +430,6 @@ class Library: return NotImplemented -def coroutine(func): - return coroutine_with_priority(0.0)(func) - - -def coroutine_with_priority(priority): - def decorator(func): - if getattr(func, "_esphome_coroutine", False): - # If func is already a coroutine, do not re-wrap it (performance) - return func - - @functools.wraps(func) - def _wrapper_generator(*args, **kwargs): - instance_id = kwargs.pop("__esphome_coroutine_instance__") - if not inspect.isgeneratorfunction(func): - # If func is not a generator, return result immediately - yield func(*args, **kwargs) - # pylint: disable=protected-access - CORE._remove_coroutine(instance_id) - return - gen = func(*args, **kwargs) - var = None - try: - while True: - var = gen.send(var) - if inspect.isgenerator(var): - # Yielded generator, equivalent to 'yield from' - x = None - for x in var: - yield None - # Last yield value is the result - var = x - else: - yield var - except StopIteration: - # Stopping iteration - yield var - # pylint: disable=protected-access - CORE._remove_coroutine(instance_id) - - @functools.wraps(func) - def wrapper(*args, **kwargs): - import random - - instance_id = random.randint(0, 2 ** 32) - kwargs["__esphome_coroutine_instance__"] = instance_id - gen = _wrapper_generator(*args, **kwargs) - # pylint: disable=protected-access - CORE._add_active_coroutine(instance_id, gen) - return gen - - # pylint: disable=protected-access - wrapper._esphome_coroutine = True - wrapper.priority = priority - return wrapper - - return decorator - - def find_source_files(file): files = set() directory = os.path.abspath(os.path.dirname(file)) @@ -527,7 +468,7 @@ class EsphomeCore: # The pending tasks in the task queue (mostly for C++ generation) # This is a priority queue (with heapq) # Each item is a tuple of form: (-priority, unique number, task) - self.pending_tasks = [] + self.event_loop = _FakeEventLoop() # Task counter for pending tasks self.task_counter = 0 # The variable cache, for each ID this holds a MockObj of the variable obj @@ -542,9 +483,6 @@ class EsphomeCore: self.build_flags: Set[str] = set() # A set of defines to set for the compile process in esphome/core/defines.h self.defines: Set["Define"] = set() - # A dictionary of started coroutines, used to warn when a coroutine was not - # awaited. - self.active_coroutines: Dict[int, Any] = {} # A set of strings of names of loaded integrations, used to find namespace ID conflicts self.loaded_integrations = set() # A set of component IDs to track what Component subclasses are declared @@ -561,7 +499,7 @@ class EsphomeCore: self.board = None self.raw_config = None self.config = None - self.pending_tasks = [] + self.event_loop = _FakeEventLoop() self.task_counter = 0 self.variables = {} self.main_statements = [] @@ -569,7 +507,6 @@ class EsphomeCore: self.libraries = [] self.build_flags = set() self.defines = set() - self.active_coroutines = {} self.loaded_integrations = set() self.component_ids = set() @@ -596,12 +533,6 @@ class EsphomeCore: return None - def _add_active_coroutine(self, instance_id, obj): - self.active_coroutines[instance_id] = obj - - def _remove_coroutine(self, instance_id): - self.active_coroutines.pop(instance_id) - @property def arduino_version(self) -> str: if self.config is None: @@ -657,50 +588,13 @@ class EsphomeCore: return self.esp_platform == "ESP32" def add_job(self, func, *args, **kwargs): - coro = coroutine(func) - task = coro(*args, **kwargs) - item = (-coro.priority, self.task_counter, task) - self.task_counter += 1 - heapq.heappush(self.pending_tasks, item) - return task + self.event_loop.add_job(func, *args, **kwargs) def flush_tasks(self): - i = 0 - while self.pending_tasks: - i += 1 - if i > 1000000: - raise EsphomeError("Circular dependency detected!") - - inv_priority, num, task = heapq.heappop(self.pending_tasks) - priority = -inv_priority - _LOGGER.debug("Running %s (num %s)", task, num) - try: - next(task) - # Decrease priority over time, so that if this task is blocked - # due to a dependency others will clear the dependency - # This could be improved with a less naive approach - priority -= 1 - item = (-priority, num, task) - heapq.heappush(self.pending_tasks, item) - except StopIteration: - _LOGGER.debug(" -> finished") - - # Print not-awaited coroutines - for obj in self.active_coroutines.values(): - _LOGGER.warning( - "Coroutine '%s' %s was never awaited with 'yield'.", obj.__name__, obj - ) - _LOGGER.warning("Please file a bug report with your configuration.") - if self.active_coroutines: - raise EsphomeError() - if self.component_ids: - comps = ", ".join(f"'{x}'" for x in self.component_ids) - _LOGGER.warning( - "Components %s were never registered. Please create a bug report", comps - ) - _LOGGER.warning("with your configuration.") - raise EsphomeError() - self.active_coroutines.clear() + try: + self.event_loop.flush_tasks() + except RuntimeError as e: + raise EsphomeError(str(e)) from e def add(self, expression): from esphome.cpp_generator import Expression, Statement, statement @@ -779,25 +673,35 @@ class EsphomeCore: _LOGGER.debug("Adding define: %s", define) return define - def get_variable(self, id): + def _get_variable_generator(self, id): + while True: + try: + return self.variables[id] + except KeyError: + _LOGGER.debug("Waiting for variable %s (%r)", id, id) + yield + + async def get_variable(self, id) -> "MockObj": if not isinstance(id, ID): raise ValueError(f"ID {id!r} must be of type ID!") - while True: - if id in self.variables: - yield self.variables[id] - return - _LOGGER.debug("Waiting for variable %s (%r)", id, id) - yield None + # Fast path, check if already registered without awaiting + if id in self.variables: + return self.variables[id] + return await _FakeAwaitable(self._get_variable_generator(id)) - def get_variable_with_full_id(self, id): + def _get_variable_with_full_id_generator(self, id): while True: if id in self.variables: for k, v in self.variables.items(): if k == id: - yield (k, v) - return + return (k, v) _LOGGER.debug("Waiting for variable %s", id) - yield None, None + yield + + async def get_variable_with_full_id(self, id: ID) -> Tuple[ID, "MockObj"]: + if not isinstance(id, ID): + raise ValueError(f"ID {id!r} must be of type ID!") + return await _FakeAwaitable(self._get_variable_with_full_id_generator(id)) def register_variable(self, id, obj): if id in self.variables: diff --git a/esphome/core/config.py b/esphome/core/config.py index 5893f086f2..6bd8c6be0e 100644 --- a/esphome/core/config.py +++ b/esphome/core/config.py @@ -233,7 +233,7 @@ def include_file(path, basename): @coroutine_with_priority(-1000.0) -def add_includes(includes): +async def add_includes(includes): # Add includes at the very end, so that the included files can access global variables for include in includes: path = CORE.relative_config_path(include) @@ -249,7 +249,7 @@ def add_includes(includes): @coroutine_with_priority(-1000.0) -def _esp8266_add_lwip_type(): +async def _esp8266_add_lwip_type(): # If any component has already set this, do not change it if any( flag.startswith("-DPIO_FRAMEWORK_ARDUINO_LWIP2_") for flag in CORE.build_flags @@ -271,25 +271,25 @@ def _esp8266_add_lwip_type(): @coroutine_with_priority(30.0) -def _add_automations(config): +async def _add_automations(config): for conf in config.get(CONF_ON_BOOT, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], conf.get(CONF_PRIORITY)) - yield cg.register_component(trigger, conf) - yield automation.build_automation(trigger, [], conf) + await cg.register_component(trigger, conf) + await automation.build_automation(trigger, [], conf) for conf in config.get(CONF_ON_SHUTDOWN, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) - yield cg.register_component(trigger, conf) - yield automation.build_automation(trigger, [], conf) + await cg.register_component(trigger, conf) + await automation.build_automation(trigger, [], conf) for conf in config.get(CONF_ON_LOOP, []): trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) - yield cg.register_component(trigger, conf) - yield automation.build_automation(trigger, [], conf) + await cg.register_component(trigger, conf) + await automation.build_automation(trigger, [], conf) @coroutine_with_priority(100.0) -def to_code(config): +async def to_code(config): cg.add_global(cg.global_ns.namespace("esphome").using) cg.add( cg.App.pre_setup( diff --git a/esphome/coroutine.py b/esphome/coroutine.py new file mode 100644 index 0000000000..58f79c6b36 --- /dev/null +++ b/esphome/coroutine.py @@ -0,0 +1,252 @@ +""" +ESPHome's coroutine system. + +The Problem: When running the code generationg, components can depend on variables being registered. +For example, an i2c-based sensor would need the i2c bus component to first be declared before the +codegen can emit code using that variable (or otherwise the C++ won't compile). + +ESPHome's codegen system solves this by using coroutine-like methods. When a component depends on +a variable, it waits for it to be registered using `await cg.get_variable()`. If the variable +hasn't been registered yet, control will be yielded back to another component until the variable +is registered. This leads to a topological sort, solving the dependency problem. + +Importantly, ESPHome only uses the coroutine *syntax*, no actual asyncio event loop is running in +the background. This is so that we can ensure the order of execution is constant for the same +YAML configuration, thus main.cpp only has to be recompiled if the configuration actually changes. + +There are two syntaxes for ESPHome coroutines ("old style" vs "new style" coroutines). + +"new style" - This is very much like coroutines you might be used to: + +```py +async def my_coroutine(config): + var = await cg.get_variable(config[CONF_ID]) + await some_other_coroutine(xyz) + return var +``` + +new style coroutines are `async def` methods that use `await` to await the result of another coroutine, +and can return values using a `return` statement. + +"old style" - This was a hack for when ESPHome still had to run on python 2, but is still compatible + +```py +@coroutine +def my_coroutine(config): + var = yield cg.get_variable(config[CONF_ID]) + yield some_other_coroutine(xyz) + yield var +``` + +Here everything is combined in `yield` expressions. You await other coroutines using `yield` and +the last `yield` expression defines what is returned. +""" + +import collections +import functools +import heapq +import inspect +import logging +import types +from typing import Any, Awaitable, Callable, Generator, Iterator, List, Tuple + +_LOGGER = logging.getLogger(__name__) + + +def coroutine(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: + """Decorator to apply to methods to convert them to ESPHome coroutines.""" + if getattr(func, "_esphome_coroutine", False): + # If func is already a coroutine, do not re-wrap it (performance) + return func + if inspect.isasyncgenfunction(func): + # Trade-off: In ESPHome, there's not really a use-case for async generators. + # and during the transition to new-style syntax it will happen that a `yield` + # is not replaced properly, so don't accept async generators. + raise ValueError( + f"Async generator functions are not allowed. " + f"Please check whether you've replaced all yields with awaits/returns. " + f"See {func} in {func.__module__}" + ) + if inspect.iscoroutinefunction(func): + # A new-style async-def coroutine function, no conversion needed. + return func + + if inspect.isgeneratorfunction(func): + + @functools.wraps(func) + def coro(*args, **kwargs): + gen = func(*args, **kwargs) + ret = yield from _flatten_generator(gen) + return ret + + else: + # A "normal" function with no `yield` statements, convert to generator + # that includes a yield just so it's also a generator function + @functools.wraps(func) + def coro(*args, **kwargs): + res = func(*args, **kwargs) + yield + return res + + # Add coroutine internal python flag so that it can be awaited from new-style coroutines. + coro = types.coroutine(coro) + # pylint: disable=protected-access + coro._esphome_coroutine = True + return coro + + +def coroutine_with_priority(priority: float): + """Decorator to apply to functions to convert them to ESPHome coroutines. + + :param priority: priority with which to schedule the coroutine, higher priorities run first. + """ + + def decorator(func): + coro = coroutine(func) + coro.priority = priority + return coro + + return decorator + + +def _flatten_generator(gen: Generator[Any, Any, Any]): + to_send = None + while True: + try: + # Run until next yield expression + val = gen.send(to_send) + except StopIteration as e: + # return statement or end of function + + # From py3.3, return with a value is allowed in generators, + # and return value is transported in the value field of the exception. + # If we find a value in the exception, use that as the return value, + # otherwise use the value from the last yield statement ("old style") + ret = to_send if e.value is None else e.value + return ret + + if isinstance(val, collections.abc.Awaitable): + # yielded object that is awaitable (like `yield some_new_style_method()`) + # yield from __await__() like actual coroutines would. + to_send = yield from val.__await__() + elif inspect.isgenerator(val): + # Old style, like `yield cg.get_variable()` + to_send = yield from _flatten_generator(val) + else: + # Could be the last expression from this generator, record this as the return value + to_send = val + # perform a yield so that expressions like `while some_condition(): yield None` + # do not run without yielding control back to the top + yield + + +class FakeAwaitable: + """Convert a generator to an awaitable object. + + Needed for internals of `cg.get_variable`. There we can't use @coroutine because + native coroutines await from types.coroutine() directly without yielding back control to the top + (likely as a performance enhancement). + + If we instead wrap the generator in this FakeAwaitable, control is yielded back to the top + (reason unknown). + """ + + def __init__(self, gen: Generator[Any, Any, Any]) -> None: + self._gen = gen + + def __await__(self): + ret = yield from self._gen + return ret + + +@functools.total_ordering +class _Task: + def __init__( + self, + priority: float, + id_number: int, + iterator: Iterator[None], + original_function: Any, + ): + self.priority = priority + self.id_number = id_number + self.iterator = iterator + self.original_function = original_function + + def with_priority(self, priority: float) -> "_Task": + return _Task(priority, self.id_number, self.iterator, self.original_function) + + @property + def _cmp_tuple(self) -> Tuple[float, int]: + return (-self.priority, self.id_number) + + def __eq__(self, other): + return self._cmp_tuple == other._cmp_tuple + + def __ne__(self, other): + return not (self == other) + + def __lt__(self, other): + return self._cmp_tuple < other._cmp_tuple + + +class FakeEventLoop: + """Emulate an asyncio EventLoop to run some registered coroutine jobs in sequence.""" + + def __init__(self): + self._pending_tasks: List[_Task] = [] + self._task_counter = 0 + + def add_job(self, func, *args, **kwargs): + """Add a job to the task queue, + + Optionally retrieves priority from the function object, and schedules according to that. + """ + if inspect.iscoroutine(func): + raise ValueError("Can only add coroutine functions, not coroutine objects") + if inspect.iscoroutinefunction(func): + coro = func + gen = coro(*args, **kwargs).__await__() + else: + coro = coroutine(func) + gen = coro(*args, **kwargs) + prio = getattr(coro, "priority", 0.0) + task = _Task(prio, self._task_counter, gen, func) + self._task_counter += 1 + heapq.heappush(self._pending_tasks, task) + + def flush_tasks(self): + """Run until all tasks have been completed. + + :raises RuntimeError: if a deadlock is detected. + """ + i = 0 + while self._pending_tasks: + i += 1 + if i > 1000000: + # Detect deadlock/circular dependency by measuring how many times tasks have been + # executed. On the big tests/test1.yaml we only get to a fraction of this, so + # this shouldn't be a problem. + raise RuntimeError( + "Circular dependency detected! " + "Please run with -v option to see what functions failed to " + "complete." + ) + + task: _Task = heapq.heappop(self._pending_tasks) + _LOGGER.debug( + "Running %s in %s (num %s)", + task.original_function.__qualname__, + task.original_function.__module__, + task.id_number, + ) + + try: + next(task.iterator) + # Decrease priority over time, so that if this task is blocked + # due to a dependency others will clear the dependency + # This could be improved with a less naive approach + new_task = task.with_priority(task.priority - 1) + heapq.heappush(self._pending_tasks, new_task) + except StopIteration: + _LOGGER.debug(" -> finished") diff --git a/esphome/cpp_generator.py b/esphome/cpp_generator.py index 999b252dde..d71e0df4d2 100644 --- a/esphome/cpp_generator.py +++ b/esphome/cpp_generator.py @@ -549,8 +549,7 @@ def add_define(name: str, value: SafeExpType = None): CORE.add_define(Define(name, safe_exp(value))) -@coroutine -def get_variable(id_: ID) -> Generator["MockObj", None, None]: +async def get_variable(id_: ID) -> "MockObj": """ Wait for the given ID to be defined in the code generation and return it as a MockObj. @@ -560,12 +559,10 @@ def get_variable(id_: ID) -> Generator["MockObj", None, None]: :param id_: The ID to retrieve :return: The variable as a MockObj. """ - var = yield CORE.get_variable(id_) - yield var + return await CORE.get_variable(id_) -@coroutine -def get_variable_with_full_id(id_: ID) -> Generator[Tuple[ID, "MockObj"], None, None]: +async def get_variable_with_full_id(id_: ID) -> Tuple[ID, "MockObj"]: """ Wait for the given ID to be defined in the code generation and return it as a MockObj. @@ -575,8 +572,7 @@ def get_variable_with_full_id(id_: ID) -> Generator[Tuple[ID, "MockObj"], None, :param id_: The ID to retrieve :return: The variable as a MockObj. """ - full_id, var = yield CORE.get_variable_with_full_id(id_) - yield full_id, var + return await CORE.get_variable_with_full_id(id_) @coroutine @@ -604,7 +600,7 @@ def process_lambda( return parts = value.parts[:] for i, id in enumerate(value.requires_ids): - full_id, var = yield CORE.get_variable_with_full_id(id) + full_id, var = yield get_variable_with_full_id(id) if ( full_id is not None and isinstance(full_id.type, MockObjClass) @@ -675,6 +671,9 @@ class MockObj(Expression): self.op = op def __getattr__(self, attr: str) -> "MockObj": + # prevent python dunder methods being replaced by mock objects + if attr.startswith("__"): + raise AttributeError() next_op = "." if attr.startswith("P") and self.op not in ["::", ""]: attr = attr[1:] diff --git a/requirements_test.txt b/requirements_test.txt index ecfaed5d05..b5cf617fee 100644 --- a/requirements_test.txt +++ b/requirements_test.txt @@ -10,5 +10,6 @@ pre-commit pytest==6.2.4 pytest-cov==2.11.1 pytest-mock==3.5.1 +pytest-asyncio==0.14.0 asyncmock==0.4.2 hypothesis==5.21.0 diff --git a/tests/unit_tests/test_core.py b/tests/unit_tests/test_core.py index 37a4920224..4e60880033 100644 --- a/tests/unit_tests/test_core.py +++ b/tests/unit_tests/test_core.py @@ -490,11 +490,15 @@ class TestEsphomeCore: def test_reset(self, target): """Call reset on target and compare to new instance""" - other = core.EsphomeCore() + other = core.EsphomeCore().__dict__ target.reset() + t = target.__dict__ + # ignore event loop + del other["event_loop"] + del t["event_loop"] - assert target.__dict__ == other.__dict__ + assert t == other def test_address__none(self, target): target.config = {} diff --git a/tests/unit_tests/test_cpp_helpers.py b/tests/unit_tests/test_cpp_helpers.py index c6f37f6b5d..3e317589a9 100644 --- a/tests/unit_tests/test_cpp_helpers.py +++ b/tests/unit_tests/test_cpp_helpers.py @@ -6,25 +6,24 @@ from esphome import const from esphome.cpp_generator import MockObj -def test_gpio_pin_expression__conf_is_none(monkeypatch): - target = ch.gpio_pin_expression(None) - - actual = next(target) +@pytest.mark.asyncio +async def test_gpio_pin_expression__conf_is_none(monkeypatch): + actual = await ch.gpio_pin_expression(None) assert actual is None -def test_gpio_pin_expression__new_pin(monkeypatch): - target = ch.gpio_pin_expression( +@pytest.mark.asyncio +async def test_gpio_pin_expression__new_pin(monkeypatch): + actual = await ch.gpio_pin_expression( {const.CONF_NUMBER: 42, const.CONF_MODE: "input", const.CONF_INVERTED: False} ) - actual = next(target) - assert isinstance(actual, MockObj) -def test_register_component(monkeypatch): +@pytest.mark.asyncio +async def test_register_component(monkeypatch): var = Mock(base="foo.bar") app_mock = Mock(register_component=Mock(return_value=var)) @@ -36,9 +35,7 @@ def test_register_component(monkeypatch): add_mock = Mock() monkeypatch.setattr(ch, "add", add_mock) - target = ch.register_component(var, {}) - - actual = next(target) + actual = await ch.register_component(var, {}) assert actual is var add_mock.assert_called_once() @@ -46,18 +43,19 @@ def test_register_component(monkeypatch): assert core_mock.component_ids == [] -def test_register_component__no_component_id(monkeypatch): +@pytest.mark.asyncio +async def test_register_component__no_component_id(monkeypatch): var = Mock(base="foo.eek") core_mock = Mock(component_ids=["foo.bar"]) monkeypatch.setattr(ch, "CORE", core_mock) with pytest.raises(ValueError, match="Component ID foo.eek was not declared to"): - target = ch.register_component(var, {}) - next(target) + await ch.register_component(var, {}) -def test_register_component__with_setup_priority(monkeypatch): +@pytest.mark.asyncio +async def test_register_component__with_setup_priority(monkeypatch): var = Mock(base="foo.bar") app_mock = Mock(register_component=Mock(return_value=var)) @@ -69,7 +67,7 @@ def test_register_component__with_setup_priority(monkeypatch): add_mock = Mock() monkeypatch.setattr(ch, "add", add_mock) - target = ch.register_component( + actual = await ch.register_component( var, { const.CONF_SETUP_PRIORITY: "123", @@ -77,8 +75,6 @@ def test_register_component__with_setup_priority(monkeypatch): }, ) - actual = next(target) - assert actual is var add_mock.assert_called() assert add_mock.call_count == 3