libei/src/brei-shared.c
Peter Hutterer 46681e2855 Add SPDX identifiers to all source files
Signed-off-by: Peter Hutterer <peter.hutterer@who-t.net>
2022-03-03 00:27:36 +00:00

271 lines
6.8 KiB
C

/* SPDX-License-Identifier: MIT */
/*
* Copyright © 2020 Red Hat, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the "Software"),
* to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense,
* and/or sell copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice (including the next
* paragraph) shall be included in all copies or substantial portions of the
* Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
* DEALINGS IN THE SOFTWARE.
*/
#include "config.h"
#include <stdbool.h>
#include "util-mem.h"
#include "util-io.h"
#include "proto/ei.pb-c.h"
#include "brei-shared.h"
struct brei_message_private {
struct brei_message base;
int *fd;
};
int
brei_message_take_fd(struct brei_message *m)
{
struct brei_message_private *msg = (struct brei_message_private*)m;
int fd = *msg->fd;
*msg->fd = -1;
return fd;
}
static void packet_cleanup(Packet **f) {
if (*f)
packet__free_unpacked(*f, NULL);
}
#define _cleanup_packet_ _cleanup_(packet_cleanup)
/**
* The BREI (i.e. the EI protocol) is using packets to separate the messages.
* This helper takes the data of length msglength and returns the position
* of the next actual message.
*
* @param msglen Length of data
* @param consumed The number of bytes consumed to advance to the next
* message
* @return The position of the next message
*/
static const char *
brei_next_message(const char *data, size_t *msglen, size_t *consumed)
{
/* Every message is prefixed by a fixed-length packet message which
* contains the length of the next message. packets are always the
* same length, so we only need to calculate the size once.
*/
static size_t packetlen = 0;
if (packetlen == 0) {
Packet f = PACKET__INIT;
f.length = 0xffff;
packetlen = packet__get_packed_size(&f);
assert(packetlen >= 5);
}
_cleanup_packet_ Packet *packet = packet__unpack(NULL, packetlen, (const unsigned char *)data);
if (!packet)
return NULL;
*msglen = packet->length;
*consumed = packetlen;
return data + packetlen;
}
int
brei_dispatch(int fd,
int (*callback)(struct brei_message *m, void *user_data),
void *user_data)
{
_cleanup_iobuf_ struct iobuf *buf = iobuf_new(64);
int rc = iobuf_recv_from_fd(buf, fd);
if (rc == -EAGAIN) {
return 0;
} else if (rc == 0) {
return -ECANCELED;
} else if (rc < 0) {
return rc;
}
_cleanup_close_ int recvfd = -1;
size_t idx = 0;
while (true) {
const char *data = iobuf_data(buf) + idx;
size_t len = iobuf_len(buf) - idx;
int consumed = 0;
if (len == 0)
break;
size_t headerbytes = 0;
size_t msglen = 0;
const char *msgdata = brei_next_message(data, &msglen, &headerbytes);
assert(len >= msglen);
/* This is a bit messy because it's just blu tacked on.
* Our protocol passes maximum of one fd per message. We
* take whatever next fd is and pass it along. Where the
* parser takes it (brei_message_take_fd()) it gets set to
* -1 and we take the next fd for the next message.
*/
if (recvfd == -1)
recvfd = iobuf_take_fd(buf);
struct brei_message_private msg = {
.base.data = msgdata,
.base.len = msglen,
.fd = &recvfd,
};
/* Actual message parsing is done by the caller */
consumed = callback(&msg.base, user_data);
assert(consumed != 0);
if (consumed < 0) {
rc = consumed;
goto error;
}
idx += consumed + headerbytes;
}
rc = 0;
error:
return rc;
}
void
brei_drain_fd(int fd)
{
_cleanup_iobuf_ struct iobuf *buf = iobuf_new(1024);
int rc;
while ((rc = iobuf_recv_from_fd(buf, fd)) > 0)
;;
}
#ifdef _enable_tests_
#include "src/util-munit.h"
MUNIT_TEST(test_proto_next_message)
{
char data[64];
/* Invalid packet, rest can be random */
memset(data, 0xab, sizeof(data));
data[0] = 0xaa;
size_t msglen = 0xab;
size_t consumed = 0xbc;
const char *rval = brei_next_message(data, &msglen, &consumed);
munit_assert_ptr_null(rval);
munit_assert_int(msglen, ==, 0xab);
munit_assert_int(consumed, ==, 0xbc);
/* Now try a valid one */
Packet f = PACKET__INIT;
f.length = 0xcd;
size_t packetlen = packet__get_packed_size(&f);
unsigned char buf[packetlen * 4];
for (int i = 0; i < 4; i++)
packet__pack(&f, buf + packetlen * i);
const char *ptr = (char*)buf;
for (int i = 0; i < 4; i++) {
const char *next_packet = brei_next_message(ptr, &msglen, &consumed);
munit_assert_ptr_equal(next_packet, buf + (i + 1) * packetlen);
munit_assert_int(consumed, ==, packetlen);
munit_assert_int(msglen, ==, 0xcd);
ptr += consumed;
}
return MUNIT_OK;
}
static int
brei_dispatch_cb(struct brei_message *msg,
void *user_data)
{
char *buf = user_data;
memcpy(buf, msg->data, msg->len);
return msg->len;
}
static inline void
send_data(int fd, const char *data, size_t data_size)
{
Packet f = PACKET__INIT;
f.length = 1;
size_t packetlen = packet__get_packed_size(&f);
/* note: data is null-terminated, we copy all of it but only use
datalen to check truncation works */
unsigned char buf[1024] = {0};
f.length = data_size;
packet__pack(&f, buf);
memcpy(buf + packetlen, data, strlen(data)); /* intentionally strlen */
int rc = xsend(fd, buf, packetlen + data_size);
munit_assert_int(rc, ==, packetlen + data_size);
}
MUNIT_TEST(test_brei_dispatch)
{
int sv[2];
int rc = socketpair(AF_UNIX, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0, sv);
munit_assert_int(rc, ==, 0);
int sock_read = sv[0];
int sock_write = sv[1];
{
/* Packet one: just an 'x' */
char return_buffer[1024] = {0};
send_data(sock_write, "x", 1);
int rc = brei_dispatch(sock_read, brei_dispatch_cb, return_buffer);
munit_assert_int(rc, ==, 0);
munit_assert_string_equal(return_buffer, "x");
}
{
/* Packet two: 'foobar' */
char return_buffer[1024] = {0};
send_data(sock_write, "foobar", 6);
int rc = brei_dispatch(sock_read, brei_dispatch_cb, return_buffer);
munit_assert_int(rc, ==, 0);
munit_assert_string_equal(return_buffer, "foobar");
}
{
/* Packet three: 'foobar' but last char truncated */
char return_buffer[1024] = {0};
send_data(sock_write, "foobar", 5); /* truncated */
int rc = brei_dispatch(sock_read, brei_dispatch_cb, return_buffer);
munit_assert_int(rc, ==, 0);
munit_assert_string_equal(return_buffer, "fooba");
}
return MUNIT_OK;
}
#endif