diff --git a/src/libnm-platform/nm-linux-platform.c b/src/libnm-platform/nm-linux-platform.c index 40cb439692..4fa63f8065 100644 --- a/src/libnm-platform/nm-linux-platform.c +++ b/src/libnm-platform/nm-linux-platform.c @@ -9721,21 +9721,13 @@ constructed(GObject *_object) nm_platform_get_use_udev(platform) ? "use" : "no", nm_platform_get_cache_tc(platform) ? "use" : "no"); - priv->genl = nl_socket_alloc(); - g_assert(priv->genl); - - nle = nl_connect(priv->genl, NETLINK_GENERIC); - if (nle) { + nle = nl_socket_new(&priv->genl, NETLINK_GENERIC); + if (nle) _LOGE("unable to connect the generic netlink socket \"%s\" (%d)", nm_strerror(nle), -nle); - nl_socket_free(priv->genl); - priv->genl = NULL; - } - priv->nlh = nl_socket_alloc(); - g_assert(priv->nlh); - - nle = nl_connect(priv->nlh, NETLINK_ROUTE); + nle = nl_socket_new(&priv->nlh, NETLINK_ROUTE); g_assert(!nle); + nle = nl_socket_set_passcred(priv->nlh, 1); g_assert(!nle); diff --git a/src/libnm-platform/nm-netlink.c b/src/libnm-platform/nm-netlink.c index 697ae5919d..d7a74c7a0b 100644 --- a/src/libnm-platform/nm-netlink.c +++ b/src/libnm-platform/nm-netlink.c @@ -18,6 +18,16 @@ /*****************************************************************************/ +#define nm_assert_sk(sk) \ + G_STMT_START \ + { \ + const struct nl_sock *_sk = (sk); \ + \ + nm_assert(_sk); \ + nm_assert(_sk->s_fd >= 0); \ + } \ + G_STMT_END + #define NL_SOCK_PASSCRED (1 << 1) #define NL_MSG_PEEK (1 << 3) #define NL_MSG_PEEK_EXPLICIT (1 << 4) @@ -879,30 +889,14 @@ genl_ctrl_resolve(struct nl_sock *sk, const char *name) /*****************************************************************************/ -struct nl_sock * -nl_socket_alloc(void) -{ - struct nl_sock *sk; - - sk = g_slice_new0(struct nl_sock); - - sk->s_fd = -1; - sk->s_local.nl_family = AF_NETLINK; - sk->s_peer.nl_family = AF_NETLINK; - sk->s_seq_expect = sk->s_seq_next = time(NULL); - - return sk; -} - void nl_socket_free(struct nl_sock *sk) { if (!sk) return; - if (sk->s_fd >= 0) - nm_close(sk->s_fd); - g_slice_free(struct nl_sock, sk); + nm_close(sk->s_fd); + nm_g_slice_free(sk); } int @@ -928,8 +922,7 @@ nl_socket_set_passcred(struct nl_sock *sk, int state) { int err; - if (sk->s_fd == -1) - return -NME_NL_BAD_SOCK; + nm_assert_sk(sk); err = setsockopt(sk->s_fd, SOL_SOCKET, SO_PASSCRED, &state, sizeof(state)); if (err < 0) @@ -960,8 +953,7 @@ nlmsg_get_dst(struct nl_msg *msg) int nl_socket_set_nonblocking(const struct nl_sock *sk) { - if (sk->s_fd == -1) - return -NME_NL_BAD_SOCK; + nm_assert_sk(sk); if (fcntl(sk->s_fd, F_SETFL, O_NONBLOCK) < 0) return -nm_errno_from_native(errno); @@ -974,15 +966,14 @@ nl_socket_set_buffer_size(struct nl_sock *sk, int rxbuf, int txbuf) { int err; + nm_assert_sk(sk); + if (rxbuf <= 0) rxbuf = 32768; if (txbuf <= 0) txbuf = 32768; - if (sk->s_fd == -1) - return -NME_NL_BAD_SOCK; - err = setsockopt(sk->s_fd, SOL_SOCKET, SO_SNDBUF, &txbuf, sizeof(txbuf)); if (err < 0) { return -nm_errno_from_native(errno); @@ -1002,8 +993,7 @@ nl_socket_add_memberships(struct nl_sock *sk, int group, ...) int err; va_list ap; - if (sk->s_fd == -1) - return -NME_NL_BAD_SOCK; + nm_assert_sk(sk); va_start(ap, group); @@ -1032,10 +1022,10 @@ nl_socket_add_memberships(struct nl_sock *sk, int group, ...) int nl_socket_set_ext_ack(struct nl_sock *sk, gboolean enable) { - int err, val; + int err; + int val; - if (sk->s_fd == -1) - return -NME_NL_BAD_SOCK; + nm_assert_sk(sk); val = !!enable; err = setsockopt(sk->s_fd, SOL_NETLINK, NETLINK_EXT_ACK, &val, sizeof(val)); @@ -1052,62 +1042,70 @@ nl_socket_disable_msg_peek(struct nl_sock *sk) sk->s_flags &= ~NL_MSG_PEEK; } +/*****************************************************************************/ + int -nl_connect(struct nl_sock *sk, int protocol) +nl_socket_new(struct nl_sock **out_sk, int protocol) { - int err, nmerr; - socklen_t addrlen; - struct sockaddr_nl local = {0}; + nm_auto_nlsock struct nl_sock *sk = NULL; + nm_auto_close int fd = -1; + time_t t; + int err; + int nmerr; + socklen_t addrlen; + struct sockaddr_nl local = {0}; - if (sk->s_fd != -1) - return -NME_NL_BAD_SOCK; + nm_assert(out_sk && !*out_sk); - sk->s_fd = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, protocol); - if (sk->s_fd < 0) { - nmerr = -nm_errno_from_native(errno); - goto errout; - } + fd = socket(AF_NETLINK, SOCK_RAW | SOCK_CLOEXEC, protocol); + if (fd < 0) + return -nm_errno_from_native(errno); + + t = time(NULL); + + sk = g_slice_new(struct nl_sock); + *sk = (struct nl_sock){ + .s_fd = nm_steal_fd(&fd), + .s_local = + { + .nl_pid = 0, + .nl_family = AF_NETLINK, + .nl_groups = 0, + }, + .s_peer = + { + .nl_pid = 0, + .nl_family = AF_NETLINK, + .nl_groups = 0, + }, + .s_seq_expect = t, + .s_seq_next = t, + }; nmerr = nl_socket_set_buffer_size(sk, 0, 0); if (nmerr < 0) - goto errout; - - nm_assert(sk->s_local.nl_pid == 0); + return nmerr; err = bind(sk->s_fd, (struct sockaddr *) &sk->s_local, sizeof(sk->s_local)); - if (err != 0) { - nmerr = -nm_errno_from_native(errno); - goto errout; - } + if (err != 0) + return -nm_errno_from_native(errno); addrlen = sizeof(local); err = getsockname(sk->s_fd, (struct sockaddr *) &local, &addrlen); - if (err < 0) { - nmerr = -nm_errno_from_native(errno); - goto errout; - } + if (err < 0) + return -nm_errno_from_native(errno); - if (addrlen != sizeof(local)) { - nmerr = -NME_UNSPEC; - goto errout; - } + if (addrlen != sizeof(local)) + return -NME_UNSPEC; - if (local.nl_family != AF_NETLINK) { - nmerr = -NME_UNSPEC; - goto errout; - } + if (local.nl_family != AF_NETLINK) + return -NME_UNSPEC; sk->s_local = local; sk->s_proto = protocol; + *out_sk = g_steal_pointer(&sk); return 0; - -errout: - if (sk->s_fd != -1) { - close(sk->s_fd); - sk->s_fd = -1; - } - return nmerr; } /*****************************************************************************/ diff --git a/src/libnm-platform/nm-netlink.h b/src/libnm-platform/nm-netlink.h index 2ac8511393..441a123333 100644 --- a/src/libnm-platform/nm-netlink.h +++ b/src/libnm-platform/nm-netlink.h @@ -488,7 +488,7 @@ nlmsg_put(struct nl_msg *n, uint32_t pid, uint32_t seq, int type, int payload, i struct nl_sock; -struct nl_sock *nl_socket_alloc(void); +int nl_socket_new(struct nl_sock **out_sk, int protocol); void nl_socket_free(struct nl_sock *sk); diff --git a/src/libnm-platform/tests/test-nm-platform.c b/src/libnm-platform/tests/test-nm-platform.c index 9ac69bdeda..9263a6e7d7 100644 --- a/src/libnm-platform/tests/test-nm-platform.c +++ b/src/libnm-platform/tests/test-nm-platform.c @@ -59,7 +59,7 @@ test_use_symbols(void) (void (*)(void)) genlmsg_valid_hdr, (void (*)(void)) genlmsg_parse, (void (*)(void)) genl_ctrl_resolve, - (void (*)(void)) nl_socket_alloc, + (void (*)(void)) nl_socket_new, (void (*)(void)) nl_socket_free, (void (*)(void)) nl_socket_get_fd, (void (*)(void)) nl_socket_get_local_port, @@ -72,7 +72,6 @@ test_use_symbols(void) (void (*)(void)) nl_socket_add_memberships, (void (*)(void)) nl_socket_set_ext_ack, (void (*)(void)) nl_socket_disable_msg_peek, - (void (*)(void)) nl_connect, (void (*)(void)) nl_wait_for_ack, (void (*)(void)) nl_recvmsgs, (void (*)(void)) nl_sendmsg,