diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 374cba7..38e6dec 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1,5 +1,5 @@ { - "coverage_score": 85.1, + "coverage_score": 85.0, "exclude_path": "", "crate_features": "" } diff --git a/src/lib.rs b/src/lib.rs index 4851ed9..c65a19e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,7 +106,39 @@ where }) } - /// Connect to the vhost-user socket and run a dedicated thread handling all requests coming + /// Run a dedicated thread handling all requests coming through the socket. + /// This runs in an infinite loop that should be terminating once the other + /// end of the socket (the VMM) hangs up. + /// + /// This function is the common code for starting a new daemon, no matter if + /// it acts as a client or a server. + fn start_daemon( + &mut self, + mut handler: SlaveReqHandler>>, + ) -> Result<()> { + let handle = thread::Builder::new() + .name(self.name.clone()) + .spawn(move || loop { + handler.handle_request().map_err(Error::HandleRequest)?; + }) + .map_err(Error::StartDaemon)?; + + self.main_thread = Some(handle); + + Ok(()) + } + + /// Connect to the vhost-user socket and run a dedicated thread handling + /// all requests coming through this socket. This runs in an infinite loop + /// that should be terminating once the other end of the socket (the VMM) + /// hangs up. + pub fn start_client(&mut self, socket_path: &str) -> Result<()> { + let slave_handler = SlaveReqHandler::connect(socket_path, self.handler.clone()) + .map_err(Error::CreateSlaveReqHandler)?; + self.start_daemon(slave_handler) + } + + /// Listen to the vhost-user socket and run a dedicated thread handling all requests coming /// through this socket. /// /// This runs in an infinite loop that should be terminating once the other end of the socket @@ -116,19 +148,8 @@ where pub fn start(&mut self, listener: Listener) -> Result<()> { let mut slave_listener = SlaveListener::new(listener, self.handler.clone()) .map_err(Error::CreateSlaveListener)?; - let mut slave_handler = self.accept(&mut slave_listener)?; - let handle = thread::Builder::new() - .name(self.name.clone()) - .spawn(move || loop { - slave_handler - .handle_request() - .map_err(Error::HandleRequest)?; - }) - .map_err(Error::StartDaemon)?; - - self.main_thread = Some(handle); - - Ok(()) + let slave_handler = self.accept(&mut slave_listener)?; + self.start_daemon(slave_handler) } fn accept( @@ -171,7 +192,7 @@ where mod tests { use super::backend::tests::MockVhostBackend; use super::*; - use std::os::unix::net::UnixStream; + use std::os::unix::net::{UnixListener, UnixStream}; use std::sync::Barrier; use vm_memory::{GuestAddress, GuestMemoryAtomic, GuestMemoryMmap}; @@ -209,4 +230,41 @@ mod tests { daemon.wait().unwrap(); thread.join().unwrap(); } + + #[test] + fn test_new_daemon_client() { + let mem = GuestMemoryAtomic::new( + GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0x100000), 0x10000)]).unwrap(), + ); + let backend = Arc::new(Mutex::new(MockVhostBackend::new())); + let mut daemon = VhostUserDaemon::new("test".to_owned(), backend, mem).unwrap(); + + let handlers = daemon.get_epoll_handlers(); + assert_eq!(handlers.len(), 2); + + let barrier = Arc::new(Barrier::new(2)); + let tmpdir = tempfile::tempdir().unwrap(); + let mut path = tmpdir.path().to_path_buf(); + path.push("socket"); + + let barrier2 = barrier.clone(); + let path1 = path.clone(); + let thread = thread::spawn(move || { + let listener = UnixListener::bind(&path1).unwrap(); + barrier2.wait(); + let (stream, _) = listener.accept().unwrap(); + barrier2.wait(); + drop(stream) + }); + + barrier.wait(); + daemon + .start_client(path.as_path().to_str().unwrap()) + .unwrap(); + barrier.wait(); + // Above process generates a `HandleRequest(PartialMessage)` error. + daemon.wait().unwrap_err(); + daemon.wait().unwrap(); + thread.join().unwrap(); + } }