diff options
Diffstat (limited to 'net/ipv4/inet_diag.c')
| -rw-r--r-- | net/ipv4/inet_diag.c | 146 | 
1 files changed, 78 insertions, 68 deletions
diff --git a/net/ipv4/inet_diag.c b/net/ipv4/inet_diag.c index 46d1e7199a8..570e61f9611 100644 --- a/net/ipv4/inet_diag.c +++ b/net/ipv4/inet_diag.c @@ -46,9 +46,6 @@ struct inet_diag_entry {  	u16 userlocks;  }; -#define INET_DIAG_PUT(skb, attrtype, attrlen) \ -	RTA_DATA(__RTA_PUT(skb, attrtype, attrlen)) -  static DEFINE_MUTEX(inet_diag_table_mutex);  static const struct inet_diag_handler *inet_diag_lock_handler(int proto) @@ -78,24 +75,22 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,  	const struct inet_sock *inet = inet_sk(sk);  	struct inet_diag_msg *r;  	struct nlmsghdr  *nlh; +	struct nlattr *attr;  	void *info = NULL; -	struct inet_diag_meminfo  *minfo = NULL; -	unsigned char	 *b = skb_tail_pointer(skb);  	const struct inet_diag_handler *handler;  	int ext = req->idiag_ext;  	handler = inet_diag_table[req->sdiag_protocol];  	BUG_ON(handler == NULL); -	nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); -	nlh->nlmsg_flags = nlmsg_flags; +	nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), +			nlmsg_flags); +	if (!nlh) +		return -EMSGSIZE; -	r = NLMSG_DATA(nlh); +	r = nlmsg_data(nlh);  	BUG_ON(sk->sk_state == TCP_TIME_WAIT); -	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) -		minfo = INET_DIAG_PUT(skb, INET_DIAG_MEMINFO, sizeof(*minfo)); -  	r->idiag_family = sk->sk_family;  	r->idiag_state = sk->sk_state;  	r->idiag_timer = 0; @@ -113,7 +108,8 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,  	 * hence this needs to be included regardless of socket family.  	 */  	if (ext & (1 << (INET_DIAG_TOS - 1))) -		RTA_PUT_U8(skb, INET_DIAG_TOS, inet->tos); +		if (nla_put_u8(skb, INET_DIAG_TOS, inet->tos) < 0) +			goto errout;  #if IS_ENABLED(CONFIG_IPV6)  	if (r->idiag_family == AF_INET6) { @@ -121,24 +117,31 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,  		*(struct in6_addr *)r->id.idiag_src = np->rcv_saddr;  		*(struct in6_addr *)r->id.idiag_dst = np->daddr; +  		if (ext & (1 << (INET_DIAG_TCLASS - 1))) -			RTA_PUT_U8(skb, INET_DIAG_TCLASS, np->tclass); +			if (nla_put_u8(skb, INET_DIAG_TCLASS, np->tclass) < 0) +				goto errout;  	}  #endif  	r->idiag_uid = sock_i_uid(sk);  	r->idiag_inode = sock_i_ino(sk); -	if (minfo) { -		minfo->idiag_rmem = sk_rmem_alloc_get(sk); -		minfo->idiag_wmem = sk->sk_wmem_queued; -		minfo->idiag_fmem = sk->sk_forward_alloc; -		minfo->idiag_tmem = sk_wmem_alloc_get(sk); +	if (ext & (1 << (INET_DIAG_MEMINFO - 1))) { +		struct inet_diag_meminfo minfo = { +			.idiag_rmem = sk_rmem_alloc_get(sk), +			.idiag_wmem = sk->sk_wmem_queued, +			.idiag_fmem = sk->sk_forward_alloc, +			.idiag_tmem = sk_wmem_alloc_get(sk), +		}; + +		if (nla_put(skb, INET_DIAG_MEMINFO, sizeof(minfo), &minfo) < 0) +			goto errout;  	}  	if (ext & (1 << (INET_DIAG_SKMEMINFO - 1)))  		if (sock_diag_put_meminfo(sk, skb, INET_DIAG_SKMEMINFO)) -			goto rtattr_failure; +			goto errout;  	if (icsk == NULL) {  		handler->idiag_get_info(sk, r, NULL); @@ -165,16 +168,20 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,  	}  #undef EXPIRES_IN_MS -	if (ext & (1 << (INET_DIAG_INFO - 1))) -		info = INET_DIAG_PUT(skb, INET_DIAG_INFO, sizeof(struct tcp_info)); - -	if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) { -		const size_t len = strlen(icsk->icsk_ca_ops->name); +	if (ext & (1 << (INET_DIAG_INFO - 1))) { +		attr = nla_reserve(skb, INET_DIAG_INFO, +				   sizeof(struct tcp_info)); +		if (!attr) +			goto errout; -		strcpy(INET_DIAG_PUT(skb, INET_DIAG_CONG, len + 1), -		       icsk->icsk_ca_ops->name); +		info = nla_data(attr);  	} +	if ((ext & (1 << (INET_DIAG_CONG - 1))) && icsk->icsk_ca_ops) +		if (nla_put_string(skb, INET_DIAG_CONG, +				   icsk->icsk_ca_ops->name) < 0) +			goto errout; +  	handler->idiag_get_info(sk, r, info);  	if (sk->sk_state < TCP_TIME_WAIT && @@ -182,12 +189,10 @@ int inet_sk_diag_fill(struct sock *sk, struct inet_connection_sock *icsk,  		icsk->icsk_ca_ops->get_info(sk, ext, skb);  out: -	nlh->nlmsg_len = skb_tail_pointer(skb) - b; -	return skb->len; +	return nlmsg_end(skb, nlh); -rtattr_failure: -nlmsg_failure: -	nlmsg_trim(skb, b); +errout: +	nlmsg_cancel(skb, nlh);  	return -EMSGSIZE;  }  EXPORT_SYMBOL_GPL(inet_sk_diag_fill); @@ -208,14 +213,15 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,  {  	long tmo;  	struct inet_diag_msg *r; -	const unsigned char *previous_tail = skb_tail_pointer(skb); -	struct nlmsghdr *nlh = NLMSG_PUT(skb, pid, seq, -					 unlh->nlmsg_type, sizeof(*r)); +	struct nlmsghdr *nlh; -	r = NLMSG_DATA(nlh); -	BUG_ON(tw->tw_state != TCP_TIME_WAIT); +	nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), +			nlmsg_flags); +	if (!nlh) +		return -EMSGSIZE; -	nlh->nlmsg_flags = nlmsg_flags; +	r = nlmsg_data(nlh); +	BUG_ON(tw->tw_state != TCP_TIME_WAIT);  	tmo = tw->tw_ttd - jiffies;  	if (tmo < 0) @@ -245,11 +251,8 @@ static int inet_twsk_diag_fill(struct inet_timewait_sock *tw,  		*(struct in6_addr *)r->id.idiag_dst = tw6->tw_v6_daddr;  	}  #endif -	nlh->nlmsg_len = skb_tail_pointer(skb) - previous_tail; -	return skb->len; -nlmsg_failure: -	nlmsg_trim(skb, previous_tail); -	return -EMSGSIZE; + +	return nlmsg_end(skb, nlh);  }  static int sk_diag_fill(struct sock *sk, struct sk_buff *skb, @@ -269,16 +272,17 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s  	int err;  	struct sock *sk;  	struct sk_buff *rep; +	struct net *net = sock_net(in_skb->sk);  	err = -EINVAL;  	if (req->sdiag_family == AF_INET) { -		sk = inet_lookup(&init_net, hashinfo, req->id.idiag_dst[0], +		sk = inet_lookup(net, hashinfo, req->id.idiag_dst[0],  				 req->id.idiag_dport, req->id.idiag_src[0],  				 req->id.idiag_sport, req->id.idiag_if);  	}  #if IS_ENABLED(CONFIG_IPV6)  	else if (req->sdiag_family == AF_INET6) { -		sk = inet6_lookup(&init_net, hashinfo, +		sk = inet6_lookup(net, hashinfo,  				  (struct in6_addr *)req->id.idiag_dst,  				  req->id.idiag_dport,  				  (struct in6_addr *)req->id.idiag_src, @@ -298,23 +302,23 @@ int inet_diag_dump_one_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *in_s  	if (err)  		goto out; -	err = -ENOMEM; -	rep = alloc_skb(NLMSG_SPACE((sizeof(struct inet_diag_msg) + -				     sizeof(struct inet_diag_meminfo) + -				     sizeof(struct tcp_info) + 64)), -			GFP_KERNEL); -	if (!rep) +	rep = nlmsg_new(sizeof(struct inet_diag_msg) + +			sizeof(struct inet_diag_meminfo) + +			sizeof(struct tcp_info) + 64, GFP_KERNEL); +	if (!rep) { +		err = -ENOMEM;  		goto out; +	}  	err = sk_diag_fill(sk, rep, req,  			   NETLINK_CB(in_skb).pid,  			   nlh->nlmsg_seq, 0, nlh);  	if (err < 0) {  		WARN_ON(err == -EMSGSIZE); -		kfree_skb(rep); +		nlmsg_free(rep);  		goto out;  	} -	err = netlink_unicast(sock_diag_nlsk, rep, NETLINK_CB(in_skb).pid, +	err = netlink_unicast(net->diag_nlsk, rep, NETLINK_CB(in_skb).pid,  			      MSG_DONTWAIT);  	if (err > 0)  		err = 0; @@ -592,15 +596,16 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,  {  	const struct inet_request_sock *ireq = inet_rsk(req);  	struct inet_sock *inet = inet_sk(sk); -	unsigned char *b = skb_tail_pointer(skb);  	struct inet_diag_msg *r;  	struct nlmsghdr *nlh;  	long tmo; -	nlh = NLMSG_PUT(skb, pid, seq, unlh->nlmsg_type, sizeof(*r)); -	nlh->nlmsg_flags = NLM_F_MULTI; -	r = NLMSG_DATA(nlh); +	nlh = nlmsg_put(skb, pid, seq, unlh->nlmsg_type, sizeof(*r), +			NLM_F_MULTI); +	if (!nlh) +		return -EMSGSIZE; +	r = nlmsg_data(nlh);  	r->idiag_family = sk->sk_family;  	r->idiag_state = TCP_SYN_RECV;  	r->idiag_timer = 1; @@ -628,13 +633,8 @@ static int inet_diag_fill_req(struct sk_buff *skb, struct sock *sk,  		*(struct in6_addr *)r->id.idiag_dst = inet6_rsk(req)->rmt_addr;  	}  #endif -	nlh->nlmsg_len = skb_tail_pointer(skb) - b; - -	return skb->len; -nlmsg_failure: -	nlmsg_trim(skb, b); -	return -1; +	return nlmsg_end(skb, nlh);  }  static int inet_diag_dump_reqs(struct sk_buff *skb, struct sock *sk, @@ -725,6 +725,7 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,  {  	int i, num;  	int s_i, s_num; +	struct net *net = sock_net(skb->sk);  	s_i = cb->args[1];  	s_num = num = cb->args[2]; @@ -744,6 +745,9 @@ void inet_diag_dump_icsk(struct inet_hashinfo *hashinfo, struct sk_buff *skb,  			sk_nulls_for_each(sk, node, &ilb->head) {  				struct inet_sock *inet = inet_sk(sk); +				if (!net_eq(sock_net(sk), net)) +					continue; +  				if (num < s_num) {  					num++;  					continue; @@ -814,6 +818,8 @@ skip_listen_ht:  		sk_nulls_for_each(sk, node, &head->chain) {  			struct inet_sock *inet = inet_sk(sk); +			if (!net_eq(sock_net(sk), net)) +				continue;  			if (num < s_num)  				goto next_normal;  			if (!(r->idiag_states & (1 << sk->sk_state))) @@ -840,6 +846,8 @@ next_normal:  			inet_twsk_for_each(tw, node,  				    &head->twchain) { +				if (!net_eq(twsk_net(tw), net)) +					continue;  				if (num < s_num)  					goto next_dying; @@ -892,7 +900,7 @@ static int inet_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)  	if (nlmsg_attrlen(cb->nlh, hdrlen))  		bc = nlmsg_find_attr(cb->nlh, hdrlen, INET_DIAG_REQ_BYTECODE); -	return __inet_diag_dump(skb, cb, (struct inet_diag_req_v2 *)NLMSG_DATA(cb->nlh), bc); +	return __inet_diag_dump(skb, cb, nlmsg_data(cb->nlh), bc);  }  static inline int inet_diag_type2proto(int type) @@ -909,7 +917,7 @@ static inline int inet_diag_type2proto(int type)  static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *cb)  { -	struct inet_diag_req *rc = NLMSG_DATA(cb->nlh); +	struct inet_diag_req *rc = nlmsg_data(cb->nlh);  	struct inet_diag_req_v2 req;  	struct nlattr *bc = NULL;  	int hdrlen = sizeof(struct inet_diag_req); @@ -929,7 +937,7 @@ static int inet_diag_dump_compat(struct sk_buff *skb, struct netlink_callback *c  static int inet_diag_get_exact_compat(struct sk_buff *in_skb,  			       const struct nlmsghdr *nlh)  { -	struct inet_diag_req *rc = NLMSG_DATA(nlh); +	struct inet_diag_req *rc = nlmsg_data(nlh);  	struct inet_diag_req_v2 req;  	req.sdiag_family = rc->idiag_family; @@ -944,6 +952,7 @@ static int inet_diag_get_exact_compat(struct sk_buff *in_skb,  static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)  {  	int hdrlen = sizeof(struct inet_diag_req); +	struct net *net = sock_net(skb->sk);  	if (nlh->nlmsg_type >= INET_DIAG_GETSOCK_MAX ||  	    nlmsg_len(nlh) < hdrlen) @@ -964,7 +973,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)  			struct netlink_dump_control c = {  				.dump = inet_diag_dump_compat,  			}; -			return netlink_dump_start(sock_diag_nlsk, skb, nlh, &c); +			return netlink_dump_start(net->diag_nlsk, skb, nlh, &c);  		}  	} @@ -974,6 +983,7 @@ static int inet_diag_rcv_msg_compat(struct sk_buff *skb, struct nlmsghdr *nlh)  static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)  {  	int hdrlen = sizeof(struct inet_diag_req_v2); +	struct net *net = sock_net(skb->sk);  	if (nlmsg_len(h) < hdrlen)  		return -EINVAL; @@ -992,11 +1002,11 @@ static int inet_diag_handler_dump(struct sk_buff *skb, struct nlmsghdr *h)  			struct netlink_dump_control c = {  				.dump = inet_diag_dump,  			}; -			return netlink_dump_start(sock_diag_nlsk, skb, h, &c); +			return netlink_dump_start(net->diag_nlsk, skb, h, &c);  		}  	} -	return inet_diag_get_exact(skb, h, (struct inet_diag_req_v2 *)NLMSG_DATA(h)); +	return inet_diag_get_exact(skb, h, nlmsg_data(h));  }  static const struct sock_diag_handler inet_diag_handler = {  |