diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
| -rw-r--r-- | drivers/vhost/vhost.c | 232 | 
1 files changed, 198 insertions, 34 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 0b99783083f..e05557d5299 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -17,12 +17,13 @@  #include <linux/mm.h>  #include <linux/miscdevice.h>  #include <linux/mutex.h> -#include <linux/workqueue.h>  #include <linux/rcupdate.h>  #include <linux/poll.h>  #include <linux/file.h>  #include <linux/highmem.h>  #include <linux/slab.h> +#include <linux/kthread.h> +#include <linux/cgroup.h>  #include <linux/net.h>  #include <linux/if_packet.h> @@ -37,8 +38,6 @@ enum {  	VHOST_MEMORY_F_LOG = 0x1,  }; -static struct workqueue_struct *vhost_workqueue; -  static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,  			    poll_table *pt)  { @@ -52,23 +51,31 @@ static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,  static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync,  			     void *key)  { -	struct vhost_poll *poll; -	poll = container_of(wait, struct vhost_poll, wait); +	struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); +  	if (!((unsigned long)key & poll->mask))  		return 0; -	queue_work(vhost_workqueue, &poll->work); +	vhost_poll_queue(poll);  	return 0;  }  /* Init poll structure */ -void vhost_poll_init(struct vhost_poll *poll, work_func_t func, -		     unsigned long mask) +void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, +		     unsigned long mask, struct vhost_dev *dev)  { -	INIT_WORK(&poll->work, func); +	struct vhost_work *work = &poll->work; +  	init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);  	init_poll_funcptr(&poll->table, vhost_poll_func);  	poll->mask = mask; +	poll->dev = dev; + +	INIT_LIST_HEAD(&work->node); +	work->fn = fn; +	init_waitqueue_head(&work->done); +	work->flushing = 0; +	work->queue_seq = work->done_seq = 0;  }  /* Start polling a file. We add ourselves to file's wait queue. The caller must @@ -92,12 +99,40 @@ void vhost_poll_stop(struct vhost_poll *poll)   * locks that are also used by the callback. */  void vhost_poll_flush(struct vhost_poll *poll)  { -	flush_work(&poll->work); +	struct vhost_work *work = &poll->work; +	unsigned seq; +	int left; +	int flushing; + +	spin_lock_irq(&poll->dev->work_lock); +	seq = work->queue_seq; +	work->flushing++; +	spin_unlock_irq(&poll->dev->work_lock); +	wait_event(work->done, ({ +		   spin_lock_irq(&poll->dev->work_lock); +		   left = seq - work->done_seq <= 0; +		   spin_unlock_irq(&poll->dev->work_lock); +		   left; +	})); +	spin_lock_irq(&poll->dev->work_lock); +	flushing = --work->flushing; +	spin_unlock_irq(&poll->dev->work_lock); +	BUG_ON(flushing < 0);  }  void vhost_poll_queue(struct vhost_poll *poll)  { -	queue_work(vhost_workqueue, &poll->work); +	struct vhost_dev *dev = poll->dev; +	struct vhost_work *work = &poll->work; +	unsigned long flags; + +	spin_lock_irqsave(&dev->work_lock, flags); +	if (list_empty(&work->node)) { +		list_add_tail(&work->node, &dev->work_list); +		work->queue_seq++; +		wake_up_process(dev->worker); +	} +	spin_unlock_irqrestore(&dev->work_lock, flags);  }  static void vhost_vq_reset(struct vhost_dev *dev, @@ -114,7 +149,8 @@ static void vhost_vq_reset(struct vhost_dev *dev,  	vq->used_flags = 0;  	vq->log_used = false;  	vq->log_addr = -1ull; -	vq->hdr_size = 0; +	vq->vhost_hlen = 0; +	vq->sock_hlen = 0;  	vq->private_data = NULL;  	vq->log_base = NULL;  	vq->error_ctx = NULL; @@ -125,10 +161,51 @@ static void vhost_vq_reset(struct vhost_dev *dev,  	vq->log_ctx = NULL;  } +static int vhost_worker(void *data) +{ +	struct vhost_dev *dev = data; +	struct vhost_work *work = NULL; +	unsigned uninitialized_var(seq); + +	for (;;) { +		/* mb paired w/ kthread_stop */ +		set_current_state(TASK_INTERRUPTIBLE); + +		spin_lock_irq(&dev->work_lock); +		if (work) { +			work->done_seq = seq; +			if (work->flushing) +				wake_up_all(&work->done); +		} + +		if (kthread_should_stop()) { +			spin_unlock_irq(&dev->work_lock); +			__set_current_state(TASK_RUNNING); +			return 0; +		} +		if (!list_empty(&dev->work_list)) { +			work = list_first_entry(&dev->work_list, +						struct vhost_work, node); +			list_del_init(&work->node); +			seq = work->queue_seq; +		} else +			work = NULL; +		spin_unlock_irq(&dev->work_lock); + +		if (work) { +			__set_current_state(TASK_RUNNING); +			work->fn(work); +		} else +			schedule(); + +	} +} +  long vhost_dev_init(struct vhost_dev *dev,  		    struct vhost_virtqueue *vqs, int nvqs)  {  	int i; +  	dev->vqs = vqs;  	dev->nvqs = nvqs;  	mutex_init(&dev->mutex); @@ -136,6 +213,9 @@ long vhost_dev_init(struct vhost_dev *dev,  	dev->log_file = NULL;  	dev->memory = NULL;  	dev->mm = NULL; +	spin_lock_init(&dev->work_lock); +	INIT_LIST_HEAD(&dev->work_list); +	dev->worker = NULL;  	for (i = 0; i < dev->nvqs; ++i) {  		dev->vqs[i].dev = dev; @@ -143,9 +223,9 @@ long vhost_dev_init(struct vhost_dev *dev,  		vhost_vq_reset(dev, dev->vqs + i);  		if (dev->vqs[i].handle_kick)  			vhost_poll_init(&dev->vqs[i].poll, -					dev->vqs[i].handle_kick, -					POLLIN); +					dev->vqs[i].handle_kick, POLLIN, dev);  	} +  	return 0;  } @@ -159,12 +239,36 @@ long vhost_dev_check_owner(struct vhost_dev *dev)  /* Caller should have device mutex */  static long vhost_dev_set_owner(struct vhost_dev *dev)  { +	struct task_struct *worker; +	int err;  	/* Is there an owner already? */ -	if (dev->mm) -		return -EBUSY; +	if (dev->mm) { +		err = -EBUSY; +		goto err_mm; +	}  	/* No owner, become one */  	dev->mm = get_task_mm(current); +	worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); +	if (IS_ERR(worker)) { +		err = PTR_ERR(worker); +		goto err_worker; +	} + +	dev->worker = worker; +	err = cgroup_attach_task_current_cg(worker); +	if (err) +		goto err_cgroup; +	wake_up_process(worker);	/* avoid contributing to loadavg */ +  	return 0; +err_cgroup: +	kthread_stop(worker); +err_worker: +	if (dev->mm) +		mmput(dev->mm); +	dev->mm = NULL; +err_mm: +	return err;  }  /* Caller should have device mutex */ @@ -217,6 +321,9 @@ void vhost_dev_cleanup(struct vhost_dev *dev)  	if (dev->mm)  		mmput(dev->mm);  	dev->mm = NULL; + +	WARN_ON(!list_empty(&dev->work_list)); +	kthread_stop(dev->worker);  }  static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) @@ -237,8 +344,8 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,  {  	int i; -        if (!mem) -                return 0; +	if (!mem) +		return 0;  	for (i = 0; i < mem->nregions; ++i) {  		struct vhost_memory_region *m = mem->regions + i; @@ -995,9 +1102,9 @@ int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,  }  /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ -void vhost_discard_vq_desc(struct vhost_virtqueue *vq) +void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)  { -	vq->last_avail_idx--; +	vq->last_avail_idx -= n;  }  /* After we've used one of their buffers, we tell them about it.  We'll then @@ -1042,6 +1149,67 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)  	return 0;  } +static int __vhost_add_used_n(struct vhost_virtqueue *vq, +			    struct vring_used_elem *heads, +			    unsigned count) +{ +	struct vring_used_elem __user *used; +	int start; + +	start = vq->last_used_idx % vq->num; +	used = vq->used->ring + start; +	if (copy_to_user(used, heads, count * sizeof *used)) { +		vq_err(vq, "Failed to write used"); +		return -EFAULT; +	} +	if (unlikely(vq->log_used)) { +		/* Make sure data is seen before log. */ +		smp_wmb(); +		/* Log used ring entry write. */ +		log_write(vq->log_base, +			  vq->log_addr + +			   ((void __user *)used - (void __user *)vq->used), +			  count * sizeof *used); +	} +	vq->last_used_idx += count; +	return 0; +} + +/* After we've used one of their buffers, we tell them about it.  We'll then + * want to notify the guest, using eventfd. */ +int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, +		     unsigned count) +{ +	int start, n, r; + +	start = vq->last_used_idx % vq->num; +	n = vq->num - start; +	if (n < count) { +		r = __vhost_add_used_n(vq, heads, n); +		if (r < 0) +			return r; +		heads += n; +		count -= n; +	} +	r = __vhost_add_used_n(vq, heads, count); + +	/* Make sure buffer is written before we update index. */ +	smp_wmb(); +	if (put_user(vq->last_used_idx, &vq->used->idx)) { +		vq_err(vq, "Failed to increment used idx"); +		return -EFAULT; +	} +	if (unlikely(vq->log_used)) { +		/* Log used index update. */ +		log_write(vq->log_base, +			  vq->log_addr + offsetof(struct vring_used, idx), +			  sizeof vq->used->idx); +		if (vq->log_ctx) +			eventfd_signal(vq->log_ctx, 1); +	} +	return r; +} +  /* This actually signals the guest, using eventfd. */  void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)  { @@ -1076,6 +1244,15 @@ void vhost_add_used_and_signal(struct vhost_dev *dev,  	vhost_signal(dev, vq);  } +/* multi-buffer version of vhost_add_used_and_signal */ +void vhost_add_used_and_signal_n(struct vhost_dev *dev, +				 struct vhost_virtqueue *vq, +				 struct vring_used_elem *heads, unsigned count) +{ +	vhost_add_used_n(vq, heads, count); +	vhost_signal(dev, vq); +} +  /* OK, now we need to know about added descriptors. */  bool vhost_enable_notify(struct vhost_virtqueue *vq)  { @@ -1100,7 +1277,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)  		return false;  	} -	return avail_idx != vq->last_avail_idx; +	return avail_idx != vq->avail_idx;  }  /* We don't need to be notified again. */ @@ -1115,16 +1292,3 @@ void vhost_disable_notify(struct vhost_virtqueue *vq)  		vq_err(vq, "Failed to enable notification at %p: %d\n",  		       &vq->used->flags, r);  } - -int vhost_init(void) -{ -	vhost_workqueue = create_singlethread_workqueue("vhost"); -	if (!vhost_workqueue) -		return -ENOMEM; -	return 0; -} - -void vhost_cleanup(void) -{ -	destroy_workqueue(vhost_workqueue); -}  |