From e5a5f1fe346c9d4c8ee58dad7a412646d353c557 Mon Sep 17 00:00:00 2001 From: Liu Jiang Date: Sat, 18 Dec 2021 12:41:48 +0800 Subject: [PATCH] Refine VringStateGuard and VringStateMutGuard Previously VringStateGuard and VringStateMutGuard are defined as enum, which limits the extensibility of the interface. So convert them into traits by using the High Rank Trait Bound tricky. Signed-off-by: Liu Jiang --- Cargo.toml | 4 +-- src/vring.rs | 99 ++++++++++++++++++++++------------------------------ 2 files changed, 43 insertions(+), 60 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 835be1f..f568ab9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ license = "Apache-2.0" [dependencies] libc = ">=0.2.39" log = ">=0.4.6" -vhost = { version = "0.2", features = ["vhost-user-slave"] } +vhost = { version = "0.3", features = ["vhost-user-slave"] } virtio-bindings = "0.1" -virtio-queue = { git = "https://github.com/rust-vmm/vm-virtio", rev = "cc1fa35" } +virtio-queue = "0.1" vm-memory = {version = "0.7", features = ["backend-mmap", "backend-atomic"]} vmm-sys-util = "0.9" diff --git a/src/vring.rs b/src/vring.rs index 7594cca..374890e 100644 --- a/src/vring.rs +++ b/src/vring.rs @@ -16,62 +16,29 @@ use virtio_queue::{Error as VirtQueError, Queue}; use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; use vmm_sys_util::eventfd::EventFd; -/// Struct to hold a shared reference to the underlying `VringState` object. -pub enum VringStateGuard<'a, M: GuestAddressSpace> { - /// A `MutexGuard` for a `VringState` object. - MutexGuard(MutexGuard<'a, VringState>), - /// A `ReadGuard` for a `VringState` object. - RwLockReadGuard(RwLockReadGuard<'a, VringState>), +/// Trait for objects returned by `VringT::get_ref()`. +pub trait VringStateGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_ref()`. + type G: Deref>; } -impl<'a, M: GuestAddressSpace> Deref for VringStateGuard<'a, M> { - type Target = VringState; - - fn deref(&self) -> &Self::Target { - match self { - VringStateGuard::MutexGuard(v) => v.deref(), - VringStateGuard::RwLockReadGuard(v) => v.deref(), - } - } +/// Trait for objects returned by `VringT::get_mut()`. +pub trait VringStateMutGuard<'a, M: GuestAddressSpace> { + /// Type for guard returned by `VringT::get_mut()`. + type G: DerefMut>; } -/// Struct to hold an exclusive reference to the underlying `VringState` object. -pub enum VringStateMutGuard<'a, M: GuestAddressSpace> { - /// A `MutexGuard` for a `VringState` object. - MutexGuard(MutexGuard<'a, VringState>), - /// A `WriteGuard` for a `VringState` object. - RwLockWriteGuard(RwLockWriteGuard<'a, VringState>), -} - -impl<'a, M: GuestAddressSpace> Deref for VringStateMutGuard<'a, M> { - type Target = VringState; - - fn deref(&self) -> &Self::Target { - match self { - VringStateMutGuard::MutexGuard(v) => v.deref(), - VringStateMutGuard::RwLockWriteGuard(v) => v.deref(), - } - } -} - -impl<'a, M: GuestAddressSpace> DerefMut for VringStateMutGuard<'a, M> { - fn deref_mut(&mut self) -> &mut Self::Target { - match self { - VringStateMutGuard::MutexGuard(v) => v.deref_mut(), - VringStateMutGuard::RwLockWriteGuard(v) => v.deref_mut(), - } - } -} - -pub trait VringT { +pub trait VringT: + for<'a> VringStateGuard<'a, M> + for<'a> VringStateMutGuard<'a, M> +{ /// Create a new instance of Vring. fn new(mem: M, max_queue_size: u16) -> Self; /// Get an immutable reference to the kick event fd. - fn get_ref(&self) -> VringStateGuard; + fn get_ref(&self) -> >::G; /// Get a mutable reference to the kick event fd. - fn get_mut(&self) -> VringStateMutGuard; + fn get_mut(&self) -> >::G; /// Add an used descriptor into the used queue. fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError>; @@ -276,19 +243,27 @@ impl VringMutex { } } -impl VringT for VringMutex { +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringMutex { + type G = MutexGuard<'a, VringState>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringMutex { + type G = MutexGuard<'a, VringState>; +} + +impl VringT for VringMutex { fn new(mem: M, max_queue_size: u16) -> Self { VringMutex { state: Arc::new(Mutex::new(VringState::new(mem, max_queue_size))), } } - fn get_ref(&self) -> VringStateGuard { - VringStateGuard::MutexGuard(self.state.lock().unwrap()) + fn get_ref(&self) -> >::G { + self.state.lock().unwrap() } - fn get_mut(&self) -> VringStateMutGuard { - VringStateMutGuard::MutexGuard(self.lock()) + fn get_mut(&self) -> >::G { + self.lock() } fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { @@ -370,19 +345,27 @@ impl VringRwLock { } } -impl VringT for VringRwLock { +impl<'a, M: 'a + GuestAddressSpace> VringStateGuard<'a, M> for VringRwLock { + type G = RwLockReadGuard<'a, VringState>; +} + +impl<'a, M: 'a + GuestAddressSpace> VringStateMutGuard<'a, M> for VringRwLock { + type G = RwLockWriteGuard<'a, VringState>; +} + +impl VringT for VringRwLock { fn new(mem: M, max_queue_size: u16) -> Self { VringRwLock { state: Arc::new(RwLock::new(VringState::new(mem, max_queue_size))), } } - fn get_ref(&self) -> VringStateGuard { - VringStateGuard::RwLockReadGuard(self.state.read().unwrap()) + fn get_ref(&self) -> >::G { + self.state.read().unwrap() } - fn get_mut(&self) -> VringStateMutGuard { - VringStateMutGuard::RwLockWriteGuard(self.write_lock()) + fn get_mut(&self) -> >::G { + self.write_lock() } fn add_used(&self, desc_index: u16, len: u32) -> Result<(), VirtQueError> { @@ -467,7 +450,7 @@ mod tests { let vring = VringMutex::new(mem, 0x1000); assert!(vring.get_ref().get_kick().is_none()); - assert_eq!(vring.get_ref().enabled, false); + assert_eq!(vring.get_mut().enabled, false); assert_eq!(vring.lock().queue.ready(), false); assert_eq!(vring.lock().queue.state.event_idx_enabled, false); @@ -514,7 +497,7 @@ mod tests { let eventfd = EventFd::new(0).unwrap(); let file = unsafe { File::from_raw_fd(eventfd.as_raw_fd()) }; - assert!(vring.get_ref().kick.is_none()); + assert!(vring.get_mut().kick.is_none()); assert_eq!(vring.read_kick().unwrap(), true); vring.set_kick(Some(file)); eventfd.write(1).unwrap();