diff --git a/pci/src/msi.rs b/pci/src/msi.rs index fda44819b..05494cdac 100644 --- a/pci/src/msi.rs +++ b/pci/src/msi.rs @@ -214,6 +214,8 @@ impl MsiConfig { if let Err(e) = self.interrupt_source_group.mask(idx as InterruptIndex) { error!("Failed masking vector: {:?}", e); } + } else if let Err(e) = self.interrupt_source_group.unmask(idx as InterruptIndex) { + error!("Failed unmasking vector: {:?}", e); } } diff --git a/pci/src/msix.rs b/pci/src/msix.rs index e8020e00b..a83218dca 100644 --- a/pci/src/msix.rs +++ b/pci/src/msix.rs @@ -116,6 +116,9 @@ impl MsixConfig { if let Err(e) = self.interrupt_source_group.mask(idx as InterruptIndex) { error!("Failed masking vector: {:?}", e); } + } else if let Err(e) = self.interrupt_source_group.unmask(idx as InterruptIndex) + { + error!("Failed unmasking vector: {:?}", e); } } diff --git a/vmm/src/interrupt.rs b/vmm/src/interrupt.rs index ee43173f4..fad2ea691 100644 --- a/vmm/src/interrupt.rs +++ b/vmm/src/interrupt.rs @@ -9,6 +9,7 @@ use kvm_ioctls::VmFd; use std::collections::HashMap; use std::io; use std::mem::size_of; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use vm_allocator::SystemAllocator; use vm_device::interrupt::{ @@ -52,6 +53,7 @@ pub fn vec_with_array_field(count: usize) -> Vec { pub struct InterruptRoute { pub gsi: u32, pub irq_fd: EventFd, + registered: AtomicBool, } impl InterruptRoute { @@ -61,25 +63,43 @@ impl InterruptRoute { .allocate_gsi() .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Failed allocating new GSI"))?; - Ok(InterruptRoute { gsi, irq_fd }) + Ok(InterruptRoute { + gsi, + irq_fd, + registered: AtomicBool::new(false), + }) } pub fn enable(&self, vm: &Arc) -> Result<()> { - vm.register_irqfd(&self.irq_fd, self.gsi).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("Failed registering irq_fd: {}", e), - ) - }) + if !self.registered.load(Ordering::SeqCst) { + vm.register_irqfd(&self.irq_fd, self.gsi).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed registering irq_fd: {}", e), + ) + })?; + + // Update internals to track the irq_fd as "registered". + self.registered.store(true, Ordering::SeqCst); + } + + Ok(()) } pub fn disable(&self, vm: &Arc) -> Result<()> { - vm.unregister_irqfd(&self.irq_fd, self.gsi).map_err(|e| { - io::Error::new( - io::ErrorKind::Other, - format!("Failed unregistering irq_fd: {}", e), - ) - }) + if self.registered.load(Ordering::SeqCst) { + vm.unregister_irqfd(&self.irq_fd, self.gsi).map_err(|e| { + io::Error::new( + io::ErrorKind::Other, + format!("Failed unregistering irq_fd: {}", e), + ) + })?; + + // Update internals to track the irq_fd as "unregistered". + self.registered.store(false, Ordering::SeqCst); + } + + Ok(()) } } @@ -234,11 +254,29 @@ impl InterruptSourceGroup for MsiInterruptGroup { } fn mask(&self, index: InterruptIndex) -> Result<()> { - self.mask_kvm_entry(index, true) + self.mask_kvm_entry(index, true)?; + + if let Some(route) = self.irq_routes.get(&index) { + return route.disable(&self.vm_fd); + } + + Err(io::Error::new( + io::ErrorKind::Other, + format!("mask: Invalid interrupt index {}", index), + )) } fn unmask(&self, index: InterruptIndex) -> Result<()> { - self.mask_kvm_entry(index, false) + self.mask_kvm_entry(index, false)?; + + if let Some(route) = self.irq_routes.get(&index) { + return route.enable(&self.vm_fd); + } + + Err(io::Error::new( + io::ErrorKind::Other, + format!("unmask: Invalid interrupt index {}", index), + )) } }