diff --git a/Cargo.toml b/Cargo.toml index 835be1f..f568ab9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ license = "Apache-2.0" [dependencies] libc = ">=0.2.39" log = ">=0.4.6" -vhost = { version = "0.2", features = ["vhost-user-slave"] } +vhost = { version = "0.3", features = ["vhost-user-slave"] } virtio-bindings = "0.1" -virtio-queue = { git = "https://github.com/rust-vmm/vm-virtio", rev = "cc1fa35" } +virtio-queue = "0.1" vm-memory = {version = "0.7", features = ["backend-mmap", "backend-atomic"]} vmm-sys-util = "0.9" diff --git a/src/vring.rs b/src/vring.rs index 7594cca..374890e 100644 --- a/src/vring.rs +++ b/src/vring.rs @@ -16,62 +16,29 @@ use virtio_queue::{Error as VirtQueError, Queue}; use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; use vmm_sys_util::eventfd::EventFd; -/// Struct to hold a shared reference to the underlying `VringState` object. -pub enum VringStateGuard<'a, M: GuestAddressSpace> { - /// A `MutexGuard` for a `VringState` object. - MutexGuard(MutexGuard<'a, VringState>), - /// A `ReadGuard` for a `VringState` object. - RwLockReadGuard(RwLockReadGuard<'a, VringState>), +/// Trait for objects returned by `VringT::get_ref()`. +pub trait VringStateGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_ref()`. + type G: Deref>; } -impl<'a, M: GuestAddressSpace> Deref for VringStateGuard<'a, M> { - type Target = VringState; - - fn deref(&self) -> &Self::Target { - match self { - VringStateGuard::MutexGuard(v) => v.deref(), - VringStateGuard::RwLockReadGuard(v) => v.deref(), - } - } +/// Trait for objects returned by `VringT::get_mut()`. +pub trait VringStateMutGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_mut()`. + type G: DerefMut>; } -/// Struct to hold an exclusive reference to the underlying `VringState` object. -pub enum VringStateMutGuard<'a, M: GuestAddressSpace> { - /// A `MutexGuard` for a `VringState` object. - MutexGuard(MutexGuard<'a, VringState>), - /// A `WriteGuard` for a `VringState` object. - RwLockWriteGuard(RwLockWriteGuard<'a, VringState>), -} - -impl<'a, M: GuestAddressSpace> Deref for VringStateMutGuard<'a, M> { - type Target = VringState; - - fn deref(&self) -> &Self::Target { - match self { - VringStateMutGuard::MutexGuard(v) => v.deref(), - VringStateMutGuard::RwLockWriteGuard(v) => v.deref(), - } - } -} - -impl<'a, M: GuestAddressSpace> DerefMut for VringStateMutGuard<'a, M> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - VringStateMutGuard::MutexGuard(v) => v.deref_mut(), - VringStateMutGuard::RwLockWriteGuard(v) => v.deref_mut(), - } - } -} - -pub trait VringT { +pub trait VringT: + for<'a> VringStateGuard<'a, M> + for<'a> VringStateMutGuard<'a, M> +{ /// Create a new instance of Vring. fn new(mem: M, max_queue_size: u16) -> Self; /// Get an immutable reference to the kick event fd. - fn get_ref(&self) -> VringStateGuard; + fn get_ref(&self) -> >::G; /// Get a mutable reference to the kick event fd. - fn get_mut(&self) -> VringStateMutGuard; + fn get_mut(&self) -> >::G; /// Add an used descriptor into the used queue. fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError>; @@ -276,19 +243,27 @@ impl VringMutex { } } -impl VringT for VringMutex { +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringMutex { + type G = MutexGuard<'a, VringState>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringMutex { + type G = MutexGuard<'a, VringState>; +} + +impl VringT for VringMutex { fn new(mem: M, max_queue_size: u16) -> Self { VringMutex { state: Arc::new(Mutex::new(VringState::new(mem, max_queue_size))), } } - fn get_ref(&self) -> VringStateGuard { - VringStateGuard::MutexGuard(self.state.lock().unwrap()) + fn get_ref(&self) -> >::G { + self.state.lock().unwrap() } - fn get_mut(&self) -> VringStateMutGuard { - VringStateMutGuard::MutexGuard(self.lock()) + fn get_mut(&self) -> >::G { + self.lock() } fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { @@ -370,19 +345,27 @@ impl VringRwLock { } } -impl VringT for VringRwLock { +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringRwLock { + type G = RwLockReadGuard<'a, VringState>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringRwLock { + type G = RwLockWriteGuard<'a, VringState>; +} + +impl VringT for VringRwLock { fn new(mem: M, max_queue_size: u16) -> Self { VringRwLock { state: Arc::new(RwLock::new(VringState::new(mem, max_queue_size))), } } - fn get_ref(&self) -> VringStateGuard { - VringStateGuard::RwLockReadGuard(self.state.read().unwrap()) + fn get_ref(&self) -> >::G { + self.state.read().unwrap() } - fn get_mut(&self) -> VringStateMutGuard { - VringStateMutGuard::RwLockWriteGuard(self.write_lock()) + fn get_mut(&self) -> >::G { + self.write_lock() } fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { @@ -467,7 +450,7 @@ mod tests { let vring = VringMutex::new(mem, 0x1000); assert!(vring.get_ref().get_kick().is_none()); - assert_eq!(vring.get_ref().enabled, false); + assert_eq!(vring.get_mut().enabled, false); assert_eq!(vring.lock().queue.ready(), false); assert_eq!(vring.lock().queue.state.event_idx_enabled, false); @@ -514,7 +497,7 @@ mod tests { let eventfd = EventFd::new(0).unwrap(); let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) }; - assert!(vring.get_ref().kick.is_none()); + assert!(vring.get_mut().kick.is_none()); assert_eq!(vring.read_kick().unwrap(), true); vring.set_kick(Some(file)); eventfd.write(1).unwrap();