diff options
author | Petteri Aimonen <jpa@git.mail.kapsi.fi> | 2014-03-10 18:19:38 +0200 |
---|---|---|
committer | Petteri Aimonen <jpa@git.mail.kapsi.fi> | 2014-03-10 18:19:38 +0200 |
commit | bf61d2337b4107b2c37c35bb41c7b809d8f3feb9 (patch) | |
tree | ad3256b7be02a3a80445aba06c6fd4fd27c83b74 | |
parent | 48ac4613724ef0bfb7b0db61871127575b1d9002 (diff) |
More fixes for dynamic allocation
-rw-r--r-- | pb_decode.c | 106 | ||||
-rw-r--r-- | tests/alltypes_pointer/SConscript | 1 | ||||
-rw-r--r-- | tests/alltypes_pointer/decode_alltypes_pointer.c | 52 |
3 files changed, 93 insertions, 66 deletions
diff --git a/pb_decode.c b/pb_decode.c index 30c36b36..65e22af5 100644 --- a/pb_decode.c +++ b/pb_decode.c @@ -472,29 +472,19 @@ static bool checkreturn decode_static_field(pb_istream_t *stream, pb_wire_type_t #ifdef PB_ENABLE_MALLOC /* Allocate storage for the field and store the pointer at iter->pData. * array_size is the number of entries to reserve in an array. */ -static bool checkreturn allocate_field(pb_istream_t *stream, pb_field_iterator_t *iter, size_t array_size) +static bool checkreturn allocate_field(pb_istream_t *stream, void *pData, size_t data_size, size_t array_size) { - void *ptr = *(void**)iter->pData; - size_t size = array_size * iter->pos->data_size; + void *ptr = *(void**)pData; + size_t size = array_size * data_size; + /* Allocate new or expand previous allocation */ + /* Note: on failure the old pointer will remain in the structure, + * the message must be freed by caller also on error return. */ + ptr = realloc(ptr, size); if (ptr == NULL) - { - /* First allocation */ - ptr = malloc(size); - if (ptr == NULL) - PB_RETURN_ERROR(stream, "malloc failed"); - } - else - { - /* Expand previous allocation */ - /* Note: on failure the old pointer will remain in the structure, - * the message must be freed by caller also on error return. */ - ptr = realloc(ptr, size); - if (ptr == NULL) - PB_RETURN_ERROR(stream, "realloc failed"); - } + PB_RETURN_ERROR(stream, "realloc failed"); - *(void**)iter->pData = ptr; + *(void**)pData = ptr; return true; } #endif @@ -522,7 +512,7 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ } else { - if (!allocate_field(stream, iter, 1)) + if (!allocate_field(stream, iter->pData, iter->pos->data_size, 1)) return false; return func(stream, iter->pos, *(void**)iter->pData); @@ -547,12 +537,11 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ if (*size + 1 > allocated_size) { /* Allocate more storage. This tries to guess the - * number of remaining entries. */ - allocated_size += substream.bytes_left / iter->pos->data_size; - if (*size + 1 > allocated_size) - allocated_size++; /* Division gave zero. */ + * number of remaining entries. Round the division + * upwards. */ + allocated_size += (substream.bytes_left - 1) / iter->pos->data_size + 1; - if (!allocate_field(&substream, iter, allocated_size)) + if (!allocate_field(&substream, iter->pData, iter->pos->data_size, allocated_size)) { status = false; break; @@ -560,7 +549,7 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ } /* Decode the array entry */ - pItem = (uint8_t*)iter->pData + iter->pos->data_size * (*size); + pItem = *(uint8_t**)iter->pData + iter->pos->data_size * (*size); if (!func(&substream, iter->pos, pItem)) { status = false; @@ -576,11 +565,26 @@ static bool checkreturn decode_pointer_field(pb_istream_t *stream, pb_wire_type_ { /* Normal repeated field, i.e. only one item at a time. */ size_t *size = (size_t*)iter->pSize; - void *pItem = (uint8_t*)iter->pData + iter->pos->data_size * (*size); + void *pItem; - if (!allocate_field(stream, iter, *size + 1)) + if (!allocate_field(stream, iter->pData, iter->pos->data_size, *size + 1)) return false; + pItem = *(uint8_t**)iter->pData + iter->pos->data_size * (*size); + + /* Clear the new item in case it contains a pointer, or is a submessage. */ + if (PB_LTYPE(type) == PB_LTYPE_STRING) + { + *(char**)pItem = NULL; + } + else if (PB_LTYPE(type) == PB_LTYPE_BYTES) + { + memset(pItem, 0, iter->pos->data_size); + } + else if (PB_LTYPE(type) == PB_LTYPE_SUBMESSAGE) + { + pb_message_set_to_defaults((const pb_field_t *) iter->pos->ptr, pItem); + } (*size)++; return func(stream, iter->pos, pItem); @@ -1026,42 +1030,62 @@ bool checkreturn pb_dec_fixed64(pb_istream_t *stream, const pb_field_t *field, v bool checkreturn pb_dec_bytes(pb_istream_t *stream, const pb_field_t *field, void *dest) { - pb_bytes_array_t *x = (pb_bytes_array_t*)dest; + uint32_t size; + size_t alloc_size; - uint32_t temp; - if (!pb_decode_varint32(stream, &temp)) + if (!pb_decode_varint32(stream, &size)) return false; - x->size = temp; - /* Check length, noting the space taken by the size_t header. */ - if (x->size > field->data_size - offsetof(pb_bytes_array_t, bytes)) - PB_RETURN_ERROR(stream, "bytes overflow"); + /* Space for the size_t header */ + alloc_size = size + offsetof(pb_bytes_array_t, bytes); - return pb_read(stream, x->bytes, x->size); + if (PB_ATYPE(field->type) == PB_ATYPE_POINTER) + { +#ifndef PB_ENABLE_MALLOC + PB_RETURN_ERROR(stream, "no malloc support"); +#else + pb_bytes_ptr_t *bdest = (pb_bytes_ptr_t*)dest; + if (!allocate_field(stream, &bdest->bytes, alloc_size, 1)) + return false; + + bdest->size = size; + return pb_read(stream, bdest->bytes, size); +#endif + } + else + { + pb_bytes_array_t* bdest = (pb_bytes_array_t*)dest; + if (alloc_size > field->data_size) + PB_RETURN_ERROR(stream, "bytes overflow"); + bdest->size = size; + return pb_read(stream, bdest->bytes, size); + } } bool checkreturn pb_dec_string(pb_istream_t *stream, const pb_field_t *field, void *dest) { uint32_t size; + size_t alloc_size; bool status; if (!pb_decode_varint32(stream, &size)) return false; - /* Check length, noting the null terminator */ + /* Space for null terminator */ + alloc_size = size + 1; + if (PB_ATYPE(field->type) == PB_ATYPE_POINTER) { #ifndef PB_ENABLE_MALLOC PB_RETURN_ERROR(stream, "no malloc support"); #else - *(void**)dest = realloc(*(void**)dest, size + 1); - if (*(void**)dest == NULL) - PB_RETURN_ERROR(stream, "out of memory"); + if (!allocate_field(stream, dest, alloc_size, 1)) + return false; dest = *(void**)dest; #endif } else { - if (size + 1 > field->data_size) + if (alloc_size > field->data_size) PB_RETURN_ERROR(stream, "string overflow"); } diff --git a/tests/alltypes_pointer/SConscript b/tests/alltypes_pointer/SConscript index f0103baa..05b4e52d 100644 --- a/tests/alltypes_pointer/SConscript +++ b/tests/alltypes_pointer/SConscript @@ -23,7 +23,6 @@ env.RunTest(enc) env.RunTest("decode_alltypes.output", [dec, "encode_alltypes_pointer.output"]) env.RunTest("decode_alltypes_ref.output", [refdec, "encode_alltypes_pointer.output"]) env.Compare(["encode_alltypes_pointer.output", "$BUILD/alltypes/encode_alltypes.output"]) -env.Compare(["encode_alltypes_pointer_ref.output", "$BUILD/alltypes/encode_alltypes.output"]) # Do the same thing with the optional fields present env.RunTest("optionals.output", enc, ARGS = ['1']) diff --git a/tests/alltypes_pointer/decode_alltypes_pointer.c b/tests/alltypes_pointer/decode_alltypes_pointer.c index 32e34c58..3db48114 100644 --- a/tests/alltypes_pointer/decode_alltypes_pointer.c +++ b/tests/alltypes_pointer/decode_alltypes_pointer.c @@ -7,13 +7,14 @@ #define TEST(x) if (!(x)) { \ printf("Test " #x " failed.\n"); \ - return false; \ + status = false; \ } /* This function is called once from main(), it handles the decoding and checks the fields. */ bool check_alltypes(pb_istream_t *stream, int mode) { + bool status = true; AllTypes alltypes; /* Fill with garbage to better detect initialization errors */ @@ -22,28 +23,31 @@ bool check_alltypes(pb_istream_t *stream, int mode) if (!pb_decode(stream, AllTypes_fields, &alltypes)) return false; - TEST(*alltypes.req_int32 == -1001); - TEST(*alltypes.req_int64 == -1002); - TEST(*alltypes.req_uint32 == 1003); - TEST(*alltypes.req_uint64 == 1004); - TEST(*alltypes.req_sint32 == -1005); - TEST(*alltypes.req_sint64 == -1006); - TEST(*alltypes.req_bool == true); - - TEST(*alltypes.req_fixed32 == 1008); - TEST(*alltypes.req_sfixed32 == -1009); - TEST(*alltypes.req_float == 1010.0f); - - TEST(*alltypes.req_fixed64 == 1011); - TEST(*alltypes.req_sfixed64 == -1012); - TEST(*alltypes.req_double == 1013.0f); - - TEST(strcmp(alltypes.req_string, "1014") == 0); - TEST(alltypes.req_bytes->size == 4); - TEST(memcmp(alltypes.req_bytes->bytes, "1015", 4) == 0); - TEST(strcmp(alltypes.req_submsg->substuff1, "1016") == 0); - TEST(*alltypes.req_submsg->substuff2 == 1016); - TEST(*alltypes.req_submsg->substuff3 == 3); + TEST(alltypes.req_int32 && *alltypes.req_int32 == -1001); + TEST(alltypes.req_int64 && *alltypes.req_int64 == -1002); + TEST(alltypes.req_uint32 && *alltypes.req_uint32 == 1003); + TEST(alltypes.req_uint64 && *alltypes.req_uint64 == 1004); + TEST(alltypes.req_sint32 && *alltypes.req_sint32 == -1005); + TEST(alltypes.req_sint64 && *alltypes.req_sint64 == -1006); + TEST(alltypes.req_bool && *alltypes.req_bool == true); + + TEST(alltypes.req_fixed32 && *alltypes.req_fixed32 == 1008); + TEST(alltypes.req_sfixed32 && *alltypes.req_sfixed32 == -1009); + TEST(alltypes.req_float && *alltypes.req_float == 1010.0f); + + TEST(alltypes.req_fixed64 && *alltypes.req_fixed64 == 1011); + TEST(alltypes.req_sfixed64 && *alltypes.req_sfixed64 == -1012); + TEST(alltypes.req_double && *alltypes.req_double == 1013.0f); + + TEST(alltypes.req_string && strcmp(alltypes.req_string, "1014") == 0); + TEST(alltypes.req_bytes && alltypes.req_bytes->size == 4); + TEST(alltypes.req_bytes && alltypes.req_bytes->bytes + && memcmp(alltypes.req_bytes->bytes, "1015", 4) == 0); + TEST(alltypes.req_submsg && alltypes.req_submsg->substuff1 + && strcmp(alltypes.req_submsg->substuff1, "1016") == 0); + TEST(alltypes.req_submsg && alltypes.req_submsg->substuff2 + && *alltypes.req_submsg->substuff2 == 1016); + /* TEST(*alltypes.req_submsg->substuff3 == 3); Default values are not currently supported for pointer fields */ TEST(*alltypes.req_enum == MyEnum_Truth); #if 0 @@ -180,7 +184,7 @@ bool check_alltypes(pb_istream_t *stream, int mode) TEST(alltypes.end == 1099); #endif - return true; + return status; } int main(int argc, char **argv) |