diff options
author | Petteri Aimonen <jpa@git.mail.kapsi.fi> | 2015-09-12 14:46:00 +0300 |
---|---|---|
committer | Petteri Aimonen <jpa@git.mail.kapsi.fi> | 2015-09-12 14:46:00 +0300 |
commit | 35dff3367452f89a1d8d483d0f8f601d89d78937 (patch) | |
tree | a5f8975be6db0bc734c19b38f253c42c9c9f7aa9 | |
parent | 936cfdc675c2dc3580c2459e8b1773a1d0bf9a8b (diff) |
Refactor the generator logic into a ProtoFile class.
In preparation for multi-file support in generator.
No functional changes yet.
-rwxr-xr-x | generator/nanopb_generator.py | 582 |
1 files changed, 297 insertions, 285 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py index 2b1d63e5..6d06c2cd 100755 --- a/generator/nanopb_generator.py +++ b/generator/nanopb_generator.py @@ -471,7 +471,7 @@ class Field: return max(self.tag, self.max_size, self.max_count) - def encoded_size(self, allmsgs): + def encoded_size(self, dependencies): '''Return the maximum size that this field can take when encoded, including the field tag. If the size cannot be determined, returns None.''' @@ -480,15 +480,14 @@ class Field: return None if self.pbtype == 'MESSAGE': - for msg in allmsgs: - if msg.name == self.submsgname: - encsize = msg.encoded_size(allmsgs) - if encsize is None: - return None # Submessage size is indeterminate - - # Include submessage length prefix - encsize += varint_max_size(encsize.upperlimit()) - break + if str(self.submsgname) in dependencies: + submsg = dependencies[str(self.submsgname)] + encsize = submsg.encoded_size(dependencies) + if encsize is None: + return None # Submessage size is indeterminate + + # Include submessage length prefix + encsize += varint_max_size(encsize.upperlimit()) else: # Submessage cannot be found, this currently occurs when # the submessage type is defined in a different file. @@ -831,7 +830,6 @@ class Message: # Processing of entire .proto files # --------------------------------------------------------------------------- - def iterate_messages(desc, names = Names()): '''Recursively find all messages. For each, yield name, DescriptorProto.''' if hasattr(desc, 'message_type'): @@ -857,57 +855,6 @@ def iterate_extensions(desc, names = Names()): for extension in subdesc.extension: yield subname, extension -def parse_file(fdesc, file_options): - '''Takes a FileDescriptorProto and returns tuple (enums, messages, extensions).''' - - enums = [] - messages = [] - extensions = [] - - if fdesc.package: - base_name = Names(fdesc.package.split('.')) - else: - base_name = Names() - - for enum in fdesc.enum_type: - enum_options = get_nanopb_suboptions(enum, file_options, base_name + enum.name) - enums.append(Enum(base_name, enum, enum_options)) - - for names, message in iterate_messages(fdesc, base_name): - message_options = get_nanopb_suboptions(message, file_options, names) - - if message_options.skip_message: - continue - - messages.append(Message(names, message, message_options)) - for enum in message.enum_type: - enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) - enums.append(Enum(names, enum, enum_options)) - - for names, extension in iterate_extensions(fdesc, base_name): - field_options = get_nanopb_suboptions(extension, file_options, names + extension.name) - if field_options.type != nanopb_pb2.FT_IGNORE: - extensions.append(ExtensionField(names, extension, field_options)) - - # Fix field default values where enum short names are used. - for enum in enums: - if not enum.options.long_names: - for message in messages: - for field in message.fields: - if field.default in enum.value_longnames: - idx = enum.value_longnames.index(field.default) - field.default = enum.values[idx][0] - - # Fix field data types where enums have negative values. - for enum in enums: - if not enum.has_negative(): - for message in messages: - for field in message.fields: - if field.pbtype == 'ENUM' and field.ctype == enum.names: - field.pbtype = 'UENUM' - - return enums, messages, extensions - def toposort2(data): '''Topological sort. From http://code.activestate.com/recipes/577413-topological-sort/ @@ -949,231 +896,299 @@ def make_identifier(headername): result += '_' return result -def generate_header(dependencies, headername, enums, messages, extensions, options): - '''Generate content for a header file. - Generates strings, which should be concatenated and stored to file. - ''' - - yield '/* Automatically generated nanopb header */\n' - if options.notimestamp: - yield '/* Generated by %s */\n\n' % (nanopb_version) - else: - yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) +class ProtoFile: + def __init__(self, fdesc, file_options): + '''Takes a FileDescriptorProto and parses it.''' + self.fdesc = fdesc + self.file_options = file_options + self.dependencies = {} + self.parse() + + # Some of types used in this file probably come from the file itself. + # Thus it has implicit dependency on itself. + self.add_dependency(self) + + def parse(self): + self.enums = [] + self.messages = [] + self.extensions = [] + + if self.fdesc.package: + base_name = Names(self.fdesc.package.split('.')) + else: + base_name = Names() - symbol = make_identifier(headername) - yield '#ifndef PB_%s_INCLUDED\n' % symbol - yield '#define PB_%s_INCLUDED\n' % symbol - try: - yield options.libformat % ('pb.h') - except TypeError: - # no %s specified - use whatever was passed in as options.libformat - yield options.libformat - yield '\n' + for enum in self.fdesc.enum_type: + enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name) + self.enums.append(Enum(base_name, enum, enum_options)) + + for names, message in iterate_messages(self.fdesc, base_name): + message_options = get_nanopb_suboptions(message, self.file_options, names) + + if message_options.skip_message: + continue + + self.messages.append(Message(names, message, message_options)) + for enum in message.enum_type: + enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name) + self.enums.append(Enum(names, enum, enum_options)) + + for names, extension in iterate_extensions(self.fdesc, base_name): + field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name) + if field_options.type != nanopb_pb2.FT_IGNORE: + self.extensions.append(ExtensionField(names, extension, field_options)) - for dependency in dependencies: - noext = os.path.splitext(dependency)[0] - yield options.genformat % (noext + options.extension + '.h') + def add_dependency(self, other): + for enum in other.enums: + self.dependencies[str(enum.names)] = enum + + for msg in other.messages: + self.dependencies[str(msg.name)] = msg + + # Fix field default values where enum short names are used. + for enum in other.enums: + if not enum.options.long_names: + for message in self.messages: + for field in message.fields: + if field.default in enum.value_longnames: + idx = enum.value_longnames.index(field.default) + field.default = enum.values[idx][0] + + # Fix field data types where enums have negative values. + for enum in other.enums: + if not enum.has_negative(): + for message in self.messages: + for field in message.fields: + if field.pbtype == 'ENUM' and field.ctype == enum.names: + field.pbtype = 'UENUM' + + def generate_header(self, includes, headername, options): + '''Generate content for a header file. + Generates strings, which should be concatenated and stored to file. + ''' + + yield '/* Automatically generated nanopb header */\n' + if options.notimestamp: + yield '/* Generated by %s */\n\n' % (nanopb_version) + else: + yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) + + symbol = make_identifier(headername) + yield '#ifndef PB_%s_INCLUDED\n' % symbol + yield '#define PB_%s_INCLUDED\n' % symbol + try: + yield options.libformat % ('pb.h') + except TypeError: + # no %s specified - use whatever was passed in as options.libformat + yield options.libformat yield '\n' - - yield '#if PB_PROTO_HEADER_VERSION != 30\n' - yield '#error Regenerate this file with the current version of nanopb generator.\n' - yield '#endif\n' - yield '\n' - - yield '#ifdef __cplusplus\n' - yield 'extern "C" {\n' - yield '#endif\n\n' - - yield '/* Enum definitions */\n' - for enum in enums: - yield str(enum) + '\n\n' - - yield '/* Struct definitions */\n' - for msg in sort_dependencies(messages): - yield msg.types() - yield str(msg) + '\n\n' - - if extensions: - yield '/* Extensions */\n' - for extension in extensions: - yield extension.extension_decl() + + for incfile in includes: + noext = os.path.splitext(incfile)[0] + yield options.genformat % (noext + options.extension + '.h') + yield '\n' + + yield '#if PB_PROTO_HEADER_VERSION != 30\n' + yield '#error Regenerate this file with the current version of nanopb generator.\n' + yield '#endif\n' yield '\n' + + yield '#ifdef __cplusplus\n' + yield 'extern "C" {\n' + yield '#endif\n\n' - yield '/* Default values for struct fields */\n' - for msg in messages: - yield msg.default_decl(True) - yield '\n' - - yield '/* Initializer values for message structs */\n' - for msg in messages: - identifier = '%s_init_default' % msg.name - yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False)) - for msg in messages: - identifier = '%s_init_zero' % msg.name - yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True)) - yield '\n' - - yield '/* Field tags (for use in manual encoding/decoding) */\n' - for msg in sort_dependencies(messages): - for field in msg.fields: - yield field.tags() - for extension in extensions: - yield extension.tags() - yield '\n' - - yield '/* Struct field encoding specification for nanopb */\n' - for msg in messages: - yield msg.fields_declaration() + '\n' - yield '\n' - - yield '/* Maximum encoded size of messages (where known) */\n' - for msg in messages: - msize = msg.encoded_size(messages) - if msize is not None: - identifier = '%s_size' % msg.name - yield '#define %-40s %s\n' % (identifier, msize) - yield '\n' - - yield '/* Message IDs (where set with "msgid" option) */\n' - - yield '#ifdef PB_MSGID\n' - for msg in messages: - if hasattr(msg,'msgid'): - yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name) - yield '\n' - - symbol = make_identifier(headername.split('.')[0]) - yield '#define %s_MESSAGES \\\n' % symbol - - for msg in messages: - m = "-1" - msize = msg.encoded_size(messages) - if msize is not None: - m = msize - if hasattr(msg,'msgid'): - yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name) - yield '\n' - - for msg in messages: - if hasattr(msg,'msgid'): - yield '#define %s_msgid %d\n' % (msg.name, msg.msgid) - yield '\n' - - yield '#endif\n\n' - - - yield '#ifdef __cplusplus\n' - yield '} /* extern "C" */\n' - yield '#endif\n' - - # End of header - yield '\n#endif\n' + if self.enums: + yield '/* Enum definitions */\n' + for enum in self.enums: + yield str(enum) + '\n\n' + + if self.messages: + yield '/* Struct definitions */\n' + for msg in sort_dependencies(self.messages): + yield msg.types() + yield str(msg) + '\n\n' + + if self.extensions: + yield '/* Extensions */\n' + for extension in self.extensions: + yield extension.extension_decl() + yield '\n' + + if self.messages: + yield '/* Default values for struct fields */\n' + for msg in self.messages: + yield msg.default_decl(True) + yield '\n' + + yield '/* Initializer values for message structs */\n' + for msg in self.messages: + identifier = '%s_init_default' % msg.name + yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False)) + for msg in self.messages: + identifier = '%s_init_zero' % msg.name + yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True)) + yield '\n' + + yield '/* Field tags (for use in manual encoding/decoding) */\n' + for msg in sort_dependencies(self.messages): + for field in msg.fields: + yield field.tags() + for extension in self.extensions: + yield extension.tags() + yield '\n' + + yield '/* Struct field encoding specification for nanopb */\n' + for msg in self.messages: + yield msg.fields_declaration() + '\n' + yield '\n' + + yield '/* Maximum encoded size of messages (where known) */\n' + for msg in self.messages: + msize = msg.encoded_size(self.dependencies) + if msize is not None: + identifier = '%s_size' % msg.name + yield '#define %-40s %s\n' % (identifier, msize) + yield '\n' + + yield '/* Message IDs (where set with "msgid" option) */\n' + + yield '#ifdef PB_MSGID\n' + for msg in self.messages: + if hasattr(msg,'msgid'): + yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name) + yield '\n' + + symbol = make_identifier(headername.split('.')[0]) + yield '#define %s_MESSAGES \\\n' % symbol + + for msg in self.messages: + m = "-1" + msize = msg.encoded_size(self.dependencies) + if msize is not None: + m = msize + if hasattr(msg,'msgid'): + yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name) + yield '\n' + + for msg in self.messages: + if hasattr(msg,'msgid'): + yield '#define %s_msgid %d\n' % (msg.name, msg.msgid) + yield '\n' -def generate_source(headername, enums, messages, extensions, options): - '''Generate content for a source file.''' - - yield '/* Automatically generated nanopb constant definitions */\n' - if options.notimestamp: - yield '/* Generated by %s */\n\n' % (nanopb_version) - else: - yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) - yield options.genformat % (headername) - yield '\n' - - yield '#if PB_PROTO_HEADER_VERSION != 30\n' - yield '#error Regenerate this file with the current version of nanopb generator.\n' - yield '#endif\n' - yield '\n' - - for msg in messages: - yield msg.default_decl(False) - - yield '\n\n' - - for msg in messages: - yield msg.fields_definition() + '\n\n' - - for ext in extensions: - yield ext.extension_def() + '\n' - - # Add checks for numeric limits - if messages: - largest_msg = max(messages, key = lambda m: m.count_required_fields()) - largest_count = largest_msg.count_required_fields() - if largest_count > 64: - yield '\n/* Check that missing required fields will be properly detected */\n' - yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count - yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name - yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count - yield '#endif\n' - - worst = 0 - worst_field = '' - checks = [] - checks_msgnames = [] - for msg in messages: - checks_msgnames.append(msg.name) - for field in msg.fields: - status = field.largest_field_value() - if isinstance(status, (str, unicode)): - checks.append(status) - elif status > worst: - worst = status - worst_field = str(field.struct_name) + '.' + str(field.name) - - if worst > 255 or checks: - yield '\n/* Check that field information fits in pb_field_t */\n' - - if worst > 65535 or checks: - yield '#if !defined(PB_FIELD_32BIT)\n' - if worst > 65535: - yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field - else: - assertion = ' && '.join(str(c) + ' < 65536' for c in checks) - msgs = '_'.join(str(n) for n in checks_msgnames) - yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n' - yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' - yield ' * \n' - yield ' * The reason you need to do this is that some of your messages contain tag\n' - yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n' - yield ' * field descriptors.\n' - yield ' */\n' - yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) yield '#endif\n\n' + + yield '#ifdef __cplusplus\n' + yield '} /* extern "C" */\n' + yield '#endif\n' + + # End of header + yield '\n#endif\n' + + def generate_source(self, headername, options): + '''Generate content for a source file.''' + + yield '/* Automatically generated nanopb constant definitions */\n' + if options.notimestamp: + yield '/* Generated by %s */\n\n' % (nanopb_version) + else: + yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime()) + yield options.genformat % (headername) + yield '\n' + + yield '#if PB_PROTO_HEADER_VERSION != 30\n' + yield '#error Regenerate this file with the current version of nanopb generator.\n' + yield '#endif\n' + yield '\n' + + for msg in self.messages: + yield msg.default_decl(False) + + yield '\n\n' + + for msg in self.messages: + yield msg.fields_definition() + '\n\n' + + for ext in self.extensions: + yield ext.extension_def() + '\n' + + # Add checks for numeric limits + if self.messages: + largest_msg = max(self.messages, key = lambda m: m.count_required_fields()) + largest_count = largest_msg.count_required_fields() + if largest_count > 64: + yield '\n/* Check that missing required fields will be properly detected */\n' + yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count + yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name + yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count + yield '#endif\n' + + worst = 0 + worst_field = '' + checks = [] + checks_msgnames = [] + for msg in self.messages: + checks_msgnames.append(msg.name) + for field in msg.fields: + status = field.largest_field_value() + if isinstance(status, (str, unicode)): + checks.append(status) + elif status > worst: + worst = status + worst_field = str(field.struct_name) + '.' + str(field.name) + + if worst > 255 or checks: + yield '\n/* Check that field information fits in pb_field_t */\n' + + if worst > 65535 or checks: + yield '#if !defined(PB_FIELD_32BIT)\n' + if worst > 65535: + yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field + else: + assertion = ' && '.join(str(c) + ' < 65536' for c in checks) + msgs = '_'.join(str(n) for n in checks_msgnames) + yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n' + yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' + yield ' * \n' + yield ' * The reason you need to do this is that some of your messages contain tag\n' + yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n' + yield ' * field descriptors.\n' + yield ' */\n' + yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) + yield '#endif\n\n' + + if worst < 65536: + yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n' + if worst > 255: + yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field + else: + assertion = ' && '.join(str(c) + ' < 256' for c in checks) + msgs = '_'.join(str(n) for n in checks_msgnames) + yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n' + yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' + yield ' * \n' + yield ' * The reason you need to do this is that some of your messages contain tag\n' + yield ' * numbers or field sizes that are larger than what can fit in the default\n' + yield ' * 8 bit descriptors.\n' + yield ' */\n' + yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) + yield '#endif\n\n' + + # Add check for sizeof(double) + has_double = False + for msg in self.messages: + for field in msg.fields: + if field.ctype == 'double': + has_double = True + + if has_double: + yield '\n' + yield '/* On some platforms (such as AVR), double is really float.\n' + yield ' * These are not directly supported by nanopb, but see example_avr_double.\n' + yield ' * To get rid of this error, remove any double fields from your .proto.\n' + yield ' */\n' + yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' - if worst < 65536: - yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n' - if worst > 255: - yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field - else: - assertion = ' && '.join(str(c) + ' < 256' for c in checks) - msgs = '_'.join(str(n) for n in checks_msgnames) - yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n' - yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n' - yield ' * \n' - yield ' * The reason you need to do this is that some of your messages contain tag\n' - yield ' * numbers or field sizes that are larger than what can fit in the default\n' - yield ' * 8 bit descriptors.\n' - yield ' */\n' - yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs) - yield '#endif\n\n' - - # Add check for sizeof(double) - has_double = False - for msg in messages: - for field in msg.fields: - if field.ctype == 'double': - has_double = True - - if has_double: yield '\n' - yield '/* On some platforms (such as AVR), double is really float.\n' - yield ' * These are not directly supported by nanopb, but see example_avr_double.\n' - yield ' * To get rid of this error, remove any double fields from your .proto.\n' - yield ' */\n' - yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n' - - yield '\n' # --------------------------------------------------------------------------- # Options parsing for the .proto files @@ -1338,7 +1353,7 @@ def process_file(filename, fdesc, options): # Parse the file file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename])) - enums, messages, extensions = parse_file(fdesc, file_options) + f = ProtoFile(fdesc, file_options) # Decide the file names noext = os.path.splitext(filename)[0] @@ -1349,13 +1364,10 @@ def process_file(filename, fdesc, options): # List of .proto files that should not be included in the C header file # even if they are mentioned in the source .proto. excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude - dependencies = [d for d in fdesc.dependency if d not in excludes] + includes = [d for d in fdesc.dependency if d not in excludes] - headerdata = ''.join(generate_header(dependencies, headerbasename, enums, - messages, extensions, options)) - - sourcedata = ''.join(generate_source(headerbasename, enums, - messages, extensions, options)) + headerdata = ''.join(f.generate_header(includes, headerbasename, options)) + sourcedata = ''.join(f.generate_source(headerbasename, options)) # Check if there were any lines in .options that did not match a member unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks] |