summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorPetteri Aimonen <jpa@npb.mail.kapsi.fi>2011-08-24 13:52:08 +0000
committerPetteri Aimonen <jpa@npb.mail.kapsi.fi>2011-08-24 13:52:08 +0000
commit9af96cd669f7f9189cdedcef110e7dbc1d526857 (patch)
tree90d0433f2c5543b4367b848424996b83a1908b77
parent646e3c4944ae8e49e92c2ff14765648b1acbe0c3 (diff)
Generator bugfixes
git-svn-id: https://svn.kapsi.fi/jpa/nanopb@970 e3a754e5-d11d-0410-8d38-ebb782a927b9
-rw-r--r--generator/nanopb_generator.py70
1 files changed, 46 insertions, 24 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index f09be346..67c422d9 100644
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -108,7 +108,8 @@ class Field:
elif desc.type == FieldD.TYPE_ENUM:
self.ltype = 'PB_LTYPE_VARINT'
self.ctype = names_from_type_name(desc.type_name)
- self.default = Names(self.ctype) + self.default
+ if self.default is not None:
+ self.default = self.ctype + self.default
elif desc.type == FieldD.TYPE_STRING:
self.ltype = 'PB_LTYPE_STRING'
if self.max_size is None:
@@ -218,7 +219,7 @@ class Field:
result += '\n pb_membersize(%s, %s[0]),' % (self.struct_name, self.name)
result += ('\n pb_membersize(%s, %s) / pb_membersize(%s, %s[0]),'
% (self.struct_name, self.name, self.struct_name, self.name))
- elif self.ltype == 'PB_LTYPE_BYTES':
+ elif self.htype != 'PB_HTYPE_CALLBACK' and self.ltype == 'PB_LTYPE_BYTES':
result += '\n pb_membersize(%s, bytes),' % self.ctype
result += ' 0,'
else:
@@ -240,24 +241,10 @@ class Message:
self.fields = [Field(self.name, f) for f in desc.field]
self.ordered_fields = self.fields[:]
self.ordered_fields.sort()
-
- def __cmp__(self, other):
- '''Sort messages so that submessages are declared before the message
- that uses them.
- '''
- if self.refers_to(other.name):
- return 1
- elif other.refers_to(self.name):
- return -1
- else:
- return 0
-
- def refers_to(self, name):
- '''Returns True if this message uses the specified type as field type.'''
- for field in self.fields:
- if str(field.ctype) == str(name):
- return True
- return False
+
+ def get_dependencies(self):
+ '''Get list of type names that this structure refers to.'''
+ return [str(field.ctype) for field in self.fields]
def __str__(self):
result = 'typedef struct {\n'
@@ -317,16 +304,52 @@ def parse_file(fdesc):
enums = []
messages = []
+ if fdesc.package:
+ base_name = Names(fdesc.package.split('.'))
+ else:
+ base_name = Names()
+
for enum in fdesc.enum_type:
- enums.append(Enum(Names(), enum))
+ enums.append(Enum(base_name, enum))
- for names, message in iterate_messages(fdesc):
+ for names, message in iterate_messages(fdesc, base_name):
messages.append(Message(names, message))
for enum in message.enum_type:
enums.append(Enum(names, enum))
return enums, messages
+def toposort2(data):
+ '''Topological sort.
+ From http://code.activestate.com/recipes/577413-topological-sort/
+ This function is under the MIT license.
+ '''
+ for k, v in data.items():
+ v.discard(k) # Ignore self dependencies
+ extra_items_in_deps = reduce(set.union, data.values()) - set(data.keys())
+ data.update({item:set() for item in extra_items_in_deps})
+ while True:
+ ordered = set(item for item,dep in data.items() if not dep)
+ if not ordered:
+ break
+ for item in sorted(ordered):
+ yield item
+ data = {item: (dep - ordered) for item,dep in data.items()
+ if item not in ordered}
+ assert not data, "A cyclic dependency exists amongst %r" % data
+
+def sort_dependencies(messages):
+ '''Sort a list of Messages based on dependencies.'''
+ dependencies = {}
+ message_by_name = {}
+ for message in messages:
+ dependencies[str(message.name)] = set(message.get_dependencies())
+ message_by_name[str(message.name)] = message
+
+ for msgname in toposort2(dependencies):
+ if msgname in message_by_name:
+ yield message_by_name[msgname]
+
def generate_header(headername, enums, messages):
'''Generate content for a header file.
Generates strings, which should be concatenated and stored to file.
@@ -344,8 +367,7 @@ def generate_header(headername, enums, messages):
yield str(enum) + '\n\n'
yield '/* Struct definitions */\n'
- messages.sort()
- for msg in messages:
+ for msg in sort_dependencies(messages):
yield msg.types()
yield str(msg) + '\n\n'