diff --git a/script/api_protobuf/api_protobuf.py b/script/api_protobuf/api_protobuf.py index b1292095d8..a2bc3abf64 100755 --- a/script/api_protobuf/api_protobuf.py +++ b/script/api_protobuf/api_protobuf.py @@ -17,28 +17,22 @@ then run this script with python3 and the files will be generated, they still need to be formatted """ -import re import os +import re +import sys +from abc import ABC, abstractmethod from pathlib import Path -from textwrap import dedent from subprocess import call +from textwrap import dedent # Generate with # protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto - import aioesphomeapi.api_options_pb2 as pb import google.protobuf.descriptor_pb2 as descriptor -file_header = "// This file was automatically generated with a tool.\n" -file_header += "// See scripts/api_protobuf/api_protobuf.py\n" - -cwd = Path(__file__).resolve().parent -root = cwd.parent.parent / "esphome" / "components" / "api" -prot = root / "api.protoc" -call(["protoc", "-o", str(prot), "-I", str(root), "api.proto"]) -content = prot.read_bytes() - -d = descriptor.FileDescriptorSet.FromString(content) +FILE_HEADER = """// This file was automatically generated with a tool. +// See scripts/api_protobuf/api_protobuf.py +""" def indent_list(text, padding=" "): @@ -64,7 +58,7 @@ def camel_to_snake(name): return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() -class TypeInfo: +class TypeInfo(ABC): def __init__(self, field): self._field = field @@ -186,10 +180,12 @@ class TypeInfo: def dump_content(self): o = f'out.append(" {self.name}: ");\n' o += self.dump(f"this->{self.field_name}") + "\n" - o += f'out.append("\\n");\n' + o += 'out.append("\\n");\n' return o - dump = None + @abstractmethod + def dump(self, name: str): + pass TYPE_INFO = {} @@ -212,7 +208,7 @@ class DoubleType(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%g", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -225,7 +221,7 @@ class FloatType(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%g", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -238,7 +234,7 @@ class Int64Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%lld", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -251,7 +247,7 @@ class UInt64Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%llu", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -264,7 +260,7 @@ class Int32Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%" PRId32, {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -277,7 +273,7 @@ class Fixed64Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%llu", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -290,7 +286,7 @@ class Fixed32Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%" PRIu32, {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -372,7 +368,7 @@ class UInt32Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%" PRIu32, {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -406,7 +402,7 @@ class SFixed32Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%" PRId32, {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -419,7 +415,7 @@ class SFixed64Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%lld", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -432,7 +428,7 @@ class SInt32Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%" PRId32, {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -445,7 +441,7 @@ class SInt64Type(TypeInfo): def dump(self, name): o = f'sprintf(buffer, "%lld", {name});\n' - o += f"out.append(buffer);" + o += "out.append(buffer);" return o @@ -527,7 +523,7 @@ class RepeatedTypeInfo(TypeInfo): def encode_content(self): o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n" o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n" - o += f"}}" + o += "}" return o @property @@ -535,10 +531,13 @@ class RepeatedTypeInfo(TypeInfo): o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n' o += f' out.append(" {self.name}: ");\n' o += indent(self._ti.dump("it")) + "\n" - o += f' out.append("\\n");\n' - o += f"}}\n" + o += ' out.append("\\n");\n' + o += "}\n" return o + def dump(self, _: str): + pass + def build_enum_type(desc): name = desc.name @@ -547,17 +546,17 @@ def build_enum_type(desc): out += f" {v.name} = {v.number},\n" out += "};\n" - cpp = f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" + cpp = "#ifdef HAS_PROTO_MESSAGE_DUMP\n" cpp += f"template<> const char *proto_enum_to_string(enums::{name} value) {{\n" - cpp += f" switch (value) {{\n" + cpp += " switch (value) {\n" for v in desc.value: cpp += f" case enums::{v.name}:\n" cpp += f' return "{v.name}";\n' - cpp += f" default:\n" - cpp += f' return "UNKNOWN";\n' - cpp += f" }}\n" - cpp += f"}}\n" - cpp += f"#endif\n" + cpp += " default:\n" + cpp += ' return "UNKNOWN";\n' + cpp += " }\n" + cpp += "}\n" + cpp += "#endif\n" return out, cpp @@ -652,10 +651,10 @@ def build_message_type(desc): o += f" {dump[0]} " else: o += "\n" - o += f" __attribute__((unused)) char buffer[64];\n" + o += " __attribute__((unused)) char buffer[64];\n" o += f' out.append("{desc.name} {{\\n");\n' o += indent("\n".join(dump)) + "\n" - o += f' out.append("}}");\n' + o += ' out.append("}");\n' else: o2 = f'out.append("{desc.name} {{}}");' if len(o) + len(o2) + 3 < 120: @@ -664,9 +663,9 @@ def build_message_type(desc): o += "\n" o += f" {o2}\n" o += "}\n" - cpp += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" + cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n" cpp += o - cpp += f"#endif\n" + cpp += "#endif\n" prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n" prot += "void dump_to(std::string &out) const override;\n" prot += "#endif\n" @@ -684,71 +683,12 @@ def build_message_type(desc): return out, cpp -file = d.file[0] -content = file_header -content += """\ -#pragma once - -#include "proto.h" - -namespace esphome { -namespace api { - -""" - -cpp = file_header -cpp += """\ -#include "api_pb2.h" -#include "esphome/core/log.h" - -#include - -namespace esphome { -namespace api { - -""" - -content += "namespace enums {\n\n" - -for enum in file.enum_type: - s, c = build_enum_type(enum) - content += s - cpp += c - -content += "\n} // namespace enums\n\n" - -mt = file.message_type - -for m in mt: - s, c = build_message_type(m) - content += s - cpp += c - -content += """\ - -} // namespace api -} // namespace esphome -""" -cpp += """\ - -} // namespace api -} // namespace esphome -""" - -with open(root / "api_pb2.h", "w") as f: - f.write(content) - -with open(root / "api_pb2.cpp", "w") as f: - f.write(cpp) - SOURCE_BOTH = 0 SOURCE_SERVER = 1 SOURCE_CLIENT = 2 RECEIVE_CASES = {} -class_name = "APIServerConnectionBase" - ifdefs = {} @@ -768,7 +708,6 @@ def build_service_message_type(mt): ifdef = get_opt(mt, pb.ifdef) log = get_opt(mt, pb.log, True) - nodelay = get_opt(mt, pb.no_delay, False) hout = "" cout = "" @@ -781,14 +720,14 @@ def build_service_message_type(mt): # Generate send func = f"send_{snake}" hout += f"bool {func}(const {mt.name} &msg);\n" - cout += f"bool {class_name}::{func}(const {mt.name} &msg) {{\n" + cout += f"bool APIServerConnectionBase::{func}(const {mt.name} &msg) {{\n" if log: - cout += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" + cout += "#ifdef HAS_PROTO_MESSAGE_DUMP\n" cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' - cout += f"#endif\n" + cout += "#endif\n" # cout += f' this->set_nodelay({str(nodelay).lower()});\n' cout += f" return this->send_message_<{mt.name}>(msg, {id_});\n" - cout += f"}}\n" + cout += "}\n" if source in (SOURCE_BOTH, SOURCE_CLIENT): # Generate receive func = f"on_{snake}" @@ -797,169 +736,242 @@ def build_service_message_type(mt): if ifdef is not None: case += f"#ifdef {ifdef}\n" case += f"{mt.name} msg;\n" - case += f"msg.decode(msg_data, msg_size);\n" + case += "msg.decode(msg_data, msg_size);\n" if log: - case += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n" + case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n" case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' - case += f"#endif\n" + case += "#endif\n" case += f"this->{func}(msg);\n" if ifdef is not None: - case += f"#endif\n" + case += "#endif\n" case += "break;" RECEIVE_CASES[id_] = case if ifdef is not None: - hout += f"#endif\n" - cout += f"#endif\n" + hout += "#endif\n" + cout += "#endif\n" return hout, cout -hpp = file_header -hpp += """\ -#pragma once +def main(): + cwd = Path(__file__).resolve().parent + root = cwd.parent.parent / "esphome" / "components" / "api" + prot_file = root / "api.protoc" + call(["protoc", "-o", str(prot_file), "-I", str(root), "api.proto"]) + proto_content = prot_file.read_bytes() -#include "api_pb2.h" -#include "esphome/core/defines.h" + # pylint: disable-next=no-member + d = descriptor.FileDescriptorSet.FromString(proto_content) -namespace esphome { -namespace api { + file = d.file[0] + content = FILE_HEADER + content += """\ + #pragma once -""" + #include "proto.h" -cpp = file_header -cpp += """\ -#include "api_pb2_service.h" -#include "esphome/core/log.h" + namespace esphome { + namespace api { -namespace esphome { -namespace api { + """ -static const char *const TAG = "api.service"; + cpp = FILE_HEADER + cpp += """\ + #include "api_pb2.h" + #include "esphome/core/log.h" -""" + #include -hpp += f"class {class_name} : public ProtoService {{\n" -hpp += " public:\n" + namespace esphome { + namespace api { -for mt in file.message_type: - obj = build_service_message_type(mt) - if obj is None: - continue - hout, cout = obj - hpp += indent(hout) + "\n" - cpp += cout + """ -cases = list(RECEIVE_CASES.items()) -cases.sort() -hpp += " protected:\n" -hpp += f" bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n" -out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n" -out += f" switch (msg_type) {{\n" -for i, case in cases: - c = f"case {i}: {{\n" - c += indent(case) + "\n" - c += f"}}" - out += indent(c, " ") + "\n" -out += " default:\n" -out += " return false;\n" -out += " }\n" -out += " return true;\n" -out += "}\n" -cpp += out -hpp += "};\n" + content += "namespace enums {\n\n" -serv = file.service[0] -class_name = "APIServerConnection" -hpp += "\n" -hpp += f"class {class_name} : public {class_name}Base {{\n" -hpp += " public:\n" -hpp_protected = "" -cpp += "\n" + for enum in file.enum_type: + s, c = build_enum_type(enum) + content += s + cpp += c -m = serv.method[0] -for m in serv.method: - func = m.name - inp = m.input_type[1:] - ret = m.output_type[1:] - is_void = ret == "void" - snake = camel_to_snake(inp) - on_func = f"on_{snake}" - needs_conn = get_opt(m, pb.needs_setup_connection, True) - needs_auth = get_opt(m, pb.needs_authentication, True) + content += "\n} // namespace enums\n\n" - ifdef = ifdefs.get(inp, None) + mt = file.message_type - if ifdef is not None: - hpp += f"#ifdef {ifdef}\n" - hpp_protected += f"#ifdef {ifdef}\n" - cpp += f"#ifdef {ifdef}\n" + for m in mt: + s, c = build_message_type(m) + content += s + cpp += c - hpp_protected += f" void {on_func}(const {inp} &msg) override;\n" - hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n" - cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n" - body = "" - if needs_conn: - body += "if (!this->is_connection_setup()) {\n" - body += " this->on_no_setup_connection();\n" - body += " return;\n" - body += "}\n" - if needs_auth: - body += "if (!this->is_authenticated()) {\n" - body += " this->on_unauthenticated_access();\n" - body += " return;\n" - body += "}\n" + content += """\ - if is_void: - body += f"this->{func}(msg);\n" - else: - body += f"{ret} ret = this->{func}(msg);\n" - ret_snake = camel_to_snake(ret) - body += f"if (!this->send_{ret_snake}(ret)) {{\n" - body += f" this->on_fatal_error();\n" - body += "}\n" - cpp += indent(body) + "\n" + "}\n" + } // namespace api + } // namespace esphome + """ + cpp += """\ - if ifdef is not None: - hpp += f"#endif\n" - hpp_protected += f"#endif\n" - cpp += f"#endif\n" + } // namespace api + } // namespace esphome + """ -hpp += " protected:\n" -hpp += hpp_protected -hpp += "};\n" + with open(root / "api_pb2.h", "w", encoding="utf-8") as f: + f.write(content) -hpp += """\ + with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f: + f.write(cpp) -} // namespace api -} // namespace esphome -""" -cpp += """\ + hpp = FILE_HEADER + hpp += """\ + #pragma once -} // namespace api -} // namespace esphome -""" + #include "api_pb2.h" + #include "esphome/core/defines.h" -with open(root / "api_pb2_service.h", "w") as f: - f.write(hpp) + namespace esphome { + namespace api { -with open(root / "api_pb2_service.cpp", "w") as f: - f.write(cpp) + """ -prot.unlink() + cpp = FILE_HEADER + cpp += """\ + #include "api_pb2_service.h" + #include "esphome/core/log.h" -try: - import clang_format + namespace esphome { + namespace api { - def exec_clang_format(path): - clang_format_path = os.path.join( - os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" - ) - call([clang_format_path, "-i", path]) + static const char *const TAG = "api.service"; - exec_clang_format(root / "api_pb2_service.h") - exec_clang_format(root / "api_pb2_service.cpp") - exec_clang_format(root / "api_pb2.h") - exec_clang_format(root / "api_pb2.cpp") -except ImportError: - pass + """ + + class_name = "APIServerConnectionBase" + + hpp += f"class {class_name} : public ProtoService {{\n" + hpp += " public:\n" + + for mt in file.message_type: + obj = build_service_message_type(mt) + if obj is None: + continue + hout, cout = obj + hpp += indent(hout) + "\n" + cpp += cout + + cases = list(RECEIVE_CASES.items()) + cases.sort() + hpp += " protected:\n" + hpp += " bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n" + out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n" + out += " switch (msg_type) {\n" + for i, case in cases: + c = f"case {i}: {{\n" + c += indent(case) + "\n" + c += "}" + out += indent(c, " ") + "\n" + out += " default:\n" + out += " return false;\n" + out += " }\n" + out += " return true;\n" + out += "}\n" + cpp += out + hpp += "};\n" + + serv = file.service[0] + class_name = "APIServerConnection" + hpp += "\n" + hpp += f"class {class_name} : public {class_name}Base {{\n" + hpp += " public:\n" + hpp_protected = "" + cpp += "\n" + + m = serv.method[0] + for m in serv.method: + func = m.name + inp = m.input_type[1:] + ret = m.output_type[1:] + is_void = ret == "void" + snake = camel_to_snake(inp) + on_func = f"on_{snake}" + needs_conn = get_opt(m, pb.needs_setup_connection, True) + needs_auth = get_opt(m, pb.needs_authentication, True) + + ifdef = ifdefs.get(inp, None) + + if ifdef is not None: + hpp += f"#ifdef {ifdef}\n" + hpp_protected += f"#ifdef {ifdef}\n" + cpp += f"#ifdef {ifdef}\n" + + hpp_protected += f" void {on_func}(const {inp} &msg) override;\n" + hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n" + cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n" + body = "" + if needs_conn: + body += "if (!this->is_connection_setup()) {\n" + body += " this->on_no_setup_connection();\n" + body += " return;\n" + body += "}\n" + if needs_auth: + body += "if (!this->is_authenticated()) {\n" + body += " this->on_unauthenticated_access();\n" + body += " return;\n" + body += "}\n" + + if is_void: + body += f"this->{func}(msg);\n" + else: + body += f"{ret} ret = this->{func}(msg);\n" + ret_snake = camel_to_snake(ret) + body += f"if (!this->send_{ret_snake}(ret)) {{\n" + body += " this->on_fatal_error();\n" + body += "}\n" + cpp += indent(body) + "\n" + "}\n" + + if ifdef is not None: + hpp += "#endif\n" + hpp_protected += "#endif\n" + cpp += "#endif\n" + + hpp += " protected:\n" + hpp += hpp_protected + hpp += "};\n" + + hpp += """\ + + } // namespace api + } // namespace esphome + """ + cpp += """\ + + } // namespace api + } // namespace esphome + """ + + with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f: + f.write(hpp) + + with open(root / "api_pb2_service.cpp", "w", encoding="utf-8") as f: + f.write(cpp) + + prot_file.unlink() + + try: + import clang_format + + def exec_clang_format(path): + clang_format_path = os.path.join( + os.path.dirname(clang_format.__file__), "data", "bin", "clang-format" + ) + call([clang_format_path, "-i", path]) + + exec_clang_format(root / "api_pb2_service.h") + exec_clang_format(root / "api_pb2_service.cpp") + exec_clang_format(root / "api_pb2.h") + exec_clang_format(root / "api_pb2.cpp") + except ImportError: + pass + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/script/build_language_schema.py b/script/build_language_schema.py index fb2010fe3e..fc6ccadc5f 100644 --- a/script/build_language_schema.py +++ b/script/build_language_schema.py @@ -61,6 +61,7 @@ solve_registry = [] def get_component_names(): + # pylint: disable-next=redefined-outer-name,reimported from esphome.loader import CORE_COMPONENTS_PATH component_names = ["esphome", "sensor", "esp32", "esp8266"] @@ -82,9 +83,12 @@ def load_components(): components[domain] = get_component(domain) +# pylint: disable=wrong-import-position from esphome.const import CONF_TYPE, KEY_CORE from esphome.core import CORE +# pylint: enable=wrong-import-position + CORE.data[KEY_CORE] = {} load_components() @@ -114,7 +118,7 @@ def write_file(name, obj): def delete_extra_files(keep_names): for d in os.listdir(args.output_path): - if d.endswith(".json") and not d[:-5] in keep_names: + if d.endswith(".json") and d[:-5] not in keep_names: os.remove(os.path.join(args.output_path, d)) print(f"Deleted {d}") @@ -552,11 +556,11 @@ def shrink(): s = f"{domain}.{schema_name}" if ( not s.endswith("." + S_CONFIG_SCHEMA) - and s not in referenced_schemas.keys() + and s not in referenced_schemas and not is_platform_schema(s) ): print(f"Removing {s}") - output[domain][S_SCHEMAS].pop(schema_name) + domain_schemas[S_SCHEMAS].pop(schema_name) def build_schema(): @@ -564,7 +568,7 @@ def build_schema(): # check esphome was not loaded globally (IDE auto imports) if len(ejs.extended_schemas) == 0: - raise Exception( + raise LookupError( "no data collected. Did you globally import an ESPHome component?" ) @@ -703,7 +707,7 @@ def convert(schema, config_var, path): if schema_instance is schema: assert S_CONFIG_VARS not in config_var assert S_EXTENDS not in config_var - if not S_TYPE in config_var: + if S_TYPE not in config_var: config_var[S_TYPE] = S_SCHEMA # assert config_var[S_TYPE] == S_SCHEMA @@ -765,9 +769,9 @@ def convert(schema, config_var, path): elif schema == automation.validate_potentially_and_condition: config_var[S_TYPE] = "registry" config_var["registry"] = "condition" - elif schema == cv.int_ or schema == cv.int_range: + elif schema in (cv.int_, cv.int_range): config_var[S_TYPE] = "integer" - elif schema == cv.string or schema == cv.string_strict or schema == cv.valid_name: + elif schema in (cv.string, cv.string_strict, cv.valid_name): config_var[S_TYPE] = "string" elif isinstance(schema, vol.Schema): @@ -779,6 +783,7 @@ def convert(schema, config_var, path): config_var |= pin_validators[repr_schema] config_var[S_TYPE] = "pin" + # pylint: disable-next=too-many-nested-blocks elif repr_schema in ejs.hidden_schemas: schema_type = ejs.hidden_schemas[repr_schema] @@ -869,7 +874,7 @@ def convert(schema, config_var, path): config_var["use_id_type"] = str(data.base) config_var[S_TYPE] = "use_id" else: - raise Exception("Unknown extracted schema type") + raise TypeError("Unknown extracted schema type") elif config_var.get("key") == "GeneratedID": if path.startswith("i2c/CONFIG_SCHEMA/") and path.endswith("/id"): config_var["id_type"] = { @@ -884,7 +889,7 @@ def convert(schema, config_var, path): elif path == "pins/esp32/val 1/id": config_var["id_type"] = "pin" else: - raise Exception("Cannot determine id_type for " + path) + raise TypeError("Cannot determine id_type for " + path) elif repr_schema in ejs.registry_schemas: solve_registry.append((ejs.registry_schemas[repr_schema], config_var)) @@ -948,11 +953,7 @@ def convert_keys(converted, schema, path): result["key"] = "GeneratedID" elif isinstance(k, cv.Required): result["key"] = "Required" - elif ( - isinstance(k, cv.Optional) - or isinstance(k, cv.Inclusive) - or isinstance(k, cv.Exclusive) - ): + elif isinstance(k, (cv.Optional, cv.Inclusive, cv.Exclusive)): result["key"] = "Optional" else: converted["key"] = "String" diff --git a/script/bump-version.py b/script/bump-version.py index 3e1e473c4b..a55bb65cd6 100755 --- a/script/bump-version.py +++ b/script/bump-version.py @@ -2,7 +2,6 @@ import argparse import re -import subprocess from dataclasses import dataclass import sys @@ -40,12 +39,12 @@ class Version: def sub(path, pattern, repl, expected_count=1): - with open(path) as fh: + with open(path, encoding="utf-8") as fh: content = fh.read() content, count = re.subn(pattern, repl, content, flags=re.MULTILINE) if expected_count is not None: assert count == expected_count, f"Pattern {pattern} replacement failed!" - with open(path, "w") as fh: + with open(path, "w", encoding="utf-8") as fh: fh.write(content) diff --git a/script/ci-custom.py b/script/ci-custom.py index cc9bdcadbb..41ce030d48 100755 --- a/script/ci-custom.py +++ b/script/ci-custom.py @@ -1,10 +1,8 @@ #!/usr/bin/env python3 -from helpers import styled, print_error_for_file, git_ls_files, filter_changed import argparse import codecs import collections -import colorama import fnmatch import functools import os.path @@ -12,6 +10,9 @@ import re import sys import time +import colorama +from helpers import filter_changed, git_ls_files, print_error_for_file, styled + sys.path.append(os.path.dirname(__file__)) @@ -30,31 +31,6 @@ def find_all(a_str, sub): column += len(sub) -colorama.init() - -parser = argparse.ArgumentParser() -parser.add_argument( - "files", nargs="*", default=[], help="files to be processed (regex on path)" -) -parser.add_argument( - "-c", "--changed", action="store_true", help="Only run on changed files" -) -parser.add_argument( - "--print-slowest", action="store_true", help="Print the slowest checks" -) -args = parser.parse_args() - -EXECUTABLE_BIT = git_ls_files() -files = list(EXECUTABLE_BIT.keys()) -# Match against re -file_name_re = re.compile("|".join(args.files)) -files = [p for p in files if file_name_re.search(p)] - -if args.changed: - files = filter_changed(files) - -files.sort() - file_types = ( ".h", ".c", @@ -86,6 +62,30 @@ ignore_types = (".ico", ".png", ".woff", ".woff2", "") LINT_FILE_CHECKS = [] LINT_CONTENT_CHECKS = [] LINT_POST_CHECKS = [] +EXECUTABLE_BIT = {} + +errors = collections.defaultdict(list) + + +def add_errors(fname, errs): + if not isinstance(errs, list): + errs = [errs] + for err in errs: + if err is None: + continue + try: + lineno, col, msg = err + except ValueError: + lineno = 1 + col = 1 + msg = err + if not isinstance(msg, str): + raise ValueError("Error is not instance of string!") + if not isinstance(lineno, int): + raise ValueError("Line number is not an int!") + if not isinstance(col, int): + raise ValueError("Column number is not an int!") + errors[fname].append((lineno, col, msg)) def run_check(lint_obj, fname, *args): @@ -155,7 +155,7 @@ def lint_re_check(regex, **kwargs): def decorator(func): @functools.wraps(func) def new_func(fname, content): - errors = [] + errs = [] for match in prog.finditer(content): if "NOLINT" in match.group(0): continue @@ -165,8 +165,8 @@ def lint_re_check(regex, **kwargs): err = func(fname, match) if err is None: continue - errors.append((lineno, col + 1, err)) - return errors + errs.append((lineno, col + 1, err)) + return errs return decor(new_func) @@ -182,13 +182,13 @@ def lint_content_find_check(find, only_first=False, **kwargs): find_ = find if callable(find): find_ = find(fname, content) - errors = [] + errs = [] for line, col in find_all(content, find_): err = func(fname) - errors.append((line + 1, col + 1, err)) + errs.append((line + 1, col + 1, err)) if only_first: break - return errors + return errs return decor(new_func) @@ -235,8 +235,8 @@ def lint_executable_bit(fname): ex = EXECUTABLE_BIT[fname] if ex != 100644: return ( - "File has invalid executable bit {}. If running from a windows machine please " - "see disabling executable bit in git.".format(ex) + f"File has invalid executable bit {ex}. If running from a windows machine please " + "see disabling executable bit in git." ) return None @@ -285,8 +285,8 @@ def lint_no_defines(fname, match): s = highlight(f"static const uint8_t {match.group(1)} = {match.group(2)};") return ( "#define macros for integer constants are not allowed, please use " - "{} style instead (replace uint8_t with the appropriate " - "datatype). See also Google style guide.".format(s) + f"{s} style instead (replace uint8_t with the appropriate " + "datatype). See also Google style guide." ) @@ -296,11 +296,11 @@ def lint_no_long_delays(fname, match): if duration_ms < 50: return None return ( - "{} - long calls to delay() are not allowed in ESPHome because everything executes " - "in one thread. Calling delay() will block the main thread and slow down ESPHome.\n" + f"{highlight(match.group(0).strip())} - long calls to delay() are not allowed " + "in ESPHome because everything executes in one thread. Calling delay() will " + "block the main thread and slow down ESPHome.\n" "If there's no way to work around the delay() and it doesn't execute often, please add " "a '// NOLINT' comment to the line." - "".format(highlight(match.group(0).strip())) ) @@ -311,28 +311,28 @@ def lint_const_ordered(fname, content): Reason: Otherwise people add it to the end, and then that results in merge conflicts. """ lines = content.splitlines() - errors = [] + errs = [] for start in ["CONF_", "ICON_", "UNIT_"]: matching = [ (i + 1, line) for i, line in enumerate(lines) if line.startswith(start) ] ordered = list(sorted(matching, key=lambda x: x[1].replace("_", " "))) ordered = [(mi, ol) for (mi, _), (_, ol) in zip(matching, ordered)] - for (mi, ml), (oi, ol) in zip(matching, ordered): - if ml == ol: + for (mi, mline), (_, ol) in zip(matching, ordered): + if mline == ol: continue - target = next(i for i, l in ordered if l == ml) - target_text = next(l for i, l in matching if target == i) - errors.append( + target = next(i for i, line in ordered if line == mline) + target_text = next(line for i, line in matching if target == i) + errs.append( ( mi, 1, - f"Constant {highlight(ml)} is not ordered, please make sure all " + f"Constant {highlight(mline)} is not ordered, please make sure all " f"constants are ordered. See line {mi} (should go to line {target}, " f"{target_text})", ) ) - return errors + return errs @lint_re_check(r'^\s*CONF_([A-Z_0-9a-z]+)\s+=\s+[\'"](.*?)[\'"]\s*?$', include=["*.py"]) @@ -344,15 +344,14 @@ def lint_conf_matches(fname, match): if const_norm == value_norm: return None return ( - "Constant {} does not match value {}! Please make sure the constant's name matches its " - "value!" - "".format(highlight("CONF_" + const), highlight(value)) + f"Constant {highlight('CONF_' + const)} does not match value {highlight(value)}! " + "Please make sure the constant's name matches its value!" ) CONF_RE = r'^(CONF_[a-zA-Z0-9_]+)\s*=\s*[\'"].*?[\'"]\s*?$' -with codecs.open("esphome/const.py", "r", encoding="utf-8") as f_handle: - constants_content = f_handle.read() +with codecs.open("esphome/const.py", "r", encoding="utf-8") as const_f_handle: + constants_content = const_f_handle.read() CONSTANTS = [m.group(1) for m in re.finditer(CONF_RE, constants_content, re.MULTILINE)] CONSTANTS_USES = collections.defaultdict(list) @@ -365,8 +364,8 @@ def lint_conf_from_const_py(fname, match): CONSTANTS_USES[name].append(fname) return None return ( - "Constant {} has already been defined in const.py - please import the constant from " - "const.py directly.".format(highlight(name)) + f"Constant {highlight(name)} has already been defined in const.py - " + "please import the constant from const.py directly." ) @@ -473,16 +472,15 @@ def lint_no_byte_datatype(fname, match): @lint_post_check def lint_constants_usage(): - errors = [] + errs = [] for constant, uses in CONSTANTS_USES.items(): if len(uses) < 4: continue - errors.append( - "Constant {} is defined in {} files. Please move all definitions of the " - "constant to const.py (Uses: {})" - "".format(highlight(constant), len(uses), ", ".join(uses)) + errs.append( + f"Constant {highlight(constant)} is defined in {len(uses)} files. Please move all definitions of the " + f"constant to const.py (Uses: {', '.join(uses)})" ) - return errors + return errs def relative_cpp_search_text(fname, content): @@ -553,7 +551,7 @@ def lint_namespace(fname, content): return ( "Invalid namespace found in C++ file. All integration C++ files should put all " "functions in a separate namespace that matches the integration's name. " - "Please make sure the file contains {}".format(highlight(search)) + f"Please make sure the file contains {highlight(search)}" ) @@ -639,66 +637,73 @@ def lint_log_in_header(fname): ) -errors = collections.defaultdict(list) +def main(): + colorama.init() + parser = argparse.ArgumentParser() + parser.add_argument( + "files", nargs="*", default=[], help="files to be processed (regex on path)" + ) + parser.add_argument( + "-c", "--changed", action="store_true", help="Only run on changed files" + ) + parser.add_argument( + "--print-slowest", action="store_true", help="Print the slowest checks" + ) + args = parser.parse_args() -def add_errors(fname, errs): - if not isinstance(errs, list): - errs = [errs] - for err in errs: - if err is None: + global EXECUTABLE_BIT + EXECUTABLE_BIT = git_ls_files() + files = list(EXECUTABLE_BIT.keys()) + # Match against re + file_name_re = re.compile("|".join(args.files)) + files = [p for p in files if file_name_re.search(p)] + + if args.changed: + files = filter_changed(files) + + files.sort() + + for fname in files: + _, ext = os.path.splitext(fname) + run_checks(LINT_FILE_CHECKS, fname, fname) + if ext in ignore_types: continue try: - lineno, col, msg = err - except ValueError: - lineno = 1 - col = 1 - msg = err - if not isinstance(msg, str): - raise ValueError("Error is not instance of string!") - if not isinstance(lineno, int): - raise ValueError("Line number is not an int!") - if not isinstance(col, int): - raise ValueError("Column number is not an int!") - errors[fname].append((lineno, col, msg)) + with codecs.open(fname, "r", encoding="utf-8") as f_handle: + content = f_handle.read() + except UnicodeDecodeError: + add_errors( + fname, + "File is not readable as UTF-8. Please set your editor to UTF-8 mode.", + ) + continue + run_checks(LINT_CONTENT_CHECKS, fname, fname, content) + run_checks(LINT_POST_CHECKS, "POST") -for fname in files: - _, ext = os.path.splitext(fname) - run_checks(LINT_FILE_CHECKS, fname, fname) - if ext in ignore_types: - continue - try: - with codecs.open(fname, "r", encoding="utf-8") as f_handle: - content = f_handle.read() - except UnicodeDecodeError: - add_errors( - fname, - "File is not readable as UTF-8. Please set your editor to UTF-8 mode.", + for f, errs in sorted(errors.items()): + bold = functools.partial(styled, colorama.Style.BRIGHT) + bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED)) + err_str = ( + f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n" + for lineno, col, msg in errs ) - continue - run_checks(LINT_CONTENT_CHECKS, fname, fname, content) + print_error_for_file(f, "\n".join(err_str)) -run_checks(LINT_POST_CHECKS, "POST") + if args.print_slowest: + lint_times = [] + for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS: + durations = lint.get("durations", []) + lint_times.append((sum(durations), len(durations), lint["func"].__name__)) + lint_times.sort(key=lambda x: -x[0]) + for i in range(min(len(lint_times), 10)): + dur, invocations, name = lint_times[i] + print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)") + print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s") -for f, errs in sorted(errors.items()): - bold = functools.partial(styled, colorama.Style.BRIGHT) - bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED)) - err_str = ( - f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n" - for lineno, col, msg in errs - ) - print_error_for_file(f, "\n".join(err_str)) + return len(errors) -if args.print_slowest: - lint_times = [] - for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS: - durations = lint.get("durations", []) - lint_times.append((sum(durations), len(durations), lint["func"].__name__)) - lint_times.sort(key=lambda x: -x[0]) - for i in range(min(len(lint_times), 10)): - dur, invocations, name = lint_times[i] - print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)") - print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s") -sys.exit(len(errors)) +if __name__ == "__main__": + sys.exit(main()) diff --git a/script/helpers.py b/script/helpers.py index c042362aeb..b1908e9875 100644 --- a/script/helpers.py +++ b/script/helpers.py @@ -1,10 +1,11 @@ -import colorama +import json import os.path import re import subprocess -import json from pathlib import Path +import colorama + root_path = os.path.abspath(os.path.normpath(os.path.join(__file__, "..", ".."))) basepath = os.path.join(root_path, "esphome") temp_folder = os.path.join(root_path, ".temp") @@ -44,7 +45,7 @@ def build_all_include(): content = "\n".join(headers) p = Path(temp_header_file) p.parent.mkdir(exist_ok=True) - p.write_text(content) + p.write_text(content, encoding="utf-8") def walk_files(path): @@ -54,14 +55,14 @@ def walk_files(path): def get_output(*args): - proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, err = proc.communicate() + with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: + output, _ = proc.communicate() return output.decode("utf-8") def get_err(*args): - proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) - output, err = proc.communicate() + with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc: + _, err = proc.communicate() return err.decode("utf-8") @@ -78,7 +79,7 @@ def changed_files(): merge_base = splitlines_no_ends(get_output(*command))[0] break # pylint: disable=bare-except - except: + except: # noqa: E722 pass else: raise ValueError("Git not configured") @@ -103,7 +104,7 @@ def filter_changed(files): def filter_grep(files, value): matched = [] for file in files: - with open(file) as handle: + with open(file, encoding="utf-8") as handle: contents = handle.read() if value in contents: matched.append(file) @@ -114,8 +115,8 @@ def git_ls_files(patterns=None): command = ["git", "ls-files", "-s"] if patterns is not None: command.extend(patterns) - proc = subprocess.Popen(command, stdout=subprocess.PIPE) - output, err = proc.communicate() + with subprocess.Popen(command, stdout=subprocess.PIPE) as proc: + output, _ = proc.communicate() lines = [x.split() for x in output.decode("utf-8").splitlines()] return {s[3].strip(): int(s[0]) for s in lines} diff --git a/script/sync-device_class.py b/script/sync-device_class.py index ae6f4be0c8..8f91b97997 100755 --- a/script/sync-device_class.py +++ b/script/sync-device_class.py @@ -2,6 +2,7 @@ import re +# pylint: disable=import-error from homeassistant.components.binary_sensor import BinarySensorDeviceClass from homeassistant.components.button import ButtonDeviceClass from homeassistant.components.cover import CoverDeviceClass @@ -9,6 +10,8 @@ from homeassistant.components.number import NumberDeviceClass from homeassistant.components.sensor import SensorDeviceClass from homeassistant.components.switch import SwitchDeviceClass +# pylint: enable=import-error + BLOCKLIST = ( # requires special support on HA side "enum", @@ -25,10 +28,10 @@ DOMAINS = { def sub(path, pattern, repl): - with open(path) as handle: + with open(path, encoding="utf-8") as handle: content = handle.read() content = re.sub(pattern, repl, content, flags=re.MULTILINE) - with open(path, "w") as handle: + with open(path, "w", encoding="utf-8") as handle: handle.write(content)