diff --git a/ebpf/xdp_lb.c b/ebpf/xdp_lb.c index f2a9997a94..2ee924572d 100644 --- a/ebpf/xdp_lb.c +++ b/ebpf/xdp_lb.c @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -66,18 +67,18 @@ struct bpf_map_def SEC("maps") cpus_count = { .max_entries = 1, }; -static int __always_inline filter_ipv4(struct xdp_md *ctx, void *data, __u64 nh_off, void *data_end) +static int __always_inline hash_ipv4(void *data, void *data_end) { - struct iphdr *iph = data + nh_off; + struct iphdr *iph = data; + if ((void *)(iph + 1) > data_end) + return XDP_PASS; + __u32 key0 = 0; __u32 cpu_dest; __u32 *cpu_max = bpf_map_lookup_elem(&cpus_count, &key0); __u32 *cpu_selected; __u32 cpu_hash; - if ((void *)(iph + 1) > data_end) - return XDP_PASS; - /* IP-pairs hit same CPU */ cpu_hash = iph->saddr + iph->daddr; cpu_hash = SuperFastHash((char *)&cpu_hash, 4, INITVAL); @@ -94,18 +95,18 @@ static int __always_inline filter_ipv4(struct xdp_md *ctx, void *data, __u64 nh_ } } -static int __always_inline filter_ipv6(struct xdp_md *ctx, void *data, __u64 nh_off, void *data_end) +static int __always_inline hash_ipv6(void *data, void *data_end) { - struct ipv6hdr *ip6h = data + nh_off; + struct ipv6hdr *ip6h = data; + if ((void *)(ip6h + 1) > data_end) + return XDP_PASS; + __u32 key0 = 0; __u32 cpu_dest; __u32 *cpu_max = bpf_map_lookup_elem(&cpus_count, &key0); __u32 *cpu_selected; __u32 cpu_hash; - if ((void *)(ip6h + 1) > data_end) - return XDP_PASS; - /* IP-pairs hit same CPU */ cpu_hash = ip6h->saddr.s6_addr32[0] + ip6h->daddr.s6_addr32[0]; cpu_hash += ip6h->saddr.s6_addr32[1] + ip6h->daddr.s6_addr32[1]; @@ -127,6 +128,75 @@ static int __always_inline filter_ipv6(struct xdp_md *ctx, void *data, __u64 nh_ return XDP_PASS; } +static int __always_inline filter_gre(struct xdp_md *ctx, void *data, __u64 nh_off, void *data_end) +{ + struct iphdr *iph = data + nh_off; + __be16 proto; + struct gre_hdr { + __be16 flags; + __be16 proto; + }; + + nh_off += sizeof(struct iphdr); + struct gre_hdr *grhdr = (struct gre_hdr *)(iph + 1); + + if ((void *)(grhdr + 1) > data_end) + return XDP_PASS; + + if (grhdr->flags & (GRE_VERSION|GRE_ROUTING)) + return XDP_PASS; + + nh_off += 4; + proto = grhdr->proto; + if (grhdr->flags & GRE_CSUM) + nh_off += 4; + if (grhdr->flags & GRE_KEY) + nh_off += 4; + if (grhdr->flags & GRE_SEQ) + nh_off += 4; + + if (data + nh_off > data_end) + return XDP_PASS; + if (bpf_xdp_adjust_head(ctx, 0 + nh_off)) + return XDP_PASS; + + data = (void *)(long)ctx->data; + data_end = (void *)(long)ctx->data_end; + + if (proto == __constant_htons(ETH_P_8021Q)) { + struct vlan_hdr *vhdr = (struct vlan_hdr *)(data); + if ((void *)(vhdr + 1) > data_end) + return XDP_PASS; + proto = vhdr->h_vlan_encapsulated_proto; + nh_off += sizeof(struct vlan_hdr); + } + + if (proto == __constant_htons(ETH_P_IP)) { + return hash_ipv4(data, data_end); + } else if (proto == __constant_htons(ETH_P_IPV6)) { + return hash_ipv6(data, data_end); + } else + return XDP_PASS; +} + +static int __always_inline filter_ipv4(struct xdp_md *ctx, void *data, __u64 nh_off, void *data_end) +{ + struct iphdr *iph = data + nh_off; + if ((void *)(iph + 1) > data_end) + return XDP_PASS; + + if (iph->protocol == IPPROTO_GRE) { + return filter_gre(ctx, data, nh_off, data_end); + } + return hash_ipv4(data + nh_off, data_end); +} + +static int __always_inline filter_ipv6(struct xdp_md *ctx, void *data, __u64 nh_off, void *data_end) +{ + struct ipv6hdr *ip6h = data + nh_off; + return hash_ipv6((void *)ip6h, data_end); +} + int SEC("xdp") xdp_loadfilter(struct xdp_md *ctx) { void *data_end = (void *)(long)ctx->data_end; @@ -141,6 +211,12 @@ int SEC("xdp") xdp_loadfilter(struct xdp_md *ctx) h_proto = eth->h_proto; +#if 0 + if (h_proto != __constant_htons(ETH_P_IP)) { + char fmt[] = "Current proto: %u\n"; + bpf_trace_printk(fmt, sizeof(fmt), h_proto); + } +#endif if (h_proto == __constant_htons(ETH_P_8021Q) || h_proto == __constant_htons(ETH_P_8021AD)) { struct vlan_hdr *vhdr;