diff --git a/src/brei-shared.c b/src/brei-shared.c index 6846bbf..fa1fccd 100644 --- a/src/brei-shared.c +++ b/src/brei-shared.c @@ -396,8 +396,14 @@ brei_dispatch(struct brei_context *brei, const struct brei_header *header = (const struct brei_header *)data; uint32_t msglen = header->msglen; - if (len < msglen) - break; + /* Max message size: 1MiB, plenty enough for the current protocol */ + static const uint32_t max_msglen = 1024 * 1024; + if (msglen < headersize || msglen > max_msglen) { + result = brei_result_new(BREI_CONNECTION_DISCONNECT_REASON_PROTOCOL, + "invalid message length %u", + msglen); + goto error; + } object_id_t object_id = header->sender_id; uint32_t opcode = header->opcode; diff --git a/src/brei-shared.h b/src/brei-shared.h index 6f25d43..d28c505 100644 --- a/src/brei-shared.h +++ b/src/brei-shared.h @@ -26,6 +26,10 @@ #include "config.h" +#ifdef NDEBUG +#warning "This project relies on assert(). #defining NDEBUG is not recommended" +#endif + #include #include #include diff --git a/src/libeis-fd.c b/src/libeis-fd.c index 6ed151e..dcccf0b 100644 --- a/src/libeis-fd.c +++ b/src/libeis-fd.c @@ -28,6 +28,7 @@ #include #include +#include "util-io.h" #include "util-macros.h" #include "util-mem.h" #include "util-sources.h" @@ -83,8 +84,11 @@ eis_backend_fd_add_client(struct eis *eis) return -errno; struct eis_client *client = eis_client_new(eis, fds[0]); - if (client == NULL) + if (client == NULL) { + xclose(fds[0]); + xclose(fds[1]); return -ENOMEM; + } eis_client_unref(client); diff --git a/src/libeis-socket.c b/src/libeis-socket.c index 671b3d5..5ce918b 100644 --- a/src/libeis-socket.c +++ b/src/libeis-socket.c @@ -115,6 +115,10 @@ listener_dispatch(struct source *source, void *data) return; struct eis_client *client = eis_client_new(eis, fd); + if (client == NULL) { + xclose(fd); + return; + } eis_client_unref(client); } @@ -143,11 +147,15 @@ eis_setup_backend_socket(struct eis *eis, const char *socketpath) * socket file. */ _cleanup_free_ char *lockfile = xaprintf("%s.lock", path); _cleanup_close_ int lockfd = - open(lockfile, O_CREAT | O_CLOEXEC | O_RDWR, S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP); - int rc = flock(lockfd, LOCK_EX | LOCK_NB); + xerrno(open(lockfile, O_CREAT | O_CLOEXEC | O_RDWR, S_IRUSR | S_IWUSR)); + int rc; + if (lockfd >= 0) + rc = xerrno(flock(lockfd, LOCK_EX | LOCK_NB)); + else + rc = lockfd; if (rc < 0) { log_error(eis, "Failed to create lockfile %s, is another EIS running?", lockfile); - return -errno; + return -rc; } struct stat st; @@ -177,6 +185,10 @@ eis_setup_backend_socket(struct eis *eis, const char *socketpath) if (bind(sockfd, (struct sockaddr *)&addr, sizeof(addr)) == -1) return -errno; + /* Restrict socket to owner-only access regardless of umask */ + if (fchmod(sockfd, S_IRUSR | S_IWUSR) == -1) + return -errno; + if (listen(sockfd, 2) == -1) return -errno; diff --git a/src/liboeffis.c b/src/liboeffis.c index d516650..b1855d8 100644 --- a/src/liboeffis.c +++ b/src/liboeffis.c @@ -327,7 +327,6 @@ xdp_session_path(char *sender_name, char *token) static int connect_to_eis_returned(sd_bus_message *m, void *userdata, sd_bus_error *ret_error) { - int eisfd; struct oeffis *oeffis = userdata; int rc = sd_bus_message_get_errno(m); @@ -336,27 +335,37 @@ connect_to_eis_returned(sd_bus_message *m, void *userdata, sd_bus_error *ret_err return rc; } - rc = sd_bus_message_read(m, "h", &eisfd); + int sd_eisfd; + rc = sd_bus_message_read(m, "h", &sd_eisfd); if (rc < 0) { oeffis_disconnect(oeffis, "Unable to get fd from portal: %s", strerror(-rc)); return -rc; } /* the fd is owned by the message */ - rc = xerrno(xdup(eisfd)); - if (rc < 0) { - oeffis_disconnect(oeffis, "Failed to dup fd: %s", strerror(-rc)); - return -rc; + _cleanup_close_ int eisfd = xerrno(xdup(sd_eisfd)); + if (eisfd < 0) { + oeffis_disconnect(oeffis, "Failed to dup fd: %s", strerror(-eisfd)); + return -eisfd; } else { - eisfd = rc; - int flags = fcntl(eisfd, F_GETFL, 0); - fcntl(eisfd, F_SETFL, flags | O_NONBLOCK); + int flags = xerrno(fcntl(eisfd, F_GETFL, 0)); + if (flags >= 0) + rc = xerrno(fcntl(eisfd, F_SETFL, flags | O_NONBLOCK)); + else + rc = flags; + + if (rc < 0) { + oeffis_disconnect(oeffis, + "Failed to set the fd to non-blocking: %s", + strerror(-rc)); + return -rc; + } } log_debug("Got fd %d from portal", eisfd); - rc = oeffis_set_eis_fd(oeffis, eisfd); + rc = oeffis_set_eis_fd(oeffis, steal_fd(&eisfd)); if (rc < 0) { oeffis_disconnect(oeffis, "Failed to set the fd: %s", strerror(-rc)); return -rc; diff --git a/src/util-io.c b/src/util-io.c index 7b39b6c..3315ab6 100644 --- a/src/util-io.c +++ b/src/util-io.c @@ -88,6 +88,11 @@ xsend_with_fd(int fd, const void *buf, size_t len, int *fds) if (nfds == 0) return xsend(fd, buf, len); + const size_t MAX_FDS = 32; + if (nfds > MAX_FDS) { + return -EINVAL; + } + char control[CMSG_SPACE(nfds * sizeof(int))]; struct cmsghdr *header = (struct cmsghdr *)control; @@ -126,11 +131,8 @@ struct iobuf { struct iobuf * iobuf_new(size_t size) { - struct iobuf *buf = malloc(sizeof(*buf)); - uint8_t *data = malloc(size); - - assert(buf); - assert(data); + struct iobuf *buf = xalloc(sizeof(*buf)); + uint8_t *data = xalloc(size); *buf = (struct iobuf){ .sz = size, @@ -212,8 +214,7 @@ iobuf_take_fd(struct iobuf *buf) static inline void iobuf_resize(struct iobuf *buf, size_t to_size) { - uint8_t *newdata = realloc(buf->data, to_size); - assert(newdata); + uint8_t *newdata = xrealloc(buf->data, to_size); buf->data = newdata; buf->sz = to_size; @@ -389,6 +390,12 @@ iobuf_recv_from_fd(struct iobuf *buf, int fd) fd++; } } + /* Close any remaining fds that didn't fit in the buffer + * to prevent fd leaks */ + while (*fd != -1) { + xclose(*fd); + fd++; + } } nread += rc; diff --git a/src/util-io.h b/src/util-io.h index 26cf02a..0ea7d63 100644 --- a/src/util-io.h +++ b/src/util-io.h @@ -127,7 +127,7 @@ xerrno(int value) static inline int xclose(int fd) { - if (fd != -1) { + if (fd > -1) { /* Not SYSCALL(), see libei MR!261#note_2131802 */ close(fd); } @@ -223,7 +223,8 @@ xconnect(const char *path) if (!xsnprintf(addr.sun_path, sizeof(addr.sun_path), "%s", path)) return -EINVAL; - int sockfd = xerrno(SYSCALL(socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0))); + _cleanup_close_ int sockfd = + xerrno(SYSCALL(socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0))); if (sockfd < 0) return sockfd; @@ -231,7 +232,7 @@ xconnect(const char *path) if (rc < 0) return rc; - return sockfd; + return steal_fd(&sockfd); } /** diff --git a/src/util-mem.h b/src/util-mem.h index c02affd..7271f6f 100644 --- a/src/util-mem.h +++ b/src/util-mem.h @@ -96,6 +96,21 @@ _steal(void *ptr) #define steal(ptr_) \ (typeof(*ptr_))_steal(ptr_) +/** + * Resets the pointer content to -1 and returns + * the original value. + */ +static inline int +steal_fd(int *ptr) +{ + if (ptr) { + int original = *ptr; + *ptr = -1; + return original; + } + return -1; +} + /** * Never-failing calloc with a size limit check. */ @@ -117,7 +132,7 @@ xalloc(size_t size) } static inline void * -xrealloc(void *ptr, int size) +xrealloc(void *ptr, size_t size) { void *tmp = realloc(ptr, size); assert(tmp); diff --git a/src/util-memfile.c b/src/util-memfile.c index ea55c46..6f5e334 100644 --- a/src/util-memfile.c +++ b/src/util-memfile.c @@ -63,7 +63,8 @@ memfile_new(const char *data, size_t sz) if (fd < 0) return NULL; - fcntl(fd, F_ADD_SEALS, F_SEAL_SHRINK); + if (fcntl(fd, F_ADD_SEALS, F_SEAL_SHRINK) < 0) + return NULL; int rc; with_signals_blocked(SIGALRM) diff --git a/src/util-memmap.c b/src/util-memmap.c index a106ef2..31445f7 100644 --- a/src/util-memmap.c +++ b/src/util-memmap.c @@ -55,6 +55,9 @@ OBJECT_IMPLEMENT_GETTER(memmap, data, void *); struct memmap * memmap_new(int fd, size_t sz) { + if (sz == 0) + return NULL; + _unref_(memmap) *memmap = memmap_create(NULL); void *map = mmap(NULL, sz, PROT_READ, MAP_PRIVATE, fd, 0); diff --git a/src/util-strings.c b/src/util-strings.c index 656426d..6e6140a 100644 --- a/src/util-strings.c +++ b/src/util-strings.c @@ -150,11 +150,17 @@ strv_join(char **strv, const char *joiner) slen += (count - 1) * strlen(joiner); str = xalloc(slen + 1); /* trailing \0 */ + size_t jlen = strlen(joiner); + size_t offset = 0; for (s = strv; *s; s++) { - strcat(str, *s); + size_t l = strlen(*s); + memcpy(str + offset, *s, l); + offset += l; --count; - if (count > 0) - strcat(str, joiner); + if (count > 0) { + memcpy(str + offset, joiner, jlen); + offset += jlen; + } } return str; @@ -214,9 +220,7 @@ strreplace(const char *string, const char *separator, const char *replacement) destptr += len; } - void *tmp = realloc(r, (destptr - r) + 1); - assert(tmp); - return tmp; + return xrealloc(r, (destptr - r) + 1); } size_t @@ -240,9 +244,7 @@ strv_append_take(char **strv, char **str) size_t len = strv_len(strv) + 1; len = max(len, 2); - char **s = realloc(strv, len * sizeof(*strv)); - if (!s) - abort(); + char **s = xrealloc(strv, len * sizeof(*strv)); s[len - 1] = NULL; s[len - 2] = *str; *str = NULL; @@ -662,4 +664,70 @@ MUNIT_TEST(test_strv_from_mem) return MUNIT_OK; } + +MUNIT_TEST(test_xatou) +{ + unsigned int val; + + munit_assert_true(xatou("0", &val)); + munit_assert_uint(val, ==, 0); + + munit_assert_true(xatou("1", &val)); + munit_assert_uint(val, ==, 1); + + munit_assert_true(xatou("123", &val)); + munit_assert_uint(val, ==, 123); + + /* UINT_MAX is the upper boundary, must succeed */ + munit_assert_true(xatou("4294967295", &val)); + munit_assert_uint(val, ==, UINT_MAX); + + /* UINT_MAX + 1 must fail */ + munit_assert_false(xatou("4294967296", &val)); + + /* Another random value in the range UINT_MAX < val < LONG_MAX */ + munit_assert_false(xatou("8589934592", &val)); + + /* LONG_MAX as string - must fail */ + munit_assert_false(xatou("9223372036854775807", &val)); + + /* Overflow beyond ULONG_MAX - strtoul sets errno, must fail */ + munit_assert_false(xatou("18446744073709551616", &val)); + + /* negative numbers: strtoul wraps "-1" to ULONG_MAX without + * setting errno, but v > UINT_MAX catches it */ + munit_assert_false(xatou("-1", &val)); + + /* invalid strings */ + munit_assert_false(xatou("", &val)); + munit_assert_false(xatou("abc", &val)); + munit_assert_false(xatou("123abc", &val)); + munit_assert_false(xatou("12 34", &val)); + + /* hex via xatou_base */ + munit_assert_true(xatou_base("ff", &val, 16)); + munit_assert_uint(val, ==, 0xff); + + munit_assert_true(xatou_base("0xff", &val, 16)); + munit_assert_uint(val, ==, 0xff); + + munit_assert_true(xatou_base("FFFFFFFF", &val, 16)); + munit_assert_uint(val, ==, UINT_MAX); + + /* hex UINT_MAX + 1 */ + munit_assert_false(xatou_base("100000000", &val, 16)); + + /* octal */ + munit_assert_true(xatou_base("77", &val, 8)); + munit_assert_uint(val, ==, 077); + + munit_assert_true(xatou_base("37777777777", &val, 8)); + munit_assert_uint(val, ==, UINT_MAX); + + /* octal UINT_MAX + 1 */ + munit_assert_false(xatou_base("40000000000", &val, 8)); + + return MUNIT_OK; +} + #endif diff --git a/src/util-strings.h b/src/util-strings.h index 74b7f38..e62df4c 100644 --- a/src/util-strings.h +++ b/src/util-strings.h @@ -183,7 +183,7 @@ xatou_base(const char *str, unsigned int *val, int base) if (*str != '\0' && *endptr != '\0') return false; - if ((long)v < 0) + if (v > UINT_MAX) return false; *val = v; @@ -441,11 +441,11 @@ cmdline_as_str(void) if (sysctl(mib, ARRAY_LENGTH(mib), NULL, &len, NULL, 0)) return NULL; - char *const procargs = malloc(len); + _cleanup_free_ char *procargs = xalloc(len); if (sysctl(mib, ARRAY_LENGTH(mib), procargs, &len, NULL, 0)) return NULL; - return procargs; + return steal(&procargs); #else int fd = open("/proc/self/cmdline", O_RDONLY); if (fd != -1) { diff --git a/test/test_protocol.py b/test/test_protocol.py index 60420b3..1d6d61b 100644 --- a/test/test_protocol.py +++ b/test/test_protocol.py @@ -33,6 +33,7 @@ import time import shlex import signal import socket +import struct import structlog try: @@ -1359,3 +1360,60 @@ class TestEiProtocol: ei.wait_for(lambda: status.disconnected) assert status.disconnected is True + + @pytest.mark.parametrize( + "invalid_msglen", + (0, 1, 15, 1024 * 1024 + 1, 0xFFFFFFFF), + ids=("zero", "one", "header-minus-one", "over-max", "uint32-max"), + ) + def test_invalid_message_length(self, eis, invalid_msglen): + """ + Ensure the server disconnects us if we send a message with an invalid + msglen (less than header size or greater than 1MiB). + """ + ei = eis.ei + ei.dispatch() + ei.init_default_sender_connection() + ei.wait_for_connection() + + assert ei.connection is not None + connection = ei.connection + + @dataclass + class Status: + disconnected: bool = False + reason: int = 0 + explanation: Optional[str] = None + + status = Status() + + def on_disconnected(connection, last_serial, reason, explanation): + status.disconnected = True + status.reason = reason + status.explanation = explanation + + connection.connect("Disconnected", on_disconnected) + + # Craft a raw message with a valid object_id (the connection's) but + # an invalid msglen. Use opcode 0 since the msglen check happens + # before opcode validation. + raw_msg = struct.pack("=QII", connection.object_id, invalid_msglen, 0) + # For oversized messages we only need to send the header - the server + # checks msglen from the header before trying to read the full payload. + try: + ei.send(raw_msg) + ei.dispatch() + time.sleep(0.5) + ei.dispatch() + except (ConnectionResetError, BrokenPipeError): + # The server may have already closed the connection + return + + # If we didn't get a socket error, we should have gotten a + # Disconnected event with a protocol error reason + assert status.disconnected, ( + f"Expected disconnection for invalid msglen {invalid_msglen}" + ) + assert status.reason == EiConnection.EiDisconnectReason.PROTOCOL, ( + status.explanation + ) diff --git a/tools/ei-demo-client.c b/tools/ei-demo-client.c index b12de52..415ecc4 100644 --- a/tools/ei-demo-client.c +++ b/tools/ei-demo-client.c @@ -131,7 +131,8 @@ setup_xkb_keymap(struct ei_keymap *keymap) for (unsigned int evcode = KEY_Q; evcode <= KEY_Y; evcode++) { char utf8[7]; xkb_keysym_t keysym = xkb_state_key_get_one_sym(xkbstate, evcode + 8); - xkb_keysym_to_utf8(keysym, utf8, sizeof(utf8)); + int len = xkb_keysym_to_utf8(keysym, utf8, sizeof(utf8)); + assert(len > 0 && (size_t)len <= sizeof(utf8)); strcat(layout, utf8); } @@ -258,10 +259,16 @@ main(int argc, char **argv) receiver = true; break; case OPT_INTERVAL: - interval = atoi(optarg); + if (!xatou(optarg, &interval)) { + fprintf(stderr, "Invalid interval: %s\n", optarg); + return EXIT_FAILURE; + } break; case OPT_ITERATIONS: - iterations = atoi(optarg); + if (!xatou(optarg, &iterations)) { + fprintf(stderr, "Invalid iterations: %s\n", optarg); + return EXIT_FAILURE; + } break; default: usage(stderr, argv[0]); diff --git a/tools/eis-demo-server.c b/tools/eis-demo-server.c index 9cf04ad..06a9c89 100644 --- a/tools/eis-demo-server.c +++ b/tools/eis-demo-server.c @@ -698,7 +698,10 @@ main(int argc, char **argv) verbose = true; break; case OPT_INTERVAL: - interval = atoi(optarg); + if (!xatou(optarg, &interval)) { + fprintf(stderr, "Invalid interval: %s\n", optarg); + return EXIT_FAILURE; + } break; default: usage(stderr, argv[0]);