From 102d14e3b54dfe573224409c4537aac52cdc0689 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Benn=C3=A9e?= Date: Thu, 14 Jul 2022 16:42:32 +0100 Subject: [PATCH] vhost_user: re-factor feature/proto bit checking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We can wrap up the feature checking into helpers to reduce the amount of boilerplate code while enabling us to use a more idiomatic ?; exit path on error. Signed-off-by: Alex Bennée --- src/vhost_user/dummy_slave.rs | 36 +++++++++++---- src/vhost_user/master.rs | 65 ++++++++++---------------- src/vhost_user/slave_req_handler.rs | 71 +++++++++++------------------ 3 files changed, 78 insertions(+), 94 deletions(-) diff --git a/src/vhost_user/dummy_slave.rs b/src/vhost_user/dummy_slave.rs index 222a5bb..783c5de 100644 --- a/src/vhost_user/dummy_slave.rs +++ b/src/vhost_user/dummy_slave.rs @@ -35,6 +35,24 @@ impl DummySlaveReqHandler { ..Default::default() } } + + /// Helper to check if VirtioFeature enabled + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { + if self.acked_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InvalidOperation) + } + } + + /// Helper to check is VhostUserProtocolFeatures enabled + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InvalidOperation) + } + } } impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { @@ -190,9 +208,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> { // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES // has been negotiated. - if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 { - return Err(Error::InvalidOperation); - } else if index as usize >= self.queue_num || index as usize > self.queue_num { + self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; + + if index as usize >= self.queue_num || index as usize > self.queue_num { return Err(Error::InvalidParam); } @@ -210,9 +228,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { size: u32, _flags: VhostUserConfigFlags, ) -> Result> { - if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return Err(Error::InvalidOperation); - } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { @@ -225,9 +243,9 @@ impl VhostUserSlaveReqHandlerMut for DummySlaveReqHandler { fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> { let size = buf.len() as u32; - if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return Err(Error::InvalidOperation); - } else if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; + + if !(VHOST_USER_CONFIG_OFFSET..VHOST_USER_CONFIG_SIZE).contains(&offset) || size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET || size + offset > VHOST_USER_CONFIG_SIZE { diff --git a/src/vhost_user/master.rs b/src/vhost_user/master.rs index 711b673..d883bb0 100644 --- a/src/vhost_user/master.rs +++ b/src/vhost_user/master.rs @@ -339,10 +339,7 @@ impl VhostBackend for Master { impl VhostUserMaster for Master { fn get_protocol_features(&mut self) -> Result { let mut node = self.node(); - let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); - if node.virtio_features & flag == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?; let val = node.recv_reply::(&hdr)?; node.protocol_features = val.value; @@ -356,10 +353,7 @@ impl VhostUserMaster for Master { fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> { let mut node = self.node(); - let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); - if node.virtio_features & flag == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; let val = VhostUserU64::new(features.bits()); let hdr = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?; // Don't wait for ACK here because the protocol feature negotiation process hasn't been @@ -371,9 +365,7 @@ impl VhostUserMaster for Master { fn get_queue_num(&mut self) -> Result { let mut node = self.node(); - if !node.is_feature_mq_available() { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::MQ)?; let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?; let val = node.recv_reply::(&hdr)?; @@ -413,9 +405,7 @@ impl VhostUserMaster for Master { let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG - if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; // vhost-user spec states that: // "Master payload: virtio device config space" @@ -448,9 +438,7 @@ impl VhostUserMaster for Master { let mut node = self.node(); // depends on VhostUserProtocolFeatures::CONFIG - if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?; node.wait_for_ack(&hdr).map_err(|e| e.into()) @@ -458,9 +446,7 @@ impl VhostUserMaster for Master { fn set_slave_request_fd(&mut self, fd: &dyn AsRawFd) -> Result<()> { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; let fds = [fd.as_raw_fd()]; let hdr = node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?; node.wait_for_ack(&hdr).map_err(|e| e.into()) @@ -471,9 +457,7 @@ impl VhostUserMaster for Master { inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)> { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let hdr = node.send_request_with_body(MasterReq::GET_INFLIGHT_FD, inflight, None)?; let (inflight, files) = node.recv_reply_with_files::(&hdr)?; @@ -486,9 +470,7 @@ impl VhostUserMaster for Master { fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, fd: RawFd) -> Result<()> { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; if inflight.mmap_size == 0 || inflight.num_queues == 0 || inflight.queue_size == 0 || fd < 0 { @@ -501,10 +483,7 @@ impl VhostUserMaster for Master { fn get_max_mem_slots(&mut self) -> Result { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 - { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let hdr = node.send_request_header(MasterReq::GET_MAX_MEM_SLOTS, None)?; let val = node.recv_reply::(&hdr)?; @@ -514,10 +493,7 @@ impl VhostUserMaster for Master { fn add_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 - { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; if region.memory_size == 0 || region.mmap_handle < 0 { return error_code(VhostUserError::InvalidParam); } @@ -535,10 +511,7 @@ impl VhostUserMaster for Master { fn remove_mem_region(&mut self, region: &VhostUserMemoryRegionInfo) -> Result<()> { let mut node = self.node(); - if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0 - { - return error_code(VhostUserError::InvalidOperation); - } + node.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; if region.memory_size == 0 { return error_code(VhostUserError::InvalidParam); } @@ -753,8 +726,20 @@ impl MasterInternal { Ok(()) } - fn is_feature_mq_available(&self) -> bool { - self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> VhostUserResult<()> { + if self.virtio_features & feat.bits() != 0 { + Ok(()) + } else { + Err(VhostUserError::InvalidOperation) + } + } + + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> VhostUserResult<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(VhostUserError::InvalidOperation) + } } fn check_state(&self) -> VhostUserResult<()> { diff --git a/src/vhost_user/slave_req_handler.rs b/src/vhost_user/slave_req_handler.rs index f43159a..518af43 100644 --- a/src/vhost_user/slave_req_handler.rs +++ b/src/vhost_user/slave_req_handler.rs @@ -269,6 +269,22 @@ impl SlaveReqHandler { } } + fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { + if self.acked_virtio_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InvalidOperation) + } + } + + fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> { + if self.acked_protocol_features & feat.bits() != 0 { + Ok(()) + } else { + Err(Error::InvalidOperation) + } + } + /// Create a vhost-user slave endpoint from a connected socket. pub fn from_stream(socket: UnixStream, backend: Arc) -> Self { Self::new(Endpoint::from_stream(socket), backend) @@ -416,9 +432,7 @@ impl SlaveReqHandler { self.send_ack_message(&hdr, res)?; } MasterReq::GET_QUEUE_NUM => { - if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::MQ)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_queue_num()?; let msg = VhostUserU64::new(num); @@ -426,11 +440,7 @@ impl SlaveReqHandler { } MasterReq::SET_VRING_ENABLE => { let msg = self.extract_request_body::(&hdr, size, &buf)?; - if self.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; let enable = match msg.num { 1 => true, 0 => false, @@ -441,34 +451,24 @@ impl SlaveReqHandler { self.send_ack_message(&hdr, res)?; } MasterReq::GET_CONFIG => { - if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } MasterReq::SET_CONFIG => { - if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_config(size, &buf); self.send_ack_message(&hdr, res)?; } MasterReq::SET_SLAVE_REQ_FD => { - if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_slave_req_fd(files); self.send_ack_message(&hdr, res)?; } MasterReq::GET_INFLIGHT_FD => { - if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; let (inflight, file) = self.backend.get_inflight_fd(&msg)?; @@ -477,35 +477,21 @@ impl SlaveReqHandler { .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; } MasterReq::SET_INFLIGHT_FD => { - if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let file = take_single_file(files).ok_or(Error::IncorrectFds)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_inflight_fd(&msg, file); self.send_ack_message(&hdr, res)?; } MasterReq::GET_MAX_MEM_SLOTS => { - if self.acked_protocol_features - & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_max_mem_slots()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } MasterReq::ADD_MEM_REG => { - if self.acked_protocol_features - & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let mut files = files.ok_or(Error::InvalidParam)?; if files.len() != 1 { return Err(Error::InvalidParam); @@ -516,12 +502,7 @@ impl SlaveReqHandler { self.send_ack_message(&hdr, res)?; } MasterReq::REM_MEM_REG => { - if self.acked_protocol_features - & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() - == 0 - { - return Err(Error::InvalidOperation); - } + self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let msg = self.extract_request_body::(&hdr, size, &buf)?;