diff --git a/esphome/components/switch/__init__.py b/esphome/components/switch/__init__.py index 71a16439cd..54ad2b852e 100644 --- a/esphome/components/switch/__init__.py +++ b/esphome/components/switch/__init__.py @@ -5,6 +5,7 @@ from esphome.automation import Condition, maybe_simple_id from esphome.components import mqtt from esphome.const import ( CONF_DEVICE_CLASS, + CONF_ENTITY_CATEGORY, CONF_ID, CONF_INVERTED, CONF_MQTT_ID, @@ -16,6 +17,7 @@ from esphome.const import ( DEVICE_CLASS_SWITCH, ) from esphome.core import CORE, coroutine_with_priority +from esphome.cpp_generator import MockObjClass from esphome.cpp_helpers import setup_entity CODEOWNERS = ["@esphome/core"] @@ -45,6 +47,8 @@ SwitchTurnOffTrigger = switch_ns.class_( icon = cv.icon +validate_device_class = cv.one_of(*DEVICE_CLASSES, lower=True) + SWITCH_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA).extend( { @@ -60,10 +64,40 @@ SWITCH_SCHEMA = cv.ENTITY_BASE_SCHEMA.extend(cv.MQTT_COMMAND_COMPONENT_SCHEMA).e cv.GenerateID(CONF_TRIGGER_ID): cv.declare_id(SwitchTurnOffTrigger), } ), - cv.Optional(CONF_DEVICE_CLASS): cv.one_of(*DEVICE_CLASSES, lower=True), + cv.Optional(CONF_DEVICE_CLASS): validate_device_class, } ) +_UNDEF = object() + + +def switch_schema( + class_: MockObjClass = _UNDEF, + *, + entity_category: str = _UNDEF, + device_class: str = _UNDEF, +): + schema = SWITCH_SCHEMA + if class_ is not _UNDEF: + schema = schema.extend({cv.GenerateID(): cv.declare_id(class_)}) + if entity_category is not _UNDEF: + schema = schema.extend( + { + cv.Optional( + CONF_ENTITY_CATEGORY, default=entity_category + ): cv.entity_category + } + ) + if device_class is not _UNDEF: + schema = schema.extend( + { + cv.Optional( + CONF_DEVICE_CLASS, default=device_class + ): validate_device_class + } + ) + return schema + async def setup_switch_core_(var, config): await setup_entity(var, config) @@ -92,6 +126,12 @@ async def register_switch(var, config): await setup_switch_core_(var, config) +async def new_switch(config, *args): + var = cg.new_Pvariable(config[CONF_ID], *args) + await register_switch(var, config) + return var + + SWITCH_ACTION_SCHEMA = maybe_simple_id( { cv.Required(CONF_ID): cv.use_id(Switch),