vhost-user: Add support for GET_SHMEM_CONFIG message

Add support for GET_SHMEM_CONFIG message to retrieve
VirtIO Shared Memory Regions configuration.

This is useful when the frontend is unaware of specific
backend type and configuration of the memory layout.

Based on the patch [1] which is just waiting for being
merged.

[1] -
https://lore.kernel.org/all/20251111091058.879669-1-aesteve@redhat.com/

Signed-off-by: Albert Esteve <aesteve@redhat.com>
This commit is contained in:
Albert Esteve 2025-12-11 10:40:50 +01:00 committed by Manos Pitsidianakis
parent 8c00b8829f
commit 03c9524bc4
8 changed files with 262 additions and 2 deletions

View file

@ -3,6 +3,8 @@
## [Unreleased] ## [Unreleased]
### Added ### Added
- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message
### Changed ### Changed
### Deprecated ### Deprecated
### Fixed ### Fixed

View file

@ -25,7 +25,7 @@ use std::sync::{Arc, Mutex, RwLock};
use vhost::vhost_user::message::{ use vhost::vhost_user::message::{
VhostTransferStateDirection, VhostTransferStatePhase, VhostUserProtocolFeatures, VhostTransferStateDirection, VhostTransferStatePhase, VhostUserProtocolFeatures,
VhostUserSharedMsg, VhostUserShMemConfig, VhostUserSharedMsg,
}; };
use vhost::vhost_user::Backend; use vhost::vhost_user::Backend;
use vm_memory::bitmap::Bitmap; use vm_memory::bitmap::Bitmap;
@ -180,6 +180,13 @@ pub trait VhostUserBackend: Send + Sync {
"back end does not support state transfer", "back end does not support state transfer",
)) ))
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"back end does not support shared memory regions",
))
}
} }
/// Trait without interior mutability for vhost user backend servers to implement concrete services. /// Trait without interior mutability for vhost user backend servers to implement concrete services.
@ -322,6 +329,13 @@ pub trait VhostUserBackendMut: Send + Sync {
"back end does not support state transfer", "back end does not support state transfer",
)) ))
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"back end does not support shared memory regions",
))
}
} }
impl<T: VhostUserBackend> VhostUserBackend for Arc<T> { impl<T: VhostUserBackend> VhostUserBackend for Arc<T> {
@ -411,6 +425,10 @@ impl<T: VhostUserBackend> VhostUserBackend for Arc<T> {
fn check_device_state(&self) -> Result<()> { fn check_device_state(&self) -> Result<()> {
self.deref().check_device_state() self.deref().check_device_state()
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.deref().get_shmem_config()
}
} }
impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> { impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> {
@ -503,6 +521,10 @@ impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> {
fn check_device_state(&self) -> Result<()> { fn check_device_state(&self) -> Result<()> {
self.lock().unwrap().check_device_state() self.lock().unwrap().check_device_state()
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.lock().unwrap().get_shmem_config()
}
} }
impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> { impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> {
@ -595,6 +617,10 @@ impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> {
fn check_device_state(&self) -> Result<()> { fn check_device_state(&self) -> Result<()> {
self.read().unwrap().check_device_state() self.read().unwrap().check_device_state()
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.read().unwrap().get_shmem_config()
}
} }
#[cfg(test)] #[cfg(test)]

View file

@ -18,7 +18,7 @@ use crate::bitmap::{BitmapReplace, MemRegionBitmap, MmapLogReg};
use userfaultfd::{Uffd, UffdBuilder}; use userfaultfd::{Uffd, UffdBuilder};
use vhost::vhost_user::message::{ use vhost::vhost_user::message::{
VhostTransferStateDirection, VhostTransferStatePhase, VhostUserConfigFlags, VhostUserLog, VhostTransferStateDirection, VhostTransferStatePhase, VhostUserConfigFlags, VhostUserLog,
VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserSharedMsg, VhostUserMemoryRegion, VhostUserProtocolFeatures, VhostUserShMemConfig, VhostUserSharedMsg,
VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags, VhostUserSingleMemoryRegion, VhostUserVirtioFeatures, VhostUserVringAddrFlags,
VhostUserVringState, VhostUserVringState,
}; };
@ -688,6 +688,12 @@ where
.map_err(VhostUserError::ReqHandlerError) .map_err(VhostUserError::ReqHandlerError)
} }
fn get_shmem_config(&mut self) -> VhostUserResult<VhostUserShMemConfig> {
self.backend
.get_shmem_config()
.map_err(VhostUserError::ReqHandlerError)
}
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> VhostUserResult<File> { fn postcopy_advice(&mut self) -> VhostUserResult<File> {
let mut uffd_builder = UffdBuilder::new(); let mut uffd_builder = UffdBuilder::new();

View file

@ -4,6 +4,7 @@
### Added ### Added
- [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support - [[#251]](https://github.com/rust-vmm/vhost/pull/251) Add `SHMEM_MAP` and `SHMEM_UNMAP` support
- [[#339]](https://github.com/rust-vmm/vhost/pull/339) Add support for `GET_SHMEM_CONFIG` message
### Changed ### Changed
### Deprecated ### Deprecated

View file

@ -81,6 +81,7 @@ pub trait VhostUserBackendReqHandler {
fd: File, fd: File,
) -> Result<Option<File>>; ) -> Result<Option<File>>;
fn check_device_state(&self) -> Result<()>; fn check_device_state(&self) -> Result<()>;
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File>; fn postcopy_advice(&self) -> Result<File>;
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
@ -146,6 +147,7 @@ pub trait VhostUserBackendReqHandlerMut {
fd: File, fd: File,
) -> Result<Option<File>>; ) -> Result<Option<File>>;
fn check_device_state(&mut self) -> Result<()>; fn check_device_state(&mut self) -> Result<()>;
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig>;
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> Result<File>; fn postcopy_advice(&mut self) -> Result<File>;
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
@ -289,6 +291,10 @@ impl<T: VhostUserBackendReqHandlerMut> VhostUserBackendReqHandler for Mutex<T> {
self.lock().unwrap().check_device_state() self.lock().unwrap().check_device_state()
} }
fn get_shmem_config(&self) -> Result<VhostUserShMemConfig> {
self.lock().unwrap().get_shmem_config()
}
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advice(&self) -> Result<File> { fn postcopy_advice(&self) -> Result<File> {
self.lock().unwrap().postcopy_advice() self.lock().unwrap().postcopy_advice()
@ -679,6 +685,11 @@ impl<S: VhostUserBackendReqHandler> BackendReqHandler<S> {
}; };
self.send_reply_message(&hdr, &msg)?; self.send_reply_message(&hdr, &msg)?;
} }
Ok(FrontendReq::GET_SHMEM_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
let msg = self.backend.get_shmem_config()?;
self.send_reply_message(&hdr, &msg)?;
}
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
Ok(FrontendReq::POSTCOPY_ADVISE) => { Ok(FrontendReq::POSTCOPY_ADVISE) => {
self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?; self.check_proto_feature(VhostUserProtocolFeatures::PAGEFAULT)?;
@ -1038,4 +1049,108 @@ mod tests {
handler.check_state().unwrap_err(); handler.check_state().unwrap_err();
assert!(handler.as_raw_fd() >= 0); assert!(handler.as_raw_fd() >= 0);
} }
// Helper to send GET_SHMEM_CONFIG request and receive response
fn send_get_shmem_config_request(
mut endpoint: Endpoint<VhostUserMsgHeader<FrontendReq>>,
) -> VhostUserShMemConfig {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
endpoint.send_message(&hdr, &VhostUserEmpty, None).unwrap();
let (reply_hdr, reply_config, rfds) = endpoint.recv_body::<VhostUserShMemConfig>().unwrap();
assert_eq!(reply_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
assert!(reply_hdr.is_reply());
assert!(rfds.is_none());
reply_config
}
// Helper to create handler with SHMEM protocol feature enabled
fn create_handler_with_shmem(
backend: Arc<Mutex<DummyBackendReqHandler>>,
p1: UnixStream,
) -> BackendReqHandler<Mutex<DummyBackendReqHandler>> {
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
handler.acked_protocol_features = VhostUserProtocolFeatures::SHMEM.bits();
handler
}
#[test]
fn test_get_shmem_config_multiple_regions() {
let memory_sizes = [
0x1000, 0x2000, 0x3000, 0x4000, 0x5000, 0x6000, 0x7000, 0x8000,
];
let config = VhostUserShMemConfig::new(8, &memory_sizes);
let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);
let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});
handler.handle_request().unwrap();
let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 8);
for i in 0..8 {
assert_eq!(reply_config.memory_sizes[i], (i as u64 + 1) * 0x1000);
}
for i in 8..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}
#[test]
fn test_get_shmem_config_non_continuous_regions() {
// Create a configuration with non-continuous regions
let memory_sizes = [0x10000, 0, 0x20000, 0, 0, 0, 0, 0];
let config = VhostUserShMemConfig::new(2, &memory_sizes);
let (p1, p2) = UnixStream::pair().unwrap();
let mut dummy_backend = DummyBackendReqHandler::new();
dummy_backend.set_shmem_config(config);
let mut handler = create_handler_with_shmem(Arc::new(Mutex::new(dummy_backend)), p1);
let handle = std::thread::spawn(move || {
send_get_shmem_config_request(Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(
p2,
))
});
handler.handle_request().unwrap();
let reply_config = handle.join().unwrap();
assert_eq!(reply_config.nregions, 2);
assert_eq!(reply_config.memory_sizes[0], 0x10000);
assert_eq!(reply_config.memory_sizes[1], 0);
assert_eq!(reply_config.memory_sizes[2], 0x20000);
for i in 3..256 {
assert_eq!(reply_config.memory_sizes[i], 0);
}
}
#[test]
fn test_get_shmem_config_feature_not_negotiated() {
// Test that the request fails when SHMEM protocol feature is not negotiated
let (p1, p2) = UnixStream::pair().unwrap();
let backend = Arc::new(Mutex::new(DummyBackendReqHandler::new()));
let mut handler = BackendReqHandler::new(
Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p1),
backend,
);
let mut frontend_endpoint = Endpoint::<VhostUserMsgHeader<FrontendReq>>::from_stream(p2);
std::thread::spawn(move || {
let hdr = VhostUserMsgHeader::new(FrontendReq::GET_SHMEM_CONFIG, 0, 0);
let _ = frontend_endpoint.send_message(&hdr, &VhostUserEmpty, None);
});
assert!(handler.handle_request().is_err());
}
} }

View file

@ -27,6 +27,7 @@ pub struct DummyBackendReqHandler {
pub vring_enabled: [bool; MAX_QUEUE_NUM], pub vring_enabled: [bool; MAX_QUEUE_NUM],
pub inflight_file: Option<File>, pub inflight_file: Option<File>,
pub shared_file: Option<File>, pub shared_file: Option<File>,
pub shmem_config: Option<VhostUserShMemConfig>,
} }
impl DummyBackendReqHandler { impl DummyBackendReqHandler {
@ -37,6 +38,12 @@ impl DummyBackendReqHandler {
} }
} }
/// Set the shared memory configuration to be returned by `get_shmem_config`
pub fn set_shmem_config(&mut self, config: VhostUserShMemConfig) {
self.acked_protocol_features |= VhostUserProtocolFeatures::SHMEM.bits();
self.shmem_config = Some(config);
}
/// Helper to check if VirtioFeature enabled /// Helper to check if VirtioFeature enabled
fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
if self.acked_features & feat.bits() != 0 { if self.acked_features & feat.bits() != 0 {
@ -329,6 +336,15 @@ impl VhostUserBackendReqHandlerMut for DummyBackendReqHandler {
))) )))
} }
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig> {
self.shmem_config.ok_or_else(|| {
Error::ReqHandlerError(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"dummy back end does not support shared memory regions",
))
})
}
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advice(&mut self) -> Result<File> { fn postcopy_advice(&mut self) -> Result<File> {
let file = tempfile::tempfile().unwrap(); let file = tempfile::tempfile().unwrap();

View file

@ -79,6 +79,9 @@ pub trait VhostUserFrontend: VhostBackend {
/// Remove a guest memory mapping from vhost. /// Remove a guest memory mapping from vhost.
fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>; fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()>;
/// Get the shared memory region configuration from the backend.
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig>;
/// Sends VHOST_USER_POSTCOPY_ADVISE msg to the backend /// Sends VHOST_USER_POSTCOPY_ADVISE msg to the backend
/// initiating the beginning of the postcopy process. /// initiating the beginning of the postcopy process.
/// Backend will return a userfaultfd. /// Backend will return a userfaultfd.
@ -568,6 +571,16 @@ impl VhostUserFrontend for Frontend {
node.wait_for_ack(&hdr).map_err(|e| e.into()) node.wait_for_ack(&hdr).map_err(|e| e.into())
} }
fn get_shmem_config(&mut self) -> Result<VhostUserShMemConfig> {
let mut node = self.node();
node.check_proto_feature(VhostUserProtocolFeatures::SHMEM)?;
let hdr = node.send_request_header(FrontendReq::GET_SHMEM_CONFIG, None)?;
let config = node.recv_reply::<VhostUserShMemConfig>(&hdr)?;
Ok(config)
}
#[cfg(feature = "postcopy")] #[cfg(feature = "postcopy")]
fn postcopy_advise(&mut self) -> Result<File> { fn postcopy_advise(&mut self) -> Result<File> {
let mut node = self.node(); let mut node = self.node();
@ -1202,4 +1215,45 @@ mod tests {
let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1]; let tables = vec![VhostUserMemoryRegionInfo::default(); MAX_ATTACHED_FD_ENTRIES + 1];
frontend.set_mem_table(&tables).unwrap_err(); frontend.set_mem_table(&tables).unwrap_err();
} }
#[test]
fn test_frontend_get_shmem_config() {
let (mut frontend, mut peer) = create_pair2();
let expected_config = VhostUserShMemConfig::new(2, &[0x1000, 0x2000]);
let hdr = VhostUserMsgHeader::new(
FrontendReq::GET_SHMEM_CONFIG,
0x4,
std::mem::size_of::<VhostUserShMemConfig>() as u32,
);
peer.send_message(&hdr, &expected_config, None).unwrap();
let config = frontend.get_shmem_config().unwrap();
assert_eq!(config.nregions, 2);
assert_eq!(config.memory_sizes[0], 0x1000);
assert_eq!(config.memory_sizes[1], 0x2000);
let (recv_hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(recv_hdr.get_code().unwrap(), FrontendReq::GET_SHMEM_CONFIG);
assert!(rfds.is_none());
}
#[test]
fn test_frontend_get_shmem_config_no_regions() {
let (mut frontend, mut peer) = create_pair2();
let expected_config = VhostUserShMemConfig::default();
let hdr = VhostUserMsgHeader::new(
FrontendReq::GET_SHMEM_CONFIG,
0x4,
std::mem::size_of::<VhostUserShMemConfig>() as u32,
);
peer.send_message(&hdr, &expected_config, None).unwrap();
let config = frontend.get_shmem_config().unwrap();
assert_eq!(config.nregions, 0);
for i in 0..256 {
assert_eq!(config.memory_sizes[i], 0);
}
}
} }

View file

@ -169,6 +169,8 @@ enum_value! {
/// After transferring state, check the backend for any errors that may have /// After transferring state, check the backend for any errors that may have
/// occurred during the transfer /// occurred during the transfer
CHECK_DEVICE_STATE = 43, CHECK_DEVICE_STATE = 43,
/// Get shared memory regions configuration from the backend.
GET_SHMEM_CONFIG = 44,
} }
} }
@ -688,6 +690,44 @@ impl VhostUserSingleMemoryRegion {
unsafe impl ByteValued for VhostUserSingleMemoryRegion {} unsafe impl ByteValued for VhostUserSingleMemoryRegion {}
impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {} impl VhostUserMsgValidator for VhostUserSingleMemoryRegion {}
/// Get shared memory regions configuration.
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct VhostUserShMemConfig {
/// Total number of shared memory regions
pub nregions: u32,
/// Padding for correct alignment
padding: u32,
/// Size of each memory region
pub memory_sizes: [u64; 256],
}
impl Default for VhostUserShMemConfig {
fn default() -> Self {
Self {
nregions: 0,
padding: 0,
memory_sizes: [0; 256],
}
}
}
impl VhostUserShMemConfig {
/// Create a new instance
pub fn new(nregions: u32, memory: &[u64]) -> Self {
let memory_sizes: [u64; 256] = std::array::from_fn(|i| *memory.get(i).unwrap_or(&0));
Self {
nregions,
padding: 0,
memory_sizes,
}
}
}
// SAFETY: Safe because all fields of VhostUserShMemConfig are POD.
unsafe impl ByteValued for VhostUserShMemConfig {}
impl VhostUserMsgValidator for VhostUserShMemConfig {}
/// Vring state descriptor. /// Vring state descriptor.
#[repr(C, packed)] #[repr(C, packed)]
#[derive(Copy, Clone, Default)] #[derive(Copy, Clone, Default)]