diff options
-rw-r--r-- | pb_decode.c | 39 | ||||
-rw-r--r-- | tests/mem_release/mem_release.c | 117 | ||||
-rw-r--r-- | tests/mem_release/mem_release.proto | 10 |
3 files changed, 144 insertions, 22 deletions
diff --git a/pb_decode.c b/pb_decode.c index 0677594f..9d25dc61 100644 --- a/pb_decode.c +++ b/pb_decode.c @@ -48,6 +48,7 @@ static bool checkreturn pb_skip_string(pb_istream_t *stream); #ifdef PB_ENABLE_MALLOC static bool checkreturn allocate_field(pb_istream_t *stream, void *pData, size_t data_size, size_t array_size); +static bool checkreturn pb_release_union_field(pb_istream_t *stream, pb_field_iter_t *iter); static void pb_release_single_field(const pb_field_iter_t *iter); #endif @@ -628,6 +629,16 @@ static bool checkreturn decode_callback_field(pb_istream_t *stream, pb_wire_type static bool checkreturn decode_field(pb_istream_t *stream, pb_wire_type_t wire_type, pb_field_iter_t *iter) { +#ifdef PB_ENABLE_MALLOC + /* When decoding an oneof field, check if there is old data that must be + * released first. */ + if (PB_HTYPE(iter->pos->type) == PB_HTYPE_ONEOF) + { + if (!pb_release_union_field(stream, iter)) + return false; + } +#endif + switch (PB_ATYPE(iter->pos->type)) { case PB_ATYPE_STATIC: @@ -936,6 +947,34 @@ bool pb_decode_delimited(pb_istream_t *stream, const pb_field_t fields[], void * } #ifdef PB_ENABLE_MALLOC +/* Given an oneof field, if there has already been a field inside this oneof, + * release it before overwriting with a different one. */ +static bool pb_release_union_field(pb_istream_t *stream, pb_field_iter_t *iter) +{ + pb_size_t old_tag = *(pb_size_t*)iter->pSize; /* Previous which_ value */ + pb_size_t new_tag = iter->pos->tag; /* New which_ value */ + + if (old_tag == 0) + return true; /* Ok, no old data in union */ + + if (old_tag == new_tag) + return true; /* Ok, old data is of same type => merge */ + + /* Release old data. The find can fail if the message struct contains + * invalid data. */ + if (!pb_field_iter_find(iter, old_tag)) + PB_RETURN_ERROR(stream, "invalid union tag"); + + pb_release_single_field(iter); + + /* Restore iterator to where it should be. + * This shouldn't fail unless the pb_field_t structure is corrupted. */ + if (!pb_field_iter_find(iter, new_tag)) + PB_RETURN_ERROR(stream, "iterator error"); + + return true; +} + static void pb_release_single_field(const pb_field_iter_t *iter) { pb_type_t type; diff --git a/tests/mem_release/mem_release.c b/tests/mem_release/mem_release.c index 796df13e..40fdc9e4 100644 --- a/tests/mem_release/mem_release.c +++ b/tests/mem_release/mem_release.c @@ -14,8 +14,31 @@ static char *test_str_arr[] = {"1", "2", ""}; static SubMessage test_msg_arr[] = {SubMessage_init_zero, SubMessage_init_zero}; +static pb_extension_t ext1, ext2; -static bool do_test() +static void fill_TestMessage(TestMessage *msg) +{ + msg->static_req_submsg.dynamic_str = "12345"; + msg->static_req_submsg.dynamic_str_arr_count = 3; + msg->static_req_submsg.dynamic_str_arr = test_str_arr; + msg->static_req_submsg.dynamic_submsg_count = 2; + msg->static_req_submsg.dynamic_submsg = test_msg_arr; + msg->static_req_submsg.dynamic_submsg[1].dynamic_str = "abc"; + msg->static_opt_submsg.dynamic_str = "abc"; + msg->has_static_opt_submsg = true; + msg->dynamic_submsg = &msg->static_req_submsg; + + msg->extensions = &ext1; + ext1.type = &dynamic_ext; + ext1.dest = &msg->static_req_submsg; + ext1.next = &ext2; + ext2.type = &static_ext; + ext2.dest = &msg->static_req_submsg; + ext2.next = NULL; +} + +/* Basic fields, nested submessages, extensions */ +static bool test_TestMessage() { uint8_t buffer[256]; size_t msgsize; @@ -23,25 +46,9 @@ static bool do_test() /* Construct a message with various fields filled in */ { TestMessage msg = TestMessage_init_zero; - pb_extension_t ext1, ext2; pb_ostream_t stream; - msg.static_req_submsg.dynamic_str = "12345"; - msg.static_req_submsg.dynamic_str_arr_count = 3; - msg.static_req_submsg.dynamic_str_arr = test_str_arr; - msg.static_req_submsg.dynamic_submsg_count = 2; - msg.static_req_submsg.dynamic_submsg = test_msg_arr; - msg.static_req_submsg.dynamic_submsg[1].dynamic_str = "abc"; - msg.static_opt_submsg.dynamic_str = "abc"; - msg.has_static_opt_submsg = true; - msg.dynamic_submsg = &msg.static_req_submsg; - - msg.extensions = &ext1; - ext1.type = &dynamic_ext; - ext1.dest = &msg.static_req_submsg; - ext1.next = &ext2; - ext2.type = &static_ext; - ext2.dest = &msg.static_req_submsg; - ext2.next = NULL; + + fill_TestMessage(&msg); stream = pb_ostream_from_buffer(buffer, sizeof(buffer)); if (!pb_encode(&stream, TestMessage_fields, &msg)) @@ -61,8 +68,7 @@ static bool do_test() TestMessage msg = TestMessage_init_zero; pb_istream_t stream; SubMessage ext2_dest; - pb_extension_t ext1, ext2; - + msg.extensions = &ext1; ext1.type = &dynamic_ext; ext1.dest = NULL; @@ -102,9 +108,76 @@ static bool do_test() return true; } +/* Oneofs */ +static bool test_OneofMessage() +{ + uint8_t buffer[256]; + size_t msgsize; + + { + pb_ostream_t stream = pb_ostream_from_buffer(buffer, sizeof(buffer)); + + /* Encode first with TestMessage */ + { + OneofMessage msg = OneofMessage_init_zero; + msg.which_msgs = OneofMessage_msg1_tag; + + fill_TestMessage(&msg.msgs.msg1); + + if (!pb_encode(&stream, OneofMessage_fields, &msg)) + { + fprintf(stderr, "Encode failed: %s\n", PB_GET_ERROR(&stream)); + return false; + } + } + + /* Encode second with SubMessage, invoking 'merge' behaviour */ + { + OneofMessage msg = OneofMessage_init_zero; + msg.which_msgs = OneofMessage_msg2_tag; + + msg.first = 999; + msg.msgs.msg2.dynamic_str = "ABCD"; + msg.last = 888; + + if (!pb_encode(&stream, OneofMessage_fields, &msg)) + { + fprintf(stderr, "Encode failed: %s\n", PB_GET_ERROR(&stream)); + return false; + } + } + msgsize = stream.bytes_written; + } + + { + OneofMessage msg = OneofMessage_init_zero; + pb_istream_t stream = pb_istream_from_buffer(buffer, msgsize); + if (!pb_decode(&stream, OneofMessage_fields, &msg)) + { + fprintf(stderr, "Decode failed: %s\n", PB_GET_ERROR(&stream)); + return false; + } + + TEST(msg.first == 999); + TEST(msg.which_msgs == OneofMessage_msg2_tag); + TEST(msg.msgs.msg2.dynamic_str); + TEST(strcmp(msg.msgs.msg2.dynamic_str, "ABCD") == 0); + TEST(msg.msgs.msg2.dynamic_str_arr == NULL); + TEST(msg.msgs.msg2.dynamic_submsg == NULL); + TEST(msg.last == 888); + + pb_release(OneofMessage_fields, &msg); + TEST(get_alloc_count() == 0); + pb_release(OneofMessage_fields, &msg); + TEST(get_alloc_count() == 0); + } + + return true; +} + int main() { - if (do_test()) + if (test_TestMessage() && test_OneofMessage()) return 0; else return 1; diff --git a/tests/mem_release/mem_release.proto b/tests/mem_release/mem_release.proto index 0db393ae..c3b38c8b 100644 --- a/tests/mem_release/mem_release.proto +++ b/tests/mem_release/mem_release.proto @@ -22,3 +22,13 @@ extend TestMessage optional SubMessage static_ext = 101 [(nanopb).type = FT_STATIC]; } +message OneofMessage +{ + required int32 first = 1; + oneof msgs + { + TestMessage msg1 = 2; + SubMessage msg2 = 3; + } + required int32 last = 4; +} |