bpf: clat: ensure data is pulled for direct packet access

There is no guarantee that the part of the packet we want to read or
write via direct packet access is linear. From the documentation of
bpf_skb_pull_data():

  For direct packet access, testing that offsets to access are within
  packet boundaries (test on skb->data_end) is susceptible to fail if
  offsets are invalid, or if the requested data is in non-linear parts
  of the skb. On failure the program can just bail out, or in the case
  of a non-linear buffer, use a helper to make the data available. The
  bpf_skb_load_bytes() helper is a first solution to access the
  data. Another one consists in using bpf_skb_pull_data to pull in
  once the non-linear parts, then retesting and eventually access the
  data.

See: https://gitlab.freedesktop.org/NetworkManager/NetworkManager/-/merge_requests/2107#note_3288979
This commit is contained in:
Beniamino Galvani 2026-01-18 10:12:26 +01:00
parent e6a9a1ab77
commit 79f2e3ffe0

View file

@ -78,6 +78,36 @@ struct ip6_frag {
__u32 identification;
} __attribute__((packed));
#define ensure_header(header, skb, data, data_end, offset) \
_ensure_header((void **) header, (skb), (data), (data_end), sizeof(**(header)), (offset))
/*
* Verifies that the header at offset @offset and with size @size can
* be accessed, and assigns the pointer to @header. In case the data
* is not available, the function tries to pull it. Note that all packet
* pointers must be refreshed after calling this function.
*/
static __always_inline bool
_ensure_header(void **header,
struct __sk_buff *skb,
void **data,
void **data_end,
unsigned size,
unsigned offset)
{
if (*data + offset + size > *data_end) {
bpf_skb_pull_data(skb, offset + size);
*data = SKB_DATA(skb);
*data_end = SKB_DATA_END(skb);
}
if (*data + offset + size > *data_end)
return false;
*header = *data + offset;
return true;
}
/* This function must be declared as inline because the BPF calling
* convention only supports up to 5 function arguments. */
static __always_inline void
@ -206,8 +236,7 @@ rewrite_icmp(struct __sk_buff *skb, const struct ipv6hdr *ip6h)
struct icmp6hdr *icmp6;
__u32 mtu;
icmp = data + sizeof(struct ethhdr) + sizeof(struct iphdr);
if ((icmp + 1) > data_end)
if (!ensure_header(&icmp, skb, &data, &data_end, sizeof(struct ethhdr) + sizeof(struct iphdr)))
return -1;
icmp_buf = *icmp;
@ -495,14 +524,11 @@ clat_handle_v4(struct __sk_buff *skb)
struct iphdr *iph;
struct ethhdr *eth;
eth = data;
if (eth + 1 > data_end)
goto out;
if (eth->h_proto != bpf_htons(ETH_P_IP))
if (!ensure_header(&iph, skb, &data, &data_end, sizeof(struct ethhdr)))
goto out;
iph = data + sizeof(struct ethhdr);
if (iph + 1 > data_end)
eth = data;
if (eth->h_proto != bpf_htons(ETH_P_IP))
goto out;
if (iph->saddr != config.local_v4.s_addr)
@ -561,14 +587,10 @@ clat_handle_v4(struct __sk_buff *skb)
data = SKB_DATA(skb);
data_end = SKB_DATA_END(skb);
eth = data;
if (eth + 1 > data_end)
goto out;
ip6h = (void *) (eth + 1);
if (ip6h + 1 > data_end)
if (!ensure_header(&ip6h, skb, &data, &data_end, sizeof(struct ethhdr)))
goto out;
eth = data;
eth->h_proto = bpf_htons(ETH_P_IPV6);
*ip6h = dst_hdr;
@ -723,8 +745,12 @@ rewrite_icmpv6_inner(struct __sk_buff *skb, __u32 *csum_diff)
* -------------------------------------------------------------------------
*/
icmp6 = data + sizeof(struct ethhdr) + 2 * sizeof(struct ipv6hdr) + sizeof(struct icmp6hdr);
if (icmp6 + 1 > data_end)
if (!ensure_header(&icmp6,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + 2 * sizeof(struct ipv6hdr)
+ sizeof(struct icmp6hdr)))
return -1;
icmp6_buf = *icmp6;
@ -747,8 +773,12 @@ rewrite_icmpv6_inner(struct __sk_buff *skb, __u32 *csum_diff)
data_end = SKB_DATA_END(skb);
data = SKB_DATA(skb);
icmp = data + sizeof(struct ethhdr) + 2 * sizeof(struct ipv6hdr) + sizeof(struct icmp6hdr);
if (icmp + 1 > data_end)
if (!ensure_header(&icmp,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + 2 * sizeof(struct ipv6hdr)
+ sizeof(struct icmp6hdr)))
return -1;
/* Compute the checksum difference between the old ICMPv6 header and the new ICMPv4 one */
@ -779,8 +809,11 @@ rewrite_ipv6_inner(struct __sk_buff *skb, struct iphdr *dst_hdr, __u32 *csum_dif
* ----------------------------------------------------------------
*/
ip6h = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr) + sizeof(struct icmp6hdr);
if (ip6h + 1 > data_end)
if (!ensure_header(&ip6h,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + sizeof(struct ipv6hdr) + sizeof(struct icmp6hdr)))
return -1;
if (!v6addr_equal(&ip6h->saddr, &config.local_v6))
@ -840,8 +873,11 @@ rewrite_icmpv6(struct __sk_buff *skb, int *out_length_diff)
* ---------------------------------------------
*/
icmp6 = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
if (icmp6 + 1 > data_end)
if (!ensure_header(&icmp6,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + sizeof(struct ipv6hdr)))
return -1;
icmp6_buf = *icmp6;
@ -889,18 +925,18 @@ rewrite_icmpv6(struct __sk_buff *skb, int *out_length_diff)
data_end = SKB_DATA_END(skb);
data = SKB_DATA(skb);
/* Rewrite the ICMPv6 header with the translated ICMPv4 one */
if (!ensure_header(&ip,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + sizeof(struct ipv6hdr) + sizeof(struct icmphdr)))
return -1;
icmp = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
if (icmp + 1 > data_end)
return -1;
/* Rewrite the ICMPv6 header with the translated ICMPv4 one */
*icmp = icmp_buf;
/* Rewrite the inner IPv6 header with the translated IPv4 one */
ip = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr) + sizeof(struct icmphdr);
if (ip + 1 > data_end)
return -1;
*ip = ip_in_buf;
/* Update the ICMPv4 checksum according to all the changes in headers */
@ -931,14 +967,11 @@ clat_handle_v6(struct __sk_buff *skb)
int length_diff = 0;
bool fragmented = false;
eth = data;
if (eth + 1 > data_end)
goto out;
if (eth->h_proto != bpf_htons(ETH_P_IPV6))
if (!ensure_header(&ip6h, skb, &data, &data_end, sizeof(struct ethhdr)))
goto out;
ip6h = data + sizeof(struct ethhdr);
if (ip6h + 1 > data_end)
eth = data;
if (eth->h_proto != bpf_htons(ETH_P_IPV6))
goto out;
if (!v6addr_equal(&ip6h->daddr, &config.local_v6))
@ -960,10 +993,15 @@ clat_handle_v6(struct __sk_buff *skb)
if (ip6h->nexthdr != IPPROTO_ICMPV6)
goto out;
icmp6 = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
if (icmp6 + 1 > data_end)
if (!ensure_header(&icmp6,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + sizeof(struct ipv6hdr)))
goto out;
ip6h = data + sizeof(struct ethhdr);
if (icmp6->icmp6_type != ICMPV6_DEST_UNREACH && icmp6->icmp6_type != ICMPV6_TIME_EXCEED
&& icmp6->icmp6_type != ICMPV6_PKT_TOOBIG)
goto out;
@ -985,13 +1023,19 @@ clat_handle_v6(struct __sk_buff *skb)
translate_ipv6_header(ip6h, &dst_hdr, addr4, config.local_v4.s_addr);
DBG("v6: incoming pkt from src %pI6c (%pI4)\n", &ip6h->saddr, &addr4);
} else if (ip6h->nexthdr == IPPROTO_FRAGMENT) {
struct ip6_frag *frag = data + sizeof(struct ethhdr) + sizeof(struct ipv6hdr);
struct ip6_frag *frag;
int tot_len;
__u16 offset;
if (frag + 1 > data_end)
if (!ensure_header(&frag,
skb,
&data,
&data_end,
sizeof(struct ethhdr) + sizeof(struct ipv6hdr)))
goto out;
ip6h = data + sizeof(struct ethhdr);
/* Translate into an IPv4 fragmented packet, RFC 6145 5.1.1 */
tot_len = bpf_ntohs(ip6h->payload_len) + sizeof(struct iphdr) - sizeof(struct ip6_frag);
@ -1059,8 +1103,7 @@ clat_handle_v6(struct __sk_buff *skb)
data = SKB_DATA(skb);
data_end = SKB_DATA_END(skb);
ip6h = data + sizeof(struct ethhdr);
if (ip6h + 1 > data_end)
if (!ensure_header(&ip6h, skb, &data, &data_end, sizeof(struct ethhdr)))
goto out;
dst_hdr.tot_len =
@ -1083,14 +1126,10 @@ clat_handle_v6(struct __sk_buff *skb)
data = SKB_DATA(skb);
data_end = SKB_DATA_END(skb);
eth = data;
if (eth + 1 > data_end)
goto out;
iph = (void *) (eth + 1);
if (iph + 1 > data_end)
if (!ensure_header(&iph, skb, &data, &data_end, sizeof(struct ethhdr)))
goto out;
eth = data;
eth->h_proto = bpf_htons(ETH_P_IP);
*iph = dst_hdr;