From 7713d43bc3d448358a04393c4e44dd12a768bdea Mon Sep 17 00:00:00 2001 From: Petteri Aimonen Date: Sun, 4 Jan 2015 19:39:37 +0200 Subject: Implement support for oneofs (C unions). Basic test included, should probably add an oneof to the AllTypes test also. Update issue 131 Status: Started --- generator/nanopb_generator.py | 264 ++++++++++++++++++++++++++++++------------ pb.h | 18 +++ pb_common.c | 7 ++ pb_decode.c | 12 +- pb_encode.c | 11 ++ tests/oneof/SConscript | 22 ++++ tests/oneof/decode_oneof.c | 72 ++++++++++++ tests/oneof/encode_oneof.c | 64 ++++++++++ tests/oneof/oneof.proto | 18 +++ 9 files changed, 415 insertions(+), 73 deletions(-) create mode 100644 tests/oneof/SConscript create mode 100644 tests/oneof/decode_oneof.c create mode 100644 tests/oneof/encode_oneof.c create mode 100644 tests/oneof/oneof.proto diff --git a/generator/nanopb_generator.py b/generator/nanopb_generator.py index 85cb413b..60944038 100755 --- a/generator/nanopb_generator.py +++ b/generator/nanopb_generator.py @@ -171,6 +171,7 @@ class Field: '''desc is FieldDescriptorProto''' self.tag = desc.number self.struct_name = struct_name + self.union_name = None self.name = desc.name self.default = None self.max_size = None @@ -300,57 +301,91 @@ class Field: if self.pbtype == 'BYTES' and self.allocation == 'STATIC': result = 'typedef PB_BYTES_ARRAY_T(%d) %s;\n' % (self.max_size, self.ctype) else: - result = None + result = '' return result - def get_initializer(self, null_init): - '''Return literal expression for this field's default value.''' - + def get_dependencies(self): + '''Get list of type names used by this field.''' + if self.allocation == 'STATIC': + return [str(self.ctype)] + else: + return [] + + def get_initializer(self, null_init, inner_init_only = False): + '''Return literal expression for this field's default value. + null_init: If True, initialize to a 0 value instead of default from .proto + inner_init_only: If True, exclude initialization for any count/has fields + ''' + + inner_init = None if self.pbtype == 'MESSAGE': if null_init: - return '%s_init_zero' % self.ctype + inner_init = '%s_init_zero' % self.ctype else: - return '%s_init_default' % self.ctype - - if self.default is None or null_init: + inner_init = '%s_init_default' % self.ctype + elif self.default is None or null_init: if self.pbtype == 'STRING': - return '""' + inner_init = '""' elif self.pbtype == 'BYTES': - return '{0, {0}}' + inner_init = '{0, {0}}' elif self.pbtype == 'ENUM': - return '(%s)0' % self.ctype + inner_init = '(%s)0' % self.ctype else: - return '0' - - default = str(self.default) - - if self.pbtype == 'STRING': - default = default.encode('utf-8').encode('string_escape') - default = default.replace('"', '\\"') - default = '"' + default + '"' - elif self.pbtype == 'BYTES': - data = default.decode('string_escape') - data = ['0x%02x' % ord(c) for c in data] - if len(data) == 0: - default = '{0, {0}}' + inner_init = '0' + else: + if self.pbtype == 'STRING': + inner_init = self.default.encode('utf-8').encode('string_escape') + inner_init = inner_init.replace('"', '\\"') + inner_init = '"' + inner_init + '"' + elif self.pbtype == 'BYTES': + data = str(self.default).decode('string_escape') + data = ['0x%02x' % ord(c) for c in data] + if len(data) == 0: + inner_init = '{0, {0}}' + else: + inner_init = '{%d, {%s}}' % (len(data), ','.join(data)) + elif self.pbtype in ['FIXED32', 'UINT32']: + inner_init = str(self.default) + 'u' + elif self.pbtype in ['FIXED64', 'UINT64']: + inner_init = str(self.default) + 'ull' + elif self.pbtype in ['SFIXED64', 'INT64']: + inner_init = str(self.default) + 'll' else: - default = '{%d, {%s}}' % (len(data), ','.join(data)) - elif self.pbtype in ['FIXED32', 'UINT32']: - default += 'u' - elif self.pbtype in ['FIXED64', 'UINT64']: - default += 'ull' - elif self.pbtype in ['SFIXED64', 'INT64']: - default += 'll' + inner_init = str(self.default) - return default - + if inner_init_only: + return inner_init + + outer_init = None + if self.allocation == 'STATIC': + if self.rules == 'REPEATED': + outer_init = '0, {' + outer_init += ', '.join([inner_init] * self.max_count) + outer_init += '}' + elif self.rules == 'OPTIONAL': + outer_init = 'false, ' + inner_init + else: + outer_init = inner_init + elif self.allocation == 'POINTER': + if self.rules == 'REPEATED': + outer_init = '0, NULL' + else: + outer_init = 'NULL' + elif self.allocation == 'CALLBACK': + if self.pbtype == 'EXTENSION': + outer_init = 'NULL' + else: + outer_init = '{{NULL}, NULL}' + + return outer_init + def default_decl(self, declaration_only = False): '''Return definition for this field's default value.''' if self.default is None: return None ctype = self.ctype - default = self.get_initializer(False) + default = self.get_initializer(False, True) array_decl = '' if self.pbtype == 'STRING': @@ -375,7 +410,13 @@ class Field: '''Return the pb_field_t initializer to use in the constant array. prev_field_name is the name of the previous field or None. ''' - result = ' PB_FIELD(%3d, ' % self.tag + + if self.rules == 'ONEOF': + result = ' PB_ONEOF_FIELD(%s, ' % self.union_name + else: + result = ' PB_FIELD(' + + result += '%3d, ' % self.tag result += '%-8s, ' % self.pbtype result += '%s, ' % self.rules result += '%-8s, ' % self.allocation @@ -403,6 +444,8 @@ class Field: if self.pbtype == 'MESSAGE': if self.rules == 'REPEATED' and self.allocation == 'STATIC': return 'pb_membersize(%s, %s[0])' % (self.struct_name, self.name) + elif self.rules == 'ONEOF': + return 'pb_membersize(%s, %s.%s)' % (self.struct_name, self.union_name, self.name) else: return 'pb_membersize(%s, %s)' % (self.struct_name, self.name) @@ -534,6 +577,71 @@ class ExtensionField(Field): return result +# --------------------------------------------------------------------------- +# Generation of oneofs (unions) +# --------------------------------------------------------------------------- + +class OneOf(Field): + def __init__(self, oneof_desc): + self.name = oneof_desc.name + self.ctype = 'union' + self.fields = [] + + def add_field(self, field): + if field.allocation == 'CALLBACK': + raise Exception("Callback fields inside of oneof are not supported" + + " (field %s)" % field.fullname) + + field.union_name = self.name + field.rules = 'ONEOF' + self.fields.append(field) + self.fields.sort(key = lambda f: f.tag) + + # Sort by the lowest tag number inside union + self.tag = min([f.tag for f in self.fields]) + + def __cmp__(self, other): + return cmp(self.tag, other.tag) + + def __str__(self): + result = '' + if self.fields: + result += ' pb_size_t which_' + self.name + ";\n" + result += ' union {\n' + for f in self.fields: + result += ' ' + str(f).replace('\n', '\n ') + '\n' + result += ' } ' + self.name + ';' + return result + + def types(self): + return ''.join([f.types() for f in self.fields]) + + def get_dependencies(self): + deps = [] + for f in self.fields: + deps += f.get_dependencies() + return deps + + def get_initializer(self, null_init): + return '0, {' + self.fields[0].get_initializer(null_init) + '}' + + def default_decl(self, declaration_only = False): + return None + + def tags(self): + return '\n'.join([f.tags() for f in self.fields]) + + def pb_field_t(self, prev_field_name): + prev_field_name = prev_field_name or self.name + result = ',\n'.join([f.pb_field_t(prev_field_name) for f in self.fields]) + return result + + def largest_field_value(self): + return max([f.largest_field_value() for f in self.fields]) + + def encoded_size(self, allmsgs): + return max([f.encoded_size(allmsgs) for f in self.fields]) + # --------------------------------------------------------------------------- # Generation of messages (structures) # --------------------------------------------------------------------------- @@ -543,11 +651,24 @@ class Message: def __init__(self, names, desc, message_options): self.name = names self.fields = [] - + self.oneofs = [] + + if hasattr(desc, 'oneof_decl'): + for f in desc.oneof_decl: + oneof = OneOf(f) + self.oneofs.append(oneof) + self.fields.append(oneof) + for f in desc.field: field_options = get_nanopb_suboptions(f, message_options, self.name + f.name) - if field_options.type != nanopb_pb2.FT_IGNORE: - self.fields.append(Field(self.name, f, field_options)) + if field_options.type == nanopb_pb2.FT_IGNORE: + continue + + field = Field(self.name, f, field_options) + if hasattr(f, 'oneof_index') and f.HasField('oneof_index'): + self.oneofs[f.oneof_index].add_field(field) + else: + self.fields.append(field) if len(desc.extension_range) > 0: field_options = get_nanopb_suboptions(desc, message_options, self.name + 'extensions') @@ -561,7 +682,10 @@ class Message: def get_dependencies(self): '''Get list of type names that this structure refers to.''' - return [str(field.ctype) for field in self.fields if field.allocation == 'STATIC'] + deps = [] + for f in self.fields: + deps += f.get_dependencies() + return deps def __str__(self): result = 'typedef struct _%s {\n' % self.name @@ -586,39 +710,15 @@ class Message: return result def types(self): - result = "" - for field in self.fields: - types = field.types() - if types is not None: - result += types + '\n' - return result - + return ''.join([f.types() for f in self.fields]) + def get_initializer(self, null_init): if not self.ordered_fields: return '{0}' parts = [] for field in self.ordered_fields: - if field.allocation == 'STATIC': - if field.rules == 'REPEATED': - parts.append('0') - parts.append('{' - + ', '.join([field.get_initializer(null_init)] * field.max_count) - + '}') - elif field.rules == 'OPTIONAL': - parts.append('false') - parts.append(field.get_initializer(null_init)) - else: - parts.append(field.get_initializer(null_init)) - elif field.allocation == 'POINTER': - if field.rules == 'REPEATED': - parts.append('0') - parts.append('NULL') - elif field.allocation == 'CALLBACK': - if field.pbtype == 'EXTENSION': - parts.append('NULL') - else: - parts.append('{{NULL}, NULL}') + parts.append(field.get_initializer(null_init)) return '{' + ', '.join(parts) + '}' def default_decl(self, declaration_only = False): @@ -629,18 +729,39 @@ class Message: result += default + '\n' return result + def count_required_fields(self): + '''Returns number of required fields inside this message''' + count = 0 + for f in self.fields: + if f not in self.oneofs: + if f.rules == 'REQUIRED': + count += 1 + return count + + def count_all_fields(self): + count = 0 + for f in self.fields: + if f in self.oneofs: + count += len(f.fields) + else: + count += 1 + return count + def fields_declaration(self): - result = 'extern const pb_field_t %s_fields[%d];' % (self.name, len(self.fields) + 1) + result = 'extern const pb_field_t %s_fields[%d];' % (self.name, self.count_all_fields() + 1) return result def fields_definition(self): - result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, len(self.fields) + 1) + result = 'const pb_field_t %s_fields[%d] = {\n' % (self.name, self.count_all_fields() + 1) prev = None for field in self.ordered_fields: result += field.pb_field_t(prev) result += ',\n' - prev = field.name + if isinstance(field, OneOf): + prev = field.name + '.' + field.fields[-1].name + else: + prev = field.name result += ' PB_LAST_FIELD\n};' return result @@ -894,9 +1015,8 @@ def generate_source(headername, enums, messages, extensions, options): # Add checks for numeric limits if messages: - count_required_fields = lambda m: len([f for f in msg.fields if f.rules == 'REQUIRED']) - largest_msg = max(messages, key = count_required_fields) - largest_count = count_required_fields(largest_msg) + largest_msg = max(messages, key = lambda m: m.count_required_fields()) + largest_count = largest_msg.count_required_fields() if largest_count > 64: yield '\n/* Check that missing required fields will be properly detected */\n' yield '#if PB_MAX_REQUIRED_FIELDS < %d\n' % largest_count diff --git a/pb.h b/pb.h index 5a3340aa..af58dca3 100644 --- a/pb.h +++ b/pb.h @@ -183,6 +183,7 @@ typedef uint8_t pb_type_t; #define PB_HTYPE_REQUIRED 0x00 #define PB_HTYPE_OPTIONAL 0x10 #define PB_HTYPE_REPEATED 0x20 +#define PB_HTYPE_ONEOF 0x30 #define PB_HTYPE_MASK 0x30 /**** Field allocation types ****/ @@ -502,6 +503,23 @@ struct pb_extension_s { PB_DATAOFFSET_ ## placement(message, field, prevfield), \ PB_LTYPE_MAP_ ## type, ptr) +/* Field description for oneof fields. This requires taking into account the + * union name also, that's why a separate set of macros is needed. + */ +#define PB_ONEOF_STATIC(u, tag, st, m, fd, ltype, ptr) \ + {tag, PB_ATYPE_STATIC | PB_HTYPE_ONEOF | ltype, \ + fd, pb_delta(st, which_ ## u, u.m), \ + pb_membersize(st, u.m), 0, ptr} + +#define PB_ONEOF_POINTER(u, tag, st, m, fd, ltype, ptr) \ + {tag, PB_ATYPE_POINTER | PB_HTYPE_ONEOF | ltype, \ + fd, pb_delta(st, which_ ## u, u.m), \ + pb_membersize(st, u.m), 0, ptr} + +#define PB_ONEOF_FIELD(union_name, tag, type, rules, allocation, placement, message, field, prevfield, ptr) \ + PB_ ## rules ## _ ## allocation(union_name, tag, message, field, \ + PB_DATAOFFSET_ ## placement(message, union_name.field, prevfield), \ + PB_LTYPE_MAP_ ## type, ptr) /* These macros are used for giving out error messages. * They are mostly a debugging aid; the main error information diff --git a/pb_common.c b/pb_common.c index a9cade63..98964850 100644 --- a/pb_common.c +++ b/pb_common.c @@ -54,6 +54,13 @@ bool pb_field_iter_next(pb_field_iter_t *iter) * The data_size only applies to the dynamically allocated area. */ prev_size = sizeof(void*); } + else if (PB_HTYPE(prev_field->type) == PB_HTYPE_ONEOF && + PB_HTYPE(iter->pos->type) == PB_HTYPE_ONEOF) + { + /* Don't advance pointers inside unions */ + prev_size = 0; + iter->pData = (char*)iter->pData - prev_field->data_offset; + } if (PB_HTYPE(prev_field->type) == PB_HTYPE_REQUIRED) { diff --git a/pb_decode.c b/pb_decode.c index 5982c8e5..542fdc4c 100644 --- a/pb_decode.c +++ b/pb_decode.c @@ -393,6 +393,10 @@ static bool checkreturn decode_static_field(pb_istream_t *stream, pb_wire_type_t return func(stream, iter->pos, pItem); } + case PB_HTYPE_ONEOF: + *(pb_size_t*)iter->pSize = iter->pos->tag; + return func(stream, iter->pos, iter->pData); + default: PB_RETURN_ERROR(stream, "invalid field type"); } @@ -470,6 +474,7 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ { case PB_HTYPE_REQUIRED: case PB_HTYPE_OPTIONAL: + case PB_HTYPE_ONEOF: if (PB_LTYPE(type) == PB_LTYPE_SUBMESSAGE && *(void**)iter->pData != NULL) { @@ -477,6 +482,11 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ pb_release_single_field(iter); } + if (PB_HTYPE(type) == PB_HTYPE_ONEOF) + { + *(pb_size_t*)iter->pSize = iter->pos->tag; + } + if (PB_LTYPE(type) == PB_LTYPE_STRING || PB_LTYPE(type) == PB_LTYPE_BYTES) { @@ -562,7 +572,7 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ initialize_pointer_field(pItem, iter); return func(stream, iter->pos, pItem); } - + default: PB_RETURN_ERROR(stream, "invalid field type"); } diff --git a/pb_encode.c b/pb_encode.c index cef98861..cc372b8f 100644 --- a/pb_encode.c +++ b/pb_encode.c @@ -250,6 +250,17 @@ static bool checkreturn encode_basic_field(pb_ostream_t *stream, return false; break; + case PB_HTYPE_ONEOF: + if (*(const pb_size_t*)pSize == field->tag) + { + if (!pb_encode_tag_for_field(stream, field)) + return false; + + if (!func(stream, field, pData)) + return false; + } + break; + default: PB_RETURN_ERROR(stream, "invalid field type"); } diff --git a/tests/oneof/SConscript b/tests/oneof/SConscript new file mode 100644 index 00000000..19845278 --- /dev/null +++ b/tests/oneof/SConscript @@ -0,0 +1,22 @@ +# Test the 'oneof' feature for generating C unions. + +Import('env') + +env.NanopbProto('oneof') + +enc = env.Program(['encode_oneof.c', + 'oneof.pb.c', + '$COMMON/pb_encode.o', + '$COMMON/pb_common.o']) + +dec = env.Program(['decode_oneof.c', + 'oneof.pb.c', + '$COMMON/pb_decode.o', + '$COMMON/pb_common.o']) + +env.RunTest("message1.pb", enc, ARGS = ['1']) +env.RunTest("message1.txt", [dec, 'message1.pb'], ARGS = ['1']) +env.RunTest("message2.pb", enc, ARGS = ['2']) +env.RunTest("message2.txt", [dec, 'message2.pb'], ARGS = ['2']) +env.RunTest("message3.pb", enc, ARGS = ['3']) +env.RunTest("message3.txt", [dec, 'message3.pb'], ARGS = ['3']) diff --git a/tests/oneof/decode_oneof.c b/tests/oneof/decode_oneof.c new file mode 100644 index 00000000..e94becc7 --- /dev/null +++ b/tests/oneof/decode_oneof.c @@ -0,0 +1,72 @@ +/* Decode a message using oneof fields */ + +#include +#include +#include +#include "oneof.pb.h" +#include "test_helpers.h" +#include "unittests.h" + +int main(int argc, char **argv) +{ + uint8_t buffer[OneOfMessage_size]; + OneOfMessage msg = OneOfMessage_init_zero; + pb_istream_t stream; + size_t count; + int option; + + if (argc != 2) + { + fprintf(stderr, "Usage: encode_oneof [number]\n"); + return 1; + } + option = atoi(argv[1]); + + SET_BINARY_MODE(stdin); + count = fread(buffer, 1, sizeof(buffer), stdin); + + if (!feof(stdin)) + { + printf("Message does not fit in buffer\n"); + return 1; + } + + stream = pb_istream_from_buffer(buffer, count); + + if (!pb_decode(&stream, OneOfMessage_fields, &msg)) + { + printf("Decoding failed: %s\n", PB_GET_ERROR(&stream)); + return 1; + } + + { + int status = 0; + + /* Check that the basic fields work normally */ + TEST(msg.prefix == 123); + TEST(msg.suffix == 321); + + /* Check that we got the right oneof according to command line */ + if (option == 1) + { + TEST(msg.which_values == OneOfMessage_first_tag); + TEST(msg.values.first == 999); + } + else if (option == 2) + { + TEST(msg.which_values == OneOfMessage_second_tag); + TEST(strcmp(msg.values.second, "abcd") == 0); + } + else if (option == 3) + { + TEST(msg.which_values == OneOfMessage_third_tag); + TEST(msg.values.third.array[0] == 1); + TEST(msg.values.third.array[1] == 2); + TEST(msg.values.third.array[2] == 3); + TEST(msg.values.third.array[3] == 4); + TEST(msg.values.third.array[4] == 5); + } + + return status; + } +} \ No newline at end of file diff --git a/tests/oneof/encode_oneof.c b/tests/oneof/encode_oneof.c new file mode 100644 index 00000000..913d2d43 --- /dev/null +++ b/tests/oneof/encode_oneof.c @@ -0,0 +1,64 @@ +/* Encode a message using oneof fields */ + +#include +#include +#include +#include "oneof.pb.h" +#include "test_helpers.h" + +int main(int argc, char **argv) +{ + uint8_t buffer[OneOfMessage_size]; + OneOfMessage msg = OneOfMessage_init_zero; + pb_ostream_t stream; + int option; + + if (argc != 2) + { + fprintf(stderr, "Usage: encode_oneof [number]\n"); + return 1; + } + option = atoi(argv[1]); + + /* Prefix and suffix are used to test that the union does not disturb + * other fields in the same message. */ + msg.prefix = 123; + + /* We encode one of the 'values' fields based on command line argument */ + if (option == 1) + { + msg.which_values = OneOfMessage_first_tag; + msg.values.first = 999; + } + else if (option == 2) + { + msg.which_values = OneOfMessage_second_tag; + strcpy(msg.values.second, "abcd"); + } + else if (option == 3) + { + msg.which_values = OneOfMessage_third_tag; + msg.values.third.array_count = 5; + msg.values.third.array[0] = 1; + msg.values.third.array[1] = 2; + msg.values.third.array[2] = 3; + msg.values.third.array[3] = 4; + msg.values.third.array[4] = 5; + } + + msg.suffix = 321; + + stream = pb_ostream_from_buffer(buffer, sizeof(buffer)); + + if (pb_encode(&stream, OneOfMessage_fields, &msg)) + { + SET_BINARY_MODE(stdout); + fwrite(buffer, 1, stream.bytes_written, stdout); + return 0; + } + else + { + fprintf(stderr, "Encoding failed: %s\n", PB_GET_ERROR(&stream)); + return 1; + } +} diff --git a/tests/oneof/oneof.proto b/tests/oneof/oneof.proto new file mode 100644 index 00000000..a89ef131 --- /dev/null +++ b/tests/oneof/oneof.proto @@ -0,0 +1,18 @@ +import 'nanopb.proto'; + +message SubMessage +{ + repeated int32 array = 1 [(nanopb).max_count = 8]; +} + +message OneOfMessage +{ + required int32 prefix = 1; + oneof values + { + int32 first = 5; + string second = 6 [(nanopb).max_size = 8]; + SubMessage third = 7; + } + required int32 suffix = 99; +} -- cgit 1.2.3-korg