diff --git a/src/util-sources.c b/src/util-sources.c index d096301..a3f8496 100644 --- a/src/util-sources.c +++ b/src/util-sources.c @@ -211,6 +211,24 @@ sink_add_source(struct sink *sink, struct source *source) return 0; } +int +source_enable_write(struct source *source, bool enable) +{ + assert (source->is_active); + + struct epoll_event e = { + .events = EPOLLIN | (enable ? EPOLLOUT : 0), + .data.ptr = source, /* sink_add_source ref'd, so we don't need to here */ + }; + + int rc = xerrno(epoll_ctl(source->sink->epollfd, EPOLL_CTL_MOD, source_get_fd(source), &e)); + if (rc < 0) { + source_unref(source); + return rc; + } + return 0; +} + #if _enable_tests_ #include #include @@ -330,4 +348,67 @@ MUNIT_TEST(test_source_readd) return MUNIT_OK; } + +static void +count_calls(struct source *source, void *user_data) +{ + unsigned int *arg = user_data; + *arg = *arg + 1; +} + +MUNIT_TEST(test_source_write) +{ + _unref_(sink) *sink = sink_new(); + + int fd[2]; + int rc = pipe2(fd, O_CLOEXEC|O_NONBLOCK); + munit_assert_int(rc, !=, -1); + + int read_fd = fd[0]; + int write_fd = fd[1]; + + int dispatch_called = 0; + _unref_(source) *s = source_new(write_fd, count_calls, &dispatch_called); + sink_add_source(sink, s); + sink_dispatch(sink); + sink_dispatch(sink); + sink_dispatch(sink); + + munit_assert_uint(dispatch_called, ==, 0); + + source_enable_write(s, true); + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 1); + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 2); + + /* Fill up the buffer */ + do { + char buf[4096] = {0}; + rc = write(write_fd, buf, sizeof(buf)); + } while (rc != -1); + munit_assert_int(errno, ==, EAGAIN); + + /* Buffer is full, expect our dispatch to NOT be called */ + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 2); + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 2); + + do { + char buf[406]; + rc = read(read_fd, buf, sizeof(buf)); + } while (rc != -1); + munit_assert_int(errno, ==, EAGAIN); + + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 3); + + source_enable_write(s, false); + + sink_dispatch(sink); + munit_assert_uint(dispatch_called, ==, 3); + + return MUNIT_OK; +} #endif diff --git a/src/util-sources.h b/src/util-sources.h index 837bcae..60cab6a 100644 --- a/src/util-sources.h +++ b/src/util-sources.h @@ -30,12 +30,17 @@ #include "config.h" +#include + struct source; struct sink; /** * Callback invoked when the source has data available. userdata is the data - * provided to source_add() + * provided to source_add(). + * + * If source_enable_write() was called, this dispatch function is also called + * when writes are possible (and/or data is available to read at the same time). */ typedef void (*source_dispatch_t)(struct source *source, void *user_data); @@ -86,6 +91,15 @@ source_new(int fd, source_dispatch_t dispatch, void *user_data); void source_never_close_fd(struct source *s); +/** + * Enable or disable write notifications on this source. By default we assume + * our sources only read from the fd and thus their dispatch is only called + * when there's data available to read. + * + * If write is enabled, the dispatch is also called with data available to write. + */ +int +source_enable_write(struct source *source, bool enable); struct sink * sink_new(void);