Merge branch 'mpls-remove-rtnl-dependency'

Kuniyuki Iwashima says:

====================
mpls: Remove RTNL dependency.

MPLS uses RTNL

  1) to guarantee the lifetime of struct mpls_nh.nh_dev
  2) to protect net->mpls.platform_label

, but neither actually requires RTNL.

If struct mpls_nh holds a refcnt for nh_dev, we do not need RTNL,
and it can be replaced with a dedicated mutex.

The series removes RTNL from net/mpls/.

Overview:

  Patch 1 is misc cleanup.

  Patch 2 - 9 are prep to drop RTNL for RTM_{NEW,DEL,GET}ROUTE
  handlers.

  Patch 10 & 11 converts mpls_dump_routes() and RTM_GETNETCONF to RCU.

  Patch 12 replaces RTNL with a new per-netns mutex.

  Patch 13 drops RTNL from RTM_{NEW,DEL,GET}ROUTE.
====================

Link: https://patch.msgid.link/20251029173344.2934622-1-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
This commit is contained in:
Jakub Kicinski
2025-11-03 17:40:59 -08:00
5 changed files with 224 additions and 128 deletions

View File

@@ -347,6 +347,11 @@ static inline struct inet6_dev *__in6_dev_get(const struct net_device *dev)
return rcu_dereference_rtnl(dev->ip6_ptr);
}
static inline struct inet6_dev *in6_dev_rcu(const struct net_device *dev)
{
return rcu_dereference(dev->ip6_ptr);
}
static inline struct inet6_dev *__in6_dev_get_rtnl_net(const struct net_device *dev)
{
return rtnl_net_dereference(dev_net(dev), dev->ip6_ptr);

View File

@@ -16,6 +16,7 @@ struct netns_mpls {
int default_ttl;
size_t platform_labels;
struct mpls_route __rcu * __rcu *platform_label;
struct mutex platform_mutex;
struct ctl_table_header *ctl;
};

View File

@@ -75,16 +75,23 @@ static void rtmsg_lfib(int event, u32 label, struct mpls_route *rt,
struct nlmsghdr *nlh, struct net *net, u32 portid,
unsigned int nlm_flags);
static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index)
static struct mpls_route *mpls_route_input(struct net *net, unsigned int index)
{
struct mpls_route *rt = NULL;
struct mpls_route __rcu **platform_label;
if (index < net->mpls.platform_labels) {
struct mpls_route __rcu **platform_label =
rcu_dereference_rtnl(net->mpls.platform_label);
rt = rcu_dereference_rtnl(platform_label[index]);
}
return rt;
platform_label = mpls_dereference(net, net->mpls.platform_label);
return mpls_dereference(net, platform_label[index]);
}
static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index)
{
struct mpls_route __rcu **platform_label;
if (index >= net->mpls.platform_labels)
return NULL;
platform_label = rcu_dereference(net->mpls.platform_label);
return rcu_dereference(platform_label[index]);
}
bool mpls_output_possible(const struct net_device *dev)
@@ -129,25 +136,26 @@ bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu)
}
EXPORT_SYMBOL_GPL(mpls_pkt_too_big);
void mpls_stats_inc_outucastpkts(struct net_device *dev,
void mpls_stats_inc_outucastpkts(struct net *net,
struct net_device *dev,
const struct sk_buff *skb)
{
struct mpls_dev *mdev;
if (skb->protocol == htons(ETH_P_MPLS_UC)) {
mdev = mpls_dev_get(dev);
mdev = mpls_dev_rcu(dev);
if (mdev)
MPLS_INC_STATS_LEN(mdev, skb->len,
tx_packets,
tx_bytes);
} else if (skb->protocol == htons(ETH_P_IP)) {
IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len);
IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len);
#if IS_ENABLED(CONFIG_IPV6)
} else if (skb->protocol == htons(ETH_P_IPV6)) {
struct inet6_dev *in6dev = __in6_dev_get(dev);
struct inet6_dev *in6dev = in6_dev_rcu(dev);
if (in6dev)
IP6_UPD_PO_STATS(dev_net(dev), in6dev,
IP6_UPD_PO_STATS(net, in6dev,
IPSTATS_MIB_OUT, skb->len);
#endif
}
@@ -342,7 +350,7 @@ static bool mpls_egress(struct net *net, struct mpls_route *rt,
static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
struct packet_type *pt, struct net_device *orig_dev)
{
struct net *net = dev_net(dev);
struct net *net = dev_net_rcu(dev);
struct mpls_shim_hdr *hdr;
const struct mpls_nh *nh;
struct mpls_route *rt;
@@ -357,7 +365,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
/* Careful this entire function runs inside of an rcu critical section */
mdev = mpls_dev_get(dev);
mdev = mpls_dev_rcu(dev);
if (!mdev)
goto drop;
@@ -434,7 +442,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
dec.ttl -= 1;
if (unlikely(!new_header_size && dec.bos)) {
/* Penultimate hop popping */
if (!mpls_egress(dev_net(out_dev), rt, skb, dec))
if (!mpls_egress(net, rt, skb, dec))
goto err;
} else {
bool bos;
@@ -451,7 +459,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
}
}
mpls_stats_inc_outucastpkts(out_dev, skb);
mpls_stats_inc_outucastpkts(net, out_dev, skb);
/* If via wasn't specified then send out using device address */
if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC)
@@ -466,7 +474,7 @@ static int mpls_forward(struct sk_buff *skb, struct net_device *dev,
return 0;
tx_err:
out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL;
out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL;
if (out_mdev)
MPLS_INC_STATS(out_mdev, tx_errors);
goto drop;
@@ -530,10 +538,23 @@ static struct mpls_route *mpls_rt_alloc(u8 num_nh, u8 max_alen, u8 max_labels)
return rt;
}
static void mpls_rt_free_rcu(struct rcu_head *head)
{
struct mpls_route *rt;
rt = container_of(head, struct mpls_route, rt_rcu);
change_nexthops(rt) {
netdev_put(nh->nh_dev, &nh->nh_dev_tracker);
} endfor_nexthops(rt);
kfree(rt);
}
static void mpls_rt_free(struct mpls_route *rt)
{
if (rt)
kfree_rcu(rt, rt_rcu);
call_rcu(&rt->rt_rcu, mpls_rt_free_rcu);
}
static void mpls_notify_route(struct net *net, unsigned index,
@@ -557,10 +578,8 @@ static void mpls_route_update(struct net *net, unsigned index,
struct mpls_route __rcu **platform_label;
struct mpls_route *rt;
ASSERT_RTNL();
platform_label = rtnl_dereference(net->mpls.platform_label);
rt = rtnl_dereference(platform_label[index]);
platform_label = mpls_dereference(net, net->mpls.platform_label);
rt = mpls_dereference(net, platform_label[index]);
rcu_assign_pointer(platform_label[index], new);
mpls_notify_route(net, index, rt, new, info);
@@ -569,24 +588,23 @@ static void mpls_route_update(struct net *net, unsigned index,
mpls_rt_free(rt);
}
static unsigned find_free_label(struct net *net)
static unsigned int find_free_label(struct net *net)
{
struct mpls_route __rcu **platform_label;
size_t platform_labels;
unsigned index;
unsigned int index;
platform_label = rtnl_dereference(net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels;
for (index = MPLS_LABEL_FIRST_UNRESERVED;
index < net->mpls.platform_labels;
index++) {
if (!rtnl_dereference(platform_label[index]))
if (!mpls_route_input(net, index))
return index;
}
return LABEL_NOT_SPECIFIED;
}
#if IS_ENABLED(CONFIG_INET)
static struct net_device *inet_fib_lookup_dev(struct net *net,
struct mpls_nh *nh,
const void *addr)
{
struct net_device *dev;
@@ -599,14 +617,14 @@ static struct net_device *inet_fib_lookup_dev(struct net *net,
return ERR_CAST(rt);
dev = rt->dst.dev;
dev_hold(dev);
netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL);
ip_rt_put(rt);
return dev;
}
#else
static struct net_device *inet_fib_lookup_dev(struct net *net,
struct mpls_nh *nh,
const void *addr)
{
return ERR_PTR(-EAFNOSUPPORT);
@@ -615,6 +633,7 @@ static struct net_device *inet_fib_lookup_dev(struct net *net,
#if IS_ENABLED(CONFIG_IPV6)
static struct net_device *inet6_fib_lookup_dev(struct net *net,
struct mpls_nh *nh,
const void *addr)
{
struct net_device *dev;
@@ -631,13 +650,14 @@ static struct net_device *inet6_fib_lookup_dev(struct net *net,
return ERR_CAST(dst);
dev = dst->dev;
dev_hold(dev);
netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL);
dst_release(dst);
return dev;
}
#else
static struct net_device *inet6_fib_lookup_dev(struct net *net,
struct mpls_nh *nh,
const void *addr)
{
return ERR_PTR(-EAFNOSUPPORT);
@@ -653,16 +673,17 @@ static struct net_device *find_outdev(struct net *net,
if (!oif) {
switch (nh->nh_via_table) {
case NEIGH_ARP_TABLE:
dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh));
dev = inet_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh));
break;
case NEIGH_ND_TABLE:
dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh));
dev = inet6_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh));
break;
case NEIGH_LINK_TABLE:
break;
}
} else {
dev = dev_get_by_index(net, oif);
dev = netdev_get_by_index(net, oif,
&nh->nh_dev_tracker, GFP_KERNEL);
}
if (!dev)
@@ -671,8 +692,7 @@ static struct net_device *find_outdev(struct net *net,
if (IS_ERR(dev))
return dev;
/* The caller is holding rtnl anyways, so release the dev reference */
dev_put(dev);
nh->nh_dev = dev;
return dev;
}
@@ -686,20 +706,17 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
dev = find_outdev(net, rt, nh, oif);
if (IS_ERR(dev)) {
err = PTR_ERR(dev);
dev = NULL;
goto errout;
}
/* Ensure this is a supported device */
err = -EINVAL;
if (!mpls_dev_get(dev))
goto errout;
if (!mpls_dev_get(net, dev))
goto errout_put;
if ((nh->nh_via_table == NEIGH_LINK_TABLE) &&
(dev->addr_len != nh->nh_via_alen))
goto errout;
nh->nh_dev = dev;
goto errout_put;
if (!(dev->flags & IFF_UP)) {
nh->nh_flags |= RTNH_F_DEAD;
@@ -713,6 +730,9 @@ static int mpls_nh_assign_dev(struct net *net, struct mpls_route *rt,
return 0;
errout_put:
netdev_put(nh->nh_dev, &nh->nh_dev_tracker);
nh->nh_dev = NULL;
errout:
return err;
}
@@ -890,7 +910,8 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg,
struct nlattr *nla_via, *nla_newdst;
int remaining = cfg->rc_mp_len;
int err = 0;
u8 nhs = 0;
rt->rt_nhn = 0;
change_nexthops(rt) {
int attrlen;
@@ -926,11 +947,9 @@ static int mpls_nh_build_multi(struct mpls_route_config *cfg,
rt->rt_nhn_alive--;
rtnh = rtnh_next(rtnh, &remaining);
nhs++;
rt->rt_nhn++;
} endfor_nexthops(rt);
rt->rt_nhn = nhs;
return 0;
errout:
@@ -940,30 +959,28 @@ errout:
static bool mpls_label_ok(struct net *net, unsigned int *index,
struct netlink_ext_ack *extack)
{
bool is_ok = true;
/* Reserved labels may not be set */
if (*index < MPLS_LABEL_FIRST_UNRESERVED) {
NL_SET_ERR_MSG(extack,
"Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher");
is_ok = false;
return false;
}
/* The full 20 bit range may not be supported. */
if (is_ok && *index >= net->mpls.platform_labels) {
if (*index >= net->mpls.platform_labels) {
NL_SET_ERR_MSG(extack,
"Label >= configured maximum in platform_labels");
is_ok = false;
return false;
}
*index = array_index_nospec(*index, net->mpls.platform_labels);
return is_ok;
return true;
}
static int mpls_route_add(struct mpls_route_config *cfg,
struct netlink_ext_ack *extack)
{
struct mpls_route __rcu **platform_label;
struct net *net = cfg->rc_nlinfo.nl_net;
struct mpls_route *rt, *old;
int err = -EINVAL;
@@ -991,8 +1008,7 @@ static int mpls_route_add(struct mpls_route_config *cfg,
}
err = -EEXIST;
platform_label = rtnl_dereference(net->mpls.platform_label);
old = rtnl_dereference(platform_label[index]);
old = mpls_route_input(net, index);
if ((cfg->rc_nlflags & NLM_F_EXCL) && old)
goto errout;
@@ -1103,7 +1119,7 @@ static int mpls_fill_stats_af(struct sk_buff *skb,
struct mpls_dev *mdev;
struct nlattr *nla;
mdev = mpls_dev_get(dev);
mdev = mpls_dev_rcu(dev);
if (!mdev)
return -ENODATA;
@@ -1123,7 +1139,7 @@ static size_t mpls_get_stats_af_size(const struct net_device *dev)
{
struct mpls_dev *mdev;
mdev = mpls_dev_get(dev);
mdev = mpls_dev_rcu(dev);
if (!mdev)
return 0;
@@ -1264,23 +1280,32 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
if (err < 0)
goto errout;
err = -EINVAL;
if (!tb[NETCONFA_IFINDEX])
if (!tb[NETCONFA_IFINDEX]) {
err = -EINVAL;
goto errout;
}
ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]);
dev = __dev_get_by_index(net, ifindex);
if (!dev)
goto errout;
mdev = mpls_dev_get(dev);
if (!mdev)
goto errout;
err = -ENOBUFS;
skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL);
if (!skb)
if (!skb) {
err = -ENOBUFS;
goto errout;
}
rcu_read_lock();
dev = dev_get_by_index_rcu(net, ifindex);
if (!dev) {
err = -EINVAL;
goto errout_unlock;
}
mdev = mpls_dev_rcu(dev);
if (!mdev) {
err = -EINVAL;
goto errout_unlock;
}
err = mpls_netconf_fill_devconf(skb, mdev,
NETLINK_CB(in_skb).portid,
@@ -1289,12 +1314,19 @@ static int mpls_netconf_get_devconf(struct sk_buff *in_skb,
if (err < 0) {
/* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */
WARN_ON(err == -EMSGSIZE);
kfree_skb(skb);
goto errout;
goto errout_unlock;
}
err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid);
rcu_read_unlock();
errout:
return err;
errout_unlock:
rcu_read_unlock();
kfree_skb(skb);
goto errout;
}
static int mpls_netconf_dump_devconf(struct sk_buff *skb,
@@ -1326,7 +1358,7 @@ static int mpls_netconf_dump_devconf(struct sk_buff *skb,
rcu_read_lock();
for_each_netdev_dump(net, dev, ctx->ifindex) {
mdev = mpls_dev_get(dev);
mdev = mpls_dev_rcu(dev);
if (!mdev)
continue;
err = mpls_netconf_fill_devconf(skb, mdev,
@@ -1438,8 +1470,6 @@ static struct mpls_dev *mpls_add_dev(struct net_device *dev)
int err = -ENOMEM;
int i;
ASSERT_RTNL();
mdev = kzalloc(sizeof(*mdev), GFP_KERNEL);
if (!mdev)
return ERR_PTR(err);
@@ -1481,16 +1511,15 @@ static void mpls_dev_destroy_rcu(struct rcu_head *head)
static int mpls_ifdown(struct net_device *dev, int event)
{
struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
unsigned index;
unsigned int index;
platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) {
struct mpls_route *rt = rtnl_dereference(platform_label[index]);
struct mpls_route *rt;
bool nh_del = false;
u8 alive = 0;
rt = mpls_route_input(net, index);
if (!rt)
continue;
@@ -1524,8 +1553,12 @@ static int mpls_ifdown(struct net_device *dev, int event)
change_nexthops(rt) {
unsigned int nh_flags = nh->nh_flags;
if (nh->nh_dev != dev)
if (nh->nh_dev != dev) {
if (nh_del)
netdev_hold(nh->nh_dev, &nh->nh_dev_tracker,
GFP_KERNEL);
goto next;
}
switch (event) {
case NETDEV_DOWN:
@@ -1557,15 +1590,14 @@ next:
static void mpls_ifup(struct net_device *dev, unsigned int flags)
{
struct mpls_route __rcu **platform_label;
struct net *net = dev_net(dev);
unsigned index;
unsigned int index;
u8 alive;
platform_label = rtnl_dereference(net->mpls.platform_label);
for (index = 0; index < net->mpls.platform_labels; index++) {
struct mpls_route *rt = rtnl_dereference(platform_label[index]);
struct mpls_route *rt;
rt = mpls_route_input(net, index);
if (!rt)
continue;
@@ -1592,28 +1624,33 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
void *ptr)
{
struct net_device *dev = netdev_notifier_info_to_dev(ptr);
struct net *net = dev_net(dev);
struct mpls_dev *mdev;
unsigned int flags;
int err;
mutex_lock(&net->mpls.platform_mutex);
if (event == NETDEV_REGISTER) {
mdev = mpls_add_dev(dev);
if (IS_ERR(mdev))
return notifier_from_errno(PTR_ERR(mdev));
if (IS_ERR(mdev)) {
err = PTR_ERR(mdev);
goto err;
}
return NOTIFY_OK;
goto out;
}
mdev = mpls_dev_get(dev);
mdev = mpls_dev_get(net, dev);
if (!mdev)
return NOTIFY_OK;
goto out;
switch (event) {
case NETDEV_DOWN:
err = mpls_ifdown(dev, event);
if (err)
return notifier_from_errno(err);
goto err;
break;
case NETDEV_UP:
flags = netif_get_flags(dev);
@@ -1629,14 +1666,15 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
} else {
err = mpls_ifdown(dev, event);
if (err)
return notifier_from_errno(err);
goto err;
}
break;
case NETDEV_UNREGISTER:
err = mpls_ifdown(dev, event);
if (err)
return notifier_from_errno(err);
mdev = mpls_dev_get(dev);
goto err;
mdev = mpls_dev_get(net, dev);
if (mdev) {
mpls_dev_sysctl_unregister(dev, mdev);
RCU_INIT_POINTER(dev->mpls_ptr, NULL);
@@ -1644,16 +1682,23 @@ static int mpls_dev_notify(struct notifier_block *this, unsigned long event,
}
break;
case NETDEV_CHANGENAME:
mdev = mpls_dev_get(dev);
mdev = mpls_dev_get(net, dev);
if (mdev) {
mpls_dev_sysctl_unregister(dev, mdev);
err = mpls_dev_sysctl_register(dev, mdev);
if (err)
return notifier_from_errno(err);
goto err;
}
break;
}
out:
mutex_unlock(&net->mpls.platform_mutex);
return NOTIFY_OK;
err:
mutex_unlock(&net->mpls.platform_mutex);
return notifier_from_errno(err);
}
static struct notifier_block mpls_dev_notifier = {
@@ -1928,6 +1973,7 @@ errout:
static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
struct net *net = sock_net(skb->sk);
struct mpls_route_config *cfg;
int err;
@@ -1939,7 +1985,9 @@ static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh,
if (err < 0)
goto out;
mutex_lock(&net->mpls.platform_mutex);
err = mpls_route_del(cfg, extack);
mutex_unlock(&net->mpls.platform_mutex);
out:
kfree(cfg);
@@ -1950,6 +1998,7 @@ out:
static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
struct netlink_ext_ack *extack)
{
struct net *net = sock_net(skb->sk);
struct mpls_route_config *cfg;
int err;
@@ -1961,7 +2010,9 @@ static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh,
if (err < 0)
goto out;
mutex_lock(&net->mpls.platform_mutex);
err = mpls_route_add(cfg, extack);
mutex_unlock(&net->mpls.platform_mutex);
out:
kfree(cfg);
@@ -2124,7 +2175,7 @@ static int mpls_valid_fib_dump_req(struct net *net, const struct nlmsghdr *nlh,
if (i == RTA_OIF) {
ifindex = nla_get_u32(tb[i]);
filter->dev = __dev_get_by_index(net, ifindex);
filter->dev = dev_get_by_index_rcu(net, ifindex);
if (!filter->dev)
return -ENODEV;
filter->filter_set = 1;
@@ -2162,20 +2213,19 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
struct net *net = sock_net(skb->sk);
struct mpls_route __rcu **platform_label;
struct fib_dump_filter filter = {
.rtnl_held = true,
.rtnl_held = false,
};
unsigned int flags = NLM_F_MULTI;
size_t platform_labels;
unsigned int index;
int err;
ASSERT_RTNL();
rcu_read_lock();
if (cb->strict_check) {
int err;
err = mpls_valid_fib_dump_req(net, nlh, &filter, cb);
if (err < 0)
return err;
goto err;
/* for MPLS, there is only 1 table with fixed type and flags.
* If either are set in the filter then return nothing.
@@ -2183,14 +2233,14 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) ||
(filter.rt_type && filter.rt_type != RTN_UNICAST) ||
filter.flags)
return skb->len;
goto unlock;
}
index = cb->args[0];
if (index < MPLS_LABEL_FIRST_UNRESERVED)
index = MPLS_LABEL_FIRST_UNRESERVED;
platform_label = rtnl_dereference(net->mpls.platform_label);
platform_label = rcu_dereference(net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
if (filter.filter_set)
@@ -2199,7 +2249,7 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
for (; index < platform_labels; index++) {
struct mpls_route *rt;
rt = rtnl_dereference(platform_label[index]);
rt = rcu_dereference(platform_label[index]);
if (!rt)
continue;
@@ -2214,7 +2264,13 @@ static int mpls_dump_routes(struct sk_buff *skb, struct netlink_callback *cb)
}
cb->args[0] = index;
unlock:
rcu_read_unlock();
return skb->len;
err:
rcu_read_unlock();
return err;
}
static inline size_t lfib_nlmsg_size(struct mpls_route *rt)
@@ -2345,18 +2401,20 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
u32 portid = NETLINK_CB(in_skb).portid;
u32 in_label = LABEL_NOT_SPECIFIED;
struct nlattr *tb[RTA_MAX + 1];
struct mpls_route *rt = NULL;
u32 labels[MAX_NEW_LABELS];
struct mpls_shim_hdr *hdr;
unsigned int hdr_size = 0;
const struct mpls_nh *nh;
struct net_device *dev;
struct mpls_route *rt;
struct rtmsg *rtm, *r;
struct nlmsghdr *nlh;
struct sk_buff *skb;
u8 n_labels;
int err;
mutex_lock(&net->mpls.platform_mutex);
err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack);
if (err < 0)
goto errout;
@@ -2378,7 +2436,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
}
}
rt = mpls_route_input_rcu(net, in_label);
if (in_label < net->mpls.platform_labels)
rt = mpls_route_input(net, in_label);
if (!rt) {
err = -ENETUNREACH;
goto errout;
@@ -2399,7 +2458,8 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
goto errout_free;
}
return rtnl_unicast(skb, net, portid);
err = rtnl_unicast(skb, net, portid);
goto errout;
}
if (tb[RTA_NEWDST]) {
@@ -2491,12 +2551,14 @@ static int mpls_getroute(struct sk_buff *in_skb, struct nlmsghdr *in_nlh,
err = rtnl_unicast(skb, net, portid);
errout:
mutex_unlock(&net->mpls.platform_mutex);
return err;
nla_put_failure:
nlmsg_cancel(skb, nlh);
err = -EMSGSIZE;
errout_free:
mutex_unlock(&net->mpls.platform_mutex);
kfree_skb(skb);
return err;
}
@@ -2519,10 +2581,13 @@ static int resize_platform_label_table(struct net *net, size_t limit)
/* In case the predefined labels need to be populated */
if (limit > MPLS_LABEL_IPV4NULL) {
struct net_device *lo = net->loopback_dev;
rt0 = mpls_rt_alloc(1, lo->addr_len, 0);
if (IS_ERR(rt0))
goto nort0;
rt0->rt_nh->nh_dev = lo;
netdev_hold(lo, &rt0->rt_nh->nh_dev_tracker, GFP_KERNEL);
rt0->rt_protocol = RTPROT_KERNEL;
rt0->rt_payload_type = MPT_IPV4;
rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
@@ -2533,10 +2598,13 @@ static int resize_platform_label_table(struct net *net, size_t limit)
}
if (limit > MPLS_LABEL_IPV6NULL) {
struct net_device *lo = net->loopback_dev;
rt2 = mpls_rt_alloc(1, lo->addr_len, 0);
if (IS_ERR(rt2))
goto nort2;
rt2->rt_nh->nh_dev = lo;
netdev_hold(lo, &rt2->rt_nh->nh_dev_tracker, GFP_KERNEL);
rt2->rt_protocol = RTPROT_KERNEL;
rt2->rt_payload_type = MPT_IPV6;
rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT;
@@ -2546,9 +2614,10 @@ static int resize_platform_label_table(struct net *net, size_t limit)
lo->addr_len);
}
rtnl_lock();
mutex_lock(&net->mpls.platform_mutex);
/* Remember the original table */
old = rtnl_dereference(net->mpls.platform_label);
old = mpls_dereference(net, net->mpls.platform_label);
old_limit = net->mpls.platform_labels;
/* Free any labels beyond the new table */
@@ -2579,7 +2648,7 @@ static int resize_platform_label_table(struct net *net, size_t limit)
net->mpls.platform_labels = limit;
rcu_assign_pointer(net->mpls.platform_label, labels);
rtnl_unlock();
mutex_unlock(&net->mpls.platform_mutex);
mpls_rt_free(rt2);
mpls_rt_free(rt0);
@@ -2652,12 +2721,13 @@ static const struct ctl_table mpls_table[] = {
},
};
static int mpls_net_init(struct net *net)
static __net_init int mpls_net_init(struct net *net)
{
size_t table_size = ARRAY_SIZE(mpls_table);
struct ctl_table *table;
int i;
mutex_init(&net->mpls.platform_mutex);
net->mpls.platform_labels = 0;
net->mpls.platform_label = NULL;
net->mpls.ip_ttl_propagate = 1;
@@ -2683,7 +2753,7 @@ static int mpls_net_init(struct net *net)
return 0;
}
static void mpls_net_exit(struct net *net)
static __net_exit void mpls_net_exit(struct net *net)
{
struct mpls_route __rcu **platform_label;
size_t platform_labels;
@@ -2703,16 +2773,20 @@ static void mpls_net_exit(struct net *net)
* As such no additional rcu synchronization is necessary when
* freeing the platform_label table.
*/
rtnl_lock();
platform_label = rtnl_dereference(net->mpls.platform_label);
mutex_lock(&net->mpls.platform_mutex);
platform_label = mpls_dereference(net, net->mpls.platform_label);
platform_labels = net->mpls.platform_labels;
for (index = 0; index < platform_labels; index++) {
struct mpls_route *rt = rtnl_dereference(platform_label[index]);
RCU_INIT_POINTER(platform_label[index], NULL);
struct mpls_route *rt;
rt = mpls_dereference(net, platform_label[index]);
mpls_notify_route(net, index, rt, NULL, NULL);
mpls_rt_free(rt);
}
rtnl_unlock();
mutex_unlock(&net->mpls.platform_mutex);
kvfree(platform_label);
}
@@ -2729,12 +2803,15 @@ static struct rtnl_af_ops mpls_af_ops __read_mostly = {
};
static const struct rtnl_msg_handler mpls_rtnl_msg_handlers[] __initdata_or_module = {
{THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, 0},
{THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, 0},
{THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, 0},
{THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL,
RTNL_FLAG_DOIT_UNLOCKED},
{THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL,
RTNL_FLAG_DOIT_UNLOCKED},
{THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes,
RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED},
{THIS_MODULE, PF_MPLS, RTM_GETNETCONF,
mpls_netconf_get_devconf, mpls_netconf_dump_devconf,
RTNL_FLAG_DUMP_UNLOCKED},
RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED},
};
static int __init mpls_init(void)

View File

@@ -88,6 +88,7 @@ enum mpls_payload_type {
struct mpls_nh { /* next hop label forwarding entry */
struct net_device *nh_dev;
netdevice_tracker nh_dev_tracker;
/* nh_flags is accessed under RCU in the packet path; it is
* modified handling netdev events with rtnl lock held
@@ -184,9 +185,20 @@ static inline struct mpls_entry_decoded mpls_entry_decode(struct mpls_shim_hdr *
return result;
}
static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev)
#define mpls_dereference(net, p) \
rcu_dereference_protected( \
(p), \
lockdep_is_held(&(net)->mpls.platform_mutex))
static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev)
{
return rcu_dereference_rtnl(dev->mpls_ptr);
return rcu_dereference(dev->mpls_ptr);
}
static inline struct mpls_dev *mpls_dev_get(const struct net *net,
const struct net_device *dev)
{
return mpls_dereference(net, dev->mpls_ptr);
}
int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels,
@@ -196,7 +208,8 @@ int nla_get_labels(const struct nlattr *nla, u8 max_labels, u8 *labels,
bool mpls_output_possible(const struct net_device *dev);
unsigned int mpls_dev_mtu(const struct net_device *dev);
bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu);
void mpls_stats_inc_outucastpkts(struct net_device *dev,
void mpls_stats_inc_outucastpkts(struct net *net,
struct net_device *dev,
const struct sk_buff *skb);
#endif /* MPLS_INTERNAL_H */

View File

@@ -53,7 +53,7 @@ static int mpls_xmit(struct sk_buff *skb)
/* Find the output device */
out_dev = dst->dev;
net = dev_net(out_dev);
net = dev_net_rcu(out_dev);
if (!mpls_output_possible(out_dev) ||
!dst->lwtstate || skb_warn_if_lro(skb))
@@ -128,7 +128,7 @@ static int mpls_xmit(struct sk_buff *skb)
bos = false;
}
mpls_stats_inc_outucastpkts(out_dev, skb);
mpls_stats_inc_outucastpkts(net, out_dev, skb);
if (rt) {
if (rt->rt_gw_family == AF_INET6)
@@ -153,7 +153,7 @@ static int mpls_xmit(struct sk_buff *skb)
return LWTUNNEL_XMIT_DONE;
drop:
out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL;
out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL;
if (out_mdev)
MPLS_INC_STATS(out_mdev, tx_errors);
kfree_skb(skb);