diff --git a/src/brei-shared.c b/src/brei-shared.c index 8ddae00..5751786 100644 --- a/src/brei-shared.c +++ b/src/brei-shared.c @@ -49,12 +49,6 @@ struct brei_header { } _packed_; static_assert(sizeof(struct brei_header) == 16, "Unexpected size for brei_header struct"); -struct brei_string { - uint32_t len; - const char str[]; -}; -static_assert(sizeof(struct brei_string) == 4, "Unexpected size for brei_string struct"); - /** * For a given string length (including null byte) return * the number of bytes needed on the protocol, including the @@ -63,12 +57,13 @@ static_assert(sizeof(struct brei_string) == 4, "Unexpected size for brei_string static inline uint32_t brei_string_proto_length(uint32_t slen) { - uint32_t length = sizeof(struct brei_string) + slen; + uint32_t length = 4 + slen; uint32_t protolen = (length + 3)/4 * 4; assert(protolen % 4 == 0); return protolen; } + static void brei_context_destroy(struct brei_context *ctx) { @@ -183,7 +178,7 @@ brei_log_msg(struct brei_context *brei, static struct brei_result * brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signature, - size_t *nargs_out, union brei_arg **args_out) + size_t *nargs_out, union brei_arg **args_out, char ***strings_out) { size_t nargs = strlen(signature); if (nargs > 256) { @@ -193,11 +188,15 @@ brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signatu /* This over-allocates if we have more than one char per type but meh */ _cleanup_free_ union brei_arg *args = xalloc(nargs * sizeof(*args)); + /* This over-allocates since not all args are strings but meh. + Needs to be NULL-terminated for strv_freep to work */ + _cleanup_(strv_freep) char **strings = xalloc((nargs + 1) * sizeof(*strings)); const char *s = signature; union brei_arg *arg = args; uint32_t *p = (uint32_t*)iobuf_data(buf); uint32_t *end = (uint32_t*)iobuf_data_end(buf); + size_t nstrings = 0; nargs = 0; while (*s) { @@ -223,24 +222,29 @@ brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signatu arg->h = iobuf_take_fd(buf); break; case 's': { - struct brei_string *s = (struct brei_string *)p; - + uint32_t slen = *p; uint32_t remaining = end - p; - uint32_t protolen = brei_string_proto_length(s->len); /* in bytes */ + uint32_t protolen = brei_string_proto_length(slen); /* in bytes */ uint32_t len32 = protolen/4; /* p and end are uint32_t* */ if (remaining < len32) { return brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, - "Invalid string length %u, only %u bytes remaining", s->len, remaining * 4); + "Invalid string length %u, only %u bytes remaining", slen, remaining * 4); } - if (s->len == 0) { + + if (slen == 0) { arg->s = NULL; - } else if (s->str[s->len - 1] != '\0') { - return brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, - "Message string not zero-terminated"); } else { - arg->s = s->str; + _cleanup_free_ char *str = xalloc(slen); + memcpy(str, p + 1, slen); + if (str[slen - 1] != '\0') { + return brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "Message string not zero-terminated"); + } + strings[nstrings] = steal(&str); + arg->s = strings[nstrings]; + nstrings++; } p += len32; break; @@ -255,6 +259,7 @@ brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signatu } *args_out = steal(&args); + *strings_out = steal(&strings); *nargs_out = nargs; return NULL; @@ -412,10 +417,11 @@ brei_dispatch(struct brei_context *brei, iobuf_pop(buf, headersize); /* Demarshal the protocol into a set of arguments */ - _cleanup_free_ union brei_arg * args = NULL; + _cleanup_free_ union brei_arg *args = NULL; + _cleanup_(strv_freep) char **strings = NULL; const char *signature = interface->incoming[opcode].signature; size_t nargs = 0; - result = brei_demarshal(brei, buf, signature, &nargs, &args); + result = brei_demarshal(brei, buf, signature, &nargs, &args, &strings); if (result) goto error; @@ -491,8 +497,9 @@ MUNIT_TEST(test_brei_marshal) } _cleanup_free_ union brei_arg *args = NULL; + _cleanup_(strv_freep) char **strings = NULL; size_t nargs = 0; - _unref_(brei_result) *result = brei_demarshal(brei, buf, "noiusf", &nargs, &args); + _unref_(brei_result) *result = brei_demarshal(brei, buf, "noiusf", &nargs, &args, &strings); munit_assert_ptr_null(result); munit_assert_int(nargs, ==, 6); @@ -503,6 +510,10 @@ MUNIT_TEST(test_brei_marshal) munit_assert_string_equal(args[4].s, str); munit_assert_double_equal(args[5].f, 1.45, 3 /* precision */); + /* make sure strings is filled in as expected and null-terminated */ + munit_assert_ptr_equal(args[4].s, strings[0]); + munit_assert_ptr_null(strings[1]); + return MUNIT_OK; } @@ -521,8 +532,9 @@ MUNIT_TEST(test_brei_marshal_bad_sig) } _cleanup_free_ union brei_arg *args = NULL; + _cleanup_(strv_freep) char **strings = NULL; size_t nargs = 789; - _unref_(brei_result) *result = brei_demarshal(brei, buf, "nxoiusf", &nargs, &args); + _unref_(brei_result) *result = brei_demarshal(brei, buf, "nxoiusf", &nargs, &args, &strings); munit_assert_ptr_not_null(result); munit_assert_int(brei_result_get_reason(result), ==, BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL); @@ -672,10 +684,13 @@ MUNIT_TEST(test_brei_send_message) munit_assert_int(buf[4], ==, -42); - const struct brei_string *s = (const struct brei_string *)&buf[5]; - munit_assert_int(s->len, ==, strlen0(string)); - munit_assert_string_equal(s->str, string); - munit_assert_int(memcmp(s->str, string, brei_string_proto_length(s->len) - 4), ==, 0); + uint32_t slen = buf[5]; + munit_assert_int(slen, ==, strlen0(string)); + char protostring[sizeof(string)] = {0}; + assert(brei_string_proto_length(slen) - 4 == sizeof(protostring)); + memcpy(protostring, &buf[6], brei_string_proto_length(slen) - 4); + munit_assert_string_equal(protostring, string); + munit_assert_int(memcmp(protostring, string, brei_string_proto_length(slen) - 4), ==, 0); munit_assert_int(buf[6 + string_len/4], ==, 0xab); munit_assert_int(buf[8 + string_len/4], ==, 0xcdef); @@ -703,15 +718,21 @@ MUNIT_TEST(test_brei_send_message) munit_assert_int(header->msglen, ==, msglen); munit_assert_int(header->opcode, ==, opcode); - const struct brei_string *s1 = (const struct brei_string*)&buf[4]; - munit_assert_int(s1->len, ==, strlen0(string1)); - munit_assert_string_equal(s1->str, string1); - munit_assert_int(memcmp(s1->str, string1, brei_string_proto_length(s1->len) - 4), ==, 0); + uint32_t s1len = buf[4]; + munit_assert_int(s1len, ==, strlen0(string1)); + char protostring1[sizeof(string1)] = {0}; + assert(brei_string_proto_length(s1len) - 4 == sizeof(protostring1)); + memcpy(protostring1, &buf[5], brei_string_proto_length(s1len) - 4); + munit_assert_string_equal(protostring1, string1); + munit_assert_int(memcmp(protostring1, string1, brei_string_proto_length(s1len) - 4), ==, 0); - const struct brei_string *s2 = (const struct brei_string *)&buf[8]; - munit_assert_int(s2->len, ==, strlen0(string2)); - munit_assert_string_equal(s2->str, string2); - munit_assert_int(memcmp(s2->str, string2, brei_string_proto_length(s2->len) - 4), ==, 0); + uint32_t s2len = buf[8]; + munit_assert_int(s2len, ==, strlen0(string2)); + char protostring2[sizeof(string2)] = {0}; + assert(brei_string_proto_length(s2len) - 4 == sizeof(protostring2)); + memcpy(protostring2, &buf[9], brei_string_proto_length(s2len) - 4); + munit_assert_string_equal(protostring2, string2); + munit_assert_int(memcmp(protostring2, string2, brei_string_proto_length(s2len) - 4), ==, 0); } {