util: fix receiving of multiple fds

Initial implementation only handled this correctly if the fds were over
multiple message headers, not multiple fds in the same message header.

Signed-off-by: Peter Hutterer <peter.hutterer@who-t.net>
This commit is contained in:
Peter Hutterer 2020-08-24 11:01:01 +10:00
parent 65c069e6a0
commit e238a5d049
2 changed files with 44 additions and 17 deletions

View file

@ -114,10 +114,14 @@ xread_with_fds(int fd, void *buf, size_t count, int **fds)
hdr->cmsg_type != SCM_RIGHTS)
continue;
size_t nfds = (hdr->cmsg_len - CMSG_LEN(0)) / sizeof (int);
int *fd = (int *)CMSG_DATA(hdr);
fd_return[idx++] = *fd;
if (idx >= MAX_FDS)
break;
for (size_t i = 0; i < nfds; i++) {
fd_return[idx++] = *fd;
fd++;
if (idx >= MAX_FDS)
break;
}
}
fd_return[idx] = -1;
*fds = steal(&fd_return);

View file

@ -28,6 +28,7 @@
#include <sys/socket.h>
#include "util-io.h"
#include "util-strings.h"
#include "util-munit.h"
MUNIT_TEST(test_iobuf_new)
@ -202,20 +203,31 @@ MUNIT_TEST(test_pass_fd)
_cleanup_close_ int left = fds[0];
_cleanup_close_ int right = fds[1];
_cleanup_fclose_ FILE *fp = tmpfile();
FILE *fps[4];
int sendfds[ARRAY_LENGTH(fps) + 1];
for (size_t idx = 0; idx < ARRAY_LENGTH(fps); idx++) {
FILE *fp = tmpfile();
munit_assert_not_null(fp);
fps[idx] = fp;
sendfds[idx] = fileno(fp);
sendfds[idx + 1] = -1;
}
/* actual message data to be sent */
char data[] = "some data\n";
/* Send the fd from left to right */
int sendfds[2] = { fileno(fp), -1 };
int sendrc = xsend_with_fd(left, data, sizeof(data), sendfds);
munit_assert_int(sendrc, ==, sizeof(data));
/* Write some data to the file on it's real fd */
char buf[] = "foo\n";
fwrite(buf, sizeof(buf), 1, fp);
fflush(fp);
for (size_t idx = 0; idx < ARRAY_LENGTH(fps); idx++) {
_cleanup_free_ char *buf = xaprintf("foo %zd\n", idx);
FILE *fp = fps[idx];
fwrite(buf, strlen(buf) + 1, 1, fp);
fflush(fp);
}
/* Receive the fd on the right */
_cleanup_free_ int *recvfds = NULL;
@ -225,16 +237,27 @@ MUNIT_TEST(test_pass_fd)
munit_assert_string_equal(recvbuf, data);
munit_assert_ptr_not_null(recvfds);
munit_assert_int(recvfds[0], !=, -1);
munit_assert_int(recvfds[1], ==, -1);
munit_assert_int(recvfds[1], !=, -1);
munit_assert_int(recvfds[2], !=, -1);
munit_assert_int(recvfds[3], !=, -1);
munit_assert_int(recvfds[4], ==, -1);
/* Now check that we can read "foo" from the passed fd */
_cleanup_close_ int passed_fd = recvfds[0];
off_t off = lseek(passed_fd, 0, SEEK_SET);
munit_assert_int(off, ==, 0);
char readbuf[64];
int readrc = xread(passed_fd, readbuf, sizeof(readbuf));
munit_assert_int(readrc, ==, sizeof(buf));
munit_assert_string_equal(readbuf, buf);
/* Now check that we can read "foo N" from the passed fd */
for (size_t idx = 0; idx < ARRAY_LENGTH(fps); idx++) {
_cleanup_close_ int passed_fd = recvfds[idx];
off_t off = lseek(passed_fd, 0, SEEK_SET);
munit_assert_int(off, ==, 0);
char readbuf[64];
int readrc = xread(passed_fd, readbuf, sizeof(readbuf));
_cleanup_free_ char *expected = xaprintf("foo %zd\n", idx);
munit_assert_int(readrc, ==, strlen(expected) + 1);
munit_assert_string_equal(readbuf, expected);
/* cleanup */
FILE *fp = fps[idx];
fclose(fp);
}
return MUNIT_OK;
}