diff --git a/crates/vhost/src/vhost_kern/mod.rs b/crates/vhost/src/vhost_kern/mod.rs index 42a1450..a437e8c 100644 --- a/crates/vhost/src/vhost_kern/mod.rs +++ b/crates/vhost/src/vhost_kern/mod.rs @@ -200,20 +200,15 @@ impl VhostBackend for T { /// /// # Arguments /// * `queue_index` - Index of the queue to set addresses for. - /// * `config_data` - Vring config data. + /// * `config_data` - Vring config data, addresses of desc_table, avail_ring + /// and used_ring are in the guest address space. fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> { if !self.is_valid(config_data) { return Err(Error::InvalidQueue); } - let vring_addr = vhost_vring_addr { - index: queue_index as u32, - flags: config_data.flags, - desc_user_addr: config_data.desc_table_addr, - used_user_addr: config_data.used_ring_addr, - avail_user_addr: config_data.avail_ring_addr, - log_guest_addr: config_data.get_log_addr(), - }; + // The addresses are converted into the host address space. + let vring_addr = config_data.to_vhost_vring_addr(queue_index, self.mem())?; // This ioctl is called on a valid vhost fd and has its // return value checked. @@ -428,3 +423,34 @@ impl VhostIotlbMsgParser for vhost_msg_v2 { Ok(()) } } + +impl VringConfigData { + /// Convert the config (guest address space) into vhost_vring_addr + /// (host address space). + pub fn to_vhost_vring_addr( + &self, + queue_index: usize, + mem: &AS, + ) -> Result { + let desc_addr = mem + .memory() + .get_host_address(GuestAddress(self.desc_table_addr)) + .map_err(|_| Error::DescriptorTableAddress)?; + let avail_addr = mem + .memory() + .get_host_address(GuestAddress(self.avail_ring_addr)) + .map_err(|_| Error::AvailAddress)?; + let used_addr = mem + .memory() + .get_host_address(GuestAddress(self.used_ring_addr)) + .map_err(|_| Error::UsedAddress)?; + Ok(vhost_vring_addr { + index: queue_index as u32, + flags: self.flags, + desc_user_addr: desc_addr as u64, + used_user_addr: used_addr as u64, + avail_user_addr: avail_addr as u64, + log_guest_addr: self.get_log_addr(), + }) + } +}