Merge branch 'bpf-allow-opt-out-from-sk-sk_prot-memory_allocated'

Kuniyuki Iwashima says:

====================
bpf: Allow opt-out from sk->sk_prot->memory_allocated.

This series allows opting out of the global per-protocol memory
accounting if socket is configured as such by sysctl or BPF prog.

This series is the successor of the series below [0], but the changes
now fall in net and bpf subsystems only.

I discussed with Roman Gushchin offlist, and he suggested not mixing
two independent subsystems and it would be cleaner not to depend on
memcg.

So, sk->sk_memcg and memcg code are no longer touched, and instead we
use another hole near sk->sk_prot to store a flag for the pure net
opt-out feature.

Overview of the series:

  patch 1 is misc cleanup
  patch 2 allows opt-out from sk->sk_prot->memory_allocated
  patch 3 introduces net.core.bypass_prot_mem
  patch 4 & 5 supports flagging sk->sk_bypass_prot_mem via bpf_setsockopt()
  patch 6 is selftest

Thank you very much for all your help, Shakeel, Roman, Martin, and Eric!

[0]: https://lore.kernel.org/bpf/20250920000751.2091731-1-kuniyu@google.com/

Changes:
  v2:
    * Patch 2:
      * Fill kdoc for skc_bypass_prot_mem
    * Patch 6
      * Fix server fd leak in tcp_create_sockets()
      * Avoid close(0) in check_bypass()

  v1: https://lore.kernel.org/bpf/20251007001120.2661442-1-kuniyu@google.com/
====================

Link: https://patch.msgid.link/20251014235604.3057003-1-kuniyu@google.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>
This commit is contained in:
Martin KaFai Lau
2025-10-16 12:15:10 -07:00
18 changed files with 577 additions and 38 deletions

View File

@@ -212,6 +212,14 @@ mem_pcpu_rsv
Per-cpu reserved forward alloc cache size in page units. Default 1MB per CPU.
bypass_prot_mem
---------------
Skip charging socket buffers to the global per-protocol memory
accounting controlled by net.ipv4.tcp_mem, net.ipv4.udp_mem, etc.
Default: 0 (off)
rmem_default
------------

View File

@@ -17,6 +17,7 @@ struct netns_core {
int sysctl_optmem_max;
u8 sysctl_txrehash;
u8 sysctl_tstamp_allow_data;
u8 sysctl_bypass_prot_mem;
#ifdef CONFIG_PROC_FS
struct prot_inuse __percpu *prot_inuse;

View File

@@ -35,6 +35,9 @@ static inline bool sk_under_memory_pressure(const struct sock *sk)
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (sk->sk_bypass_prot_mem)
return false;
return !!READ_ONCE(*sk->sk_prot->memory_pressure);
}

View File

@@ -118,6 +118,7 @@ typedef __u64 __bitwise __addrpair;
* @skc_reuseport: %SO_REUSEPORT setting
* @skc_ipv6only: socket is IPV6 only
* @skc_net_refcnt: socket is using net ref counting
* @skc_bypass_prot_mem: bypass the per-protocol memory accounting for skb
* @skc_bound_dev_if: bound device index if != 0
* @skc_bind_node: bind hash linkage for various protocol lookup tables
* @skc_portaddr_node: second hash linkage for UDP/UDP-Lite protocol
@@ -174,6 +175,7 @@ struct sock_common {
unsigned char skc_reuseport:1;
unsigned char skc_ipv6only:1;
unsigned char skc_net_refcnt:1;
unsigned char skc_bypass_prot_mem:1;
int skc_bound_dev_if;
union {
struct hlist_node skc_bind_node;
@@ -381,6 +383,7 @@ struct sock {
#define sk_reuseport __sk_common.skc_reuseport
#define sk_ipv6only __sk_common.skc_ipv6only
#define sk_net_refcnt __sk_common.skc_net_refcnt
#define sk_bypass_prot_mem __sk_common.skc_bypass_prot_mem
#define sk_bound_dev_if __sk_common.skc_bound_dev_if
#define sk_bind_node __sk_common.skc_bind_node
#define sk_prot __sk_common.skc_prot

View File

@@ -303,6 +303,9 @@ static inline bool tcp_under_memory_pressure(const struct sock *sk)
mem_cgroup_sk_under_memory_pressure(sk))
return true;
if (sk->sk_bypass_prot_mem)
return false;
return READ_ONCE(tcp_memory_pressure);
}
/*

View File

@@ -7200,6 +7200,8 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
SK_BPF_BYPASS_PROT_MEM = 1010, /* Get or Set sk->sk_bypass_prot_mem */
};
enum {

View File

@@ -5733,6 +5733,77 @@ static const struct bpf_func_proto bpf_sock_addr_getsockopt_proto = {
.arg5_type = ARG_CONST_SIZE,
};
static int sk_bpf_set_get_bypass_prot_mem(struct sock *sk,
char *optval, int optlen,
bool getopt)
{
int val;
if (optlen != sizeof(int))
return -EINVAL;
if (!sk_has_account(sk))
return -EOPNOTSUPP;
if (getopt) {
*(int *)optval = sk->sk_bypass_prot_mem;
return 0;
}
val = *(int *)optval;
if (val < 0 || val > 1)
return -EINVAL;
sk->sk_bypass_prot_mem = val;
return 0;
}
BPF_CALL_5(bpf_sock_create_setsockopt, struct sock *, sk, int, level,
int, optname, char *, optval, int, optlen)
{
if (level == SOL_SOCKET && optname == SK_BPF_BYPASS_PROT_MEM)
return sk_bpf_set_get_bypass_prot_mem(sk, optval, optlen, false);
return __bpf_setsockopt(sk, level, optname, optval, optlen);
}
static const struct bpf_func_proto bpf_sock_create_setsockopt_proto = {
.func = bpf_sock_create_setsockopt,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_CTX,
.arg2_type = ARG_ANYTHING,
.arg3_type = ARG_ANYTHING,
.arg4_type = ARG_PTR_TO_MEM | MEM_RDONLY,
.arg5_type = ARG_CONST_SIZE,
};
BPF_CALL_5(bpf_sock_create_getsockopt, struct sock *, sk, int, level,
int, optname, char *, optval, int, optlen)
{
if (level == SOL_SOCKET && optname == SK_BPF_BYPASS_PROT_MEM) {
int err = sk_bpf_set_get_bypass_prot_mem(sk, optval, optlen, true);
if (err)
memset(optval, 0, optlen);
return err;
}
return __bpf_getsockopt(sk, level, optname, optval, optlen);
}
static const struct bpf_func_proto bpf_sock_create_getsockopt_proto = {
.func = bpf_sock_create_getsockopt,
.gpl_only = false,
.ret_type = RET_INTEGER,
.arg1_type = ARG_PTR_TO_CTX,
.arg2_type = ARG_ANYTHING,
.arg3_type = ARG_ANYTHING,
.arg4_type = ARG_PTR_TO_UNINIT_MEM,
.arg5_type = ARG_CONST_SIZE,
};
BPF_CALL_5(bpf_sock_ops_setsockopt, struct bpf_sock_ops_kern *, bpf_sock,
int, level, int, optname, char *, optval, int, optlen)
{
@@ -8062,6 +8133,20 @@ sock_filter_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
return &bpf_sk_storage_get_cg_sock_proto;
case BPF_FUNC_ktime_get_coarse_ns:
return &bpf_ktime_get_coarse_ns_proto;
case BPF_FUNC_setsockopt:
switch (prog->expected_attach_type) {
case BPF_CGROUP_INET_SOCK_CREATE:
return &bpf_sock_create_setsockopt_proto;
default:
return NULL;
}
case BPF_FUNC_getsockopt:
switch (prog->expected_attach_type) {
case BPF_CGROUP_INET_SOCK_CREATE:
return &bpf_sock_create_getsockopt_proto;
default:
return NULL;
}
default:
return bpf_base_func_proto(func_id, prog);
}

View File

@@ -1046,9 +1046,13 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
if (!charged)
return -ENOMEM;
if (sk->sk_bypass_prot_mem)
goto success;
/* pre-charge to forward_alloc */
sk_memory_allocated_add(sk, pages);
allocated = sk_memory_allocated(sk);
/* If the system goes into memory pressure with this
* precharge, give up and return error.
*/
@@ -1057,6 +1061,8 @@ static int sock_reserve_memory(struct sock *sk, int bytes)
mem_cgroup_sk_uncharge(sk, pages);
return -ENOMEM;
}
success:
sk_forward_alloc_add(sk, pages << PAGE_SHIFT);
WRITE_ONCE(sk->sk_reserved_mem,
@@ -2300,8 +2306,13 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
* why we need sk_prot_creator -acme
*/
sk->sk_prot = sk->sk_prot_creator = prot;
if (READ_ONCE(net->core.sysctl_bypass_prot_mem))
sk->sk_bypass_prot_mem = 1;
sk->sk_kern_sock = kern;
sock_lock_init(sk);
sk->sk_net_refcnt = kern ? 0 : 1;
if (likely(sk->sk_net_refcnt)) {
get_net_track(net, &sk->ns_tracker, priority);
@@ -3145,8 +3156,11 @@ bool sk_page_frag_refill(struct sock *sk, struct page_frag *pfrag)
if (likely(skb_page_frag_refill(32U, pfrag, sk->sk_allocation)))
return true;
sk_enter_memory_pressure(sk);
if (!sk->sk_bypass_prot_mem)
sk_enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
return false;
}
EXPORT_SYMBOL(sk_page_frag_refill);
@@ -3263,10 +3277,12 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
{
bool memcg_enabled = false, charged = false;
struct proto *prot = sk->sk_prot;
long allocated;
long allocated = 0;
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
if (!sk->sk_bypass_prot_mem) {
sk_memory_allocated_add(sk, amt);
allocated = sk_memory_allocated(sk);
}
if (mem_cgroup_sk_enabled(sk)) {
memcg_enabled = true;
@@ -3275,6 +3291,9 @@ int __sk_mem_raise_allocated(struct sock *sk, int size, int amt, int kind)
goto suppress_allocation;
}
if (!allocated)
return 1;
/* Under limit. */
if (allocated <= sk_prot_mem_limits(sk, 0)) {
sk_leave_memory_pressure(sk);
@@ -3353,7 +3372,8 @@ suppress_allocation:
trace_sock_exceed_buf_limit(sk, prot, allocated, kind);
sk_memory_allocated_sub(sk, amt);
if (allocated)
sk_memory_allocated_sub(sk, amt);
if (charged)
mem_cgroup_sk_uncharge(sk, amt);
@@ -3392,11 +3412,14 @@ EXPORT_SYMBOL(__sk_mem_schedule);
*/
void __sk_mem_reduce_allocated(struct sock *sk, int amount)
{
sk_memory_allocated_sub(sk, amount);
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_sk_uncharge(sk, amount);
if (sk->sk_bypass_prot_mem)
return;
sk_memory_allocated_sub(sk, amount);
if (sk_under_global_memory_pressure(sk) &&
(sk_memory_allocated(sk) < sk_prot_mem_limits(sk, 0)))
sk_leave_memory_pressure(sk);

View File

@@ -683,6 +683,15 @@ static struct ctl_table netns_core_table[] = {
.extra1 = SYSCTL_ZERO,
.extra2 = SYSCTL_ONE
},
{
.procname = "bypass_prot_mem",
.data = &init_net.core.sysctl_bypass_prot_mem,
.maxlen = sizeof(u8),
.mode = 0644,
.proc_handler = proc_dou8vec_minmax,
.extra1 = SYSCTL_ZERO,
.extra2 = SYSCTL_ONE
},
/* sysctl_core_net_init() will set the values after this
* to readonly in network namespaces
*/

View File

@@ -755,6 +755,28 @@ EXPORT_SYMBOL(inet_stream_connect);
void __inet_accept(struct socket *sock, struct socket *newsock, struct sock *newsk)
{
/* TODO: use sk_clone_lock() in SCTP and remove protocol checks */
if (mem_cgroup_sockets_enabled &&
(!IS_ENABLED(CONFIG_IP_SCTP) || sk_is_tcp(newsk))) {
gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
mem_cgroup_sk_alloc(newsk);
if (mem_cgroup_from_sk(newsk)) {
int amt;
/* The socket has not been accepted yet, no need
* to look at newsk->sk_wmem_queued.
*/
amt = sk_mem_pages(newsk->sk_forward_alloc +
atomic_read(&newsk->sk_rmem_alloc));
if (amt)
mem_cgroup_sk_charge(newsk, amt, gfp);
}
kmem_cache_charge(newsk, gfp);
}
sock_rps_record_flow(newsk);
WARN_ON(!((1 << newsk->sk_state) &
(TCPF_ESTABLISHED | TCPF_SYN_RECV |

View File

@@ -712,31 +712,6 @@ struct sock *inet_csk_accept(struct sock *sk, struct proto_accept_arg *arg)
release_sock(sk);
if (mem_cgroup_sockets_enabled) {
gfp_t gfp = GFP_KERNEL | __GFP_NOFAIL;
int amt = 0;
/* atomically get the memory usage, set and charge the
* newsk->sk_memcg.
*/
lock_sock(newsk);
mem_cgroup_sk_alloc(newsk);
if (mem_cgroup_from_sk(newsk)) {
/* The socket has not been accepted yet, no need
* to look at newsk->sk_wmem_queued.
*/
amt = sk_mem_pages(newsk->sk_forward_alloc +
atomic_read(&newsk->sk_rmem_alloc));
}
if (amt)
mem_cgroup_sk_charge(newsk, amt, gfp);
kmem_cache_charge(newsk, gfp);
release_sock(newsk);
}
if (req)
reqsk_put(req);

View File

@@ -928,7 +928,8 @@ struct sk_buff *tcp_stream_alloc_skb(struct sock *sk, gfp_t gfp,
}
__kfree_skb(skb);
} else {
sk->sk_prot->enter_memory_pressure(sk);
if (!sk->sk_bypass_prot_mem)
tcp_enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
}
return NULL;

View File

@@ -3743,12 +3743,17 @@ void sk_forced_mem_schedule(struct sock *sk, int size)
delta = size - sk->sk_forward_alloc;
if (delta <= 0)
return;
amt = sk_mem_pages(delta);
sk_forward_alloc_add(sk, amt << PAGE_SHIFT);
sk_memory_allocated_add(sk, amt);
if (mem_cgroup_sk_enabled(sk))
mem_cgroup_sk_charge(sk, amt, gfp_memcg_charge() | __GFP_NOFAIL);
if (sk->sk_bypass_prot_mem)
return;
sk_memory_allocated_add(sk, amt);
}
/* Send a FIN. The caller locks the socket for us.

View File

@@ -1065,11 +1065,12 @@ static void mptcp_enter_memory_pressure(struct sock *sk)
mptcp_for_each_subflow(msk, subflow) {
struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
if (first)
if (first && !ssk->sk_bypass_prot_mem) {
tcp_enter_memory_pressure(ssk);
sk_stream_moderate_sndbuf(ssk);
first = false;
}
first = false;
sk_stream_moderate_sndbuf(ssk);
}
__mptcp_sync_sndbuf(sk);
}

View File

@@ -373,7 +373,8 @@ static int tls_do_allocation(struct sock *sk,
if (!offload_ctx->open_record) {
if (unlikely(!skb_page_frag_refill(prepend_size, pfrag,
sk->sk_allocation))) {
READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
if (!sk->sk_bypass_prot_mem)
READ_ONCE(sk->sk_prot)->enter_memory_pressure(sk);
sk_stream_moderate_sndbuf(sk);
return -ENOMEM;
}

View File

@@ -7200,6 +7200,7 @@ enum {
TCP_BPF_SYN_MAC = 1007, /* Copy the MAC, IP[46], and TCP header */
TCP_BPF_SOCK_OPS_CB_FLAGS = 1008, /* Get or Set TCP sock ops flags */
SK_BPF_CB_FLAGS = 1009, /* Get or set sock ops flags in socket */
SK_BPF_BYPASS_PROT_MEM = 1010, /* Get or Set sk->sk_bypass_prot_mem */
};
enum {

View File

@@ -0,0 +1,292 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright 2025 Google LLC */
#include <test_progs.h>
#include "sk_bypass_prot_mem.skel.h"
#include "network_helpers.h"
#define NR_PAGES 32
#define NR_SOCKETS 2
#define BUF_TOTAL (NR_PAGES * 4096 / NR_SOCKETS)
#define BUF_SINGLE 1024
#define NR_SEND (BUF_TOTAL / BUF_SINGLE)
struct test_case {
char name[8];
int family;
int type;
int (*create_sockets)(struct test_case *test_case, int sk[], int len);
long (*get_memory_allocated)(struct test_case *test_case, struct sk_bypass_prot_mem *skel);
};
static int tcp_create_sockets(struct test_case *test_case, int sk[], int len)
{
int server, i, err = 0;
server = start_server(test_case->family, test_case->type, NULL, 0, 0);
if (!ASSERT_GE(server, 0, "start_server_str"))
return server;
/* Keep for-loop so we can change NR_SOCKETS easily. */
for (i = 0; i < len; i += 2) {
sk[i] = connect_to_fd(server, 0);
if (sk[i] < 0) {
ASSERT_GE(sk[i], 0, "connect_to_fd");
err = sk[i];
break;
}
sk[i + 1] = accept(server, NULL, NULL);
if (sk[i + 1] < 0) {
ASSERT_GE(sk[i + 1], 0, "accept");
err = sk[i + 1];
break;
}
}
close(server);
return err;
}
static int udp_create_sockets(struct test_case *test_case, int sk[], int len)
{
int i, j, err, rcvbuf = BUF_TOTAL;
/* Keep for-loop so we can change NR_SOCKETS easily. */
for (i = 0; i < len; i += 2) {
sk[i] = start_server(test_case->family, test_case->type, NULL, 0, 0);
if (sk[i] < 0) {
ASSERT_GE(sk[i], 0, "start_server");
return sk[i];
}
sk[i + 1] = connect_to_fd(sk[i], 0);
if (sk[i + 1] < 0) {
ASSERT_GE(sk[i + 1], 0, "connect_to_fd");
return sk[i + 1];
}
err = connect_fd_to_fd(sk[i], sk[i + 1], 0);
if (err) {
ASSERT_EQ(err, 0, "connect_fd_to_fd");
return err;
}
for (j = 0; j < 2; j++) {
err = setsockopt(sk[i + j], SOL_SOCKET, SO_RCVBUF, &rcvbuf, sizeof(int));
if (err) {
ASSERT_EQ(err, 0, "setsockopt(SO_RCVBUF)");
return err;
}
}
}
return 0;
}
static long get_memory_allocated(struct test_case *test_case,
bool *activated, long *memory_allocated)
{
int sk;
*activated = true;
/* AF_INET and AF_INET6 share the same memory_allocated.
* tcp_init_sock() is called by AF_INET and AF_INET6,
* but udp_lib_init_sock() is inline.
*/
sk = socket(AF_INET, test_case->type, 0);
if (!ASSERT_GE(sk, 0, "get_memory_allocated"))
return -1;
close(sk);
return *memory_allocated;
}
static long tcp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
{
return get_memory_allocated(test_case,
&skel->bss->tcp_activated,
&skel->bss->tcp_memory_allocated);
}
static long udp_get_memory_allocated(struct test_case *test_case, struct sk_bypass_prot_mem *skel)
{
return get_memory_allocated(test_case,
&skel->bss->udp_activated,
&skel->bss->udp_memory_allocated);
}
static int check_bypass(struct test_case *test_case,
struct sk_bypass_prot_mem *skel, bool bypass)
{
char buf[BUF_SINGLE] = {};
long memory_allocated[2];
int sk[NR_SOCKETS];
int err, i, j;
for (i = 0; i < ARRAY_SIZE(sk); i++)
sk[i] = -1;
err = test_case->create_sockets(test_case, sk, ARRAY_SIZE(sk));
if (err)
goto close;
memory_allocated[0] = test_case->get_memory_allocated(test_case, skel);
/* allocate pages >= NR_PAGES */
for (i = 0; i < ARRAY_SIZE(sk); i++) {
for (j = 0; j < NR_SEND; j++) {
int bytes = send(sk[i], buf, sizeof(buf), 0);
/* Avoid too noisy logs when something failed. */
if (bytes != sizeof(buf)) {
ASSERT_EQ(bytes, sizeof(buf), "send");
if (bytes < 0) {
err = bytes;
goto drain;
}
}
}
}
memory_allocated[1] = test_case->get_memory_allocated(test_case, skel);
if (bypass)
ASSERT_LE(memory_allocated[1], memory_allocated[0] + 10, "bypass");
else
ASSERT_GT(memory_allocated[1], memory_allocated[0] + NR_PAGES, "no bypass");
drain:
if (test_case->type == SOCK_DGRAM) {
/* UDP starts purging sk->sk_receive_queue after one RCU
* grace period, then udp_memory_allocated goes down,
* so drain the queue before close().
*/
for (i = 0; i < ARRAY_SIZE(sk); i++) {
for (j = 0; j < NR_SEND; j++) {
int bytes = recv(sk[i], buf, 1, MSG_DONTWAIT | MSG_TRUNC);
if (bytes == sizeof(buf))
continue;
if (bytes != -1 || errno != EAGAIN)
PRINT_FAIL("bytes: %d, errno: %s\n", bytes, strerror(errno));
break;
}
}
}
close:
for (i = 0; i < ARRAY_SIZE(sk); i++) {
if (sk[i] < 0)
break;
close(sk[i]);
}
return err;
}
static void run_test(struct test_case *test_case)
{
struct sk_bypass_prot_mem *skel;
struct nstoken *nstoken;
int cgroup, err;
skel = sk_bypass_prot_mem__open_and_load();
if (!ASSERT_OK_PTR(skel, "open_and_load"))
return;
skel->bss->nr_cpus = libbpf_num_possible_cpus();
err = sk_bypass_prot_mem__attach(skel);
if (!ASSERT_OK(err, "attach"))
goto destroy_skel;
cgroup = test__join_cgroup("/sk_bypass_prot_mem");
if (!ASSERT_GE(cgroup, 0, "join_cgroup"))
goto destroy_skel;
err = make_netns("sk_bypass_prot_mem");
if (!ASSERT_EQ(err, 0, "make_netns"))
goto close_cgroup;
nstoken = open_netns("sk_bypass_prot_mem");
if (!ASSERT_OK_PTR(nstoken, "open_netns"))
goto remove_netns;
err = check_bypass(test_case, skel, false);
if (!ASSERT_EQ(err, 0, "test_bypass(false)"))
goto close_netns;
err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "1");
if (!ASSERT_EQ(err, 0, "write_sysctl(1)"))
goto close_netns;
err = check_bypass(test_case, skel, true);
if (!ASSERT_EQ(err, 0, "test_bypass(true by sysctl)"))
goto close_netns;
err = write_sysctl("/proc/sys/net/core/bypass_prot_mem", "0");
if (!ASSERT_EQ(err, 0, "write_sysctl(0)"))
goto close_netns;
skel->links.sock_create = bpf_program__attach_cgroup(skel->progs.sock_create, cgroup);
if (!ASSERT_OK_PTR(skel->links.sock_create, "attach_cgroup(sock_create)"))
goto close_netns;
err = check_bypass(test_case, skel, true);
ASSERT_EQ(err, 0, "test_bypass(true by bpf)");
close_netns:
close_netns(nstoken);
remove_netns:
remove_netns("sk_bypass_prot_mem");
close_cgroup:
close(cgroup);
destroy_skel:
sk_bypass_prot_mem__destroy(skel);
}
static struct test_case test_cases[] = {
{
.name = "TCP ",
.family = AF_INET,
.type = SOCK_STREAM,
.create_sockets = tcp_create_sockets,
.get_memory_allocated = tcp_get_memory_allocated,
},
{
.name = "UDP ",
.family = AF_INET,
.type = SOCK_DGRAM,
.create_sockets = udp_create_sockets,
.get_memory_allocated = udp_get_memory_allocated,
},
{
.name = "TCPv6",
.family = AF_INET6,
.type = SOCK_STREAM,
.create_sockets = tcp_create_sockets,
.get_memory_allocated = tcp_get_memory_allocated,
},
{
.name = "UDPv6",
.family = AF_INET6,
.type = SOCK_DGRAM,
.create_sockets = udp_create_sockets,
.get_memory_allocated = udp_get_memory_allocated,
},
};
void serial_test_sk_bypass_prot_mem(void)
{
int i;
for (i = 0; i < ARRAY_SIZE(test_cases); i++) {
if (test__start_subtest(test_cases[i].name))
run_test(&test_cases[i]);
}
}

View File

@@ -0,0 +1,104 @@
// SPDX-License-Identifier: GPL-2.0
/* Copyright 2025 Google LLC */
#include "bpf_tracing_net.h"
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_tracing.h>
#include <errno.h>
extern int tcp_memory_per_cpu_fw_alloc __ksym;
extern int udp_memory_per_cpu_fw_alloc __ksym;
int nr_cpus;
bool tcp_activated, udp_activated;
long tcp_memory_allocated, udp_memory_allocated;
struct sk_prot {
long *memory_allocated;
int *memory_per_cpu_fw_alloc;
};
static int drain_memory_per_cpu_fw_alloc(__u32 i, struct sk_prot *sk_prot_ctx)
{
int *memory_per_cpu_fw_alloc;
memory_per_cpu_fw_alloc = bpf_per_cpu_ptr(sk_prot_ctx->memory_per_cpu_fw_alloc, i);
if (memory_per_cpu_fw_alloc)
*sk_prot_ctx->memory_allocated += *memory_per_cpu_fw_alloc;
return 0;
}
static long get_memory_allocated(struct sock *_sk, int *memory_per_cpu_fw_alloc)
{
struct sock *sk = bpf_core_cast(_sk, struct sock);
struct sk_prot sk_prot_ctx;
long memory_allocated;
/* net_aligned_data.{tcp,udp}_memory_allocated was not available. */
memory_allocated = sk->__sk_common.skc_prot->memory_allocated->counter;
sk_prot_ctx.memory_allocated = &memory_allocated;
sk_prot_ctx.memory_per_cpu_fw_alloc = memory_per_cpu_fw_alloc;
bpf_loop(nr_cpus, drain_memory_per_cpu_fw_alloc, &sk_prot_ctx, 0);
return memory_allocated;
}
static void fentry_init_sock(struct sock *sk, bool *activated,
long *memory_allocated, int *memory_per_cpu_fw_alloc)
{
if (!*activated)
return;
*memory_allocated = get_memory_allocated(sk, memory_per_cpu_fw_alloc);
*activated = false;
}
SEC("fentry/tcp_init_sock")
int BPF_PROG(fentry_tcp_init_sock, struct sock *sk)
{
fentry_init_sock(sk, &tcp_activated,
&tcp_memory_allocated, &tcp_memory_per_cpu_fw_alloc);
return 0;
}
SEC("fentry/udp_init_sock")
int BPF_PROG(fentry_udp_init_sock, struct sock *sk)
{
fentry_init_sock(sk, &udp_activated,
&udp_memory_allocated, &udp_memory_per_cpu_fw_alloc);
return 0;
}
SEC("cgroup/sock_create")
int sock_create(struct bpf_sock *ctx)
{
int err, val = 1;
err = bpf_setsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM,
&val, sizeof(val));
if (err)
goto err;
val = 0;
err = bpf_getsockopt(ctx, SOL_SOCKET, SK_BPF_BYPASS_PROT_MEM,
&val, sizeof(val));
if (err)
goto err;
if (val != 1) {
err = -EINVAL;
goto err;
}
return 1;
err:
bpf_set_retval(err);
return 0;
}
char LICENSE[] SEC("license") = "GPL";