diff options
| -rw-r--r-- | include/net/scm.h | 22 | ||||
| -rw-r--r-- | net/unix/af_unix.c | 45 | 
2 files changed, 48 insertions, 19 deletions
diff --git a/include/net/scm.h b/include/net/scm.h index 745460fa2f0..68e1e481658 100644 --- a/include/net/scm.h +++ b/include/net/scm.h @@ -53,6 +53,14 @@ static __inline__ void scm_set_cred(struct scm_cookie *scm,  	cred_to_ucred(pid, cred, &scm->creds);  } +static __inline__ void scm_set_cred_noref(struct scm_cookie *scm, +				    struct pid *pid, const struct cred *cred) +{ +	scm->pid  = pid; +	scm->cred = cred; +	cred_to_ucred(pid, cred, &scm->creds); +} +  static __inline__ void scm_destroy_cred(struct scm_cookie *scm)  {  	put_pid(scm->pid); @@ -70,6 +78,15 @@ static __inline__ void scm_destroy(struct scm_cookie *scm)  		__scm_destroy(scm);  } +static __inline__ void scm_release(struct scm_cookie *scm) +{ +	/* keep ref on pid and cred */ +	scm->pid = NULL; +	scm->cred = NULL; +	if (scm->fp) +		__scm_destroy(scm); +} +  static __inline__ int scm_send(struct socket *sock, struct msghdr *msg,  			       struct scm_cookie *scm)  { @@ -108,15 +125,14 @@ static __inline__ void scm_recv(struct socket *sock, struct msghdr *msg,  	if (!msg->msg_control) {  		if (test_bit(SOCK_PASSCRED, &sock->flags) || scm->fp)  			msg->msg_flags |= MSG_CTRUNC; -		scm_destroy(scm); +		if (scm && scm->fp) +			__scm_destroy(scm);  		return;  	}  	if (test_bit(SOCK_PASSCRED, &sock->flags))  		put_cmsg(msg, SOL_SOCKET, SCM_CREDENTIALS, sizeof(scm->creds), &scm->creds); -	scm_destroy_cred(scm); -  	scm_passec(sock, msg, scm);  	if (!scm->fp) diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index ec68e1c05b8..e6d9d1014ed 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1378,11 +1378,17 @@ static int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb)  	return max_level;  } -static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool send_fds) +static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, +			   bool send_fds, bool ref)  {  	int err = 0; -	UNIXCB(skb).pid  = get_pid(scm->pid); -	UNIXCB(skb).cred = get_cred(scm->cred); +	if (ref) { +		UNIXCB(skb).pid  = get_pid(scm->pid); +		UNIXCB(skb).cred = get_cred(scm->cred); +	} else { +		UNIXCB(skb).pid  = scm->pid; +		UNIXCB(skb).cred = scm->cred; +	}  	UNIXCB(skb).fp = NULL;  	if (scm->fp && send_fds)  		err = unix_attach_fds(scm, skb); @@ -1407,7 +1413,7 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,  	int namelen = 0; /* fake GCC */  	int err;  	unsigned hash; -	struct sk_buff *skb; +	struct sk_buff *skb = NULL;  	long timeo;  	struct scm_cookie tmp_scm;  	int max_level; @@ -1448,7 +1454,7 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock,  	if (skb == NULL)  		goto out; -	err = unix_scm_to_skb(siocb->scm, skb, true); +	err = unix_scm_to_skb(siocb->scm, skb, true, false);  	if (err < 0)  		goto out_free;  	max_level = err + 1; @@ -1544,7 +1550,7 @@ restart:  	unix_state_unlock(other);  	other->sk_data_ready(other, len);  	sock_put(other); -	scm_destroy(siocb->scm); +	scm_release(siocb->scm);  	return len;  out_unlock: @@ -1554,7 +1560,8 @@ out_free:  out:  	if (other)  		sock_put(other); -	scm_destroy(siocb->scm); +	if (skb == NULL) +		scm_destroy(siocb->scm);  	return err;  } @@ -1566,7 +1573,7 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,  	struct sock *sk = sock->sk;  	struct sock *other = NULL;  	int err, size; -	struct sk_buff *skb; +	struct sk_buff *skb = NULL;  	int sent = 0;  	struct scm_cookie tmp_scm;  	bool fds_sent = false; @@ -1631,11 +1638,11 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,  		size = min_t(int, size, skb_tailroom(skb)); -		/* Only send the fds in the first buffer */ -		err = unix_scm_to_skb(siocb->scm, skb, !fds_sent); +		/* Only send the fds and no ref to pid in the first buffer */ +		err = unix_scm_to_skb(siocb->scm, skb, !fds_sent, fds_sent);  		if (err < 0) {  			kfree_skb(skb); -			goto out_err; +			goto out;  		}  		max_level = err + 1;  		fds_sent = true; @@ -1643,7 +1650,7 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,  		err = memcpy_fromiovec(skb_put(skb, size), msg->msg_iov, size);  		if (err) {  			kfree_skb(skb); -			goto out_err; +			goto out;  		}  		unix_state_lock(other); @@ -1660,7 +1667,10 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock,  		sent += size;  	} -	scm_destroy(siocb->scm); +	if (skb) +		scm_release(siocb->scm); +	else +		scm_destroy(siocb->scm);  	siocb->scm = NULL;  	return sent; @@ -1673,7 +1683,9 @@ pipe_err:  		send_sig(SIGPIPE, current, 0);  	err = -EPIPE;  out_err: -	scm_destroy(siocb->scm); +	if (skb == NULL) +		scm_destroy(siocb->scm); +out:  	siocb->scm = NULL;  	return sent ? : err;  } @@ -1777,7 +1789,7 @@ static int unix_dgram_recvmsg(struct kiocb *iocb, struct socket *sock,  		siocb->scm = &tmp_scm;  		memset(&tmp_scm, 0, sizeof(tmp_scm));  	} -	scm_set_cred(siocb->scm, UNIXCB(skb).pid, UNIXCB(skb).cred); +	scm_set_cred_noref(siocb->scm, UNIXCB(skb).pid, UNIXCB(skb).cred);  	unix_set_secdata(siocb->scm, skb);  	if (!(flags & MSG_PEEK)) { @@ -1939,7 +1951,8 @@ static int unix_stream_recvmsg(struct kiocb *iocb, struct socket *sock,  			}  		} else {  			/* Copy credentials */ -			scm_set_cred(siocb->scm, UNIXCB(skb).pid, UNIXCB(skb).cred); +			scm_set_cred_noref(siocb->scm, UNIXCB(skb).pid, +					   UNIXCB(skb).cred);  			check_creds = 1;  		}  |