Introduce new async-def coroutine syntax (#1657)

This commit is contained in:
Otto Winter 2021-05-17 07:14:15 +02:00 committed by GitHub
parent 95ed3e9d46
commit d4686c0fb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 391 additions and 238 deletions

View File

@ -18,7 +18,7 @@ from esphome.const import (
CONF_ESPHOME, CONF_ESPHOME,
CONF_PLATFORMIO_OPTIONS, CONF_PLATFORMIO_OPTIONS,
) )
from esphome.core import CORE, EsphomeError, coroutine, coroutine_with_priority from esphome.core import CORE, EsphomeError, coroutine
from esphome.helpers import indent from esphome.helpers import indent
from esphome.util import ( from esphome.util import (
run_external_command, run_external_command,
@ -127,15 +127,16 @@ def wrap_to_code(name, comp):
coro = coroutine(comp.to_code) coro = coroutine(comp.to_code)
@functools.wraps(comp.to_code) @functools.wraps(comp.to_code)
@coroutine_with_priority(coro.priority) async def wrapped(conf):
def wrapped(conf):
cg.add(cg.LineComment(f"{name}:")) cg.add(cg.LineComment(f"{name}:"))
if comp.config_schema is not None: if comp.config_schema is not None:
conf_str = yaml_util.dump(conf) conf_str = yaml_util.dump(conf)
conf_str = conf_str.replace("//", "") conf_str = conf_str.replace("//", "")
cg.add(cg.LineComment(indent(conf_str))) cg.add(cg.LineComment(indent(conf_str)))
yield coro(conf) await coro(conf)
if hasattr(coro, "priority"):
wrapped.priority = coro.priority
return wrapped return wrapped
@ -610,7 +611,7 @@ def run_esphome(argv):
try: try:
return PRE_CONFIG_ACTIONS[args.command](args) return PRE_CONFIG_ACTIONS[args.command](args)
except EsphomeError as e: except EsphomeError as e:
_LOGGER.error(e) _LOGGER.error(e, exc_info=args.verbose)
return 1 return 1
for conf_path in args.configuration: for conf_path in args.configuration:
@ -628,7 +629,7 @@ def run_esphome(argv):
try: try:
rc = POST_CONFIG_ACTIONS[args.command](args, config) rc = POST_CONFIG_ACTIONS[args.command](args, config)
except EsphomeError as e: except EsphomeError as e:
_LOGGER.error(e) _LOGGER.error(e, exc_info=args.verbose)
return 1 return 1
if rc != 0: if rc != 0:
return rc return rc

View File

@ -10,7 +10,6 @@ from esphome.const import (
CONF_TYPE_ID, CONF_TYPE_ID,
CONF_TIME, CONF_TIME,
) )
from esphome.core import coroutine
from esphome.jsonschema import jschema_extractor from esphome.jsonschema import jschema_extractor
from esphome.util import Registry from esphome.util import Registry
@ -142,27 +141,27 @@ NotCondition = cg.esphome_ns.class_("NotCondition", Condition)
@register_condition("and", AndCondition, validate_condition_list) @register_condition("and", AndCondition, validate_condition_list)
def and_condition_to_code(config, condition_id, template_arg, args): async def and_condition_to_code(config, condition_id, template_arg, args):
conditions = yield build_condition_list(config, template_arg, args) conditions = await build_condition_list(config, template_arg, args)
yield cg.new_Pvariable(condition_id, template_arg, conditions) return cg.new_Pvariable(condition_id, template_arg, conditions)
@register_condition("or", OrCondition, validate_condition_list) @register_condition("or", OrCondition, validate_condition_list)
def or_condition_to_code(config, condition_id, template_arg, args): async def or_condition_to_code(config, condition_id, template_arg, args):
conditions = yield build_condition_list(config, template_arg, args) conditions = await build_condition_list(config, template_arg, args)
yield cg.new_Pvariable(condition_id, template_arg, conditions) return cg.new_Pvariable(condition_id, template_arg, conditions)
@register_condition("not", NotCondition, validate_potentially_and_condition) @register_condition("not", NotCondition, validate_potentially_and_condition)
def not_condition_to_code(config, condition_id, template_arg, args): async def not_condition_to_code(config, condition_id, template_arg, args):
condition = yield build_condition(config, template_arg, args) condition = await build_condition(config, template_arg, args)
yield cg.new_Pvariable(condition_id, template_arg, condition) return cg.new_Pvariable(condition_id, template_arg, condition)
@register_condition("lambda", LambdaCondition, cv.lambda_) @register_condition("lambda", LambdaCondition, cv.lambda_)
def lambda_condition_to_code(config, condition_id, template_arg, args): async def lambda_condition_to_code(config, condition_id, template_arg, args):
lambda_ = yield cg.process_lambda(config, args, return_type=bool) lambda_ = await cg.process_lambda(config, args, return_type=bool)
yield cg.new_Pvariable(condition_id, template_arg, lambda_) return cg.new_Pvariable(condition_id, template_arg, lambda_)
@register_condition( @register_condition(
@ -177,26 +176,26 @@ def lambda_condition_to_code(config, condition_id, template_arg, args):
} }
).extend(cv.COMPONENT_SCHEMA), ).extend(cv.COMPONENT_SCHEMA),
) )
def for_condition_to_code(config, condition_id, template_arg, args): async def for_condition_to_code(config, condition_id, template_arg, args):
condition = yield build_condition( condition = await build_condition(
config[CONF_CONDITION], cg.TemplateArguments(), [] config[CONF_CONDITION], cg.TemplateArguments(), []
) )
var = cg.new_Pvariable(condition_id, template_arg, condition) var = cg.new_Pvariable(condition_id, template_arg, condition)
yield cg.register_component(var, config) await cg.register_component(var, config)
templ = yield cg.templatable(config[CONF_TIME], args, cg.uint32) templ = await cg.templatable(config[CONF_TIME], args, cg.uint32)
cg.add(var.set_time(templ)) cg.add(var.set_time(templ))
yield var return var
@register_action( @register_action(
"delay", DelayAction, cv.templatable(cv.positive_time_period_milliseconds) "delay", DelayAction, cv.templatable(cv.positive_time_period_milliseconds)
) )
def delay_action_to_code(config, action_id, template_arg, args): async def delay_action_to_code(config, action_id, template_arg, args):
var = cg.new_Pvariable(action_id, template_arg) var = cg.new_Pvariable(action_id, template_arg)
yield cg.register_component(var, {}) await cg.register_component(var, {})
template_ = yield cg.templatable(config, args, cg.uint32) template_ = await cg.templatable(config, args, cg.uint32)
cg.add(var.set_delay(template_)) cg.add(var.set_delay(template_))
yield var return var
@register_action( @register_action(
@ -211,16 +210,16 @@ def delay_action_to_code(config, action_id, template_arg, args):
cv.has_at_least_one_key(CONF_THEN, CONF_ELSE), cv.has_at_least_one_key(CONF_THEN, CONF_ELSE),
), ),
) )
def if_action_to_code(config, action_id, template_arg, args): async def if_action_to_code(config, action_id, template_arg, args):
conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) conditions = await build_condition(config[CONF_CONDITION], template_arg, args)
var = cg.new_Pvariable(action_id, template_arg, conditions) var = cg.new_Pvariable(action_id, template_arg, conditions)
if CONF_THEN in config: if CONF_THEN in config:
actions = yield build_action_list(config[CONF_THEN], template_arg, args) actions = await build_action_list(config[CONF_THEN], template_arg, args)
cg.add(var.add_then(actions)) cg.add(var.add_then(actions))
if CONF_ELSE in config: if CONF_ELSE in config:
actions = yield build_action_list(config[CONF_ELSE], template_arg, args) actions = await build_action_list(config[CONF_ELSE], template_arg, args)
cg.add(var.add_else(actions)) cg.add(var.add_else(actions))
yield var return var
@register_action( @register_action(
@ -233,12 +232,12 @@ def if_action_to_code(config, action_id, template_arg, args):
} }
), ),
) )
def while_action_to_code(config, action_id, template_arg, args): async def while_action_to_code(config, action_id, template_arg, args):
conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) conditions = await build_condition(config[CONF_CONDITION], template_arg, args)
var = cg.new_Pvariable(action_id, template_arg, conditions) var = cg.new_Pvariable(action_id, template_arg, conditions)
actions = yield build_action_list(config[CONF_THEN], template_arg, args) actions = await build_action_list(config[CONF_THEN], template_arg, args)
cg.add(var.add_then(actions)) cg.add(var.add_then(actions))
yield var return var
def validate_wait_until(value): def validate_wait_until(value):
@ -253,17 +252,17 @@ def validate_wait_until(value):
@register_action("wait_until", WaitUntilAction, validate_wait_until) @register_action("wait_until", WaitUntilAction, validate_wait_until)
def wait_until_action_to_code(config, action_id, template_arg, args): async def wait_until_action_to_code(config, action_id, template_arg, args):
conditions = yield build_condition(config[CONF_CONDITION], template_arg, args) conditions = await build_condition(config[CONF_CONDITION], template_arg, args)
var = cg.new_Pvariable(action_id, template_arg, conditions) var = cg.new_Pvariable(action_id, template_arg, conditions)
yield cg.register_component(var, {}) await cg.register_component(var, {})
yield var return var
@register_action("lambda", LambdaAction, cv.lambda_) @register_action("lambda", LambdaAction, cv.lambda_)
def lambda_action_to_code(config, action_id, template_arg, args): async def lambda_action_to_code(config, action_id, template_arg, args):
lambda_ = yield cg.process_lambda(config, args, return_type=cg.void) lambda_ = await cg.process_lambda(config, args, return_type=cg.void)
yield cg.new_Pvariable(action_id, template_arg, lambda_) return cg.new_Pvariable(action_id, template_arg, lambda_)
@register_action( @register_action(
@ -275,54 +274,51 @@ def lambda_action_to_code(config, action_id, template_arg, args):
} }
), ),
) )
def component_update_action_to_code(config, action_id, template_arg, args): async def component_update_action_to_code(config, action_id, template_arg, args):
comp = yield cg.get_variable(config[CONF_ID]) comp = await cg.get_variable(config[CONF_ID])
yield cg.new_Pvariable(action_id, template_arg, comp) return cg.new_Pvariable(action_id, template_arg, comp)
@coroutine async def build_action(full_config, template_arg, args):
def build_action(full_config, template_arg, args):
registry_entry, config = cg.extract_registry_entry_config( registry_entry, config = cg.extract_registry_entry_config(
ACTION_REGISTRY, full_config ACTION_REGISTRY, full_config
) )
action_id = full_config[CONF_TYPE_ID] action_id = full_config[CONF_TYPE_ID]
builder = registry_entry.coroutine_fun builder = registry_entry.coroutine_fun
yield builder(config, action_id, template_arg, args) ret = await builder(config, action_id, template_arg, args)
return ret
@coroutine async def build_action_list(config, templ, arg_type):
def build_action_list(config, templ, arg_type):
actions = [] actions = []
for conf in config: for conf in config:
action = yield build_action(conf, templ, arg_type) action = await build_action(conf, templ, arg_type)
actions.append(action) actions.append(action)
yield actions return actions
@coroutine async def build_condition(full_config, template_arg, args):
def build_condition(full_config, template_arg, args):
registry_entry, config = cg.extract_registry_entry_config( registry_entry, config = cg.extract_registry_entry_config(
CONDITION_REGISTRY, full_config CONDITION_REGISTRY, full_config
) )
action_id = full_config[CONF_TYPE_ID] action_id = full_config[CONF_TYPE_ID]
builder = registry_entry.coroutine_fun builder = registry_entry.coroutine_fun
yield builder(config, action_id, template_arg, args) ret = await builder(config, action_id, template_arg, args)
return ret
@coroutine async def build_condition_list(config, templ, args):
def build_condition_list(config, templ, args):
conditions = [] conditions = []
for conf in config: for conf in config:
condition = yield build_condition(conf, templ, args) condition = await build_condition(conf, templ, args)
conditions.append(condition) conditions.append(condition)
yield conditions return conditions
@coroutine async def build_automation(trigger, args, config):
def build_automation(trigger, args, config):
arg_types = [arg[0] for arg in args] arg_types = [arg[0] for arg in args]
templ = cg.TemplateArguments(*arg_types) templ = cg.TemplateArguments(*arg_types)
obj = cg.new_Pvariable(config[CONF_AUTOMATION_ID], templ, trigger) obj = cg.new_Pvariable(config[CONF_AUTOMATION_ID], templ, trigger)
actions = yield build_action_list(config[CONF_THEN], templ, args) actions = await build_action_list(config[CONF_THEN], templ, args)
cg.add(obj.add_actions(actions)) cg.add(obj.add_actions(actions))
yield obj return obj

View File

@ -32,12 +32,12 @@ CONFIG_SCHEMA = cv.Schema(
@coroutine_with_priority(50.0) @coroutine_with_priority(50.0)
def to_code(config): async def to_code(config):
var = cg.new_Pvariable(config[CONF_ID]) var = cg.new_Pvariable(config[CONF_ID])
cg.add(var.set_port(config[CONF_PORT])) cg.add(var.set_port(config[CONF_PORT]))
cg.add(var.set_auth_password(config[CONF_PASSWORD])) cg.add(var.set_auth_password(config[CONF_PASSWORD]))
yield cg.register_component(var, config) await cg.register_component(var, config)
if config[CONF_SAFE_MODE]: if config[CONF_SAFE_MODE]:
condition = var.should_enter_safe_mode( condition = var.should_enter_safe_mode(

View File

@ -1,24 +1,23 @@
import functools
import heapq
import inspect
import logging import logging
import math import math
import os import os
import re import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
# pylint: disable=unused-import, wrong-import-order
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING # noqa
from esphome.const import ( from esphome.const import (
CONF_ARDUINO_VERSION, CONF_ARDUINO_VERSION,
SOURCE_FILE_EXTENSIONS,
CONF_COMMENT, CONF_COMMENT,
CONF_ESPHOME, CONF_ESPHOME,
CONF_USE_ADDRESS, CONF_USE_ADDRESS,
CONF_ETHERNET, CONF_ETHERNET,
CONF_WIFI, CONF_WIFI,
SOURCE_FILE_EXTENSIONS,
) )
from esphome.coroutine import FakeAwaitable as _FakeAwaitable
from esphome.coroutine import FakeEventLoop as _FakeEventLoop
# pylint: disable=unused-import
from esphome.coroutine import coroutine, coroutine_with_priority # noqa
from esphome.helpers import ensure_unique_string, is_hassio from esphome.helpers import ensure_unique_string, is_hassio
from esphome.util import OrderedDict from esphome.util import OrderedDict
@ -431,64 +430,6 @@ class Library:
return NotImplemented return NotImplemented
def coroutine(func):
return coroutine_with_priority(0.0)(func)
def coroutine_with_priority(priority):
def decorator(func):
if getattr(func, "_esphome_coroutine", False):
# If func is already a coroutine, do not re-wrap it (performance)
return func
@functools.wraps(func)
def _wrapper_generator(*args, **kwargs):
instance_id = kwargs.pop("__esphome_coroutine_instance__")
if not inspect.isgeneratorfunction(func):
# If func is not a generator, return result immediately
yield func(*args, **kwargs)
# pylint: disable=protected-access
CORE._remove_coroutine(instance_id)
return
gen = func(*args, **kwargs)
var = None
try:
while True:
var = gen.send(var)
if inspect.isgenerator(var):
# Yielded generator, equivalent to 'yield from'
x = None
for x in var:
yield None
# Last yield value is the result
var = x
else:
yield var
except StopIteration:
# Stopping iteration
yield var
# pylint: disable=protected-access
CORE._remove_coroutine(instance_id)
@functools.wraps(func)
def wrapper(*args, **kwargs):
import random
instance_id = random.randint(0, 2 ** 32)
kwargs["__esphome_coroutine_instance__"] = instance_id
gen = _wrapper_generator(*args, **kwargs)
# pylint: disable=protected-access
CORE._add_active_coroutine(instance_id, gen)
return gen
# pylint: disable=protected-access
wrapper._esphome_coroutine = True
wrapper.priority = priority
return wrapper
return decorator
def find_source_files(file): def find_source_files(file):
files = set() files = set()
directory = os.path.abspath(os.path.dirname(file)) directory = os.path.abspath(os.path.dirname(file))
@ -527,7 +468,7 @@ class EsphomeCore:
# The pending tasks in the task queue (mostly for C++ generation) # The pending tasks in the task queue (mostly for C++ generation)
# This is a priority queue (with heapq) # This is a priority queue (with heapq)
# Each item is a tuple of form: (-priority, unique number, task) # Each item is a tuple of form: (-priority, unique number, task)
self.pending_tasks = [] self.event_loop = _FakeEventLoop()
# Task counter for pending tasks # Task counter for pending tasks
self.task_counter = 0 self.task_counter = 0
# The variable cache, for each ID this holds a MockObj of the variable obj # The variable cache, for each ID this holds a MockObj of the variable obj
@ -542,9 +483,6 @@ class EsphomeCore:
self.build_flags: Set[str] = set() self.build_flags: Set[str] = set()
# A set of defines to set for the compile process in esphome/core/defines.h # A set of defines to set for the compile process in esphome/core/defines.h
self.defines: Set["Define"] = set() self.defines: Set["Define"] = set()
# A dictionary of started coroutines, used to warn when a coroutine was not
# awaited.
self.active_coroutines: Dict[int, Any] = {}
# A set of strings of names of loaded integrations, used to find namespace ID conflicts # A set of strings of names of loaded integrations, used to find namespace ID conflicts
self.loaded_integrations = set() self.loaded_integrations = set()
# A set of component IDs to track what Component subclasses are declared # A set of component IDs to track what Component subclasses are declared
@ -561,7 +499,7 @@ class EsphomeCore:
self.board = None self.board = None
self.raw_config = None self.raw_config = None
self.config = None self.config = None
self.pending_tasks = [] self.event_loop = _FakeEventLoop()
self.task_counter = 0 self.task_counter = 0
self.variables = {} self.variables = {}
self.main_statements = [] self.main_statements = []
@ -569,7 +507,6 @@ class EsphomeCore:
self.libraries = [] self.libraries = []
self.build_flags = set() self.build_flags = set()
self.defines = set() self.defines = set()
self.active_coroutines = {}
self.loaded_integrations = set() self.loaded_integrations = set()
self.component_ids = set() self.component_ids = set()
@ -596,12 +533,6 @@ class EsphomeCore:
return None return None
def _add_active_coroutine(self, instance_id, obj):
self.active_coroutines[instance_id] = obj
def _remove_coroutine(self, instance_id):
self.active_coroutines.pop(instance_id)
@property @property
def arduino_version(self) -> str: def arduino_version(self) -> str:
if self.config is None: if self.config is None:
@ -657,50 +588,13 @@ class EsphomeCore:
return self.esp_platform == "ESP32" return self.esp_platform == "ESP32"
def add_job(self, func, *args, **kwargs): def add_job(self, func, *args, **kwargs):
coro = coroutine(func) self.event_loop.add_job(func, *args, **kwargs)
task = coro(*args, **kwargs)
item = (-coro.priority, self.task_counter, task)
self.task_counter += 1
heapq.heappush(self.pending_tasks, item)
return task
def flush_tasks(self): def flush_tasks(self):
i = 0
while self.pending_tasks:
i += 1
if i > 1000000:
raise EsphomeError("Circular dependency detected!")
inv_priority, num, task = heapq.heappop(self.pending_tasks)
priority = -inv_priority
_LOGGER.debug("Running %s (num %s)", task, num)
try: try:
next(task) self.event_loop.flush_tasks()
# Decrease priority over time, so that if this task is blocked except RuntimeError as e:
# due to a dependency others will clear the dependency raise EsphomeError(str(e)) from e
# This could be improved with a less naive approach
priority -= 1
item = (-priority, num, task)
heapq.heappush(self.pending_tasks, item)
except StopIteration:
_LOGGER.debug(" -> finished")
# Print not-awaited coroutines
for obj in self.active_coroutines.values():
_LOGGER.warning(
"Coroutine '%s' %s was never awaited with 'yield'.", obj.__name__, obj
)
_LOGGER.warning("Please file a bug report with your configuration.")
if self.active_coroutines:
raise EsphomeError()
if self.component_ids:
comps = ", ".join(f"'{x}'" for x in self.component_ids)
_LOGGER.warning(
"Components %s were never registered. Please create a bug report", comps
)
_LOGGER.warning("with your configuration.")
raise EsphomeError()
self.active_coroutines.clear()
def add(self, expression): def add(self, expression):
from esphome.cpp_generator import Expression, Statement, statement from esphome.cpp_generator import Expression, Statement, statement
@ -779,25 +673,35 @@ class EsphomeCore:
_LOGGER.debug("Adding define: %s", define) _LOGGER.debug("Adding define: %s", define)
return define return define
def get_variable(self, id): def _get_variable_generator(self, id):
while True:
try:
return self.variables[id]
except KeyError:
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
yield
async def get_variable(self, id) -> "MockObj":
if not isinstance(id, ID): if not isinstance(id, ID):
raise ValueError(f"ID {id!r} must be of type ID!") raise ValueError(f"ID {id!r} must be of type ID!")
while True: # Fast path, check if already registered without awaiting
if id in self.variables: if id in self.variables:
yield self.variables[id] return self.variables[id]
return return await _FakeAwaitable(self._get_variable_generator(id))
_LOGGER.debug("Waiting for variable %s (%r)", id, id)
yield None
def get_variable_with_full_id(self, id): def _get_variable_with_full_id_generator(self, id):
while True: while True:
if id in self.variables: if id in self.variables:
for k, v in self.variables.items(): for k, v in self.variables.items():
if k == id: if k == id:
yield (k, v) return (k, v)
return
_LOGGER.debug("Waiting for variable %s", id) _LOGGER.debug("Waiting for variable %s", id)
yield None, None yield
async def get_variable_with_full_id(self, id: ID) -> Tuple[ID, "MockObj"]:
if not isinstance(id, ID):
raise ValueError(f"ID {id!r} must be of type ID!")
return await _FakeAwaitable(self._get_variable_with_full_id_generator(id))
def register_variable(self, id, obj): def register_variable(self, id, obj):
if id in self.variables: if id in self.variables:

View File

@ -233,7 +233,7 @@ def include_file(path, basename):
@coroutine_with_priority(-1000.0) @coroutine_with_priority(-1000.0)
def add_includes(includes): async def add_includes(includes):
# Add includes at the very end, so that the included files can access global variables # Add includes at the very end, so that the included files can access global variables
for include in includes: for include in includes:
path = CORE.relative_config_path(include) path = CORE.relative_config_path(include)
@ -249,7 +249,7 @@ def add_includes(includes):
@coroutine_with_priority(-1000.0) @coroutine_with_priority(-1000.0)
def _esp8266_add_lwip_type(): async def _esp8266_add_lwip_type():
# If any component has already set this, do not change it # If any component has already set this, do not change it
if any( if any(
flag.startswith("-DPIO_FRAMEWORK_ARDUINO_LWIP2_") for flag in CORE.build_flags flag.startswith("-DPIO_FRAMEWORK_ARDUINO_LWIP2_") for flag in CORE.build_flags
@ -271,25 +271,25 @@ def _esp8266_add_lwip_type():
@coroutine_with_priority(30.0) @coroutine_with_priority(30.0)
def _add_automations(config): async def _add_automations(config):
for conf in config.get(CONF_ON_BOOT, []): for conf in config.get(CONF_ON_BOOT, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], conf.get(CONF_PRIORITY)) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID], conf.get(CONF_PRIORITY))
yield cg.register_component(trigger, conf) await cg.register_component(trigger, conf)
yield automation.build_automation(trigger, [], conf) await automation.build_automation(trigger, [], conf)
for conf in config.get(CONF_ON_SHUTDOWN, []): for conf in config.get(CONF_ON_SHUTDOWN, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID])
yield cg.register_component(trigger, conf) await cg.register_component(trigger, conf)
yield automation.build_automation(trigger, [], conf) await automation.build_automation(trigger, [], conf)
for conf in config.get(CONF_ON_LOOP, []): for conf in config.get(CONF_ON_LOOP, []):
trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID]) trigger = cg.new_Pvariable(conf[CONF_TRIGGER_ID])
yield cg.register_component(trigger, conf) await cg.register_component(trigger, conf)
yield automation.build_automation(trigger, [], conf) await automation.build_automation(trigger, [], conf)
@coroutine_with_priority(100.0) @coroutine_with_priority(100.0)
def to_code(config): async def to_code(config):
cg.add_global(cg.global_ns.namespace("esphome").using) cg.add_global(cg.global_ns.namespace("esphome").using)
cg.add( cg.add(
cg.App.pre_setup( cg.App.pre_setup(

252
esphome/coroutine.py Normal file
View File

@ -0,0 +1,252 @@
"""
ESPHome's coroutine system.
The Problem: When running the code generationg, components can depend on variables being registered.
For example, an i2c-based sensor would need the i2c bus component to first be declared before the
codegen can emit code using that variable (or otherwise the C++ won't compile).
ESPHome's codegen system solves this by using coroutine-like methods. When a component depends on
a variable, it waits for it to be registered using `await cg.get_variable()`. If the variable
hasn't been registered yet, control will be yielded back to another component until the variable
is registered. This leads to a topological sort, solving the dependency problem.
Importantly, ESPHome only uses the coroutine *syntax*, no actual asyncio event loop is running in
the background. This is so that we can ensure the order of execution is constant for the same
YAML configuration, thus main.cpp only has to be recompiled if the configuration actually changes.
There are two syntaxes for ESPHome coroutines ("old style" vs "new style" coroutines).
"new style" - This is very much like coroutines you might be used to:
```py
async def my_coroutine(config):
var = await cg.get_variable(config[CONF_ID])
await some_other_coroutine(xyz)
return var
```
new style coroutines are `async def` methods that use `await` to await the result of another coroutine,
and can return values using a `return` statement.
"old style" - This was a hack for when ESPHome still had to run on python 2, but is still compatible
```py
@coroutine
def my_coroutine(config):
var = yield cg.get_variable(config[CONF_ID])
yield some_other_coroutine(xyz)
yield var
```
Here everything is combined in `yield` expressions. You await other coroutines using `yield` and
the last `yield` expression defines what is returned.
"""
import collections
import functools
import heapq
import inspect
import logging
import types
from typing import Any, Awaitable, Callable, Generator, Iterator, List, Tuple
_LOGGER = logging.getLogger(__name__)
def coroutine(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
"""Decorator to apply to methods to convert them to ESPHome coroutines."""
if getattr(func, "_esphome_coroutine", False):
# If func is already a coroutine, do not re-wrap it (performance)
return func
if inspect.isasyncgenfunction(func):
# Trade-off: In ESPHome, there's not really a use-case for async generators.
# and during the transition to new-style syntax it will happen that a `yield`
# is not replaced properly, so don't accept async generators.
raise ValueError(
f"Async generator functions are not allowed. "
f"Please check whether you've replaced all yields with awaits/returns. "
f"See {func} in {func.__module__}"
)
if inspect.iscoroutinefunction(func):
# A new-style async-def coroutine function, no conversion needed.
return func
if inspect.isgeneratorfunction(func):
@functools.wraps(func)
def coro(*args, **kwargs):
gen = func(*args, **kwargs)
ret = yield from _flatten_generator(gen)
return ret
else:
# A "normal" function with no `yield` statements, convert to generator
# that includes a yield just so it's also a generator function
@functools.wraps(func)
def coro(*args, **kwargs):
res = func(*args, **kwargs)
yield
return res
# Add coroutine internal python flag so that it can be awaited from new-style coroutines.
coro = types.coroutine(coro)
# pylint: disable=protected-access
coro._esphome_coroutine = True
return coro
def coroutine_with_priority(priority: float):
"""Decorator to apply to functions to convert them to ESPHome coroutines.
:param priority: priority with which to schedule the coroutine, higher priorities run first.
"""
def decorator(func):
coro = coroutine(func)
coro.priority = priority
return coro
return decorator
def _flatten_generator(gen: Generator[Any, Any, Any]):
to_send = None
while True:
try:
# Run until next yield expression
val = gen.send(to_send)
except StopIteration as e:
# return statement or end of function
# From py3.3, return with a value is allowed in generators,
# and return value is transported in the value field of the exception.
# If we find a value in the exception, use that as the return value,
# otherwise use the value from the last yield statement ("old style")
ret = to_send if e.value is None else e.value
return ret
if isinstance(val, collections.abc.Awaitable):
# yielded object that is awaitable (like `yield some_new_style_method()`)
# yield from __await__() like actual coroutines would.
to_send = yield from val.__await__()
elif inspect.isgenerator(val):
# Old style, like `yield cg.get_variable()`
to_send = yield from _flatten_generator(val)
else:
# Could be the last expression from this generator, record this as the return value
to_send = val
# perform a yield so that expressions like `while some_condition(): yield None`
# do not run without yielding control back to the top
yield
class FakeAwaitable:
"""Convert a generator to an awaitable object.
Needed for internals of `cg.get_variable`. There we can't use @coroutine because
native coroutines await from types.coroutine() directly without yielding back control to the top
(likely as a performance enhancement).
If we instead wrap the generator in this FakeAwaitable, control is yielded back to the top
(reason unknown).
"""
def __init__(self, gen: Generator[Any, Any, Any]) -> None:
self._gen = gen
def __await__(self):
ret = yield from self._gen
return ret
@functools.total_ordering
class _Task:
def __init__(
self,
priority: float,
id_number: int,
iterator: Iterator[None],
original_function: Any,
):
self.priority = priority
self.id_number = id_number
self.iterator = iterator
self.original_function = original_function
def with_priority(self, priority: float) -> "_Task":
return _Task(priority, self.id_number, self.iterator, self.original_function)
@property
def _cmp_tuple(self) -> Tuple[float, int]:
return (-self.priority, self.id_number)
def __eq__(self, other):
return self._cmp_tuple == other._cmp_tuple
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
return self._cmp_tuple < other._cmp_tuple
class FakeEventLoop:
"""Emulate an asyncio EventLoop to run some registered coroutine jobs in sequence."""
def __init__(self):
self._pending_tasks: List[_Task] = []
self._task_counter = 0
def add_job(self, func, *args, **kwargs):
"""Add a job to the task queue,
Optionally retrieves priority from the function object, and schedules according to that.
"""
if inspect.iscoroutine(func):
raise ValueError("Can only add coroutine functions, not coroutine objects")
if inspect.iscoroutinefunction(func):
coro = func
gen = coro(*args, **kwargs).__await__()
else:
coro = coroutine(func)
gen = coro(*args, **kwargs)
prio = getattr(coro, "priority", 0.0)
task = _Task(prio, self._task_counter, gen, func)
self._task_counter += 1
heapq.heappush(self._pending_tasks, task)
def flush_tasks(self):
"""Run until all tasks have been completed.
:raises RuntimeError: if a deadlock is detected.
"""
i = 0
while self._pending_tasks:
i += 1
if i > 1000000:
# Detect deadlock/circular dependency by measuring how many times tasks have been
# executed. On the big tests/test1.yaml we only get to a fraction of this, so
# this shouldn't be a problem.
raise RuntimeError(
"Circular dependency detected! "
"Please run with -v option to see what functions failed to "
"complete."
)
task: _Task = heapq.heappop(self._pending_tasks)
_LOGGER.debug(
"Running %s in %s (num %s)",
task.original_function.__qualname__,
task.original_function.__module__,
task.id_number,
)
try:
next(task.iterator)
# Decrease priority over time, so that if this task is blocked
# due to a dependency others will clear the dependency
# This could be improved with a less naive approach
new_task = task.with_priority(task.priority - 1)
heapq.heappush(self._pending_tasks, new_task)
except StopIteration:
_LOGGER.debug(" -> finished")

View File

@ -549,8 +549,7 @@ def add_define(name: str, value: SafeExpType = None):
CORE.add_define(Define(name, safe_exp(value))) CORE.add_define(Define(name, safe_exp(value)))
@coroutine async def get_variable(id_: ID) -> "MockObj":
def get_variable(id_: ID) -> Generator["MockObj", None, None]:
""" """
Wait for the given ID to be defined in the code generation and Wait for the given ID to be defined in the code generation and
return it as a MockObj. return it as a MockObj.
@ -560,12 +559,10 @@ def get_variable(id_: ID) -> Generator["MockObj", None, None]:
:param id_: The ID to retrieve :param id_: The ID to retrieve
:return: The variable as a MockObj. :return: The variable as a MockObj.
""" """
var = yield CORE.get_variable(id_) return await CORE.get_variable(id_)
yield var
@coroutine async def get_variable_with_full_id(id_: ID) -> Tuple[ID, "MockObj"]:
def get_variable_with_full_id(id_: ID) -> Generator[Tuple[ID, "MockObj"], None, None]:
""" """
Wait for the given ID to be defined in the code generation and Wait for the given ID to be defined in the code generation and
return it as a MockObj. return it as a MockObj.
@ -575,8 +572,7 @@ def get_variable_with_full_id(id_: ID) -> Generator[Tuple[ID, "MockObj"], None,
:param id_: The ID to retrieve :param id_: The ID to retrieve
:return: The variable as a MockObj. :return: The variable as a MockObj.
""" """
full_id, var = yield CORE.get_variable_with_full_id(id_) return await CORE.get_variable_with_full_id(id_)
yield full_id, var
@coroutine @coroutine
@ -604,7 +600,7 @@ def process_lambda(
return return
parts = value.parts[:] parts = value.parts[:]
for i, id in enumerate(value.requires_ids): for i, id in enumerate(value.requires_ids):
full_id, var = yield CORE.get_variable_with_full_id(id) full_id, var = yield get_variable_with_full_id(id)
if ( if (
full_id is not None full_id is not None
and isinstance(full_id.type, MockObjClass) and isinstance(full_id.type, MockObjClass)
@ -675,6 +671,9 @@ class MockObj(Expression):
self.op = op self.op = op
def __getattr__(self, attr: str) -> "MockObj": def __getattr__(self, attr: str) -> "MockObj":
# prevent python dunder methods being replaced by mock objects
if attr.startswith("__"):
raise AttributeError()
next_op = "." next_op = "."
if attr.startswith("P") and self.op not in ["::", ""]: if attr.startswith("P") and self.op not in ["::", ""]:
attr = attr[1:] attr = attr[1:]

View File

@ -10,5 +10,6 @@ pre-commit
pytest==6.2.4 pytest==6.2.4
pytest-cov==2.11.1 pytest-cov==2.11.1
pytest-mock==3.5.1 pytest-mock==3.5.1
pytest-asyncio==0.14.0
asyncmock==0.4.2 asyncmock==0.4.2
hypothesis==5.21.0 hypothesis==5.21.0

View File

@ -490,11 +490,15 @@ class TestEsphomeCore:
def test_reset(self, target): def test_reset(self, target):
"""Call reset on target and compare to new instance""" """Call reset on target and compare to new instance"""
other = core.EsphomeCore() other = core.EsphomeCore().__dict__
target.reset() target.reset()
t = target.__dict__
# ignore event loop
del other["event_loop"]
del t["event_loop"]
assert target.__dict__ == other.__dict__ assert t == other
def test_address__none(self, target): def test_address__none(self, target):
target.config = {} target.config = {}

View File

@ -6,25 +6,24 @@ from esphome import const
from esphome.cpp_generator import MockObj from esphome.cpp_generator import MockObj
def test_gpio_pin_expression__conf_is_none(monkeypatch): @pytest.mark.asyncio
target = ch.gpio_pin_expression(None) async def test_gpio_pin_expression__conf_is_none(monkeypatch):
actual = await ch.gpio_pin_expression(None)
actual = next(target)
assert actual is None assert actual is None
def test_gpio_pin_expression__new_pin(monkeypatch): @pytest.mark.asyncio
target = ch.gpio_pin_expression( async def test_gpio_pin_expression__new_pin(monkeypatch):
actual = await ch.gpio_pin_expression(
{const.CONF_NUMBER: 42, const.CONF_MODE: "input", const.CONF_INVERTED: False} {const.CONF_NUMBER: 42, const.CONF_MODE: "input", const.CONF_INVERTED: False}
) )
actual = next(target)
assert isinstance(actual, MockObj) assert isinstance(actual, MockObj)
def test_register_component(monkeypatch): @pytest.mark.asyncio
async def test_register_component(monkeypatch):
var = Mock(base="foo.bar") var = Mock(base="foo.bar")
app_mock = Mock(register_component=Mock(return_value=var)) app_mock = Mock(register_component=Mock(return_value=var))
@ -36,9 +35,7 @@ def test_register_component(monkeypatch):
add_mock = Mock() add_mock = Mock()
monkeypatch.setattr(ch, "add", add_mock) monkeypatch.setattr(ch, "add", add_mock)
target = ch.register_component(var, {}) actual = await ch.register_component(var, {})
actual = next(target)
assert actual is var assert actual is var
add_mock.assert_called_once() add_mock.assert_called_once()
@ -46,18 +43,19 @@ def test_register_component(monkeypatch):
assert core_mock.component_ids == [] assert core_mock.component_ids == []
def test_register_component__no_component_id(monkeypatch): @pytest.mark.asyncio
async def test_register_component__no_component_id(monkeypatch):
var = Mock(base="foo.eek") var = Mock(base="foo.eek")
core_mock = Mock(component_ids=["foo.bar"]) core_mock = Mock(component_ids=["foo.bar"])
monkeypatch.setattr(ch, "CORE", core_mock) monkeypatch.setattr(ch, "CORE", core_mock)
with pytest.raises(ValueError, match="Component ID foo.eek was not declared to"): with pytest.raises(ValueError, match="Component ID foo.eek was not declared to"):
target = ch.register_component(var, {}) await ch.register_component(var, {})
next(target)
def test_register_component__with_setup_priority(monkeypatch): @pytest.mark.asyncio
async def test_register_component__with_setup_priority(monkeypatch):
var = Mock(base="foo.bar") var = Mock(base="foo.bar")
app_mock = Mock(register_component=Mock(return_value=var)) app_mock = Mock(register_component=Mock(return_value=var))
@ -69,7 +67,7 @@ def test_register_component__with_setup_priority(monkeypatch):
add_mock = Mock() add_mock = Mock()
monkeypatch.setattr(ch, "add", add_mock) monkeypatch.setattr(ch, "add", add_mock)
target = ch.register_component( actual = await ch.register_component(
var, var,
{ {
const.CONF_SETUP_PRIORITY: "123", const.CONF_SETUP_PRIORITY: "123",
@ -77,8 +75,6 @@ def test_register_component__with_setup_priority(monkeypatch):
}, },
) )
actual = next(target)
assert actual is var assert actual is var
add_mock.assert_called() add_mock.assert_called()
assert add_mock.call_count == 3 assert add_mock.call_count == 3