esphome/script/api_protobuf/api_protobuf.py

868 lines
21 KiB
Python

"""Python 3 script to automatically generate C++ classes for ESPHome's native API.
It's pretty crappy spaghetti code, but it works.
"""
import re
from pathlib import Path
from textwrap import dedent
from subprocess import call
# Generate with
# protoc --python_out=script/api_protobuf -I esphome/components/api/ api_options.proto
import api_options_pb2 as pb
import google.protobuf.descriptor_pb2 as descriptor
cwd = Path(__file__).parent
root = cwd.parent.parent / 'esphome' / 'components' / 'api'
prot = cwd / 'api.protoc'
call(['protoc', '-o', prot, '-I', root, 'api.proto'])
content = prot.read_bytes()
d = descriptor.FileDescriptorSet.FromString(content)
def indent_list(text, padding=u' '):
return [padding + line for line in text.splitlines()]
def indent(text, padding=u' '):
return u'\n'.join(indent_list(text, padding))
def camel_to_snake(name):
# https://stackoverflow.com/a/1176023
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
class TypeInfo():
def __init__(self, field):
self._field = field
@property
def default_value(self):
return ''
@property
def name(self):
return self._field.name
@property
def arg_name(self):
return self.name
@property
def field_name(self):
return self.name
@property
def number(self):
return self._field.number
@property
def repeated(self):
return self._field.label == 3
@property
def cpp_type(self):
raise NotImplementedError
@property
def reference_type(self):
return f'{self.cpp_type} '
@property
def const_reference_type(self):
return f'{self.cpp_type} '
@property
def public_content(self) -> str:
return [self.class_member]
@property
def protected_content(self) -> str:
return []
@property
def class_member(self) -> str:
return f'{self.cpp_type} {self.field_name}{{{self.default_value}}}; // NOLINT'
@property
def decode_varint_content(self) -> str:
content = self.decode_varint
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}''')
decode_varint = None
@property
def decode_length_content(self) -> str:
content = self.decode_length
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}''')
decode_length = None
@property
def decode_32bit_content(self) -> str:
content = self.decode_32bit
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}''')
decode_32bit = None
@property
def decode_64bit_content(self) -> str:
content = self.decode_64bit
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name} = {content};
return true;
}}''')
decode_64bit = None
@property
def encode_content(self):
return f'buffer.{self.encode_func}({self.number}, this->{self.field_name});'
encode_func = None
@property
def dump_content(self):
o = f'out.append(" {self.name}: ");\n'
o += self.dump(f'this->{self.field_name}') + '\n'
o += f'out.append("\\n");\n'
return o
dump = None
TYPE_INFO = {}
def register_type(name):
def func(value):
TYPE_INFO[name] = value
return value
return func
@register_type(1)
class DoubleType(TypeInfo):
cpp_type = 'double'
default_value = '0.0'
decode_64bit = 'value.as_double()'
encode_func = 'encode_double'
def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(2)
class FloatType(TypeInfo):
cpp_type = 'float'
default_value = '0.0f'
decode_32bit = 'value.as_float()'
encode_func = 'encode_float'
def dump(self, name):
o = f'sprintf(buffer, "%g", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(3)
class Int64Type(TypeInfo):
cpp_type = 'int64_t'
default_value = '0'
decode_varint = 'value.as_int64()'
encode_func = 'encode_int64'
def dump(self, name):
o = f'sprintf(buffer, "%ll", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(4)
class UInt64Type(TypeInfo):
cpp_type = 'uint64_t'
default_value = '0'
decode_varint = 'value.as_uint64()'
encode_func = 'encode_uint64'
def dump(self, name):
o = f'sprintf(buffer, "%ull", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(5)
class Int32Type(TypeInfo):
cpp_type = 'int32_t'
default_value = '0'
decode_varint = 'value.as_int32()'
encode_func = 'encode_int32'
def dump(self, name):
o = f'sprintf(buffer, "%d", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(6)
class Fixed64Type(TypeInfo):
cpp_type = 'uint64_t'
default_value = '0'
decode_64bit = 'value.as_fixed64()'
encode_func = 'encode_fixed64'
def dump(self, name):
o = f'sprintf(buffer, "%ull", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(7)
class Fixed32Type(TypeInfo):
cpp_type = 'uint32_t'
default_value = '0'
decode_32bit = 'value.as_fixed32()'
encode_func = 'encode_fixed32'
def dump(self, name):
o = f'sprintf(buffer, "%u", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(8)
class BoolType(TypeInfo):
cpp_type = 'bool'
default_value = 'false'
decode_varint = 'value.as_bool()'
encode_func = 'encode_bool'
def dump(self, name):
o = f'out.append(YESNO({name}));'
return o
@register_type(9)
class StringType(TypeInfo):
cpp_type = 'std::string'
default_value = ''
reference_type = 'std::string &'
const_reference_type = 'const std::string &'
decode_length = 'value.as_string()'
encode_func = 'encode_string'
def dump(self, name):
o = f'out.append("\'").append({name}).append("\'");'
return o
@register_type(11)
class MessageType(TypeInfo):
@property
def cpp_type(self):
return self._field.type_name[1:]
default_value = ''
@property
def reference_type(self):
return f'{self.cpp_type} &'
@property
def const_reference_type(self):
return f'const {self.cpp_type} &'
@property
def encode_func(self):
return f'encode_message<{self.cpp_type}>'
@property
def decode_length(self):
return f'value.as_message<{self.cpp_type}>()'
def dump(self, name):
o = f'{name}.dump_to(out);'
return o
@register_type(12)
class BytesType(TypeInfo):
cpp_type = 'std::string'
default_value = ''
reference_type = 'std::string &'
const_reference_type = 'const std::string &'
decode_length = 'value.as_string()'
encode_func = 'encode_string'
def dump(self, name):
o = f'out.append("\'").append({name}).append("\'");'
return o
@register_type(13)
class UInt32Type(TypeInfo):
cpp_type = 'uint32_t'
default_value = '0'
decode_varint = 'value.as_uint32()'
encode_func = 'encode_uint32'
def dump(self, name):
o = f'sprintf(buffer, "%u", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(14)
class EnumType(TypeInfo):
@property
def cpp_type(self):
return "Enum" + self._field.type_name[1:]
@property
def decode_varint(self):
return f'value.as_enum<{self.cpp_type}>()'
default_value = ''
@property
def encode_func(self):
return f'encode_enum<{self.cpp_type}>'
def dump(self, name):
o = f'out.append(proto_enum_to_string<{self.cpp_type}>({name}));'
return o
@register_type(15)
class SFixed32Type(TypeInfo):
cpp_type = 'int32_t'
default_value = '0'
decode_32bit = 'value.as_sfixed32()'
encode_func = 'encode_sfixed32'
def dump(self, name):
o = f'sprintf(buffer, "%d", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(16)
class SFixed64Type(TypeInfo):
cpp_type = 'int64_t'
default_value = '0'
decode_64bit = 'value.as_sfixed64()'
encode_func = 'encode_sfixed64'
def dump(self, name):
o = f'sprintf(buffer, "%ll", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(17)
class SInt32Type(TypeInfo):
cpp_type = 'int32_t'
default_value = '0'
decode_varint = 'value.as_sint32()'
encode_func = 'encode_sint32'
def dump(self, name):
o = f'sprintf(buffer, "%d", {name});\n'
o += f'out.append(buffer);'
return o
@register_type(18)
class SInt64Type(TypeInfo):
cpp_type = 'int64_t'
default_value = '0'
decode_varint = 'value.as_sint64()'
encode_func = 'encode_sin64'
def dump(self):
o = f'sprintf(buffer, "%ll", {name});\n'
o += f'out.append(buffer);'
return o
class RepeatedTypeInfo(TypeInfo):
def __init__(self, field):
super(RepeatedTypeInfo, self).__init__(field)
self._ti = TYPE_INFO[field.type](field)
@property
def cpp_type(self):
return f'std::vector<{self._ti.cpp_type}>'
@property
def reference_type(self):
return f'{self.cpp_type} &'
@property
def const_reference_type(self):
return f'const {self.cpp_type} &'
@property
def decode_varint_content(self) -> str:
content = self._ti.decode_varint
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}''')
@property
def decode_length_content(self) -> str:
content = self._ti.decode_length
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}''')
@property
def decode_32bit_content(self) -> str:
content = self._ti.decode_32bit
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}''')
@property
def decode_64bit_content(self) -> str:
content = self._ti.decode_64bit
if content is None:
return None
return dedent(f'''\
case {self.number}: {{
this->{self.field_name}.push_back({content});
return true;
}}''')
@property
def _ti_is_bool(self):
# std::vector is specialized for bool, reference does not work
return isinstance(self._ti, BoolType)
@property
def encode_content(self):
return f"""\
for (auto {'' if self._ti_is_bool else '&'}it : this->{self.field_name}) {{
buffer.{self._ti.encode_func}({self.number}, it, true);
}}"""
@property
def dump_content(self):
o = f'for (const auto {"" if self._ti_is_bool else "&"}it : this->{self.field_name}) {{\n'
o += f' out.append(" {self.name}: ");\n'
o += indent(self._ti.dump('it')) + '\n'
o += f' out.append("\\n");\n'
o += f'}}\n'
return o
def build_enum_type(desc):
name = "Enum" + desc.name
out = f"enum {name} : uint32_t {{\n"
for v in desc.value:
out += f' {v.name} = {v.number},\n'
out += '};\n'
cpp = f"template<>\n"
cpp += f"const char *proto_enum_to_string<{name}>({name} value) {{\n"
cpp += f" switch (value) {{\n"
for v in desc.value:
cpp += f' case {v.name}: return "{v.name}";\n'
cpp += f' default: return "UNKNOWN";\n'
cpp += f' }}\n'
cpp += f'}}\n'
return out, cpp
def build_message_type(desc):
public_content = []
protected_content = []
decode_varint = []
decode_length = []
decode_32bit = []
decode_64bit = []
encode = []
dump = []
for field in desc.field:
if field.label == 3:
ti = RepeatedTypeInfo(field)
else:
ti = TYPE_INFO[field.type](field)
protected_content.extend(ti.protected_content)
public_content.extend(ti.public_content)
encode.append(ti.encode_content)
if ti.decode_varint_content:
decode_varint.append(ti.decode_varint_content)
if ti.decode_length_content:
decode_length.append(ti.decode_length_content)
if ti.decode_32bit_content:
decode_32bit.append(ti.decode_32bit_content)
if ti.decode_64bit_content:
decode_64bit.append(ti.decode_64bit_content)
if ti.dump_content:
dump.append(ti.dump_content)
cpp = ''
if decode_varint:
decode_varint.append('default:\n return false;')
o = f'bool {desc.name}::decode_varint(uint32_t field_id, ProtoVarInt value) {{\n'
o += ' switch (field_id) {\n'
o += indent("\n".join(decode_varint), ' ') + '\n'
o += ' }\n'
o += '}\n'
cpp += o
prot = 'bool decode_varint(uint32_t field_id, ProtoVarInt value) override;'
protected_content.insert(0, prot)
if decode_length:
decode_length.append('default:\n return false;')
o = f'bool {desc.name}::decode_length(uint32_t field_id, ProtoLengthDelimited value) {{\n'
o += ' switch (field_id) {\n'
o += indent("\n".join(decode_length), ' ') + '\n'
o += ' }\n'
o += '}\n'
cpp += o
prot = 'bool decode_length(uint32_t field_id, ProtoLengthDelimited value) override;'
protected_content.insert(0, prot)
if decode_32bit:
decode_32bit.append('default:\n return false;')
o = f'bool {desc.name}::decode_32bit(uint32_t field_id, Proto32Bit value) {{\n'
o += ' switch (field_id) {\n'
o += indent("\n".join(decode_32bit), ' ') + '\n'
o += ' }\n'
o += '}\n'
cpp += o
prot = 'bool decode_32bit(uint32_t field_id, Proto32Bit value) override;'
protected_content.insert(0, prot)
if decode_64bit:
decode_64bit.append('default:\n return false;')
o = f'bool {desc.name}::decode_64bit(uint32_t field_id, Proto64bit value) {{\n'
o += ' switch (field_id) {\n'
o += indent("\n".join(decode_64bit), ' ') + '\n'
o += ' }\n'
o += '}\n'
cpp += o
prot = 'bool decode_64bit(uint32_t field_id, Proto64bit value) override;'
protected_content.insert(0, prot)
o = f"void {desc.name}::encode(ProtoWriteBuffer buffer) const {{\n"
o += indent('\n'.join(encode)) + '\n'
o += '}\n'
cpp += o
prot = 'void encode(ProtoWriteBuffer buffer) const override;'
public_content.append(prot)
o = f"void {desc.name}::dump_to(std::string &out) const {{\n"
if dump:
o += f" char buffer[64];\n"
o += f' out.append("{desc.name} {{\\n");\n'
o += indent('\n'.join(dump)) + '\n'
o += f' out.append("}}");\n'
else:
o += f' out.append("{desc.name} {{}}");\n'
o += '}\n'
cpp += o
prot = 'void dump_to(std::string &out) const override;'
public_content.append(prot)
out = f"class {desc.name} : public ProtoMessage {{\n"
out += ' public:\n'
out += indent('\n'.join(public_content)) + '\n'
out += ' protected:\n'
out += indent('\n'.join(protected_content)) + '\n'
out += "};\n"
return out, cpp
file = d.file[0]
content = '''\
#pragma once
#include "proto.h"
namespace esphome {
namespace api {
'''
cpp = '''\
#include "api_pb2.h"
#include "esphome/core/log.h"
namespace esphome {
namespace api {
'''
for enum in file.enum_type:
s, c = build_enum_type(enum)
content += s
cpp += c
mt = file.message_type
for m in mt:
s, c = build_message_type(m)
content += s
cpp += c
content += '''\
} // namespace api
} // namespace esphome
'''
cpp += '''\
} // namespace api
} // namespace esphome
'''
with open(root / 'api_pb2.h', 'w') as f:
f.write(content)
with open(root / 'api_pb2.cpp', 'w') as f:
f.write(cpp)
SOURCE_BOTH = 0
SOURCE_SERVER = 1
SOURCE_CLIENT = 2
RECEIVE_CASES = {}
class_name = 'APIServerConnectionBase'
ifdefs = {}
def get_opt(desc, opt, default=None):
if not desc.options.HasExtension(opt):
return default
return desc.options.Extensions[opt]
def build_service_message_type(mt):
snake = camel_to_snake(mt.name)
id_ = get_opt(mt, pb.id)
if id_ is None:
return None
source = get_opt(mt, pb.source, 0)
ifdef = get_opt(mt, pb.ifdef)
log = get_opt(mt, pb.log, True)
nodelay = get_opt(mt, pb.no_delay, False)
hout = ''
cout = ''
if ifdef is not None:
ifdefs[str(mt.name)] = ifdef
hout += f'#ifdef {ifdef}\n'
cout += f'#ifdef {ifdef}\n'
if source in (SOURCE_BOTH, SOURCE_SERVER):
# Generate send
func = f'send_{snake}'
hout += f'bool {func}(const {mt.name} &msg);\n'
cout += f'bool {class_name}::{func}(const {mt.name} &msg) {{\n'
if log:
cout += f' ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
cout += f' this->set_nodelay({str(nodelay).lower()});\n'
cout += f' return this->send_message_<{mt.name}>(msg, {id_});\n'
cout += f'}}\n'
if source in (SOURCE_BOTH, SOURCE_CLIENT):
# Generate receive
func = f'on_{snake}'
hout += f'virtual void {func}(const {mt.name} &value){{}};\n'
case = ''
if ifdef is not None:
case += f'#ifdef {ifdef}\n'
case += f'{mt.name} msg;\n'
case += f'msg.decode(msg_data, msg_size);\n'
if log:
case += f'ESP_LOGVV(TAG, "{func}: %s", msg.dump().c_str());\n'
case += f'this->{func}(msg);\n'
if ifdef is not None:
case += f'#endif\n'
case += 'break;'
RECEIVE_CASES[id_] = case
if ifdef is not None:
hout += f'#endif\n'
cout += f'#endif\n'
return hout, cout
hpp = '''\
#pragma once
#include "api_pb2.h"
#include "esphome/core/defines.h"
namespace esphome {
namespace api {
'''
cpp = '''\
#include "api_pb2_service.h"
#include "esphome/core/log.h"
namespace esphome {
namespace api {
static const char *TAG = "api.service";
'''
hpp += f'class {class_name} : public ProtoService {{\n'
hpp += ' public:\n'
for mt in file.message_type:
obj = build_service_message_type(mt)
if obj is None:
continue
hout, cout = obj
hpp += indent(hout) + '\n'
cpp += cout
cases = list(RECEIVE_CASES.items())
cases.sort()
hpp += ' protected:\n'
hpp += f' bool read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) override;\n'
out = f'bool {class_name}::read_message(uint32_t msg_size, uint32_t msg_type, uint8_t *msg_data) {{\n'
out += f' switch(msg_type) {{\n'
for i, case in cases:
c = f'case {i}: {{\n'
c += indent(case) + '\n'
c += f'}}'
out += indent(c, ' ') + '\n'
out += ' default: \n'
out += ' return false;\n'
out += ' }\n'
out += ' return true;\n'
out += '}\n'
cpp += out
hpp += '};\n'
serv = file.service[0]
class_name = 'APIServerConnection'
hpp += '\n'
hpp += f'class {class_name} : public {class_name}Base {{\n'
hpp += ' public:\n'
hpp_protected = ''
cpp += '\n'
m = serv.method[0]
for m in serv.method:
func = m.name
inp = m.input_type[1:]
ret = m.output_type[1:]
is_void = ret == 'void'
snake = camel_to_snake(inp)
on_func = f'on_{snake}'
needs_conn = get_opt(m, pb.needs_setup_connection, True)
needs_auth = get_opt(m, pb.needs_authentication, True)
ifdef = ifdefs.get(inp, None)
if ifdef is not None:
hpp += f'#ifdef {ifdef}\n'
hpp_protected += f'#ifdef {ifdef}\n'
cpp += f'#ifdef {ifdef}\n'
hpp_protected += f' void {on_func}(const {inp} &msg) override;\n'
hpp += f' virtual {ret} {func}(const {inp} &msg) = 0;\n'
cpp += f'void {class_name}::{on_func}(const {inp} &msg) {{\n'
body = ''
if needs_conn:
body += 'if (!this->is_connection_setup()) {\n'
body += ' this->on_no_setup_connection();\n'
body += ' return;\n'
body += '}\n'
if needs_auth:
body += 'if (!this->is_authenticated()) {\n'
body += ' this->on_unauthenticated_access();\n'
body += ' return;\n'
body += '}\n'
if is_void:
body += f'this->{func}(msg);\n'
else:
body += f'{ret} ret = this->{func}(msg);\n'
ret_snake = camel_to_snake(ret)
body += f'if (!this->send_{ret_snake}(ret)) {{\n'
body += f' this->on_fatal_error();\n'
body += '}\n'
cpp += indent(body) + '\n' + '}\n'
if ifdef is not None:
hpp += f'#endif\n'
hpp_protected += f'#endif\n'
cpp += f'#endif\n'
hpp += ' protected:\n'
hpp += hpp_protected
hpp += '};\n'
hpp += '''\
} // namespace api
} // namespace esphome
'''
cpp += '''\
} // namespace api
} // namespace esphome
'''
with open(root / 'api_pb2_service.h', 'w') as f:
f.write(hpp)
with open(root / 'api_pb2_service.cpp', 'w') as f:
f.write(cpp)
prot.unlink()