diff --git a/vm-virtio/Cargo.toml b/vm-virtio/Cargo.toml index e18ff585b..7c193a90c 100644 --- a/vm-virtio/Cargo.toml +++ b/vm-virtio/Cargo.toml @@ -27,8 +27,8 @@ vmm-sys-util = ">=0.3.1" [dependencies.vhost_rs] path = "../vhost_rs" -features = ["vhost-user-master"] +features = ["vhost-user-master", "vhost-user-slave"] [dependencies.vm-memory] git = "https://github.com/rust-vmm/vm-memory" -features = ["backend-mmap"] +features = ["backend-mmap", "backend-atomic"] diff --git a/vm-virtio/src/lib.rs b/vm-virtio/src/lib.rs index 9154eb549..7d0bb2794 100755 --- a/vm-virtio/src/lib.rs +++ b/vm-virtio/src/lib.rs @@ -32,7 +32,7 @@ mod iommu; pub mod net; pub mod net_util; mod pmem; -mod queue; +pub mod queue; mod rng; pub mod vsock; diff --git a/vm-virtio/src/queue.rs b/vm-virtio/src/queue.rs index aec691350..29a1ac99f 100644 --- a/vm-virtio/src/queue.rs +++ b/vm-virtio/src/queue.rs @@ -9,6 +9,8 @@ // SPDX-License-Identifier: Apache-2.0 AND BSD-3-Clause use std::cmp::min; +use std::convert::TryInto; +use std::fmt::{self, Display}; use std::num::Wrapping; use std::sync::atomic::{fence, Ordering}; use std::sync::Arc; @@ -20,6 +22,26 @@ use vm_memory::{ pub(super) const VIRTQ_DESC_F_NEXT: u16 = 0x1; pub(super) const VIRTQ_DESC_F_WRITE: u16 = 0x2; +pub(super) const VIRTQ_DESC_F_INDIRECT: u16 = 0x4; + +#[derive(Debug)] +pub enum Error { + GuestMemoryError, + InvalidIndirectDescriptor, + InvalidChain, +} + +impl Display for Error { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + use self::Error::*; + + match self { + GuestMemoryError => write!(f, "error accessing guest memory"), + InvalidChain => write!(f, "invalid descriptor chain"), + InvalidIndirectDescriptor => write!(f, "invalid indirect descriptor"), + } + } +} // GuestMemoryMmap::read_obj() will be used to fetch the descriptor, // which has an explicit constraint that the entire descriptor doesn't @@ -76,7 +98,7 @@ unsafe impl ByteValued for Descriptor {} #[derive(Clone)] pub struct DescriptorChain<'a> { desc_table: GuestAddress, - queue_size: u16, + table_size: u16, ttl: u16, // used to prevent infinite chain cycles iommu_mapping_cb: Option>, @@ -104,11 +126,11 @@ impl<'a> DescriptorChain<'a> { pub fn checked_new( mem: &GuestMemoryMmap, desc_table: GuestAddress, - queue_size: u16, + table_size: u16, index: u16, iommu_mapping_cb: Option>, ) -> Option { - if index >= queue_size { + if index >= table_size { return None; } @@ -138,8 +160,8 @@ impl<'a> DescriptorChain<'a> { let chain = DescriptorChain { mem, desc_table, - queue_size, - ttl: queue_size, + table_size, + ttl: table_size, index, addr: GuestAddress(desc_addr), len: desc.len, @@ -155,12 +177,59 @@ impl<'a> DescriptorChain<'a> { } } + pub fn new_from_indirect(&self) -> Result { + if !self.is_indirect() { + return Err(Error::InvalidIndirectDescriptor); + } + + let desc_head = self.addr; + self.mem + .checked_offset(desc_head, 16) + .ok_or(Error::GuestMemoryError)?; + + // These reads can't fail unless Guest memory is hopelessly broken. + let desc = match self.mem.read_obj::(desc_head) { + Ok(ret) => ret, + Err(_) => return Err(Error::GuestMemoryError), + }; + + // Translate address if necessary + let (desc_addr, iommu_mapping_cb) = + if let Some(iommu_mapping_cb) = self.iommu_mapping_cb.clone() { + ( + (iommu_mapping_cb)(desc.addr).unwrap(), + Some(iommu_mapping_cb), + ) + } else { + (desc.addr, None) + }; + + let chain = DescriptorChain { + mem: self.mem, + desc_table: self.addr, + table_size: (self.len / 16).try_into().unwrap(), + ttl: (self.len / 16).try_into().unwrap(), + index: 0, + addr: GuestAddress(desc_addr), + len: desc.len, + flags: desc.flags, + next: desc.next, + iommu_mapping_cb, + }; + + if !chain.is_valid() { + return Err(Error::InvalidChain); + } + + Ok(chain) + } + fn is_valid(&self) -> bool { !(self .mem .checked_offset(self.addr, self.len as usize) .is_none() - || (self.has_next() && self.next >= self.queue_size)) + || (self.has_next() && self.next >= self.table_size)) } /// Gets if this descriptor chain has another descriptor chain linked after it. @@ -176,6 +245,10 @@ impl<'a> DescriptorChain<'a> { self.flags & VIRTQ_DESC_F_WRITE != 0 } + pub fn is_indirect(&self) -> bool { + self.flags & VIRTQ_DESC_F_INDIRECT != 0 + } + /// Gets the next descriptor in this descriptor chain, if there is one. /// /// Note that this is distinct from the next descriptor chain returned by `AvailIter`, which is @@ -185,7 +258,7 @@ impl<'a> DescriptorChain<'a> { DescriptorChain::checked_new( self.mem, self.desc_table, - self.queue_size, + self.table_size, self.next, self.iommu_mapping_cb.clone(), ) @@ -825,8 +898,8 @@ pub(crate) mod tests { assert_eq!(c.mem as *const GuestMemoryMmap, m as *const GuestMemoryMmap); assert_eq!(c.desc_table, vq.start()); - assert_eq!(c.queue_size, 16); - assert_eq!(c.ttl, c.queue_size); + assert_eq!(c.table_size, 16); + assert_eq!(c.ttl, c.table_size); assert_eq!(c.index, 0); assert_eq!(c.addr, GuestAddress(0x1000)); assert_eq!(c.len, 0x1000); @@ -837,6 +910,37 @@ pub(crate) mod tests { } } + #[test] + fn test_new_from_descriptor_chain() { + let m = &GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(); + let vq = VirtQueue::new(GuestAddress(0), m, 16); + + // create a chain with a descriptor pointing to an indirect table + vq.dtable[0].addr.set(0x1000); + vq.dtable[0].len.set(0x1000); + vq.dtable[0].next.set(0); + vq.dtable[0].flags.set(VIRTQ_DESC_F_INDIRECT); + + let c = DescriptorChain::checked_new(m, vq.start(), 16, 0, None).unwrap(); + assert!(c.is_indirect()); + + // create an indirect table with 4 chained descriptors + let mut indirect_table = Vec::with_capacity(4 as usize); + for j in 0..4 { + let desc = VirtqDesc::new(GuestAddress(0x1000 + (j * 16)), m); + desc.set(0x1000, 0x1000, VIRTQ_DESC_F_NEXT, (j + 1) as u16); + indirect_table.push(desc); + } + + // try to iterate through the indirect table descriptors + let mut i = c.new_from_indirect().unwrap(); + for j in 0..4 { + assert_eq!(i.flags, VIRTQ_DESC_F_NEXT); + assert_eq!(i.next, j + 1); + i = i.next_descriptor().unwrap(); + } + } + #[test] fn test_queue_and_iterator() { let m = &GuestMemoryMmap::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap();