util: add sending and receiving fds to the io utilities

Signed-off-by: Peter Hutterer <peter.hutterer@who-t.net>
This commit is contained in:
Peter Hutterer 2020-08-21 11:49:53 +10:00
parent 9e42b579d9
commit 56ca4b4ac7
2 changed files with 152 additions and 0 deletions

View file

@ -28,10 +28,13 @@
#include <assert.h>
#include <errno.h>
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <sys/socket.h>
#include "util-mem.h"
/**
* Wrapper to convert an errno-setting syscall into a
* value-or-negative-errno.
@ -53,6 +56,12 @@ xclose(int fd) {
return -1;
}
DEFINE_TRIVIAL_CLEANUP_FUNC(int, close);
#define _cleanup_close_ _cleanup_(closep)
DEFINE_TRIVIAL_CLEANUP_FUNC(FILE *, fclose);
#define _cleanup_fclose_ _cleanup_(fclosep)
/**
* Wrapper around read(). Returns the number of bytes read or a negative
* errno on failure.
@ -63,6 +72,58 @@ xread(int fd, void *buf, size_t count)
return xerrno(read(fd, buf, count));
}
/**
* Wrapper around read(). Returns the number of bytes read or a negative
* errno on failure. Any fds passed along with the message
* are stored in the -1-terminated allocated fds array, to be freed by the
* caller. Where no fds were passed, the array is NULL.
*/
static inline int
xread_with_fds(int fd, void *buf, size_t count, int **fds)
{
const size_t MAX_FDS = 32;
union {
struct cmsghdr header;
char control[CMSG_SPACE(MAX_FDS * sizeof(int))];
} ctrl;
struct iovec iov = {
.iov_base = buf,
.iov_len = count,
};
struct msghdr msg = {
.msg_name = NULL,
.msg_namelen = 0,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = ctrl.control,
.msg_controllen = sizeof(ctrl.control),
};
int received = xerrno(recvmsg(fd, &msg, 0));
if (received > 0) {
*fds = NULL;
_cleanup_free_ int *fd_return = calloc(MAX_FDS + 1, sizeof(int));
size_t idx = 0;
for (struct cmsghdr *hdr = CMSG_FIRSTHDR(&msg); hdr; hdr = CMSG_NXTHDR(&msg, hdr)) {
if (hdr->cmsg_level != SOL_SOCKET ||
hdr->cmsg_type != SCM_RIGHTS)
continue;
int *fd = (int *)CMSG_DATA(hdr);
fd_return[idx++] = *fd;
if (idx >= MAX_FDS)
break;
}
fd_return[idx] = -1;
*fds = steal(&fd_return);
}
return received;
}
/**
* Wrapper around write(). Returns the number of bytes written or a negative
* errno on failure.
@ -83,6 +144,51 @@ xsend(int fd, const void *buf, size_t len)
return xerrno(send(fd, buf, len, MSG_NOSIGNAL));
}
/**
* Wrapper around send() that always sets MSG_NOSIGNAL and allows appending
* file descriptors to the message.
*
* @param fds Array of file descriptors, terminated by -1.
*/
static inline int
xsend_with_fd(int fd, const void *buf, size_t len, int *fds)
{
size_t nfds = 0;
for (nfds = 0; fds != NULL && fds[nfds] != -1; nfds++) {
/* noop */
}
if (nfds == 0)
return xsend(fd, buf, len);
union {
struct cmsghdr header;
char control[CMSG_SPACE(nfds * sizeof(int))];
} ctrl;
struct iovec iov = {
.iov_base = (void*)buf,
.iov_len = len,
};
struct msghdr msg = {
.msg_name = NULL,
.msg_namelen = 0,
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = ctrl.control,
.msg_controllen = sizeof(ctrl.control),
};
ctrl.header.cmsg_len = CMSG_LEN(nfds * sizeof(int));
ctrl.header.cmsg_level = SOL_SOCKET;
ctrl.header.cmsg_type = SCM_RIGHTS;
memcpy(CMSG_DATA(CMSG_FIRSTHDR(&msg)), fds, nfds * sizeof(int));
return xerrno(sendmsg(fd, &msg, MSG_NOSIGNAL));
}
/* consider this struct opaque */
struct iobuf {
size_t sz;

View file

@ -23,6 +23,7 @@
#include "config.h"
#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
@ -165,6 +166,51 @@ MUNIT_TEST(test_iobuf_append_fd)
return MUNIT_OK;
}
MUNIT_TEST(test_pass_fd)
{
int fds[2];
int rc = socketpair(AF_UNIX, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0, fds);
munit_assert_int(rc, ==, 0);
_cleanup_close_ int left = fds[0];
_cleanup_close_ int right = fds[1];
_cleanup_fclose_ FILE *fp = tmpfile();
/* actual message data to be sent */
char data[] = "some data\n";
/* Send the fd from left to right */
int sendfds[2] = { fileno(fp), -1 };
int sendrc = xsend_with_fd(left, data, sizeof(data), sendfds);
munit_assert_int(sendrc, ==, sizeof(data));
/* Write some data to the file on it's real fd */
char buf[] = "foo\n";
fwrite(buf, sizeof(buf), 1, fp);
fflush(fp);
/* Receive the fd on the right */
_cleanup_free_ int *recvfds = NULL;
char recvbuf[sizeof(data)];
int recvrc = xread_with_fds(right, recvbuf, sizeof(recvbuf), &recvfds);
munit_assert_int(recvrc, ==, sizeof(data));
munit_assert_string_equal(recvbuf, data);
munit_assert_ptr_not_null(recvfds);
munit_assert_int(recvfds[0], !=, -1);
munit_assert_int(recvfds[1], ==, -1);
/* Now check that we can read "foo" from the passed fd */
_cleanup_close_ int passed_fd = recvfds[0];
off_t off = lseek(passed_fd, 0, SEEK_SET);
munit_assert_int(off, ==, 0);
char readbuf[64];
int readrc = xread(passed_fd, readbuf, sizeof(readbuf));
munit_assert_int(readrc, ==, sizeof(buf));
munit_assert_string_equal(readbuf, buf);
return MUNIT_OK;
}
int
main(int argc, char* argv[MUNIT_ARRAY_PARAM(argc + 1)])
{