From 5040a9146dd48a31c076427c231cd1f6dc8d0e59 Mon Sep 17 00:00:00 2001 From: cvwillegen Date: Tue, 28 May 2024 11:39:23 +0200 Subject: [PATCH] Use enum validation in code generator --- esphome/components/stepper/__init__.py | 32 ++++++++++---------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/esphome/components/stepper/__init__.py b/esphome/components/stepper/__init__.py index 243d7ebe2f..0c997fc2b6 100644 --- a/esphome/components/stepper/__init__.py +++ b/esphome/components/stepper/__init__.py @@ -68,29 +68,22 @@ def validate_speed(value): return value -def validate_rotation(value): - value = cv.string(value) - - if value in ("both"): - return 0 - - if value in ("cw", "clockwise"): - return 1 - - if value in ("ccw", "counterclockwise", "counter-clockwise"): - return -1 - - raise cv.Invalid( - f"Expected rotation as 'both', 'cw', 'ccw', 'clockwise', 'counterclockwise', 'counter-clockwise', got {value}" - ) - +Rotation = stepper_ns.enum("Rotation") +ROTATIONS = { + "BOTH": Rotation.ROTATION_BOTH, + "CW": Rotation.ROTATION_CW, + "CLOCKWISE": Rotation.ROTATION_CW, + "CCW": Rotation.ROTATION_CCW, + "COUNTERCLOCKWISE": Rotation.ROTATION_CCW, + "COUNTER-CLOCKWISE": Rotation.ROTATION_CCW, +} STEPPER_SCHEMA = cv.Schema( { cv.Required(CONF_MAX_SPEED): validate_speed, cv.Optional(CONF_ACCELERATION, default="inf"): validate_acceleration, cv.Optional(CONF_DECELERATION, default="inf"): validate_acceleration, - cv.Optional(CONF_ROTATION, default="both"): validate_rotation, + cv.Optional(CONF_ROTATION, default="both"): cv.enum(ROTATIONS, upper=True), } ) @@ -208,15 +201,14 @@ async def stepper_set_deceleration_to_code(config, action_id, template_arg, args cv.Schema( { cv.Required(CONF_ID): cv.use_id(Stepper), - cv.Required(CONF_ROTATION): cv.templatable(validate_rotation), + cv.Required(CONF_ROTATION): cv.enum(ROTATIONS, upper=True), } ), ) async def stepper_set_rotation_to_code(config, action_id, template_arg, args): paren = await cg.get_variable(config[CONF_ID]) var = cg.new_Pvariable(action_id, template_arg, paren) - template_ = await cg.templatable(config[CONF_ROTATION], args, cg.int32) - cg.add(var.set_rotation(template_)) + cg.add(var.set_rotation(config[CONF_ROTATION])) return var