diff options
Diffstat (limited to 'tools/lguest/lguest.c')
| -rw-r--r-- | tools/lguest/lguest.c | 84 | 
1 files changed, 35 insertions, 49 deletions
diff --git a/tools/lguest/lguest.c b/tools/lguest/lguest.c index fd2f9221b24..07a03452c22 100644 --- a/tools/lguest/lguest.c +++ b/tools/lguest/lguest.c @@ -179,29 +179,6 @@ static struct termios orig_term;  #define wmb() __asm__ __volatile__("" : : : "memory")  #define mb() __asm__ __volatile__("" : : : "memory") -/* - * Convert an iovec element to the given type. - * - * This is a fairly ugly trick: we need to know the size of the type and - * alignment requirement to check the pointer is kosher.  It's also nice to - * have the name of the type in case we report failure. - * - * Typing those three things all the time is cumbersome and error prone, so we - * have a macro which sets them all up and passes to the real function. - */ -#define convert(iov, type) \ -	((type *)_convert((iov), sizeof(type), __alignof__(type), #type)) - -static void *_convert(struct iovec *iov, size_t size, size_t align, -		      const char *name) -{ -	if (iov->iov_len != size) -		errx(1, "Bad iovec size %zu for %s", iov->iov_len, name); -	if ((unsigned long)iov->iov_base % align != 0) -		errx(1, "Bad alignment %p for %s", iov->iov_base, name); -	return iov->iov_base; -} -  /* Wrapper for the last available index.  Makes it easier to change. */  #define lg_last_avail(vq)	((vq)->last_avail_idx) @@ -228,7 +205,8 @@ static bool iov_empty(const struct iovec iov[], unsigned int num_iov)  }  /* Take len bytes from the front of this iovec. */ -static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len) +static void iov_consume(struct iovec iov[], unsigned num_iov, +			void *dest, unsigned len)  {  	unsigned int i; @@ -236,11 +214,16 @@ static void iov_consume(struct iovec iov[], unsigned num_iov, unsigned len)  		unsigned int used;  		used = iov[i].iov_len < len ? iov[i].iov_len : len; +		if (dest) { +			memcpy(dest, iov[i].iov_base, used); +			dest += used; +		}  		iov[i].iov_base += used;  		iov[i].iov_len -= used;  		len -= used;  	} -	assert(len == 0); +	if (len != 0) +		errx(1, "iovec too short!");  }  /* The device virtqueue descriptors are followed by feature bitmasks. */ @@ -864,7 +847,7 @@ static void console_output(struct virtqueue *vq)  			warn("Write to stdout gave %i (%d)", len, errno);  			break;  		} -		iov_consume(iov, out, len); +		iov_consume(iov, out, NULL, len);  	}  	/* @@ -1591,9 +1574,9 @@ static void blk_request(struct virtqueue *vq)  {  	struct vblk_info *vblk = vq->dev->priv;  	unsigned int head, out_num, in_num, wlen; -	int ret; +	int ret, i;  	u8 *in; -	struct virtio_blk_outhdr *out; +	struct virtio_blk_outhdr out;  	struct iovec iov[vq->vring.num];  	off64_t off; @@ -1603,32 +1586,36 @@ static void blk_request(struct virtqueue *vq)  	 */  	head = wait_for_vq_desc(vq, iov, &out_num, &in_num); -	/* -	 * Every block request should contain at least one output buffer -	 * (detailing the location on disk and the type of request) and one -	 * input buffer (to hold the result). -	 */ -	if (out_num == 0 || in_num == 0) -		errx(1, "Bad virtblk cmd %u out=%u in=%u", -		     head, out_num, in_num); +	/* Copy the output header from the front of the iov (adjusts iov) */ +	iov_consume(iov, out_num, &out, sizeof(out)); + +	/* Find and trim end of iov input array, for our status byte. */ +	in = NULL; +	for (i = out_num + in_num - 1; i >= out_num; i--) { +		if (iov[i].iov_len > 0) { +			in = iov[i].iov_base + iov[i].iov_len - 1; +			iov[i].iov_len--; +			break; +		} +	} +	if (!in) +		errx(1, "Bad virtblk cmd with no room for status"); -	out = convert(&iov[0], struct virtio_blk_outhdr); -	in = convert(&iov[out_num+in_num-1], u8);  	/*  	 * For historical reasons, block operations are expressed in 512 byte  	 * "sectors".  	 */ -	off = out->sector * 512; +	off = out.sector * 512;  	/*  	 * In general the virtio block driver is allowed to try SCSI commands.  	 * It'd be nice if we supported eject, for example, but we don't.  	 */ -	if (out->type & VIRTIO_BLK_T_SCSI_CMD) { +	if (out.type & VIRTIO_BLK_T_SCSI_CMD) {  		fprintf(stderr, "Scsi commands unsupported\n");  		*in = VIRTIO_BLK_S_UNSUPP;  		wlen = sizeof(*in); -	} else if (out->type & VIRTIO_BLK_T_OUT) { +	} else if (out.type & VIRTIO_BLK_T_OUT) {  		/*  		 * Write  		 * @@ -1636,10 +1623,10 @@ static void blk_request(struct virtqueue *vq)  		 * if they try to write past end.  		 */  		if (lseek64(vblk->fd, off, SEEK_SET) != off) -			err(1, "Bad seek to sector %llu", out->sector); +			err(1, "Bad seek to sector %llu", out.sector); -		ret = writev(vblk->fd, iov+1, out_num-1); -		verbose("WRITE to sector %llu: %i\n", out->sector, ret); +		ret = writev(vblk->fd, iov, out_num); +		verbose("WRITE to sector %llu: %i\n", out.sector, ret);  		/*  		 * Grr... Now we know how long the descriptor they sent was, we @@ -1655,7 +1642,7 @@ static void blk_request(struct virtqueue *vq)  		wlen = sizeof(*in);  		*in = (ret >= 0 ? VIRTIO_BLK_S_OK : VIRTIO_BLK_S_IOERR); -	} else if (out->type & VIRTIO_BLK_T_FLUSH) { +	} else if (out.type & VIRTIO_BLK_T_FLUSH) {  		/* Flush */  		ret = fdatasync(vblk->fd);  		verbose("FLUSH fdatasync: %i\n", ret); @@ -1669,10 +1656,9 @@ static void blk_request(struct virtqueue *vq)  		 * if they try to read past end.  		 */  		if (lseek64(vblk->fd, off, SEEK_SET) != off) -			err(1, "Bad seek to sector %llu", out->sector); +			err(1, "Bad seek to sector %llu", out.sector); -		ret = readv(vblk->fd, iov+1, in_num-1); -		verbose("READ from sector %llu: %i\n", out->sector, ret); +		ret = readv(vblk->fd, iov + out_num, in_num);  		if (ret >= 0) {  			wlen = sizeof(*in) + ret;  			*in = VIRTIO_BLK_S_OK; @@ -1758,7 +1744,7 @@ static void rng_input(struct virtqueue *vq)  		len = readv(rng_info->rfd, iov, in_num);  		if (len <= 0)  			err(1, "Read from /dev/random gave %i", len); -		iov_consume(iov, in_num, len); +		iov_consume(iov, in_num, NULL, len);  		totlen += len;  	}  |