brei: abstract string handling in a struct

Makes the code a bit easier to read and understand than the previous
offset mangling.
This commit is contained in:
Peter Hutterer 2023-02-28 19:55:18 +10:00
parent d686d4e281
commit d4980d8579

View file

@ -41,6 +41,26 @@ struct brei_context {
void *log_context;
};
struct brei_string {
uint32_t len;
const char str[];
};
static_assert(sizeof(struct brei_string) == 4, "Unexpected size for brei_string struct");
/**
* For a given string length (including null byte) return
* the number of bytes needed on the protocol, including the
* 4-byte length field.
*/
static inline uint32_t
brei_string_proto_length(uint32_t slen)
{
uint32_t length = sizeof(struct brei_string) + slen;
uint32_t protolen = (length + 3)/4 * 4;
assert(protolen % 4 == 0);
return protolen;
}
static void
brei_context_destroy(struct brei_context *ctx)
{
@ -153,15 +173,6 @@ brei_log_msg(struct brei_context *brei,
va_end(args);
}
/**
* Return the number of int32s required to store count bytes.
*/
static inline uint32_t
bytes_to_int32(uint32_t count)
{
return (uint32_t)(((uint64_t)count + 3)/4);
}
static struct brei_result *
brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signature,
size_t *nargs_out, union brei_arg **args_out)
@ -204,28 +215,25 @@ brei_demarshal(struct brei_context *brei, struct iobuf *buf, const char *signatu
arg->h = iobuf_take_fd(buf);
break;
case 's': {
size_t slen = *p++; /* string length includes \0 */
if (slen == 0) {
arg->s = NULL;
break;
}
struct brei_string *s = (struct brei_string *)p;
uint32_t slen32 = bytes_to_int32(slen);
if (end - p < slen32) {
uint32_t protolen = brei_string_proto_length(s->len); /* in bytes */
uint32_t len32 = protolen/4; /* p and end are uint32_t* */
if (end - p < len32) {
return brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL,
"Invalid string length %zu, only %li bytes remaining", slen, (end - p) * 4);
"Invalid string length %u, only %li bytes remaining", s->len, (end - p) * 4);
}
const char *str = (char*)p;
/* strings must be null-terminated */
if (slen && str[slen - 1] != '\0') {
if (s->len == 0) {
arg->s = NULL;
} else if (s->str[s->len - 1] != '\0') {
return brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL,
"Message string not zero-terminated");
} else {
arg->s = s->str;
}
arg->s = str;
p += slen32;
p += len32;
break;
}
default:
@ -432,6 +440,19 @@ brei_drain_fd(int fd)
#ifdef _enable_tests_
#include "util-munit.h"
MUNIT_TEST(test_brei_string_proto_length)
{
munit_assert_int(brei_string_proto_length(0), ==, 4);
munit_assert_int(brei_string_proto_length(1), ==, 8);
munit_assert_int(brei_string_proto_length(4), ==, 8);
munit_assert_int(brei_string_proto_length(5), ==, 12);
munit_assert_int(brei_string_proto_length(8), ==, 12);
munit_assert_int(brei_string_proto_length(12), ==, 16);
munit_assert_int(brei_string_proto_length(13), ==, 20);
return MUNIT_OK;
}
static struct brei_result *
brei_marshal_va(struct brei_context *brei, struct iobuf *buf, const char *signature, size_t nargs, ...)
{
@ -523,6 +544,15 @@ brei_send_message_va(int fd, object_id_t id, uint32_t opcode,
return rc;
}
/**
* Return the number of int32s required to store count bytes.
*/
static inline uint32_t
bytes_to_int32(uint32_t count)
{
return (uint32_t)(((uint64_t)count + 3)/4);
}
MUNIT_TEST(test_brei_send_message)
{
int sv[2];
@ -574,7 +604,7 @@ MUNIT_TEST(test_brei_send_message)
{
const char string[12] = "hello wor"; /* tests padding too */
int slen = bytes_to_int32(strlen(string) + 1) * 4;
int slen = bytes_to_int32(strlen0(string)) * 4;
munit_assert_int(slen, ==, sizeof(string));
const int msglen = header_size + 16 + slen; /* 3 * 4 bytes + 4 bytes slen + string length */
@ -591,9 +621,11 @@ MUNIT_TEST(test_brei_send_message)
munit_assert_int(buf[0], ==, id);
munit_assert_int(buf[1], ==, msglen << 16 | opcode);
munit_assert_int(buf[2], ==, -42);
munit_assert_int(buf[3], ==, strlen(string) + 1);
munit_assert_string_equal((const char*)&buf[4], string);
munit_assert_int(memcmp(&buf[4], string, slen), ==, 0);
const struct brei_string *s = (const struct brei_string *)&buf[3];
munit_assert_int(s->len, ==, strlen0(string));
munit_assert_string_equal(s->str, string);
munit_assert_int(memcmp(s->str, string, brei_string_proto_length(s->len) - 4), ==, 0);
munit_assert_int(buf[4 + slen/4], ==, 0xab);
munit_assert_int(buf[5 + slen/4], ==, 0xcdef);
@ -615,10 +647,16 @@ MUNIT_TEST(test_brei_send_message)
munit_assert_int(len, ==, msglen);
munit_assert_int(buf[0], ==, id);
munit_assert_int(buf[1], ==, msglen << 16 | opcode);
munit_assert_int(buf[2], ==, strlen(string1) + 1);
munit_assert_string_equal((const char*)&buf[3], string1);
munit_assert_int(buf[6], ==, strlen(string2) + 1);
munit_assert_string_equal((const char*)&buf[7], string2);
const struct brei_string *s1 = (const struct brei_string *)&buf[2];
munit_assert_int(s1->len, ==, strlen0(string1));
munit_assert_string_equal(s1->str, string1);
munit_assert_int(memcmp(s1->str, string1, brei_string_proto_length(s1->len) - 4), ==, 0);
const struct brei_string *s2 = (const struct brei_string *)&buf[6];
munit_assert_int(s2->len, ==, strlen0(string2));
munit_assert_string_equal(s2->str, string2);
munit_assert_int(memcmp(s2->str, string2, brei_string_proto_length(s2->len) - 4), ==, 0);
}
{