diff options
Diffstat (limited to 'drivers/vhost/vhost.c')
| -rw-r--r-- | drivers/vhost/vhost.c | 51 | 
1 files changed, 49 insertions, 2 deletions
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index 8b5a1b33d0f..94701ff3a23 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -212,6 +212,45 @@ static int vhost_worker(void *data)  	}  } +/* Helper to allocate iovec buffers for all vqs. */ +static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) +{ +	int i; +	for (i = 0; i < dev->nvqs; ++i) { +		dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect * +					       UIO_MAXIOV, GFP_KERNEL); +		dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV, +					  GFP_KERNEL); +		dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads * +					    UIO_MAXIOV, GFP_KERNEL); + +		if (!dev->vqs[i].indirect || !dev->vqs[i].log || +			!dev->vqs[i].heads) +			goto err_nomem; +	} +	return 0; +err_nomem: +	for (; i >= 0; --i) { +		kfree(dev->vqs[i].indirect); +		kfree(dev->vqs[i].log); +		kfree(dev->vqs[i].heads); +	} +	return -ENOMEM; +} + +static void vhost_dev_free_iovecs(struct vhost_dev *dev) +{ +	int i; +	for (i = 0; i < dev->nvqs; ++i) { +		kfree(dev->vqs[i].indirect); +		dev->vqs[i].indirect = NULL; +		kfree(dev->vqs[i].log); +		dev->vqs[i].log = NULL; +		kfree(dev->vqs[i].heads); +		dev->vqs[i].heads = NULL; +	} +} +  long vhost_dev_init(struct vhost_dev *dev,  		    struct vhost_virtqueue *vqs, int nvqs)  { @@ -229,6 +268,9 @@ long vhost_dev_init(struct vhost_dev *dev,  	dev->worker = NULL;  	for (i = 0; i < dev->nvqs; ++i) { +		dev->vqs[i].log = NULL; +		dev->vqs[i].indirect = NULL; +		dev->vqs[i].heads = NULL;  		dev->vqs[i].dev = dev;  		mutex_init(&dev->vqs[i].mutex);  		vhost_vq_reset(dev, dev->vqs + i); @@ -295,6 +337,10 @@ static long vhost_dev_set_owner(struct vhost_dev *dev)  	if (err)  		goto err_cgroup; +	err = vhost_dev_alloc_iovecs(dev); +	if (err) +		goto err_cgroup; +  	return 0;  err_cgroup:  	kthread_stop(worker); @@ -345,6 +391,7 @@ void vhost_dev_cleanup(struct vhost_dev *dev)  			fput(dev->vqs[i].call);  		vhost_vq_reset(dev, dev->vqs + i);  	} +	vhost_dev_free_iovecs(dev);  	if (dev->log_ctx)  		eventfd_ctx_put(dev->log_ctx);  	dev->log_ctx = NULL; @@ -372,7 +419,7 @@ static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz)  	/* Make sure 64 bit math will not overflow. */  	if (a > ULONG_MAX - (unsigned long)log_base ||  	    a + (unsigned long)log_base > ULONG_MAX) -		return -EFAULT; +		return 0;  	return access_ok(VERIFY_WRITE, log_base + a,  			 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); @@ -957,7 +1004,7 @@ static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,  	}  	ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, -			     ARRAY_SIZE(vq->indirect)); +			     UIO_MAXIOV);  	if (unlikely(ret < 0)) {  		vq_err(vq, "Translation failure %d in indirect.\n", ret);  		return ret;  |