aboutsummaryrefslogtreecommitdiffstats
path: root/generator/nanopb_generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'generator/nanopb_generator.py')
-rwxr-xr-xgenerator/nanopb_generator.py582
1 files changed, 297 insertions, 285 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index 2b1d63e5..6d06c2cd 100755
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -471,7 +471,7 @@ class Field:
return max(self.tag, self.max_size, self.max_count)
- def encoded_size(self, allmsgs):
+ def encoded_size(self, dependencies):
'''Return the maximum size that this field can take when encoded,
including the field tag. If the size cannot be determined, returns
None.'''
@@ -480,15 +480,14 @@ class Field:
return None
if self.pbtype == 'MESSAGE':
- for msg in allmsgs:
- if msg.name == self.submsgname:
- encsize = msg.encoded_size(allmsgs)
- if encsize is None:
- return None # Submessage size is indeterminate
-
- # Include submessage length prefix
- encsize += varint_max_size(encsize.upperlimit())
- break
+ if str(self.submsgname) in dependencies:
+ submsg = dependencies[str(self.submsgname)]
+ encsize = submsg.encoded_size(dependencies)
+ if encsize is None:
+ return None # Submessage size is indeterminate
+
+ # Include submessage length prefix
+ encsize += varint_max_size(encsize.upperlimit())
else:
# Submessage cannot be found, this currently occurs when
# the submessage type is defined in a different file.
@@ -831,7 +830,6 @@ class Message:
# Processing of entire .proto files
# ---------------------------------------------------------------------------
-
def iterate_messages(desc, names = Names()):
'''Recursively find all messages. For each, yield name, DescriptorProto.'''
if hasattr(desc, 'message_type'):
@@ -857,57 +855,6 @@ def iterate_extensions(desc, names = Names()):
for extension in subdesc.extension:
yield subname, extension
-def parse_file(fdesc, file_options):
- '''Takes a FileDescriptorProto and returns tuple (enums, messages, extensions).'''
-
- enums = []
- messages = []
- extensions = []
-
- if fdesc.package:
- base_name = Names(fdesc.package.split('.'))
- else:
- base_name = Names()
-
- for enum in fdesc.enum_type:
- enum_options = get_nanopb_suboptions(enum, file_options, base_name + enum.name)
- enums.append(Enum(base_name, enum, enum_options))
-
- for names, message in iterate_messages(fdesc, base_name):
- message_options = get_nanopb_suboptions(message, file_options, names)
-
- if message_options.skip_message:
- continue
-
- messages.append(Message(names, message, message_options))
- for enum in message.enum_type:
- enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name)
- enums.append(Enum(names, enum, enum_options))
-
- for names, extension in iterate_extensions(fdesc, base_name):
- field_options = get_nanopb_suboptions(extension, file_options, names + extension.name)
- if field_options.type != nanopb_pb2.FT_IGNORE:
- extensions.append(ExtensionField(names, extension, field_options))
-
- # Fix field default values where enum short names are used.
- for enum in enums:
- if not enum.options.long_names:
- for message in messages:
- for field in message.fields:
- if field.default in enum.value_longnames:
- idx = enum.value_longnames.index(field.default)
- field.default = enum.values[idx][0]
-
- # Fix field data types where enums have negative values.
- for enum in enums:
- if not enum.has_negative():
- for message in messages:
- for field in message.fields:
- if field.pbtype == 'ENUM' and field.ctype == enum.names:
- field.pbtype = 'UENUM'
-
- return enums, messages, extensions
-
def toposort2(data):
'''Topological sort.
From http://code.activestate.com/recipes/577413-topological-sort/
@@ -949,231 +896,299 @@ def make_identifier(headername):
result += '_'
return result
-def generate_header(dependencies, headername, enums, messages, extensions, options):
- '''Generate content for a header file.
- Generates strings, which should be concatenated and stored to file.
- '''
-
- yield '/* Automatically generated nanopb header */\n'
- if options.notimestamp:
- yield '/* Generated by %s */\n\n' % (nanopb_version)
- else:
- yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
+class ProtoFile:
+ def __init__(self, fdesc, file_options):
+ '''Takes a FileDescriptorProto and parses it.'''
+ self.fdesc = fdesc
+ self.file_options = file_options
+ self.dependencies = {}
+ self.parse()
+
+ # Some of types used in this file probably come from the file itself.
+ # Thus it has implicit dependency on itself.
+ self.add_dependency(self)
+
+ def parse(self):
+ self.enums = []
+ self.messages = []
+ self.extensions = []
+
+ if self.fdesc.package:
+ base_name = Names(self.fdesc.package.split('.'))
+ else:
+ base_name = Names()
- symbol = make_identifier(headername)
- yield '#ifndef PB_%s_INCLUDED\n' % symbol
- yield '#define PB_%s_INCLUDED\n' % symbol
- try:
- yield options.libformat % ('pb.h')
- except TypeError:
- # no %s specified - use whatever was passed in as options.libformat
- yield options.libformat
- yield '\n'
+ for enum in self.fdesc.enum_type:
+ enum_options = get_nanopb_suboptions(enum, self.file_options, base_name + enum.name)
+ self.enums.append(Enum(base_name, enum, enum_options))
+
+ for names, message in iterate_messages(self.fdesc, base_name):
+ message_options = get_nanopb_suboptions(message, self.file_options, names)
+
+ if message_options.skip_message:
+ continue
+
+ self.messages.append(Message(names, message, message_options))
+ for enum in message.enum_type:
+ enum_options = get_nanopb_suboptions(enum, message_options, names + enum.name)
+ self.enums.append(Enum(names, enum, enum_options))
+
+ for names, extension in iterate_extensions(self.fdesc, base_name):
+ field_options = get_nanopb_suboptions(extension, self.file_options, names + extension.name)
+ if field_options.type != nanopb_pb2.FT_IGNORE:
+ self.extensions.append(ExtensionField(names, extension, field_options))
- for dependency in dependencies:
- noext = os.path.splitext(dependency)[0]
- yield options.genformat % (noext + options.extension + '.h')
+ def add_dependency(self, other):
+ for enum in other.enums:
+ self.dependencies[str(enum.names)] = enum
+
+ for msg in other.messages:
+ self.dependencies[str(msg.name)] = msg
+
+ # Fix field default values where enum short names are used.
+ for enum in other.enums:
+ if not enum.options.long_names:
+ for message in self.messages:
+ for field in message.fields:
+ if field.default in enum.value_longnames:
+ idx = enum.value_longnames.index(field.default)
+ field.default = enum.values[idx][0]
+
+ # Fix field data types where enums have negative values.
+ for enum in other.enums:
+ if not enum.has_negative():
+ for message in self.messages:
+ for field in message.fields:
+ if field.pbtype == 'ENUM' and field.ctype == enum.names:
+ field.pbtype = 'UENUM'
+
+ def generate_header(self, includes, headername, options):
+ '''Generate content for a header file.
+ Generates strings, which should be concatenated and stored to file.
+ '''
+
+ yield '/* Automatically generated nanopb header */\n'
+ if options.notimestamp:
+ yield '/* Generated by %s */\n\n' % (nanopb_version)
+ else:
+ yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
+
+ symbol = make_identifier(headername)
+ yield '#ifndef PB_%s_INCLUDED\n' % symbol
+ yield '#define PB_%s_INCLUDED\n' % symbol
+ try:
+ yield options.libformat % ('pb.h')
+ except TypeError:
+ # no %s specified - use whatever was passed in as options.libformat
+ yield options.libformat
yield '\n'
-
- yield '#if PB_PROTO_HEADER_VERSION != 30\n'
- yield '#error Regenerate this file with the current version of nanopb generator.\n'
- yield '#endif\n'
- yield '\n'
-
- yield '#ifdef __cplusplus\n'
- yield 'extern "C" {\n'
- yield '#endif\n\n'
-
- yield '/* Enum definitions */\n'
- for enum in enums:
- yield str(enum) + '\n\n'
-
- yield '/* Struct definitions */\n'
- for msg in sort_dependencies(messages):
- yield msg.types()
- yield str(msg) + '\n\n'
-
- if extensions:
- yield '/* Extensions */\n'
- for extension in extensions:
- yield extension.extension_decl()
+
+ for incfile in includes:
+ noext = os.path.splitext(incfile)[0]
+ yield options.genformat % (noext + options.extension + '.h')
+ yield '\n'
+
+ yield '#if PB_PROTO_HEADER_VERSION != 30\n'
+ yield '#error Regenerate this file with the current version of nanopb generator.\n'
+ yield '#endif\n'
yield '\n'
+
+ yield '#ifdef __cplusplus\n'
+ yield 'extern "C" {\n'
+ yield '#endif\n\n'
- yield '/* Default values for struct fields */\n'
- for msg in messages:
- yield msg.default_decl(True)
- yield '\n'
-
- yield '/* Initializer values for message structs */\n'
- for msg in messages:
- identifier = '%s_init_default' % msg.name
- yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
- for msg in messages:
- identifier = '%s_init_zero' % msg.name
- yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
- yield '\n'
-
- yield '/* Field tags (for use in manual encoding/decoding) */\n'
- for msg in sort_dependencies(messages):
- for field in msg.fields:
- yield field.tags()
- for extension in extensions:
- yield extension.tags()
- yield '\n'
-
- yield '/* Struct field encoding specification for nanopb */\n'
- for msg in messages:
- yield msg.fields_declaration() + '\n'
- yield '\n'
-
- yield '/* Maximum encoded size of messages (where known) */\n'
- for msg in messages:
- msize = msg.encoded_size(messages)
- if msize is not None:
- identifier = '%s_size' % msg.name
- yield '#define %-40s %s\n' % (identifier, msize)
- yield '\n'
-
- yield '/* Message IDs (where set with "msgid" option) */\n'
-
- yield '#ifdef PB_MSGID\n'
- for msg in messages:
- if hasattr(msg,'msgid'):
- yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
- yield '\n'
-
- symbol = make_identifier(headername.split('.')[0])
- yield '#define %s_MESSAGES \\\n' % symbol
-
- for msg in messages:
- m = "-1"
- msize = msg.encoded_size(messages)
- if msize is not None:
- m = msize
- if hasattr(msg,'msgid'):
- yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
- yield '\n'
-
- for msg in messages:
- if hasattr(msg,'msgid'):
- yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
- yield '\n'
-
- yield '#endif\n\n'
-
-
- yield '#ifdef __cplusplus\n'
- yield '} /* extern "C" */\n'
- yield '#endif\n'
-
- # End of header
- yield '\n#endif\n'
+ if self.enums:
+ yield '/* Enum definitions */\n'
+ for enum in self.enums:
+ yield str(enum) + '\n\n'
+
+ if self.messages:
+ yield '/* Struct definitions */\n'
+ for msg in sort_dependencies(self.messages):
+ yield msg.types()
+ yield str(msg) + '\n\n'
+
+ if self.extensions:
+ yield '/* Extensions */\n'
+ for extension in self.extensions:
+ yield extension.extension_decl()
+ yield '\n'
+
+ if self.messages:
+ yield '/* Default values for struct fields */\n'
+ for msg in self.messages:
+ yield msg.default_decl(True)
+ yield '\n'
+
+ yield '/* Initializer values for message structs */\n'
+ for msg in self.messages:
+ identifier = '%s_init_default' % msg.name
+ yield '#define %-40s %s\n' % (identifier, msg.get_initializer(False))
+ for msg in self.messages:
+ identifier = '%s_init_zero' % msg.name
+ yield '#define %-40s %s\n' % (identifier, msg.get_initializer(True))
+ yield '\n'
+
+ yield '/* Field tags (for use in manual encoding/decoding) */\n'
+ for msg in sort_dependencies(self.messages):
+ for field in msg.fields:
+ yield field.tags()
+ for extension in self.extensions:
+ yield extension.tags()
+ yield '\n'
+
+ yield '/* Struct field encoding specification for nanopb */\n'
+ for msg in self.messages:
+ yield msg.fields_declaration() + '\n'
+ yield '\n'
+
+ yield '/* Maximum encoded size of messages (where known) */\n'
+ for msg in self.messages:
+ msize = msg.encoded_size(self.dependencies)
+ if msize is not None:
+ identifier = '%s_size' % msg.name
+ yield '#define %-40s %s\n' % (identifier, msize)
+ yield '\n'
+
+ yield '/* Message IDs (where set with "msgid" option) */\n'
+
+ yield '#ifdef PB_MSGID\n'
+ for msg in self.messages:
+ if hasattr(msg,'msgid'):
+ yield '#define PB_MSG_%d %s\n' % (msg.msgid, msg.name)
+ yield '\n'
+
+ symbol = make_identifier(headername.split('.')[0])
+ yield '#define %s_MESSAGES \\\n' % symbol
+
+ for msg in self.messages:
+ m = "-1"
+ msize = msg.encoded_size(self.dependencies)
+ if msize is not None:
+ m = msize
+ if hasattr(msg,'msgid'):
+ yield '\tPB_MSG(%d,%s,%s) \\\n' % (msg.msgid, m, msg.name)
+ yield '\n'
+
+ for msg in self.messages:
+ if hasattr(msg,'msgid'):
+ yield '#define %s_msgid %d\n' % (msg.name, msg.msgid)
+ yield '\n'
-def generate_source(headername, enums, messages, extensions, options):
- '''Generate content for a source file.'''
-
- yield '/* Automatically generated nanopb constant definitions */\n'
- if options.notimestamp:
- yield '/* Generated by %s */\n\n' % (nanopb_version)
- else:
- yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
- yield options.genformat % (headername)
- yield '\n'
-
- yield '#if PB_PROTO_HEADER_VERSION != 30\n'
- yield '#error Regenerate this file with the current version of nanopb generator.\n'
- yield '#endif\n'
- yield '\n'
-
- for msg in messages:
- yield msg.default_decl(False)
-
- yield '\n\n'
-
- for msg in messages:
- yield msg.fields_definition() + '\n\n'
-
- for ext in extensions:
- yield ext.extension_def() + '\n'
-
- # Add checks for numeric limits
- if messages:
- largest_msg = max(messages, key = lambda m: m.count_required_fields())
- largest_count = largest_msg.count_required_fields()
- if largest_count > 64:
- yield '\n/* Check that missing required fields will be properly detected */\n'
- yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
- yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
- yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
- yield '#endif\n'
-
- worst = 0
- worst_field = ''
- checks = []
- checks_msgnames = []
- for msg in messages:
- checks_msgnames.append(msg.name)
- for field in msg.fields:
- status = field.largest_field_value()
- if isinstance(status, (str, unicode)):
- checks.append(status)
- elif status > worst:
- worst = status
- worst_field = str(field.struct_name) + '.' + str(field.name)
-
- if worst > 255 or checks:
- yield '\n/* Check that field information fits in pb_field_t */\n'
-
- if worst > 65535 or checks:
- yield '#if !defined(PB_FIELD_32BIT)\n'
- if worst > 65535:
- yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
- else:
- assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
- msgs = '_'.join(str(n) for n in checks_msgnames)
- yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n'
- yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
- yield ' * \n'
- yield ' * The reason you need to do this is that some of your messages contain tag\n'
- yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n'
- yield ' * field descriptors.\n'
- yield ' */\n'
- yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
yield '#endif\n\n'
+
+ yield '#ifdef __cplusplus\n'
+ yield '} /* extern "C" */\n'
+ yield '#endif\n'
+
+ # End of header
+ yield '\n#endif\n'
+
+ def generate_source(self, headername, options):
+ '''Generate content for a source file.'''
+
+ yield '/* Automatically generated nanopb constant definitions */\n'
+ if options.notimestamp:
+ yield '/* Generated by %s */\n\n' % (nanopb_version)
+ else:
+ yield '/* Generated by %s at %s. */\n\n' % (nanopb_version, time.asctime())
+ yield options.genformat % (headername)
+ yield '\n'
+
+ yield '#if PB_PROTO_HEADER_VERSION != 30\n'
+ yield '#error Regenerate this file with the current version of nanopb generator.\n'
+ yield '#endif\n'
+ yield '\n'
+
+ for msg in self.messages:
+ yield msg.default_decl(False)
+
+ yield '\n\n'
+
+ for msg in self.messages:
+ yield msg.fields_definition() + '\n\n'
+
+ for ext in self.extensions:
+ yield ext.extension_def() + '\n'
+
+ # Add checks for numeric limits
+ if self.messages:
+ largest_msg = max(self.messages, key = lambda m: m.count_required_fields())
+ largest_count = largest_msg.count_required_fields()
+ if largest_count > 64:
+ yield '\n/* Check that missing required fields will be properly detected */\n'
+ yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count
+ yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
+ yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
+ yield '#endif\n'
+
+ worst = 0
+ worst_field = ''
+ checks = []
+ checks_msgnames = []
+ for msg in self.messages:
+ checks_msgnames.append(msg.name)
+ for field in msg.fields:
+ status = field.largest_field_value()
+ if isinstance(status, (str, unicode)):
+ checks.append(status)
+ elif status > worst:
+ worst = status
+ worst_field = str(field.struct_name) + '.' + str(field.name)
+
+ if worst > 255 or checks:
+ yield '\n/* Check that field information fits in pb_field_t */\n'
+
+ if worst > 65535 or checks:
+ yield '#if !defined(PB_FIELD_32BIT)\n'
+ if worst > 65535:
+ yield '#error Field descriptor for %s is too large. Define PB_FIELD_32BIT to fix this.\n' % worst_field
+ else:
+ assertion = ' && '.join(str(c) + ' < 65536' for c in checks)
+ msgs = '_'.join(str(n) for n in checks_msgnames)
+ yield '/* If you get an error here, it means that you need to define PB_FIELD_32BIT\n'
+ yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
+ yield ' * \n'
+ yield ' * The reason you need to do this is that some of your messages contain tag\n'
+ yield ' * numbers or field sizes that are larger than what can fit in 8 or 16 bit\n'
+ yield ' * field descriptors.\n'
+ yield ' */\n'
+ yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_32BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
+ yield '#endif\n\n'
+
+ if worst < 65536:
+ yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n'
+ if worst > 255:
+ yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field
+ else:
+ assertion = ' && '.join(str(c) + ' < 256' for c in checks)
+ msgs = '_'.join(str(n) for n in checks_msgnames)
+ yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n'
+ yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
+ yield ' * \n'
+ yield ' * The reason you need to do this is that some of your messages contain tag\n'
+ yield ' * numbers or field sizes that are larger than what can fit in the default\n'
+ yield ' * 8 bit descriptors.\n'
+ yield ' */\n'
+ yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
+ yield '#endif\n\n'
+
+ # Add check for sizeof(double)
+ has_double = False
+ for msg in self.messages:
+ for field in msg.fields:
+ if field.ctype == 'double':
+ has_double = True
+
+ if has_double:
+ yield '\n'
+ yield '/* On some platforms (such as AVR), double is really float.\n'
+ yield ' * These are not directly supported by nanopb, but see example_avr_double.\n'
+ yield ' * To get rid of this error, remove any double fields from your .proto.\n'
+ yield ' */\n'
+ yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
- if worst < 65536:
- yield '#if !defined(PB_FIELD_16BIT) && !defined(PB_FIELD_32BIT)\n'
- if worst > 255:
- yield '#error Field descriptor for %s is too large. Define PB_FIELD_16BIT to fix this.\n' % worst_field
- else:
- assertion = ' && '.join(str(c) + ' < 256' for c in checks)
- msgs = '_'.join(str(n) for n in checks_msgnames)
- yield '/* If you get an error here, it means that you need to define PB_FIELD_16BIT\n'
- yield ' * compile-time option. You can do that in pb.h or on compiler command line.\n'
- yield ' * \n'
- yield ' * The reason you need to do this is that some of your messages contain tag\n'
- yield ' * numbers or field sizes that are larger than what can fit in the default\n'
- yield ' * 8 bit descriptors.\n'
- yield ' */\n'
- yield 'PB_STATIC_ASSERT((%s), YOU_MUST_DEFINE_PB_FIELD_16BIT_FOR_MESSAGES_%s)\n'%(assertion,msgs)
- yield '#endif\n\n'
-
- # Add check for sizeof(double)
- has_double = False
- for msg in messages:
- for field in msg.fields:
- if field.ctype == 'double':
- has_double = True
-
- if has_double:
yield '\n'
- yield '/* On some platforms (such as AVR), double is really float.\n'
- yield ' * These are not directly supported by nanopb, but see example_avr_double.\n'
- yield ' * To get rid of this error, remove any double fields from your .proto.\n'
- yield ' */\n'
- yield 'PB_STATIC_ASSERT(sizeof(double) == 8, DOUBLE_MUST_BE_8_BYTES)\n'
-
- yield '\n'
# ---------------------------------------------------------------------------
# Options parsing for the .proto files
@@ -1338,7 +1353,7 @@ def process_file(filename, fdesc, options):
# Parse the file
file_options = get_nanopb_suboptions(fdesc, toplevel_options, Names([filename]))
- enums, messages, extensions = parse_file(fdesc, file_options)
+ f = ProtoFile(fdesc, file_options)
# Decide the file names
noext = os.path.splitext(filename)[0]
@@ -1349,13 +1364,10 @@ def process_file(filename, fdesc, options):
# List of .proto files that should not be included in the C header file
# even if they are mentioned in the source .proto.
excludes = ['nanopb.proto', 'google/protobuf/descriptor.proto'] + options.exclude
- dependencies = [d for d in fdesc.dependency if d not in excludes]
+ includes = [d for d in fdesc.dependency if d not in excludes]
- headerdata = ''.join(generate_header(dependencies, headerbasename, enums,
- messages, extensions, options))
-
- sourcedata = ''.join(generate_source(headerbasename, enums,
- messages, extensions, options))
+ headerdata = ''.join(f.generate_header(includes, headerbasename, options))
+ sourcedata = ''.join(f.generate_source(headerbasename, options))
# Check if there were any lines in .options that did not match a member
unmatched = [n for n,o in Globals.separate_options if n not in Globals.matched_namemasks]