diff options
Diffstat (limited to 'generator')
-rw-r--r-- | generator/nanopb_generator.py | 97 |
1 files changed, 85 insertions, 12 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py index bd365dfd..61f4d7b3 100644 --- a/generator/nanopb_generator.py +++ b/generator/nanopb_generator.py @@ -275,8 +275,53 @@ class Field: return max(self.tag, self.max_size, self.max_count) +class ExtensionRange(Field): + def __init__(self, struct_name, desc, field_options): + '''desc is ExtensionRange''' + self.tag = desc.start + self.struct_name = struct_name + self.name = 'extensions' + self.pbtype = 'EXTENSION' + self.rules = 'OPTIONAL' + self.allocation = 'CALLBACK' + self.ctype = 'pb_extension_t' + self.array_decl = '' + self.default = None + self.max_size = 0 + self.max_count = 0 + + def __str__(self): + return ' pb_extension_t *extensions;' + + def types(self): + return None + + def tags(self): + return '' - +class ExtensionField(Field): + def __init__(self, struct_name, desc, field_options): + self.fullname = struct_name + desc.name + self.extendee_name = names_from_type_name(desc.extendee) + Field.__init__(self, self.fullname + 'struct', desc, field_options) + + def extension_decl(self): + '''Declaration of the extension type in the .pb.h file''' + return 'extern const pb_extension_type_t %s;' % self.fullname + + def extension_def(self): + '''Definition of the extension type in the .pb.c file''' + result = 'typedef struct {\n' + result += str(self) + result += '} %s;\n' % self.struct_name + result += ('static const pb_field_t %s_field = %s;\n' % + (self.fullname, self.pb_field_t(None))) + result += 'const pb_extension_type_t %s = {\n' % self.fullname + result += ' NULL,\n' + result += ' NULL,\n' + result += ' &%s_field\n' % self.fullname + result += '};\n' + return result # --------------------------------------------------------------------------- @@ -294,6 +339,11 @@ class Message: if field_options.type != nanopb_pb2.FT_IGNORE: self.fields.append(Field(self.name, f, field_options)) + if len(desc.extension_range) > 0: + field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') + if field_options.type != nanopb_pb2.FT_IGNORE: + self.fields.append(ExtensionRange(self.name, desc.extension_range[0], field_options)) + self.packed = message_options.packed_struct self.ordered_fields = self.fields[:] self.ordered_fields.sort() @@ -358,9 +408,6 @@ class Message: - - - # --------------------------------------------------------------------------- # Processing of entire .proto files # --------------------------------------------------------------------------- @@ -380,11 +427,23 @@ def iterate_messages(desc, names = Names()): for x in iterate_messages(submsg, sub_names): yield x +def iterate_extensions(desc, names = Names()): + '''Recursively find all extensions. + For each, yield name, FieldDescriptorProto. + ''' + for extension in desc.extension: + yield names, extension + + for subname, subdesc in iterate_messages(desc, names): + for extension in subdesc.extension: + yield subname, extension + def parse_file(fdesc, file_options): - '''Takes a FileDescriptorProto and returns tuple (enum, messages).''' + '''Takes a FileDescriptorProto and returns tuple (enums, messages, extensions).''' enums = [] messages = [] + extensions = [] if fdesc.package: base_name = Names(fdesc.package.split('.')) @@ -402,6 +461,10 @@ def parse_file(fdesc, file_options): enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) enums.append(Enum(names, enum, enum_options)) + for names, extension in iterate_extensions(fdesc, base_name): + field_options = get_nanopb_suboptions(extension, file_options, names) + extensions.append(ExtensionField(names, extension, field_options)) + # Fix field default values where enum short names are used. for enum in enums: if not enum.options.long_names: @@ -411,7 +474,7 @@ def parse_file(fdesc, file_options): idx = enum.value_longnames.index(field.default) field.default = enum.values[idx][0] - return enums, messages + return enums, messages, extensions def toposort2(data): '''Topological sort. @@ -454,7 +517,7 @@ def make_identifier(headername): result += '_' return result -def generate_header(dependencies, headername, enums, messages, options): +def generate_header(dependencies, headername, enums, messages, extensions, options): '''Generate content for a header file. Generates strings, which should be concatenated and stored to file. ''' @@ -489,6 +552,12 @@ def generate_header(dependencies, headername, enums, messages, options): for msg in sort_dependencies(messages): yield msg.types() yield str(msg) + '\n\n' + + if extensions: + yield '/* Extensions */\n' + for extension in extensions: + yield extension.extension_decl() + yield '\n' yield '/* Default values for struct fields */\n' for msg in messages: @@ -512,7 +581,7 @@ def generate_header(dependencies, headername, enums, messages, options): # End of header yield '\n#endif\n' -def generate_source(headername, enums, messages): +def generate_source(headername, enums, messages, extensions): '''Generate content for a source file.''' yield '/* Automatically generated nanopb constant definitions */\n' @@ -527,7 +596,11 @@ def generate_source(headername, enums, messages): for msg in messages: yield msg.fields_definition() + '\n\n' + + for ext in extensions: + yield ext.extension_def() + '\n' + # Add checks for numeric limits if messages: count_required_fields = lambda m: len([f for f in msg.fields if f.rules == 'REQUIRED']) largest_msg = max(messages, key = count_required_fields) @@ -539,7 +612,6 @@ def generate_source(headername, enums, messages): yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count yield '#endif\n' - # Add checks for numeric limits worst = 0 worst_field = '' checks = [] @@ -724,7 +796,7 @@ def process(filenames, options): # Parse the file file_options = get_nanopb_suboptions(fdesc.file[0], toplevel_options, Names([filename])) - enums, messages = parse_file(fdesc.file[0], file_options) + enums, messages, extensions = parse_file(fdesc.file[0], file_options) noext = os.path.splitext(filename)[0] headername = noext + '.' + options.extension + '.h' @@ -740,11 +812,12 @@ def process(filenames, options): dependencies = [d for d in fdesc.file[0].dependency if d not in excludes] header = open(headername, 'w') - for part in generate_header(dependencies, headerbasename, enums, messages, options): + for part in generate_header(dependencies, headerbasename, enums, + messages, extensions, options): header.write(part) source = open(sourcename, 'w') - for part in generate_source(headerbasename, enums, messages): + for part in generate_source(headerbasename, enums, messages, extensions): source.write(part) return True |