diff options
Diffstat (limited to 'generator/nanopb_generator.py')
-rwxr-xr-x | generator/nanopb_generator.py | 264 |
1 files changed, 192 insertions, 72 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py index 85cb413b..60944038 100755 --- a/generator/nanopb_generator.py +++ b/generator/nanopb_generator.py @@ -171,6 +171,7 @@ class Field: '''desc is FieldDescriptorProto''' self.tag = desc.number self.struct_name = struct_name + self.union_name = None self.name = desc.name self.default = None self.max_size = None @@ -300,57 +301,91 @@ class Field: if self.pbtype == 'BYTES' and self.allocation == 'STATIC': result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype) else: - result = None + result = '' return result - def get_initializer(self, null_init): - '''Return literal expression for this field's default value.''' - + def get_dependencies(self): + '''Get list of type names used by this field.''' + if self.allocation == 'STATIC': + return [str(self.ctype)] + else: + return [] + + def get_initializer(self, null_init, inner_init_only = False): + '''Return literal expression for this field's default value. + null_init: If True, initialize to a 0 value instead of default from .proto + inner_init_only: If True, exclude initialization for any count/has fields + ''' + + inner_init = None if self.pbtype == 'MESSAGE': if null_init: - return '%s_init_zero' % self.ctype + inner_init = '%s_init_zero' % self.ctype else: - return '%s_init_default' % self.ctype - - if self.default is None or null_init: + inner_init = '%s_init_default' % self.ctype + elif self.default is None or null_init: if self.pbtype == 'STRING': - return '""' + inner_init = '""' elif self.pbtype == 'BYTES': - return '{0, {0}}' + inner_init = '{0, {0}}' elif self.pbtype == 'ENUM': - return '(%s)0' % self.ctype + inner_init = '(%s)0' % self.ctype else: - return '0' - - default = str(self.default) - - if self.pbtype == 'STRING': - default = default.encode('utf-8').encode('string_escape') - default = default.replace('"', '\\"') - default = '"' + default + '"' - elif self.pbtype == 'BYTES': - data = default.decode('string_escape') - data = ['0x%02x' % ord(c) for c in data] - if len(data) == 0: - default = '{0, {0}}' + inner_init = '0' + else: + if self.pbtype == 'STRING': + inner_init = self.default.encode('utf-8').encode('string_escape') + inner_init = inner_init.replace('"', '\\"') + inner_init = '"' + inner_init + '"' + elif self.pbtype == 'BYTES': + data = str(self.default).decode('string_escape') + data = ['0x%02x' % ord(c) for c in data] + if len(data) == 0: + inner_init = '{0, {0}}' + else: + inner_init = '{%d, {%s}}' % (len(data), ','.join(data)) + elif self.pbtype in ['FIXED32', 'UINT32']: + inner_init = str(self.default) + 'u' + elif self.pbtype in ['FIXED64', 'UINT64']: + inner_init = str(self.default) + 'ull' + elif self.pbtype in ['SFIXED64', 'INT64']: + inner_init = str(self.default) + 'll' else: - default = '{%d, {%s}}' % (len(data), ','.join(data)) - elif self.pbtype in ['FIXED32', 'UINT32']: - default += 'u' - elif self.pbtype in ['FIXED64', 'UINT64']: - default += 'ull' - elif self.pbtype in ['SFIXED64', 'INT64']: - default += 'll' + inner_init = str(self.default) - return default - + if inner_init_only: + return inner_init + + outer_init = None + if self.allocation == 'STATIC': + if self.rules == 'REPEATED': + outer_init = '0, {' + outer_init += ', '.join([inner_init] * self.max_count) + outer_init += '}' + elif self.rules == 'OPTIONAL': + outer_init = 'false, ' + inner_init + else: + outer_init = inner_init + elif self.allocation == 'POINTER': + if self.rules == 'REPEATED': + outer_init = '0, NULL' + else: + outer_init = 'NULL' + elif self.allocation == 'CALLBACK': + if self.pbtype == 'EXTENSION': + outer_init = 'NULL' + else: + outer_init = '{{NULL}, NULL}' + + return outer_init + def default_decl(self, declaration_only = False): '''Return definition for this field's default value.''' if self.default is None: return None ctype = self.ctype - default = self.get_initializer(False) + default = self.get_initializer(False, True) array_decl = '' if self.pbtype == 'STRING': @@ -375,7 +410,13 @@ class Field: '''Return the pb_field_t initializer to use in the constant array. prev_field_name is the name of the previous field or None. ''' - result = ' PB_FIELD(%3d, ' % self.tag + + if self.rules == 'ONEOF': + result = ' PB_ONEOF_FIELD(%s, ' % self.union_name + else: + result = ' PB_FIELD(' + + result += '%3d, ' % self.tag result += '%-8s, ' % self.pbtype result += '%s, ' % self.rules result += '%-8s, ' % self.allocation @@ -403,6 +444,8 @@ class Field: if self.pbtype == 'MESSAGE': if self.rules == 'REPEATED' and self.allocation == 'STATIC': return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name) + elif self.rules == 'ONEOF': + return 'pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name) else: return 'pb_membersize(%s, %s)' % (self.struct_name, self.name) @@ -535,6 +578,71 @@ class ExtensionField(Field): # --------------------------------------------------------------------------- +# Generation of oneofs (unions) +# --------------------------------------------------------------------------- + +class OneOf(Field): + def __init__(self, oneof_desc): + self.name = oneof_desc.name + self.ctype = 'union' + self.fields = [] + + def add_field(self, field): + if field.allocation == 'CALLBACK': + raise Exception("Callback fields inside of oneof are not supported" + + " (field %s)" % field.fullname) + + field.union_name = self.name + field.rules = 'ONEOF' + self.fields.append(field) + self.fields.sort(key = lambda f: f.tag) + + # Sort by the lowest tag number inside union + self.tag = min([f.tag for f in self.fields]) + + def __cmp__(self, other): + return cmp(self.tag, other.tag) + + def __str__(self): + result = '' + if self.fields: + result += ' pb_size_t which_' + self.name + ";\n" + result += ' union {\n' + for f in self.fields: + result += ' ' + str(f).replace('\n', '\n ') + '\n' + result += ' } ' + self.name + ';' + return result + + def types(self): + return ''.join([f.types() for f in self.fields]) + + def get_dependencies(self): + deps = [] + for f in self.fields: + deps += f.get_dependencies() + return deps + + def get_initializer(self, null_init): + return '0, {' + self.fields[0].get_initializer(null_init) + '}' + + def default_decl(self, declaration_only = False): + return None + + def tags(self): + return '\n'.join([f.tags() for f in self.fields]) + + def pb_field_t(self, prev_field_name): + prev_field_name = prev_field_name or self.name + result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields]) + return result + + def largest_field_value(self): + return max([f.largest_field_value() for f in self.fields]) + + def encoded_size(self, allmsgs): + return max([f.encoded_size(allmsgs) for f in self.fields]) + +# --------------------------------------------------------------------------- # Generation of messages (structures) # --------------------------------------------------------------------------- @@ -543,11 +651,24 @@ class Message: def __init__(self, names, desc, message_options): self.name = names self.fields = [] - + self.oneofs = [] + + if hasattr(desc, 'oneof_decl'): + for f in desc.oneof_decl: + oneof = OneOf(f) + self.oneofs.append(oneof) + self.fields.append(oneof) + for f in desc.field: field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) - if field_options.type != nanopb_pb2.FT_IGNORE: - self.fields.append(Field(self.name, f, field_options)) + if field_options.type == nanopb_pb2.FT_IGNORE: + continue + + field = Field(self.name, f, field_options) + if hasattr(f, 'oneof_index') and f.HasField('oneof_index'): + self.oneofs[f.oneof_index].add_field(field) + else: + self.fields.append(field) if len(desc.extension_range) > 0: field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') @@ -561,7 +682,10 @@ class Message: def get_dependencies(self): '''Get list of type names that this structure refers to.''' - return [str(field.ctype) for field in self.fields if field.allocation == 'STATIC'] + deps = [] + for f in self.fields: + deps += f.get_dependencies() + return deps def __str__(self): result = 'typedef struct _%s {\n' % self.name @@ -586,39 +710,15 @@ class Message: return result def types(self): - result = "" - for field in self.fields: - types = field.types() - if types is not None: - result += types + '\n' - return result - + return ''.join([f.types() for f in self.fields]) + def get_initializer(self, null_init): if not self.ordered_fields: return '{0}' parts = [] for field in self.ordered_fields: - if field.allocation == 'STATIC': - if field.rules == 'REPEATED': - parts.append('0') - parts.append('{' - + ', '.join([field.get_initializer(null_init)] * field.max_count) - + '}') - elif field.rules == 'OPTIONAL': - parts.append('false') - parts.append(field.get_initializer(null_init)) - else: - parts.append(field.get_initializer(null_init)) - elif field.allocation == 'POINTER': - if field.rules == 'REPEATED': - parts.append('0') - parts.append('NULL') - elif field.allocation == 'CALLBACK': - if field.pbtype == 'EXTENSION': - parts.append('NULL') - else: - parts.append('{{NULL}, NULL}') + parts.append(field.get_initializer(null_init)) return '{' + ', '.join(parts) + '}' def default_decl(self, declaration_only = False): @@ -629,18 +729,39 @@ class Message: result += default + '\n' return result + def count_required_fields(self): + '''Returns number of required fields inside this message''' + count = 0 + for f in self.fields: + if f not in self.oneofs: + if f.rules == 'REQUIRED': + count += 1 + return count + + def count_all_fields(self): + count = 0 + for f in self.fields: + if f in self.oneofs: + count += len(f.fields) + else: + count += 1 + return count + def fields_declaration(self): - result = 'extern const pb_field_t %s_fields[%d];' % (self.name, len(self.fields) + 1) + result = 'extern const pb_field_t %s_fields[%d];' % (self.name, self.count_all_fields() + 1) return result def fields_definition(self): - result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, len(self.fields) + 1) + result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, self.count_all_fields() + 1) prev = None for field in self.ordered_fields: result += field.pb_field_t(prev) result += ',\n' - prev = field.name + if isinstance(field, OneOf): + prev = field.name + '.' + field.fields[-1].name + else: + prev = field.name result += ' PB_LAST_FIELD\n};' return result @@ -894,9 +1015,8 @@ def generate_source(headername, enums, messages, extensions, options): # Add checks for numeric limits if messages: - count_required_fields = lambda m: len([f for f in msg.fields if f.rules == 'REQUIRED']) - largest_msg = max(messages, key = count_required_fields) - largest_count = count_required_fields(largest_msg) + largest_msg = max(messages, key = lambda m: m.count_required_fields()) + largest_count = largest_msg.count_required_fields() if largest_count > 64: yield '\n/* Check that missing required fields will be properly detected */\n' yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count |