make set_gpu_socket a default implementation

relax the requirement of the trait when implementing
the `set_gpu_socket` method, make the `set_gpu_socket`
method optional, and ensure that the function `set_gpu_socket`
returns an error if the backend does not implement it.

Fixes #265

Signed-off-by: Dorinda Bassey <dbassey@redhat.com>
This commit is contained in:
Dorinda Bassey 2024-10-18 17:57:44 +02:00 committed by Stefano Garzarella
parent 64cc75a8ab
commit 1350073485
5 changed files with 31 additions and 23 deletions

View file

@ -89,9 +89,14 @@ pub trait VhostUserBackend: Send + Sync {
/// Set handler for communicating with the frontend by the gpu specific backend communication
/// channel.
///
/// This method only exits when the crate feature gpu-socket is enabled, because this is only
/// useful for a gpu device.
fn set_gpu_socket(&self, _gpu_backend: GpuBackend);
/// This function returns a `Result`, returning an error if the backend does not implement this
/// function.
fn set_gpu_socket(&self, _gpu_backend: GpuBackend) -> Result<()> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"backend does not support set_gpu_socket() / VHOST_USER_GPU_SET_SOCKET",
))
}
/// Get the map to map queue index to worker thread index.
///
@ -206,9 +211,14 @@ pub trait VhostUserBackendMut: Send + Sync {
/// Set handler for communicating with the frontend by the gpu specific backend communication
/// channel.
///
/// This method only exits when the crate feature gpu-socket is enabled, because this is only
/// useful for a gpu device.
fn set_gpu_socket(&mut self, gpu_backend: GpuBackend);
/// This function returns a `Result`, returning an error if the backend does not implement this
/// function.
fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) -> Result<()> {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"backend does not support set_gpu_socket() / VHOST_USER_GPU_SET_SOCKET",
))
}
/// Get the map to map queue index to worker thread index.
///
@ -315,7 +325,7 @@ impl<T: VhostUserBackend> VhostUserBackend for Arc<T> {
self.deref().set_backend_req_fd(backend)
}
fn set_gpu_socket(&self, gpu_backend: GpuBackend) {
fn set_gpu_socket(&self, gpu_backend: GpuBackend) -> Result<()> {
self.deref().set_gpu_socket(gpu_backend)
}
@ -396,7 +406,7 @@ impl<T: VhostUserBackendMut> VhostUserBackend for Mutex<T> {
self.lock().unwrap().set_backend_req_fd(backend)
}
fn set_gpu_socket(&self, gpu_backend: GpuBackend) {
fn set_gpu_socket(&self, gpu_backend: GpuBackend) -> Result<()> {
self.lock().unwrap().set_gpu_socket(gpu_backend)
}
@ -480,7 +490,7 @@ impl<T: VhostUserBackendMut> VhostUserBackend for RwLock<T> {
self.write().unwrap().set_backend_req_fd(backend)
}
fn set_gpu_socket(&self, gpu_backend: GpuBackend) {
fn set_gpu_socket(&self, gpu_backend: GpuBackend) -> Result<()> {
self.write().unwrap().set_gpu_socket(gpu_backend)
}
@ -604,8 +614,6 @@ pub mod tests {
fn set_backend_req_fd(&mut self, _backend: Backend) {}
fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) {}
fn queues_per_thread(&self) -> Vec<u64> {
vec![1, 1]
}

View file

@ -555,8 +555,10 @@ where
self.backend.set_backend_req_fd(backend);
}
fn set_gpu_socket(&mut self, gpu_backend: GpuBackend) {
self.backend.set_gpu_socket(gpu_backend);
fn set_gpu_socket(&mut self, gpu_backend: GpuBackend) -> VhostUserResult<()> {
self.backend
.set_gpu_socket(gpu_backend)
.map_err(VhostUserError::ReqHandlerError)
}
fn get_inflight_fd(

View file

@ -10,7 +10,6 @@ use std::thread;
use vhost::vhost_user::message::{
VhostUserConfigFlags, VhostUserHeaderFlag, VhostUserInflight, VhostUserProtocolFeatures,
};
use vhost::vhost_user::GpuBackend;
use vhost::vhost_user::{Backend, Frontend, Listener, VhostUserFrontend};
use vhost::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
use vhost_user_backend::{VhostUserBackendMut, VhostUserDaemon, VringRwLock};
@ -78,8 +77,6 @@ impl VhostUserBackendMut for MockVhostBackend {
Ok(())
}
fn set_gpu_socket(&mut self, _gpu_backend: GpuBackend) {}
fn update_memory(&mut self, atomic_mem: GuestMemoryAtomic<GuestMemoryMmap>) -> Result<()> {
let mem = atomic_mem.memory();
let region = mem.find_region(GuestAddress(0x100000)).unwrap();