"""Python 3 script to automatically generate C++ classes for ESPHome's native API. It's pretty crappy spaghetti code, but it works. """ import re from pathlib import Path from textwrap import dedent from subprocess import call # Generate with # protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto import api_options_pb2 as pb import google.protobuf.descriptor_pb2 as descriptor cwd = Path(__file__).parent root = cwd.parent.parent / 'esphome' / 'components' / 'api' prot = cwd / 'api.protoc' call(['protoc', '-o', prot, '-I', root, 'api.proto']) content = prot.read_bytes() d = descriptor.FileDescriptorSet.FromString(content) def indent_list(text, padding=u' '): return [padding + line for line in text.splitlines()] def indent(text, padding=u' '): return u'\n'.join(indent_list(text, padding)) def camel_to_snake(name): # https://stackoverflow.com/a/1176023 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() class TypeInfo(): def __init__(self, field): self._field = field @property def default_value(self): return '' @property def name(self): return self._field.name @property def arg_name(self): return self.name @property def field_name(self): return self.name @property def number(self): return self._field.number @property def repeated(self): return self._field.label == 3 @property def cpp_type(self): raise NotImplementedError @property def reference_type(self): return f'{self.cpp_type} ' @property def const_reference_type(self): return f'{self.cpp_type} ' @property def public_content(self) -> str: return [self.class_member] @property def protected_content(self) -> str: return [] @property def class_member(self) -> str: return f'{self.cpp_type} {self.field_name}{{{self.default_value}}}; // NOLINT' @property def decode_varint_content(self) -> str: content = self.decode_varint if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name} = {content}; return true; }}''') decode_varint = None @property def decode_length_content(self) -> str: content = self.decode_length if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name} = {content}; return true; }}''') decode_length = None @property def decode_32bit_content(self) -> str: content = self.decode_32bit if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name} = {content}; return true; }}''') decode_32bit = None @property def decode_64bit_content(self) -> str: content = self.decode_64bit if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name} = {content}; return true; }}''') decode_64bit = None @property def encode_content(self): return f'buffer.{self.encode_func}({self.number}, this->{self.field_name});' encode_func = None @property def dump_content(self): o = f'out.append(" {self.name}: ");\n' o += self.dump(f'this->{self.field_name}') + '\n' o += f'out.append("\\n");\n' return o dump = None TYPE_INFO = {} def register_type(name): def func(value): TYPE_INFO[name] = value return value return func @register_type(1) class DoubleType(TypeInfo): cpp_type = 'double' default_value = '0.0' decode_64bit = 'value.as_double()' encode_func = 'encode_double' def dump(self, name): o = f'sprintf(buffer, "%g", {name});\n' o += f'out.append(buffer);' return o @register_type(2) class FloatType(TypeInfo): cpp_type = 'float' default_value = '0.0f' decode_32bit = 'value.as_float()' encode_func = 'encode_float' def dump(self, name): o = f'sprintf(buffer, "%g", {name});\n' o += f'out.append(buffer);' return o @register_type(3) class Int64Type(TypeInfo): cpp_type = 'int64_t' default_value = '0' decode_varint = 'value.as_int64()' encode_func = 'encode_int64' def dump(self, name): o = f'sprintf(buffer, "%ll", {name});\n' o += f'out.append(buffer);' return o @register_type(4) class UInt64Type(TypeInfo): cpp_type = 'uint64_t' default_value = '0' decode_varint = 'value.as_uint64()' encode_func = 'encode_uint64' def dump(self, name): o = f'sprintf(buffer, "%ull", {name});\n' o += f'out.append(buffer);' return o @register_type(5) class Int32Type(TypeInfo): cpp_type = 'int32_t' default_value = '0' decode_varint = 'value.as_int32()' encode_func = 'encode_int32' def dump(self, name): o = f'sprintf(buffer, "%d", {name});\n' o += f'out.append(buffer);' return o @register_type(6) class Fixed64Type(TypeInfo): cpp_type = 'uint64_t' default_value = '0' decode_64bit = 'value.as_fixed64()' encode_func = 'encode_fixed64' def dump(self, name): o = f'sprintf(buffer, "%ull", {name});\n' o += f'out.append(buffer);' return o @register_type(7) class Fixed32Type(TypeInfo): cpp_type = 'uint32_t' default_value = '0' decode_32bit = 'value.as_fixed32()' encode_func = 'encode_fixed32' def dump(self, name): o = f'sprintf(buffer, "%u", {name});\n' o += f'out.append(buffer);' return o @register_type(8) class BoolType(TypeInfo): cpp_type = 'bool' default_value = 'false' decode_varint = 'value.as_bool()' encode_func = 'encode_bool' def dump(self, name): o = f'out.append(YESNO({name}));' return o @register_type(9) class StringType(TypeInfo): cpp_type = 'std::string' default_value = '' reference_type = 'std::string &' const_reference_type = 'const std::string &' decode_length = 'value.as_string()' encode_func = 'encode_string' def dump(self, name): o = f'out.append("\'").append({name}).append("\'");' return o @register_type(11) class MessageType(TypeInfo): @property def cpp_type(self): return self._field.type_name[1:] default_value = '' @property def reference_type(self): return f'{self.cpp_type} &' @property def const_reference_type(self): return f'const {self.cpp_type} &' @property def encode_func(self): return f'encode_message<{self.cpp_type}>' @property def decode_length(self): return f'value.as_message<{self.cpp_type}>()' def dump(self, name): o = f'{name}.dump_to(out);' return o @register_type(12) class BytesType(TypeInfo): cpp_type = 'std::string' default_value = '' reference_type = 'std::string &' const_reference_type = 'const std::string &' decode_length = 'value.as_string()' encode_func = 'encode_string' def dump(self, name): o = f'out.append("\'").append({name}).append("\'");' return o @register_type(13) class UInt32Type(TypeInfo): cpp_type = 'uint32_t' default_value = '0' decode_varint = 'value.as_uint32()' encode_func = 'encode_uint32' def dump(self, name): o = f'sprintf(buffer, "%u", {name});\n' o += f'out.append(buffer);' return o @register_type(14) class EnumType(TypeInfo): @property def cpp_type(self): return "Enum" + self._field.type_name[1:] @property def decode_varint(self): return f'value.as_enum<{self.cpp_type}>()' default_value = '' @property def encode_func(self): return f'encode_enum<{self.cpp_type}>' def dump(self, name): o = f'out.append(proto_enum_to_string<{self.cpp_type}>({name}));' return o @register_type(15) class SFixed32Type(TypeInfo): cpp_type = 'int32_t' default_value = '0' decode_32bit = 'value.as_sfixed32()' encode_func = 'encode_sfixed32' def dump(self, name): o = f'sprintf(buffer, "%d", {name});\n' o += f'out.append(buffer);' return o @register_type(16) class SFixed64Type(TypeInfo): cpp_type = 'int64_t' default_value = '0' decode_64bit = 'value.as_sfixed64()' encode_func = 'encode_sfixed64' def dump(self, name): o = f'sprintf(buffer, "%ll", {name});\n' o += f'out.append(buffer);' return o @register_type(17) class SInt32Type(TypeInfo): cpp_type = 'int32_t' default_value = '0' decode_varint = 'value.as_sint32()' encode_func = 'encode_sint32' def dump(self, name): o = f'sprintf(buffer, "%d", {name});\n' o += f'out.append(buffer);' return o @register_type(18) class SInt64Type(TypeInfo): cpp_type = 'int64_t' default_value = '0' decode_varint = 'value.as_sint64()' encode_func = 'encode_sin64' def dump(self): o = f'sprintf(buffer, "%ll", {name});\n' o += f'out.append(buffer);' return o class RepeatedTypeInfo(TypeInfo): def __init__(self, field): super(RepeatedTypeInfo, self).__init__(field) self._ti = TYPE_INFO[field.type](field) @property def cpp_type(self): return f'std::vector<{self._ti.cpp_type}>' @property def reference_type(self): return f'{self.cpp_type} &' @property def const_reference_type(self): return f'const {self.cpp_type} &' @property def decode_varint_content(self) -> str: content = self._ti.decode_varint if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name}.push_back({content}); return true; }}''') @property def decode_length_content(self) -> str: content = self._ti.decode_length if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name}.push_back({content}); return true; }}''') @property def decode_32bit_content(self) -> str: content = self._ti.decode_32bit if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name}.push_back({content}); return true; }}''') @property def decode_64bit_content(self) -> str: content = self._ti.decode_64bit if content is None: return None return dedent(f'''\ case {self.number}: {{ this->{self.field_name}.push_back({content}); return true; }}''') @property def _ti_is_bool(self): # std::vector is specialized for bool, reference does not work return isinstance(self._ti, BoolType) @property def encode_content(self): return f"""\ for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{ buffer.{self._ti.encode_func}({self.number}, it, true); }}""" @property def dump_content(self): o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n' o += f' out.append(" {self.name}: ");\n' o += indent(self._ti.dump('it')) + '\n' o += f' out.append("\\n");\n' o += f'}}\n' return o def build_enum_type(desc): name = "Enum" + desc.name out = f"enum {name} : uint32_t {{\n" for v in desc.value: out += f' {v.name} = {v.number},\n' out += '};\n' cpp = f"template<>\n" cpp += f"const char *proto_enum_to_string<{name}>({name} value) {{\n" cpp += f" switch (value) {{\n" for v in desc.value: cpp += f' case {v.name}: return "{v.name}";\n' cpp += f' default: return "UNKNOWN";\n' cpp += f' }}\n' cpp += f'}}\n' return out, cpp def build_message_type(desc): public_content = [] protected_content = [] decode_varint = [] decode_length = [] decode_32bit = [] decode_64bit = [] encode = [] dump = [] for field in desc.field: if field.label == 3: ti = RepeatedTypeInfo(field) else: ti = TYPE_INFO[field.type](field) protected_content.extend(ti.protected_content) public_content.extend(ti.public_content) encode.append(ti.encode_content) if ti.decode_varint_content: decode_varint.append(ti.decode_varint_content) if ti.decode_length_content: decode_length.append(ti.decode_length_content) if ti.decode_32bit_content: decode_32bit.append(ti.decode_32bit_content) if ti.decode_64bit_content: decode_64bit.append(ti.decode_64bit_content) if ti.dump_content: dump.append(ti.dump_content) cpp = '' if decode_varint: decode_varint.append('default:\n return false;') o = f'bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n' o += ' switch (field_id) {\n' o += indent("\n".join(decode_varint), ' ') + '\n' o += ' }\n' o += '}\n' cpp += o prot = 'bool decode_varint(uint32_t field_id, ProtoVarInt value) override;' protected_content.insert(0, prot) if decode_length: decode_length.append('default:\n return false;') o = f'bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n' o += ' switch (field_id) {\n' o += indent("\n".join(decode_length), ' ') + '\n' o += ' }\n' o += '}\n' cpp += o prot = 'bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;' protected_content.insert(0, prot) if decode_32bit: decode_32bit.append('default:\n return false;') o = f'bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n' o += ' switch (field_id) {\n' o += indent("\n".join(decode_32bit), ' ') + '\n' o += ' }\n' o += '}\n' cpp += o prot = 'bool decode_32bit(uint32_t field_id, Proto32Bit value) override;' protected_content.insert(0, prot) if decode_64bit: decode_64bit.append('default:\n return false;') o = f'bool {desc.name}::decode_64bit(uint32_t field_id, Proto64bit value) {{\n' o += ' switch (field_id) {\n' o += indent("\n".join(decode_64bit), ' ') + '\n' o += ' }\n' o += '}\n' cpp += o prot = 'bool decode_64bit(uint32_t field_id, Proto64bit value) override;' protected_content.insert(0, prot) o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{\n" o += indent('\n'.join(encode)) + '\n' o += '}\n' cpp += o prot = 'void encode(ProtoWriteBuffer buffer) const override;' public_content.append(prot) o = f"void {desc.name}::dump_to(std::string &out) const {{\n" if dump: o += f" char buffer[64];\n" o += f' out.append("{desc.name} {{\\n");\n' o += indent('\n'.join(dump)) + '\n' o += f' out.append("}}");\n' else: o += f' out.append("{desc.name} {{}}");\n' o += '}\n' cpp += o prot = 'void dump_to(std::string &out) const override;' public_content.append(prot) out = f"class {desc.name} : public ProtoMessage {{\n" out += ' public:\n' out += indent('\n'.join(public_content)) + '\n' out += ' protected:\n' out += indent('\n'.join(protected_content)) + '\n' out += "};\n" return out, cpp file = d.file[0] content = '''\ #pragma once #include "proto.h" namespace esphome { namespace api { ''' cpp = '''\ #include "api_pb2.h" #include "esphome/core/log.h" namespace esphome { namespace api { ''' for enum in file.enum_type: s, c = build_enum_type(enum) content += s cpp += c mt = file.message_type for m in mt: s, c = build_message_type(m) content += s cpp += c content += '''\ } // namespace api } // namespace esphome ''' cpp += '''\ } // namespace api } // namespace esphome ''' with open(root / 'api_pb2.h', 'w') as f: f.write(content) with open(root / 'api_pb2.cpp', 'w') as f: f.write(cpp) SOURCE_BOTH = 0 SOURCE_SERVER = 1 SOURCE_CLIENT = 2 RECEIVE_CASES = {} class_name = 'APIServerConnectionBase' ifdefs = {} def get_opt(desc, opt, default=None): if not desc.options.HasExtension(opt): return default return desc.options.Extensions[opt] def build_service_message_type(mt): snake = camel_to_snake(mt.name) id_ = get_opt(mt, pb.id) if id_ is None: return None source = get_opt(mt, pb.source, 0) ifdef = get_opt(mt, pb.ifdef) log = get_opt(mt, pb.log, True) nodelay = get_opt(mt, pb.no_delay, False) hout = '' cout = '' if ifdef is not None: ifdefs[str(mt.name)] = ifdef hout += f'#ifdef {ifdef}\n' cout += f'#ifdef {ifdef}\n' if source in (SOURCE_BOTH, SOURCE_SERVER): # Generate send func = f'send_{snake}' hout += f'bool {func}(const {mt.name} &msg);\n' cout += f'bool {class_name}::{func}(const {mt.name} &msg) {{\n' if log: cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' cout += f' this->set_nodelay({str(nodelay).lower()});\n' cout += f' return this->send_message_<{mt.name}>(msg, {id_});\n' cout += f'}}\n' if source in (SOURCE_BOTH, SOURCE_CLIENT): # Generate receive func = f'on_{snake}' hout += f'virtual void {func}(const {mt.name} &value){{}};\n' case = '' if ifdef is not None: case += f'#ifdef {ifdef}\n' case += f'{mt.name} msg;\n' case += f'msg.decode(msg_data, msg_size);\n' if log: case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n' case += f'this->{func}(msg);\n' if ifdef is not None: case += f'#endif\n' case += 'break;' RECEIVE_CASES[id_] = case if ifdef is not None: hout += f'#endif\n' cout += f'#endif\n' return hout, cout hpp = '''\ #pragma once #include "api_pb2.h" #include "esphome/core/defines.h" namespace esphome { namespace api { ''' cpp = '''\ #include "api_pb2_service.h" #include "esphome/core/log.h" namespace esphome { namespace api { static const char *TAG = "api.service"; ''' hpp += f'class {class_name} : public ProtoService {{\n' hpp += ' public:\n' for mt in file.message_type: obj = build_service_message_type(mt) if obj is None: continue hout, cout = obj hpp += indent(hout) + '\n' cpp += cout cases = list(RECEIVE_CASES.items()) cases.sort() hpp += ' protected:\n' hpp += f' bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n' out = f'bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n' out += f' switch(msg_type) {{\n' for i, case in cases: c = f'case {i}: {{\n' c += indent(case) + '\n' c += f'}}' out += indent(c, ' ') + '\n' out += ' default: \n' out += ' return false;\n' out += ' }\n' out += ' return true;\n' out += '}\n' cpp += out hpp += '};\n' serv = file.service[0] class_name = 'APIServerConnection' hpp += '\n' hpp += f'class {class_name} : public {class_name}Base {{\n' hpp += ' public:\n' hpp_protected = '' cpp += '\n' m = serv.method[0] for m in serv.method: func = m.name inp = m.input_type[1:] ret = m.output_type[1:] is_void = ret == 'void' snake = camel_to_snake(inp) on_func = f'on_{snake}' needs_conn = get_opt(m, pb.needs_setup_connection, True) needs_auth = get_opt(m, pb.needs_authentication, True) ifdef = ifdefs.get(inp, None) if ifdef is not None: hpp += f'#ifdef {ifdef}\n' hpp_protected += f'#ifdef {ifdef}\n' cpp += f'#ifdef {ifdef}\n' hpp_protected += f' void {on_func}(const {inp} &msg) override;\n' hpp += f' virtual {ret} {func}(const {inp} &msg) = 0;\n' cpp += f'void {class_name}::{on_func}(const {inp} &msg) {{\n' body = '' if needs_conn: body += 'if (!this->is_connection_setup()) {\n' body += ' this->on_no_setup_connection();\n' body += ' return;\n' body += '}\n' if needs_auth: body += 'if (!this->is_authenticated()) {\n' body += ' this->on_unauthenticated_access();\n' body += ' return;\n' body += '}\n' if is_void: body += f'this->{func}(msg);\n' else: body += f'{ret} ret = this->{func}(msg);\n' ret_snake = camel_to_snake(ret) body += f'if (!this->send_{ret_snake}(ret)) {{\n' body += f' this->on_fatal_error();\n' body += '}\n' cpp += indent(body) + '\n' + '}\n' if ifdef is not None: hpp += f'#endif\n' hpp_protected += f'#endif\n' cpp += f'#endif\n' hpp += ' protected:\n' hpp += hpp_protected hpp += '};\n' hpp += '''\ } // namespace api } // namespace esphome ''' cpp += '''\ } // namespace api } // namespace esphome ''' with open(root / 'api_pb2_service.h', 'w') as f: f.write(hpp) with open(root / 'api_pb2_service.cpp', 'w') as f: f.write(cpp) prot.unlink()