aboutsummaryrefslogtreecommitdiffstats
path: root/generator/nanopb_generator.py
diff options
context:
space:
mode:
authorPetteri Aimonen <jpa@github.mail.kapsi.fi>2015-09-26 12:08:55 +0300
committerPetteri Aimonen <jpa@github.mail.kapsi.fi>2015-09-26 12:08:55 +0300
commit403706c63c1b2196960c755cd4c61bd179a04c6e (patch)
tree8233f003e317cb48d2db2756ebd58b495a9a2f47 /generator/nanopb_generator.py
parent6cd4f20b0601ce7a68d555087ce0a62f837026ed (diff)
parent9c9f7f14e71c3b49877cca8eccd448cabea6306d (diff)
Merge pull request #169 from kylemanna/python3
Add proper Python3 support to the generator
Diffstat (limited to 'generator/nanopb_generator.py')
-rwxr-xr-xgenerator/nanopb_generator.py101
1 files changed, 64 insertions, 37 deletions
diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py
index 4fd3a4d4..7fe0db95 100755
--- a/generator/nanopb_generator.py
+++ b/generator/nanopb_generator.py
@@ -1,10 +1,13 @@
-#!/usr/bin/python
+#!/usr/bin/env python
+
+from __future__ import unicode_literals
'''Generate header file for nanopb from a ProtoBuf FileDescriptorSet.'''
nanopb_version = "nanopb-0.3.5-dev"
import sys
import re
+from functools import reduce
try:
# Add some dummy imports to keep packaging tools happy.
@@ -82,7 +85,14 @@ class Names:
return '_'.join(self.parts)
def __add__(self, other):
- if isinstance(other, (str, unicode)):
+ # The fdesc names are unicode and need to be handled for
+ # python2 and python3
+ try:
+ realstr = unicode
+ except NameError:
+ realstr = str
+
+ if isinstance(other, realstr):
return Names(self.parts + (other,))
elif isinstance(other, tuple):
return Names(self.parts + other)
@@ -123,7 +133,7 @@ class EncodedSize:
self.symbols = symbols
def __add__(self, other):
- if isinstance(other, (int, long)):
+ if isinstance(other, int):
return EncodedSize(self.value + other, self.symbols)
elif isinstance(other, (str, Names)):
return EncodedSize(self.value, self.symbols + [str(other)])
@@ -133,7 +143,7 @@ class EncodedSize:
raise ValueError("Cannot add size: " + repr(other))
def __mul__(self, other):
- if isinstance(other, (int, long)):
+ if isinstance(other, int):
return EncodedSize(self.value * other, [str(other) + '*' + s for s in self.symbols])
else:
raise ValueError("Cannot multiply size: " + repr(other))
@@ -192,6 +202,24 @@ class Enum:
return result
+class FieldMaxSize:
+ def __init__(self, worst = 0, checks = [], field_name = 'undefined'):
+ if isinstance(worst, list):
+ self.worst = max(i for i in worst if i is not None)
+ else:
+ self.worst = worst
+
+ self.worst_field = field_name
+ self.checks = checks
+
+ def extend(self, extend, field_name = None):
+ self.worst = max(self.worst, extend.worst)
+
+ if self.worst == extend.worst:
+ self.worst_field = extend.worst_field
+
+ self.checks.extend(extend.checks)
+
class Field:
def __init__(self, struct_name, desc, field_options):
'''desc is FieldDescriptorProto'''
@@ -260,7 +288,7 @@ class Field:
raise NotImplementedError(field_options.type)
# Decide the C data type to use in the struct.
- if datatypes.has_key(desc.type):
+ if desc.type in datatypes:
self.ctype, self.pbtype, self.enc_size, isa = datatypes[desc.type]
# Override the field size if user wants to use smaller integers
@@ -295,8 +323,8 @@ class Field:
else:
raise NotImplementedError(desc.type)
- def __cmp__(self, other):
- return cmp(self.tag, other.tag)
+ def __lt__(self, other):
+ return self.tag < other.tag
def __str__(self):
result = ''
@@ -360,12 +388,10 @@ class Field:
inner_init = '0'
else:
if self.pbtype == 'STRING':
- inner_init = self.default.encode('utf-8').encode('string_escape')
- inner_init = inner_init.replace('"', '\\"')
+ inner_init = self.default.replace('"', '\\"')
inner_init = '"' + inner_init + '"'
elif self.pbtype == 'BYTES':
- data = str(self.default).decode('string_escape')
- data = ['0x%02x' % ord(c) for c in data]
+ data = ['0x%02x' % ord(c) for c in self.default]
if len(data) == 0:
inner_init = '{0, {0}}'
else:
@@ -467,15 +493,18 @@ class Field:
def largest_field_value(self):
'''Determine if this field needs 16bit or 32bit pb_field_t structure to compile properly.
Returns numeric value or a C-expression for assert.'''
+ check = []
if self.pbtype == 'MESSAGE':
if self.rules == 'REPEATED' and self.allocation == 'STATIC':
- return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name)
+ check.append('pb_membersize(%s, %s[0])' % (self.struct_name, self.name))
elif self.rules == 'ONEOF':
- return 'pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name)
+ check.append('pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name))
else:
- return 'pb_membersize(%s, %s)' % (self.struct_name, self.name)
+ check.append('pb_membersize(%s, %s)' % (self.struct_name, self.name))
- return max(self.tag, self.max_size, self.max_count)
+ return FieldMaxSize([self.tag, self.max_size, self.max_count],
+ check,
+ ('%s.%s' % (self.struct_name, self.name)))
def encoded_size(self, dependencies):
'''Return the maximum size that this field can take when encoded,
@@ -639,9 +668,6 @@ class OneOf(Field):
# Sort by the lowest tag number inside union
self.tag = min([f.tag for f in self.fields])
- def __cmp__(self, other):
- return cmp(self.tag, other.tag)
-
def __str__(self):
result = ''
if self.fields:
@@ -675,7 +701,10 @@ class OneOf(Field):
return result
def largest_field_value(self):
- return max([f.largest_field_value() for f in self.fields])
+ largest = FieldMaxSize()
+ for f in self.fields:
+ largest.extend(f.largest_field_value())
+ return largest
def encoded_size(self, dependencies):
largest = EncodedSize(0)
@@ -875,17 +904,17 @@ def toposort2(data):
From http://code.activestate.com/recipes/577413-topological-sort/
This function is under the MIT license.
'''
- for k, v in data.items():
+ for k, v in list(data.items()):
v.discard(k) # Ignore self dependencies
- extra_items_in_deps = reduce(set.union, data.values(), set()) - set(data.keys())
+ extra_items_in_deps = reduce(set.union, list(data.values()), set()) - set(data.keys())
data.update(dict([(item, set()) for item in extra_items_in_deps]))
while True:
- ordered = set(item for item,dep in data.items() if not dep)
+ ordered = set(item for item,dep in list(data.items()) if not dep)
if not ordered:
break
for item in sorted(ordered):
yield item
- data = dict([(item, (dep - ordered)) for item,dep in data.items()
+ data = dict([(item, (dep - ordered)) for item,dep in list(data.items())
if item not in ordered])
assert not data, "A cyclic dependency exists amongst %r" % data
@@ -1136,20 +1165,17 @@ class ProtoFile:
yield '#error Properly detecting missing required fields in %s requires \\\n' % largest_msg.name
yield ' setting PB_MAX_REQUIRED_FIELDS to %d or more.\n' % largest_count
yield '#endif\n'
-
- worst = 0
- worst_field = ''
- checks = []
+
+ max_field = FieldMaxSize()
checks_msgnames = []
for msg in self.messages:
checks_msgnames.append(msg.name)
for field in msg.fields:
- status = field.largest_field_value()
- if isinstance(status, (str, unicode)):
- checks.append(status)
- elif status > worst:
- worst = status
- worst_field = str(field.struct_name) + '.' + str(field.name)
+ max_field.extend(field.largest_field_value())
+
+ worst = max_field.worst
+ worst_field = max_field.worst_field
+ checks = max_field.checks
if worst > 255 or checks:
yield '\n/* Check that field information fits in pb_field_t */\n'
@@ -1237,7 +1263,7 @@ def read_options_file(infile):
try:
text_format.Merge(parts[1], opts)
- except Exception, e:
+ except Exception as e:
sys.stderr.write("%s:%d: " % (infile.name, i + 1) +
"Unparseable option line: '%s'. " % line +
"Error: %s\n" % str(e))
@@ -1439,14 +1465,15 @@ def main_cli():
def main_plugin():
'''Main function when invoked as a protoc plugin.'''
- import sys
+ import io, sys
if sys.platform == "win32":
import os, msvcrt
# Set stdin and stdout to binary mode
msvcrt.setmode(sys.stdin.fileno(), os.O_BINARY)
msvcrt.setmode(sys.stdout.fileno(), os.O_BINARY)
- data = sys.stdin.read()
+ data = io.open(sys.stdin.fileno(), "rb").read()
+
request = plugin_pb2.CodeGeneratorRequest.FromString(data)
try:
@@ -1489,7 +1516,7 @@ def main_plugin():
f.name = results['sourcename']
f.content = results['sourcedata']
- sys.stdout.write(response.SerializeToString())
+ io.open(sys.stdout.fileno(), "wb").write(response.SerializeToString())
if __name__ == '__main__':
# Check if we are running as a plugin under protoc