Lint the script folder files (#5991)

This commit is contained in:
Jesse Hills 2023-12-22 20:03:47 +13:00 committed by GitHub
parent 676ae6b26e
commit d2d0058386
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 398 additions and 377 deletions

View File

@ -17,28 +17,22 @@ then run this script with python3 and the files
will be generated, they still need to be formatted
"""
import re
import os
import re
import sys
from abc import ABC, abstractmethod
from pathlib import Path
from textwrap import dedent
from subprocess import call
from textwrap import dedent
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import aioesphomeapi.api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
file_header = "// This file was automatically generated with a tool.\n"
file_header += "// See scripts/api_protobuf/api_protobuf.py\n"
cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api"
prot = root / "api.protoc"
call(["protoc", "-o", str(prot), "-I", str(root), "api.proto"])
content = prot.read_bytes()
d = descriptor.FileDescriptorSet.FromString(content)
FILE_HEADER = """// This file was automatically generated with a tool.
// See scripts/api_protobuf/api_protobuf.py
"""
def indent_list(text, padding=" "):
@ -64,7 +58,7 @@ def camel_to_snake(name):
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
class TypeInfo:
class TypeInfo(ABC):
def __init__(self, field):
self._field = field
@ -186,10 +180,12 @@ class TypeInfo:
def dump_content(self):
o = f'out.append(" {self.name}: ");\n'
o += self.dump(f"this->{self.field_name}") + "\n"
o += f'out.append("\\n");\n'
o += 'out.append("\\n");\n'
return o
dump = None
@abstractmethod
def dump(self, name: str):
pass
TYPE_INFO = {}
@ -212,7 +208,7 @@ class DoubleType(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -225,7 +221,7 @@ class FloatType(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -238,7 +234,7 @@ class Int64Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -251,7 +247,7 @@ class UInt64Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%llu", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -264,7 +260,7 @@ class Int32Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -277,7 +273,7 @@ class Fixed64Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%llu", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -290,7 +286,7 @@ class Fixed32Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -372,7 +368,7 @@ class UInt32Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%" PRIu32, {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -406,7 +402,7 @@ class SFixed32Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -419,7 +415,7 @@ class SFixed64Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -432,7 +428,7 @@ class SInt32Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%" PRId32, {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -445,7 +441,7 @@ class SInt64Type(TypeInfo):
def dump(self, name):
o = f'sprintf(buffer, "%lld", {name});\n'
o += f"out.append(buffer);"
o += "out.append(buffer);"
return o
@ -527,7 +523,7 @@ class RepeatedTypeInfo(TypeInfo):
def encode_content(self):
o = f"for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{\n"
o += f" buffer.{self._ti.encode_func}({self.number}, it, true);\n"
o += f"}}"
o += "}"
return o
@property
@ -535,10 +531,13 @@ class RepeatedTypeInfo(TypeInfo):
o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump("it")) + "\n"
o += f' out.append("\\n");\n'
o += f"}}\n"
o += ' out.append("\\n");\n'
o += "}\n"
return o
def dump(self, _: str):
pass
def build_enum_type(desc):
name = desc.name
@ -547,17 +546,17 @@ def build_enum_type(desc):
out += f" {v.name} = {v.number},\n"
out += "};\n"
cpp = f"#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += f"template<> const char *proto_enum_to_string<enums::{name}>(enums::{name} value) {{\n"
cpp += f" switch (value) {{\n"
cpp += " switch (value) {\n"
for v in desc.value:
cpp += f" case enums::{v.name}:\n"
cpp += f' return "{v.name}";\n'
cpp += f" default:\n"
cpp += f' return "UNKNOWN";\n'
cpp += f" }}\n"
cpp += f"}}\n"
cpp += f"#endif\n"
cpp += " default:\n"
cpp += ' return "UNKNOWN";\n'
cpp += " }\n"
cpp += "}\n"
cpp += "#endif\n"
return out, cpp
@ -652,10 +651,10 @@ def build_message_type(desc):
o += f" {dump[0]} "
else:
o += "\n"
o += f" __attribute__((unused)) char buffer[64];\n"
o += " __attribute__((unused)) char buffer[64];\n"
o += f' out.append("{desc.name} {{\\n");\n'
o += indent("\n".join(dump)) + "\n"
o += f' out.append("}}");\n'
o += ' out.append("}");\n'
else:
o2 = f'out.append("{desc.name} {{}}");'
if len(o) + len(o2) + 3 < 120:
@ -664,9 +663,9 @@ def build_message_type(desc):
o += "\n"
o += f" {o2}\n"
o += "}\n"
cpp += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cpp += o
cpp += f"#endif\n"
cpp += "#endif\n"
prot = "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
prot += "void dump_to(std::string &out) const override;\n"
prot += "#endif\n"
@ -684,71 +683,12 @@ def build_message_type(desc):
return out, cpp
file = d.file[0]
content = file_header
content += """\
#pragma once
#include "proto.h"
namespace esphome {
namespace api {
"""
cpp = file_header
cpp += """\
#include "api_pb2.h"
#include "esphome/core/log.h"
#include <cinttypes>
namespace esphome {
namespace api {
"""
content += "namespace enums {\n\n"
for enum in file.enum_type:
s, c = build_enum_type(enum)
content += s
cpp += c
content += "\n} // namespace enums\n\n"
mt = file.message_type
for m in mt:
s, c = build_message_type(m)
content += s
cpp += c
content += """\
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
with open(root / "api_pb2.h", "w") as f:
f.write(content)
with open(root / "api_pb2.cpp", "w") as f:
f.write(cpp)
SOURCE_BOTH = 0
SOURCE_SERVER = 1
SOURCE_CLIENT = 2
RECEIVE_CASES = {}
class_name = "APIServerConnectionBase"
ifdefs = {}
@ -768,7 +708,6 @@ def build_service_message_type(mt):
ifdef = get_opt(mt, pb.ifdef)
log = get_opt(mt, pb.log, True)
nodelay = get_opt(mt, pb.no_delay, False)
hout = ""
cout = ""
@ -781,14 +720,14 @@ def build_service_message_type(mt):
# Generate send
func = f"send_{snake}"
hout += f"bool {func}(const {mt.name} &msg);\n"
cout += f"bool {class_name}::{func}(const {mt.name} &msg) {{\n"
cout += f"bool APIServerConnectionBase::{func}(const {mt.name} &msg) {{\n"
if log:
cout += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cout += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
cout += f"#endif\n"
cout += "#endif\n"
# cout += f' this->set_nodelay({str(nodelay).lower()});\n'
cout += f" return this->send_message_<{mt.name}>(msg, {id_});\n"
cout += f"}}\n"
cout += "}\n"
if source in (SOURCE_BOTH, SOURCE_CLIENT):
# Generate receive
func = f"on_{snake}"
@ -797,169 +736,242 @@ def build_service_message_type(mt):
if ifdef is not None:
case += f"#ifdef {ifdef}\n"
case += f"{mt.name} msg;\n"
case += f"msg.decode(msg_data, msg_size);\n"
case += "msg.decode(msg_data, msg_size);\n"
if log:
case += f"#ifdef HAS_PROTO_MESSAGE_DUMP\n"
case += "#ifdef HAS_PROTO_MESSAGE_DUMP\n"
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
case += f"#endif\n"
case += "#endif\n"
case += f"this->{func}(msg);\n"
if ifdef is not None:
case += f"#endif\n"
case += "#endif\n"
case += "break;"
RECEIVE_CASES[id_] = case
if ifdef is not None:
hout += f"#endif\n"
cout += f"#endif\n"
hout += "#endif\n"
cout += "#endif\n"
return hout, cout
hpp = file_header
hpp += """\
#pragma once
def main():
cwd = Path(__file__).resolve().parent
root = cwd.parent.parent / "esphome" / "components" / "api"
prot_file = root / "api.protoc"
call(["protoc", "-o", str(prot_file), "-I", str(root), "api.proto"])
proto_content = prot_file.read_bytes()
#include "api_pb2.h"
#include "esphome/core/defines.h"
# pylint: disable-next=no-member
d = descriptor.FileDescriptorSet.FromString(proto_content)
namespace esphome {
namespace api {
file = d.file[0]
content = FILE_HEADER
content += """\
#pragma once
"""
#include "proto.h"
cpp = file_header
cpp += """\
#include "api_pb2_service.h"
#include "esphome/core/log.h"
namespace esphome {
namespace api {
namespace esphome {
namespace api {
"""
static const char *const TAG = "api.service";
cpp = FILE_HEADER
cpp += """\
#include "api_pb2.h"
#include "esphome/core/log.h"
"""
#include <cinttypes>
hpp += f"class {class_name} : public ProtoService {{\n"
hpp += " public:\n"
namespace esphome {
namespace api {
for mt in file.message_type:
obj = build_service_message_type(mt)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + "\n"
cpp += cout
"""
cases = list(RECEIVE_CASES.items())
cases.sort()
hpp += " protected:\n"
hpp += f" bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
out += f" switch (msg_type) {{\n"
for i, case in cases:
c = f"case {i}: {{\n"
c += indent(case) + "\n"
c += f"}}"
out += indent(c, " ") + "\n"
out += " default:\n"
out += " return false;\n"
out += " }\n"
out += " return true;\n"
out += "}\n"
cpp += out
hpp += "};\n"
content += "namespace enums {\n\n"
serv = file.service[0]
class_name = "APIServerConnection"
hpp += "\n"
hpp += f"class {class_name} : public {class_name}Base {{\n"
hpp += " public:\n"
hpp_protected = ""
cpp += "\n"
for enum in file.enum_type:
s, c = build_enum_type(enum)
content += s
cpp += c
m = serv.method[0]
for m in serv.method:
func = m.name
inp = m.input_type[1:]
ret = m.output_type[1:]
is_void = ret == "void"
snake = camel_to_snake(inp)
on_func = f"on_{snake}"
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
content += "\n} // namespace enums\n\n"
ifdef = ifdefs.get(inp, None)
mt = file.message_type
if ifdef is not None:
hpp += f"#ifdef {ifdef}\n"
hpp_protected += f"#ifdef {ifdef}\n"
cpp += f"#ifdef {ifdef}\n"
for m in mt:
s, c = build_message_type(m)
content += s
cpp += c
hpp_protected += f" void {on_func}(const {inp} &msg) override;\n"
hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n"
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
body = ""
if needs_conn:
body += "if (!this->is_connection_setup()) {\n"
body += " this->on_no_setup_connection();\n"
body += " return;\n"
body += "}\n"
if needs_auth:
body += "if (!this->is_authenticated()) {\n"
body += " this->on_unauthenticated_access();\n"
body += " return;\n"
body += "}\n"
content += """\
if is_void:
body += f"this->{func}(msg);\n"
else:
body += f"{ret} ret = this->{func}(msg);\n"
ret_snake = camel_to_snake(ret)
body += f"if (!this->send_{ret_snake}(ret)) {{\n"
body += f" this->on_fatal_error();\n"
body += "}\n"
cpp += indent(body) + "\n" + "}\n"
} // namespace api
} // namespace esphome
"""
cpp += """\
if ifdef is not None:
hpp += f"#endif\n"
hpp_protected += f"#endif\n"
cpp += f"#endif\n"
} // namespace api
} // namespace esphome
"""
hpp += " protected:\n"
hpp += hpp_protected
hpp += "};\n"
with open(root / "api_pb2.h", "w", encoding="utf-8") as f:
f.write(content)
hpp += """\
with open(root / "api_pb2.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
} // namespace api
} // namespace esphome
"""
cpp += """\
hpp = FILE_HEADER
hpp += """\
#pragma once
} // namespace api
} // namespace esphome
"""
#include "api_pb2.h"
#include "esphome/core/defines.h"
with open(root / "api_pb2_service.h", "w") as f:
f.write(hpp)
namespace esphome {
namespace api {
with open(root / "api_pb2_service.cpp", "w") as f:
f.write(cpp)
"""
prot.unlink()
cpp = FILE_HEADER
cpp += """\
#include "api_pb2_service.h"
#include "esphome/core/log.h"
try:
import clang_format
namespace esphome {
namespace api {
def exec_clang_format(path):
clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
)
call([clang_format_path, "-i", path])
static const char *const TAG = "api.service";
exec_clang_format(root / "api_pb2_service.h")
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
except ImportError:
pass
"""
class_name = "APIServerConnectionBase"
hpp += f"class {class_name} : public ProtoService {{\n"
hpp += " public:\n"
for mt in file.message_type:
obj = build_service_message_type(mt)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + "\n"
cpp += cout
cases = list(RECEIVE_CASES.items())
cases.sort()
hpp += " protected:\n"
hpp += " bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n"
out = f"bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n"
out += " switch (msg_type) {\n"
for i, case in cases:
c = f"case {i}: {{\n"
c += indent(case) + "\n"
c += "}"
out += indent(c, " ") + "\n"
out += " default:\n"
out += " return false;\n"
out += " }\n"
out += " return true;\n"
out += "}\n"
cpp += out
hpp += "};\n"
serv = file.service[0]
class_name = "APIServerConnection"
hpp += "\n"
hpp += f"class {class_name} : public {class_name}Base {{\n"
hpp += " public:\n"
hpp_protected = ""
cpp += "\n"
m = serv.method[0]
for m in serv.method:
func = m.name
inp = m.input_type[1:]
ret = m.output_type[1:]
is_void = ret == "void"
snake = camel_to_snake(inp)
on_func = f"on_{snake}"
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
ifdef = ifdefs.get(inp, None)
if ifdef is not None:
hpp += f"#ifdef {ifdef}\n"
hpp_protected += f"#ifdef {ifdef}\n"
cpp += f"#ifdef {ifdef}\n"
hpp_protected += f" void {on_func}(const {inp} &msg) override;\n"
hpp += f" virtual {ret} {func}(const {inp} &msg) = 0;\n"
cpp += f"void {class_name}::{on_func}(const {inp} &msg) {{\n"
body = ""
if needs_conn:
body += "if (!this->is_connection_setup()) {\n"
body += " this->on_no_setup_connection();\n"
body += " return;\n"
body += "}\n"
if needs_auth:
body += "if (!this->is_authenticated()) {\n"
body += " this->on_unauthenticated_access();\n"
body += " return;\n"
body += "}\n"
if is_void:
body += f"this->{func}(msg);\n"
else:
body += f"{ret} ret = this->{func}(msg);\n"
ret_snake = camel_to_snake(ret)
body += f"if (!this->send_{ret_snake}(ret)) {{\n"
body += " this->on_fatal_error();\n"
body += "}\n"
cpp += indent(body) + "\n" + "}\n"
if ifdef is not None:
hpp += "#endif\n"
hpp_protected += "#endif\n"
cpp += "#endif\n"
hpp += " protected:\n"
hpp += hpp_protected
hpp += "};\n"
hpp += """\
} // namespace api
} // namespace esphome
"""
cpp += """\
} // namespace api
} // namespace esphome
"""
with open(root / "api_pb2_service.h", "w", encoding="utf-8") as f:
f.write(hpp)
with open(root / "api_pb2_service.cpp", "w", encoding="utf-8") as f:
f.write(cpp)
prot_file.unlink()
try:
import clang_format
def exec_clang_format(path):
clang_format_path = os.path.join(
os.path.dirname(clang_format.__file__), "data", "bin", "clang-format"
)
call([clang_format_path, "-i", path])
exec_clang_format(root / "api_pb2_service.h")
exec_clang_format(root / "api_pb2_service.cpp")
exec_clang_format(root / "api_pb2.h")
exec_clang_format(root / "api_pb2.cpp")
except ImportError:
pass
if __name__ == "__main__":
sys.exit(main())

View File

@ -61,6 +61,7 @@ solve_registry = []
def get_component_names():
# pylint: disable-next=redefined-outer-name,reimported
from esphome.loader import CORE_COMPONENTS_PATH
component_names = ["esphome", "sensor", "esp32", "esp8266"]
@ -82,9 +83,12 @@ def load_components():
components[domain] = get_component(domain)
# pylint: disable=wrong-import-position
from esphome.const import CONF_TYPE, KEY_CORE
from esphome.core import CORE
# pylint: enable=wrong-import-position
CORE.data[KEY_CORE] = {}
load_components()
@ -114,7 +118,7 @@ def write_file(name, obj):
def delete_extra_files(keep_names):
for d in os.listdir(args.output_path):
if d.endswith(".json") and not d[:-5] in keep_names:
if d.endswith(".json") and d[:-5] not in keep_names:
os.remove(os.path.join(args.output_path, d))
print(f"Deleted {d}")
@ -552,11 +556,11 @@ def shrink():
s = f"{domain}.{schema_name}"
if (
not s.endswith("." + S_CONFIG_SCHEMA)
and s not in referenced_schemas.keys()
and s not in referenced_schemas
and not is_platform_schema(s)
):
print(f"Removing {s}")
output[domain][S_SCHEMAS].pop(schema_name)
domain_schemas[S_SCHEMAS].pop(schema_name)
def build_schema():
@ -564,7 +568,7 @@ def build_schema():
# check esphome was not loaded globally (IDE auto imports)
if len(ejs.extended_schemas) == 0:
raise Exception(
raise LookupError(
"no data collected. Did you globally import an ESPHome component?"
)
@ -703,7 +707,7 @@ def convert(schema, config_var, path):
if schema_instance is schema:
assert S_CONFIG_VARS not in config_var
assert S_EXTENDS not in config_var
if not S_TYPE in config_var:
if S_TYPE not in config_var:
config_var[S_TYPE] = S_SCHEMA
# assert config_var[S_TYPE] == S_SCHEMA
@ -765,9 +769,9 @@ def convert(schema, config_var, path):
elif schema == automation.validate_potentially_and_condition:
config_var[S_TYPE] = "registry"
config_var["registry"] = "condition"
elif schema == cv.int_ or schema == cv.int_range:
elif schema in (cv.int_, cv.int_range):
config_var[S_TYPE] = "integer"
elif schema == cv.string or schema == cv.string_strict or schema == cv.valid_name:
elif schema in (cv.string, cv.string_strict, cv.valid_name):
config_var[S_TYPE] = "string"
elif isinstance(schema, vol.Schema):
@ -779,6 +783,7 @@ def convert(schema, config_var, path):
config_var |= pin_validators[repr_schema]
config_var[S_TYPE] = "pin"
# pylint: disable-next=too-many-nested-blocks
elif repr_schema in ejs.hidden_schemas:
schema_type = ejs.hidden_schemas[repr_schema]
@ -869,7 +874,7 @@ def convert(schema, config_var, path):
config_var["use_id_type"] = str(data.base)
config_var[S_TYPE] = "use_id"
else:
raise Exception("Unknown extracted schema type")
raise TypeError("Unknown extracted schema type")
elif config_var.get("key") == "GeneratedID":
if path.startswith("i2c/CONFIG_SCHEMA/") and path.endswith("/id"):
config_var["id_type"] = {
@ -884,7 +889,7 @@ def convert(schema, config_var, path):
elif path == "pins/esp32/val 1/id":
config_var["id_type"] = "pin"
else:
raise Exception("Cannot determine id_type for " + path)
raise TypeError("Cannot determine id_type for " + path)
elif repr_schema in ejs.registry_schemas:
solve_registry.append((ejs.registry_schemas[repr_schema], config_var))
@ -948,11 +953,7 @@ def convert_keys(converted, schema, path):
result["key"] = "GeneratedID"
elif isinstance(k, cv.Required):
result["key"] = "Required"
elif (
isinstance(k, cv.Optional)
or isinstance(k, cv.Inclusive)
or isinstance(k, cv.Exclusive)
):
elif isinstance(k, (cv.Optional, cv.Inclusive, cv.Exclusive)):
result["key"] = "Optional"
else:
converted["key"] = "String"

View File

@ -2,7 +2,6 @@
import argparse
import re
import subprocess
from dataclasses import dataclass
import sys
@ -40,12 +39,12 @@ class Version:
def sub(path, pattern, repl, expected_count=1):
with open(path) as fh:
with open(path, encoding="utf-8") as fh:
content = fh.read()
content, count = re.subn(pattern, repl, content, flags=re.MULTILINE)
if expected_count is not None:
assert count == expected_count, f"Pattern {pattern} replacement failed!"
with open(path, "w") as fh:
with open(path, "w", encoding="utf-8") as fh:
fh.write(content)

View File

@ -1,10 +1,8 @@
#!/usr/bin/env python3
from helpers import styled, print_error_for_file, git_ls_files, filter_changed
import argparse
import codecs
import collections
import colorama
import fnmatch
import functools
import os.path
@ -12,6 +10,9 @@ import re
import sys
import time
import colorama
from helpers import filter_changed, git_ls_files, print_error_for_file, styled
sys.path.append(os.path.dirname(__file__))
@ -30,31 +31,6 @@ def find_all(a_str, sub):
column += len(sub)
colorama.init()
parser = argparse.ArgumentParser()
parser.add_argument(
"files", nargs="*", default=[], help="files to be processed (regex on path)"
)
parser.add_argument(
"-c", "--changed", action="store_true", help="Only run on changed files"
)
parser.add_argument(
"--print-slowest", action="store_true", help="Print the slowest checks"
)
args = parser.parse_args()
EXECUTABLE_BIT = git_ls_files()
files = list(EXECUTABLE_BIT.keys())
# Match against re
file_name_re = re.compile("|".join(args.files))
files = [p for p in files if file_name_re.search(p)]
if args.changed:
files = filter_changed(files)
files.sort()
file_types = (
".h",
".c",
@ -86,6 +62,30 @@ ignore_types = (".ico", ".png", ".woff", ".woff2", "")
LINT_FILE_CHECKS = []
LINT_CONTENT_CHECKS = []
LINT_POST_CHECKS = []
EXECUTABLE_BIT = {}
errors = collections.defaultdict(list)
def add_errors(fname, errs):
if not isinstance(errs, list):
errs = [errs]
for err in errs:
if err is None:
continue
try:
lineno, col, msg = err
except ValueError:
lineno = 1
col = 1
msg = err
if not isinstance(msg, str):
raise ValueError("Error is not instance of string!")
if not isinstance(lineno, int):
raise ValueError("Line number is not an int!")
if not isinstance(col, int):
raise ValueError("Column number is not an int!")
errors[fname].append((lineno, col, msg))
def run_check(lint_obj, fname, *args):
@ -155,7 +155,7 @@ def lint_re_check(regex, **kwargs):
def decorator(func):
@functools.wraps(func)
def new_func(fname, content):
errors = []
errs = []
for match in prog.finditer(content):
if "NOLINT" in match.group(0):
continue
@ -165,8 +165,8 @@ def lint_re_check(regex, **kwargs):
err = func(fname, match)
if err is None:
continue
errors.append((lineno, col + 1, err))
return errors
errs.append((lineno, col + 1, err))
return errs
return decor(new_func)
@ -182,13 +182,13 @@ def lint_content_find_check(find, only_first=False, **kwargs):
find_ = find
if callable(find):
find_ = find(fname, content)
errors = []
errs = []
for line, col in find_all(content, find_):
err = func(fname)
errors.append((line + 1, col + 1, err))
errs.append((line + 1, col + 1, err))
if only_first:
break
return errors
return errs
return decor(new_func)
@ -235,8 +235,8 @@ def lint_executable_bit(fname):
ex = EXECUTABLE_BIT[fname]
if ex != 100644:
return (
"File has invalid executable bit {}. If running from a windows machine please "
"see disabling executable bit in git.".format(ex)
f"File has invalid executable bit {ex}. If running from a windows machine please "
"see disabling executable bit in git."
)
return None
@ -285,8 +285,8 @@ def lint_no_defines(fname, match):
s = highlight(f"static const uint8_t {match.group(1)} = {match.group(2)};")
return (
"#define macros for integer constants are not allowed, please use "
"{} style instead (replace uint8_t with the appropriate "
"datatype). See also Google style guide.".format(s)
f"{s} style instead (replace uint8_t with the appropriate "
"datatype). See also Google style guide."
)
@ -296,11 +296,11 @@ def lint_no_long_delays(fname, match):
if duration_ms < 50:
return None
return (
"{} - long calls to delay() are not allowed in ESPHome because everything executes "
"in one thread. Calling delay() will block the main thread and slow down ESPHome.\n"
f"{highlight(match.group(0).strip())} - long calls to delay() are not allowed "
"in ESPHome because everything executes in one thread. Calling delay() will "
"block the main thread and slow down ESPHome.\n"
"If there's no way to work around the delay() and it doesn't execute often, please add "
"a '// NOLINT' comment to the line."
"".format(highlight(match.group(0).strip()))
)
@ -311,28 +311,28 @@ def lint_const_ordered(fname, content):
Reason: Otherwise people add it to the end, and then that results in merge conflicts.
"""
lines = content.splitlines()
errors = []
errs = []
for start in ["CONF_", "ICON_", "UNIT_"]:
matching = [
(i + 1, line) for i, line in enumerate(lines) if line.startswith(start)
]
ordered = list(sorted(matching, key=lambda x: x[1].replace("_", " ")))
ordered = [(mi, ol) for (mi, _), (_, ol) in zip(matching, ordered)]
for (mi, ml), (oi, ol) in zip(matching, ordered):
if ml == ol:
for (mi, mline), (_, ol) in zip(matching, ordered):
if mline == ol:
continue
target = next(i for i, l in ordered if l == ml)
target_text = next(l for i, l in matching if target == i)
errors.append(
target = next(i for i, line in ordered if line == mline)
target_text = next(line for i, line in matching if target == i)
errs.append(
(
mi,
1,
f"Constant {highlight(ml)} is not ordered, please make sure all "
f"Constant {highlight(mline)} is not ordered, please make sure all "
f"constants are ordered. See line {mi} (should go to line {target}, "
f"{target_text})",
)
)
return errors
return errs
@lint_re_check(r'^\s*CONF_([A-Z_0-9a-z]+)\s+=\s+[\'"](.*?)[\'"]\s*?$', include=["*.py"])
@ -344,15 +344,14 @@ def lint_conf_matches(fname, match):
if const_norm == value_norm:
return None
return (
"Constant {} does not match value {}! Please make sure the constant's name matches its "
"value!"
"".format(highlight("CONF_" + const), highlight(value))
f"Constant {highlight('CONF_' + const)} does not match value {highlight(value)}! "
"Please make sure the constant's name matches its value!"
)
CONF_RE = r'^(CONF_[a-zA-Z0-9_]+)\s*=\s*[\'"].*?[\'"]\s*?$'
with codecs.open("esphome/const.py", "r", encoding="utf-8") as f_handle:
constants_content = f_handle.read()
with codecs.open("esphome/const.py", "r", encoding="utf-8") as const_f_handle:
constants_content = const_f_handle.read()
CONSTANTS = [m.group(1) for m in re.finditer(CONF_RE, constants_content, re.MULTILINE)]
CONSTANTS_USES = collections.defaultdict(list)
@ -365,8 +364,8 @@ def lint_conf_from_const_py(fname, match):
CONSTANTS_USES[name].append(fname)
return None
return (
"Constant {} has already been defined in const.py - please import the constant from "
"const.py directly.".format(highlight(name))
f"Constant {highlight(name)} has already been defined in const.py - "
"please import the constant from const.py directly."
)
@ -473,16 +472,15 @@ def lint_no_byte_datatype(fname, match):
@lint_post_check
def lint_constants_usage():
errors = []
errs = []
for constant, uses in CONSTANTS_USES.items():
if len(uses) < 4:
continue
errors.append(
"Constant {} is defined in {} files. Please move all definitions of the "
"constant to const.py (Uses: {})"
"".format(highlight(constant), len(uses), ", ".join(uses))
errs.append(
f"Constant {highlight(constant)} is defined in {len(uses)} files. Please move all definitions of the "
f"constant to const.py (Uses: {', '.join(uses)})"
)
return errors
return errs
def relative_cpp_search_text(fname, content):
@ -553,7 +551,7 @@ def lint_namespace(fname, content):
return (
"Invalid namespace found in C++ file. All integration C++ files should put all "
"functions in a separate namespace that matches the integration's name. "
"Please make sure the file contains {}".format(highlight(search))
f"Please make sure the file contains {highlight(search)}"
)
@ -639,66 +637,73 @@ def lint_log_in_header(fname):
)
errors = collections.defaultdict(list)
def main():
colorama.init()
parser = argparse.ArgumentParser()
parser.add_argument(
"files", nargs="*", default=[], help="files to be processed (regex on path)"
)
parser.add_argument(
"-c", "--changed", action="store_true", help="Only run on changed files"
)
parser.add_argument(
"--print-slowest", action="store_true", help="Print the slowest checks"
)
args = parser.parse_args()
def add_errors(fname, errs):
if not isinstance(errs, list):
errs = [errs]
for err in errs:
if err is None:
global EXECUTABLE_BIT
EXECUTABLE_BIT = git_ls_files()
files = list(EXECUTABLE_BIT.keys())
# Match against re
file_name_re = re.compile("|".join(args.files))
files = [p for p in files if file_name_re.search(p)]
if args.changed:
files = filter_changed(files)
files.sort()
for fname in files:
_, ext = os.path.splitext(fname)
run_checks(LINT_FILE_CHECKS, fname, fname)
if ext in ignore_types:
continue
try:
lineno, col, msg = err
except ValueError:
lineno = 1
col = 1
msg = err
if not isinstance(msg, str):
raise ValueError("Error is not instance of string!")
if not isinstance(lineno, int):
raise ValueError("Line number is not an int!")
if not isinstance(col, int):
raise ValueError("Column number is not an int!")
errors[fname].append((lineno, col, msg))
with codecs.open(fname, "r", encoding="utf-8") as f_handle:
content = f_handle.read()
except UnicodeDecodeError:
add_errors(
fname,
"File is not readable as UTF-8. Please set your editor to UTF-8 mode.",
)
continue
run_checks(LINT_CONTENT_CHECKS, fname, fname, content)
run_checks(LINT_POST_CHECKS, "POST")
for fname in files:
_, ext = os.path.splitext(fname)
run_checks(LINT_FILE_CHECKS, fname, fname)
if ext in ignore_types:
continue
try:
with codecs.open(fname, "r", encoding="utf-8") as f_handle:
content = f_handle.read()
except UnicodeDecodeError:
add_errors(
fname,
"File is not readable as UTF-8. Please set your editor to UTF-8 mode.",
for f, errs in sorted(errors.items()):
bold = functools.partial(styled, colorama.Style.BRIGHT)
bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED))
err_str = (
f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n"
for lineno, col, msg in errs
)
continue
run_checks(LINT_CONTENT_CHECKS, fname, fname, content)
print_error_for_file(f, "\n".join(err_str))
run_checks(LINT_POST_CHECKS, "POST")
if args.print_slowest:
lint_times = []
for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS:
durations = lint.get("durations", [])
lint_times.append((sum(durations), len(durations), lint["func"].__name__))
lint_times.sort(key=lambda x: -x[0])
for i in range(min(len(lint_times), 10)):
dur, invocations, name = lint_times[i]
print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)")
print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s")
for f, errs in sorted(errors.items()):
bold = functools.partial(styled, colorama.Style.BRIGHT)
bold_red = functools.partial(styled, (colorama.Style.BRIGHT, colorama.Fore.RED))
err_str = (
f"{bold(f'{f}:{lineno}:{col}:')} {bold_red('lint:')} {msg}\n"
for lineno, col, msg in errs
)
print_error_for_file(f, "\n".join(err_str))
return len(errors)
if args.print_slowest:
lint_times = []
for lint in LINT_FILE_CHECKS + LINT_CONTENT_CHECKS + LINT_POST_CHECKS:
durations = lint.get("durations", [])
lint_times.append((sum(durations), len(durations), lint["func"].__name__))
lint_times.sort(key=lambda x: -x[0])
for i in range(min(len(lint_times), 10)):
dur, invocations, name = lint_times[i]
print(f" - '{name}' took {dur:.2f}s total (ran on {invocations} files)")
print(f"Total time measured: {sum(x[0] for x in lint_times):.2f}s")
sys.exit(len(errors))
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,10 +1,11 @@
import colorama
import json
import os.path
import re
import subprocess
import json
from pathlib import Path
import colorama
root_path = os.path.abspath(os.path.normpath(os.path.join(__file__, "..", "..")))
basepath = os.path.join(root_path, "esphome")
temp_folder = os.path.join(root_path, ".temp")
@ -44,7 +45,7 @@ def build_all_include():
content = "\n".join(headers)
p = Path(temp_header_file)
p.parent.mkdir(exist_ok=True)
p.write_text(content)
p.write_text(content, encoding="utf-8")
def walk_files(path):
@ -54,14 +55,14 @@ def walk_files(path):
def get_output(*args):
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = proc.communicate()
with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
output, _ = proc.communicate()
return output.decode("utf-8")
def get_err(*args):
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = proc.communicate()
with subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as proc:
_, err = proc.communicate()
return err.decode("utf-8")
@ -78,7 +79,7 @@ def changed_files():
merge_base = splitlines_no_ends(get_output(*command))[0]
break
# pylint: disable=bare-except
except:
except: # noqa: E722
pass
else:
raise ValueError("Git not configured")
@ -103,7 +104,7 @@ def filter_changed(files):
def filter_grep(files, value):
matched = []
for file in files:
with open(file) as handle:
with open(file, encoding="utf-8") as handle:
contents = handle.read()
if value in contents:
matched.append(file)
@ -114,8 +115,8 @@ def git_ls_files(patterns=None):
command = ["git", "ls-files", "-s"]
if patterns is not None:
command.extend(patterns)
proc = subprocess.Popen(command, stdout=subprocess.PIPE)
output, err = proc.communicate()
with subprocess.Popen(command, stdout=subprocess.PIPE) as proc:
output, _ = proc.communicate()
lines = [x.split() for x in output.decode("utf-8").splitlines()]
return {s[3].strip(): int(s[0]) for s in lines}

View File

@ -2,6 +2,7 @@
import re
# pylint: disable=import-error
from homeassistant.components.binary_sensor import BinarySensorDeviceClass
from homeassistant.components.button import ButtonDeviceClass
from homeassistant.components.cover import CoverDeviceClass
@ -9,6 +10,8 @@ from homeassistant.components.number import NumberDeviceClass
from homeassistant.components.sensor import SensorDeviceClass
from homeassistant.components.switch import SwitchDeviceClass
# pylint: enable=import-error
BLOCKLIST = (
# requires special support on HA side
"enum",
@ -25,10 +28,10 @@ DOMAINS = {
def sub(path, pattern, repl):
with open(path) as handle:
with open(path, encoding="utf-8") as handle:
content = handle.read()
content = re.sub(pattern, repl, content, flags=re.MULTILINE)
with open(path, "w") as handle:
with open(path, "w", encoding="utf-8") as handle:
handle.write(content)