diff --git a/coverage_config_x86_64.json b/coverage_config_x86_64.json index 0036d98..f35323c 100644 --- a/coverage_config_x86_64.json +++ b/coverage_config_x86_64.json @@ -1,5 +1,5 @@ { - "coverage_score": 77.4, + "coverage_score": 79.1, "exclude_path": "", "crate_features": "" } diff --git a/src/backend.rs b/src/backend.rs index c4371f9..6fb19dd 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -99,7 +99,7 @@ where /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO /// events by using epoll with the specified `token`. When the returned EventFd is written to, /// the worker thread will exit. - fn exit_event(&self, _thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, _thread_index: usize) -> Option { None } @@ -182,7 +182,7 @@ where /// If an (`EventFd`, `token`) pair is returned, the returned `EventFd` will be monitored for IO /// events by using epoll with the specified `token`. When the returned EventFd is written to, /// the worker thread will exit. - fn exit_event(&self, _thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, _thread_index: usize) -> Option { None } @@ -249,7 +249,7 @@ where self.deref().queues_per_thread() } - fn exit_event(&self, thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, thread_index: usize) -> Option { self.deref().exit_event(thread_index) } @@ -314,7 +314,7 @@ where self.lock().unwrap().queues_per_thread() } - fn exit_event(&self, thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, thread_index: usize) -> Option { self.lock().unwrap().exit_event(thread_index) } @@ -380,7 +380,7 @@ where self.read().unwrap().queues_per_thread() } - fn exit_event(&self, thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, thread_index: usize) -> Option { self.read().unwrap().exit_event(thread_index) } @@ -475,10 +475,10 @@ pub mod tests { vec![1, 1] } - fn exit_event(&self, _thread_index: usize) -> Option<(EventFd, u16)> { + fn exit_event(&self, _thread_index: usize) -> Option { let event_fd = EventFd::new(0).unwrap(); - Some((event_fd, 0x100)) + Some(event_fd) } fn handle_event( diff --git a/src/event_loop.rs b/src/event_loop.rs index 869df85..6a70158 100644 --- a/src/event_loop.rs +++ b/src/event_loop.rs @@ -70,7 +70,6 @@ where vrings: Vec, thread_id: usize, exit_event_fd: Option, - exit_event_id: Option, phantom: PhantomData, } @@ -86,12 +85,13 @@ where let epoll_file = unsafe { File::from_raw_fd(epoll_fd) }; let handler = match backend.exit_event(thread_id) { - Some((exit_event_fd, exit_event_id)) => { + Some(exit_event_fd) => { + let id = backend.num_queues(); epoll::ctl( epoll_file.as_raw_fd(), epoll::ControlOptions::EPOLL_CTL_ADD, exit_event_fd.as_raw_fd(), - epoll::Event::new(epoll::Events::EPOLLIN, u64::from(exit_event_id)), + epoll::Event::new(epoll::Events::EPOLLIN, id as u64), ) .map_err(VringEpollError::RegisterExitEvent)?; @@ -101,7 +101,6 @@ where vrings, thread_id, exit_event_fd: Some(exit_event_fd), - exit_event_id: Some(exit_event_id), phantom: PhantomData, } } @@ -111,7 +110,6 @@ where vrings, thread_id, exit_event_fd: None, - exit_event_id: None, phantom: PhantomData, }, }; @@ -135,6 +133,38 @@ where fd: RawFd, ev_type: epoll::Events, data: u64, + ) -> result::Result<(), io::Error> { + // `data` range [0...num_queues] is reserved for queues and exit event. + if data <= self.backend.num_queues() as u64 { + Err(io::Error::from_raw_os_error(libc::EINVAL)) + } else { + self.register_event(fd, ev_type, data) + } + } + + /// Unregister an event from the epoll fd. + /// + /// If the event is triggered after this function has been called, the event will be silently + /// dropped. + pub fn unregister_listener( + &self, + fd: RawFd, + ev_type: epoll::Events, + data: u64, + ) -> result::Result<(), io::Error> { + // `data` range [0...num_queues] is reserved for queues and exit event. + if data <= self.backend.num_queues() as u64 { + Err(io::Error::from_raw_os_error(libc::EINVAL)) + } else { + self.unregister_event(fd, ev_type, data) + } + } + + pub(crate) fn register_event( + &self, + fd: RawFd, + ev_type: epoll::Events, + data: u64, ) -> result::Result<(), io::Error> { epoll::ctl( self.epoll_file.as_raw_fd(), @@ -144,11 +174,7 @@ where ) } - /// Unregister an event from the epoll fd. - /// - /// If the event is triggered after this function has been called, the event will be silently - /// dropped. - pub fn unregister_listener( + pub(crate) fn unregister_event( &self, fd: RawFd, ev_type: epoll::Events, @@ -211,7 +237,7 @@ where } fn handle_event(&self, device_event: u16, evset: epoll::Events) -> VringEpollResult { - if self.exit_event_id == Some(device_event) { + if self.exit_event_fd.is_some() && device_event as usize == self.backend.num_queues() { return Ok(true); } @@ -251,21 +277,28 @@ mod tests { let backend = Arc::new(Mutex::new(MockVhostBackend::new())); let handler = VringEpollHandler::new(backend, vec![vring], 0x1).unwrap(); - assert!(handler.exit_event_id.is_some()); let eventfd = EventFd::new(0).unwrap(); handler - .register_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 1) + .register_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 3) .unwrap(); // Register an already registered fd. + handler + .register_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 3) + .unwrap_err(); + // Register an invalid data. handler .register_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 1) .unwrap_err(); handler - .unregister_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 1) + .unregister_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 3) .unwrap(); // unregister an already unregistered fd. + handler + .unregister_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 3) + .unwrap_err(); + // unregister an invalid data. handler .unregister_listener(eventfd.as_raw_fd(), epoll::Events::EPOLLIN, 1) .unwrap_err(); diff --git a/src/handler.rs b/src/handler.rs index dca7f19..672785b 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -350,7 +350,7 @@ where if shifted_queues_mask & 1u64 == 1u64 { let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones(); self.handlers[thread_index] - .unregister_listener( + .unregister_event( fd.as_raw_fd(), epoll::Events::EPOLLIN, u64::from(evt_idx), @@ -389,11 +389,7 @@ where if shifted_queues_mask & 1u64 == 1u64 { let evt_idx = queues_mask.count_ones() - shifted_queues_mask.count_ones(); self.handlers[thread_index] - .register_listener( - fd.as_raw_fd(), - epoll::Events::EPOLLIN, - u64::from(evt_idx), - ) + .register_event(fd.as_raw_fd(), epoll::Events::EPOLLIN, u64::from(evt_idx)) .map_err(VhostUserError::ReqHandlerError)?; break; }