diff --git a/crates/vhost/CHANGELOG.md b/crates/vhost/CHANGELOG.md index b53087d..2c636ce 100644 --- a/crates/vhost/CHANGELOG.md +++ b/crates/vhost/CHANGELOG.md @@ -6,6 +6,7 @@ ### Changed ### Fixed +- [[#135]](https://github.com/rust-vmm/vhost/pull/135) vhost_user: fix UB on invalid master request ### Deprecated diff --git a/crates/vhost/src/vhost_user/master.rs b/crates/vhost/src/vhost_user/master.rs index 4ca821d..8170718 100644 --- a/crates/vhost/src/vhost_user/master.rs +++ b/crates/vhost/src/vhost_user/master.rs @@ -797,13 +797,13 @@ mod tests { master.reset_owner().unwrap(); let (hdr, rfds) = slave.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); let (hdr, rfds) = slave.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::RESET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); @@ -831,7 +831,7 @@ mod tests { master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); @@ -866,7 +866,7 @@ mod tests { master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert!(rfds.is_none()); assert!(master.get_protocol_features().is_err()); diff --git a/crates/vhost/src/vhost_user/master_req_handler.rs b/crates/vhost/src/vhost_user/master_req_handler.rs index 9873b5b..3c6a489 100644 --- a/crates/vhost/src/vhost_user/master_req_handler.rs +++ b/crates/vhost/src/vhost_user/master_req_handler.rs @@ -216,32 +216,32 @@ impl MasterReqHandler { }; let res = match hdr.get_code() { - SlaveReq::CONFIG_CHANGE_MSG => { + Ok(SlaveReq::CONFIG_CHANGE_MSG) => { self.check_msg_size(&hdr, size, 0)?; self.backend .handle_config_change() .map_err(Error::ReqHandlerError) } - SlaveReq::FS_MAP => { + Ok(SlaveReq::FS_MAP) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; // check_attached_files() has validated files self.backend .fs_slave_map(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_UNMAP => { + Ok(SlaveReq::FS_UNMAP) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend .fs_slave_unmap(&msg) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_SYNC => { + Ok(SlaveReq::FS_SYNC) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend .fs_slave_sync(&msg) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_IO => { + Ok(SlaveReq::FS_IO) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; // check_attached_files() has validated files self.backend @@ -285,7 +285,7 @@ impl MasterReqHandler { files: &Option>, ) -> Result<()> { match hdr.get_code() { - SlaveReq::FS_MAP | SlaveReq::FS_IO => { + Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => { // Expect a single file is passed. match files { Some(files) if files.len() == 1 => Ok(()), @@ -320,7 +320,7 @@ impl MasterReqHandler { } self.check_state()?; Ok(VhostUserMsgHeader::new( - req.get_code(), + req.get_code()?, VhostUserHeaderFlag::REPLY.bits(), mem::size_of::() as u32, )) diff --git a/crates/vhost/src/vhost_user/message.rs b/crates/vhost/src/vhost_user/message.rs index 35e1e30..b2882bc 100644 --- a/crates/vhost/src/vhost_user/message.rs +++ b/crates/vhost/src/vhost_user/message.rs @@ -14,6 +14,7 @@ use std::marker::PhantomData; use vm_memory::ByteValued; +use super::{Error, Result}; use crate::VringConfigData; /// The vhost-user specification uses a field of u32 to store message length. @@ -45,7 +46,7 @@ pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64; pub(super) trait Req: Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into { - fn is_valid(&self) -> bool; + fn is_valid(value: u32) -> bool; } /// Type of requests sending from masters to slaves. @@ -150,8 +151,8 @@ impl From for u32 { } impl Req for MasterReq { - fn is_valid(&self) -> bool { - (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + fn is_valid(value: u32) -> bool { + (value > MasterReq::NOOP as u32) && (value < MasterReq::MAX_CMD as u32) } } @@ -190,8 +191,8 @@ impl From for u32 { } impl Req for SlaveReq { - fn is_valid(&self) -> bool { - (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD) + fn is_valid(value: u32) -> bool { + (value > SlaveReq::NOOP as u32) && (value < SlaveReq::MAX_CMD as u32) } } @@ -270,9 +271,13 @@ impl VhostUserMsgHeader { } /// Get message type. - pub fn get_code(&self) -> R { - // It's safe because R is marked as repr(u32). - unsafe { std::mem::transmute_copy::(&{ self.request }) } + pub fn get_code(&self) -> Result { + if R::is_valid(self.request) { + // It's safe because R is marked as repr(u32), and the value is valid. + Ok(unsafe { std::mem::transmute_copy::(&{ self.request }) }) + } else { + Err(Error::InvalidMessage) + } } /// Set message type. @@ -321,7 +326,11 @@ impl VhostUserMsgHeader { /// Check whether it's the reply message for the request `req`. pub fn is_reply_for(&self, req: &VhostUserMsgHeader) -> bool { - self.is_reply() && !req.is_reply() && self.get_code() == req.get_code() + if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) { + self.is_reply() && !req.is_reply() && code1 == code2 + } else { + false + } } /// Get message size. @@ -351,7 +360,7 @@ unsafe impl ByteValued for VhostUserMsgHeader {} impl VhostUserMsgValidator for VhostUserMsgHeader { #[allow(clippy::if_same_then_else)] fn is_valid(&self) -> bool { - if !self.get_code().is_valid() { + if self.get_code().is_err() { return false; } else if self.size as usize > MAX_MSG_SIZE { return false; @@ -991,38 +1000,32 @@ mod tests { #[test] fn check_master_request_code() { - let code = MasterReq::NOOP; - assert!(!code.is_valid()); - let code = MasterReq::MAX_CMD; - assert!(!code.is_valid()); - assert!(code > MasterReq::NOOP); + assert!(!MasterReq::is_valid(MasterReq::NOOP as _)); + assert!(!MasterReq::is_valid(MasterReq::MAX_CMD as _)); + assert!(MasterReq::MAX_CMD > MasterReq::NOOP); let code = MasterReq::GET_FEATURES; - assert!(code.is_valid()); + assert!(MasterReq::is_valid(code as _)); assert_eq!(code, code.clone()); - let code: MasterReq = unsafe { std::mem::transmute::(10000u32) }; - assert!(!code.is_valid()); + assert!(!MasterReq::is_valid(10000)); } #[test] fn check_slave_request_code() { - let code = SlaveReq::NOOP; - assert!(!code.is_valid()); - let code = SlaveReq::MAX_CMD; - assert!(!code.is_valid()); - assert!(code > SlaveReq::NOOP); + assert!(!SlaveReq::is_valid(SlaveReq::NOOP as _)); + assert!(!SlaveReq::is_valid(SlaveReq::MAX_CMD as _)); + assert!(SlaveReq::MAX_CMD > SlaveReq::NOOP); let code = SlaveReq::CONFIG_CHANGE_MSG; - assert!(code.is_valid()); + assert!(SlaveReq::is_valid(code as _)); assert_eq!(code, code.clone()); - let code: SlaveReq = unsafe { std::mem::transmute::(10000u32) }; - assert!(!code.is_valid()); + assert!(!SlaveReq::is_valid(10000)); } #[test] fn msg_header_ops() { let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); - assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + assert_eq!(hdr.get_code().unwrap(), MasterReq::GET_FEATURES); hdr.set_code(MasterReq::SET_FEATURES); - assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_FEATURES); assert_eq!(hdr.get_version(), 0x1); @@ -1066,7 +1069,7 @@ mod tests { // Test Debug, Clone, PartiaEq trait assert_eq!(hdr, hdr.clone()); - assert_eq!(hdr.clone().get_code(), hdr.get_code()); + assert_eq!(hdr.clone().get_code().unwrap(), hdr.get_code().unwrap()); assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr)); } diff --git a/crates/vhost/src/vhost_user/slave_req_handler.rs b/crates/vhost/src/vhost_user/slave_req_handler.rs index ffde25c..db6b078 100644 --- a/crates/vhost/src/vhost_user/slave_req_handler.rs +++ b/crates/vhost/src/vhost_user/slave_req_handler.rs @@ -340,17 +340,17 @@ impl SlaveReqHandler { }; match hdr.get_code() { - MasterReq::SET_OWNER => { + Ok(MasterReq::SET_OWNER) => { self.check_request_size(&hdr, size, 0)?; let res = self.backend.set_owner(); self.send_ack_message(&hdr, res)?; } - MasterReq::RESET_OWNER => { + Ok(MasterReq::RESET_OWNER) => { self.check_request_size(&hdr, size, 0)?; let res = self.backend.reset_owner(); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_FEATURES => { + Ok(MasterReq::GET_FEATURES) => { self.check_request_size(&hdr, size, 0)?; let features = self.backend.get_features()?; let msg = VhostUserU64::new(features); @@ -358,23 +358,23 @@ impl SlaveReqHandler { self.virtio_features = features; self.update_reply_ack_flag(); } - MasterReq::SET_FEATURES => { + Ok(MasterReq::SET_FEATURES) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_features(msg.value); self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_MEM_TABLE => { + Ok(MasterReq::SET_MEM_TABLE) => { let res = self.set_mem_table(&hdr, size, &buf, files); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_NUM => { + Ok(MasterReq::SET_VRING_NUM) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_vring_num(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_ADDR => { + Ok(MasterReq::SET_VRING_ADDR) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { Some(val) => val, @@ -390,35 +390,35 @@ impl SlaveReqHandler { ); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_BASE => { + Ok(MasterReq::SET_VRING_BASE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_vring_base(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_VRING_BASE => { + Ok(MasterReq::GET_VRING_BASE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let reply = self.backend.get_vring_base(msg.index)?; self.send_reply_message(&hdr, &reply)?; } - MasterReq::SET_VRING_CALL => { + Ok(MasterReq::SET_VRING_CALL) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_call(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_KICK => { + Ok(MasterReq::SET_VRING_KICK) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_kick(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_ERR => { + Ok(MasterReq::SET_VRING_ERR) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_err(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_PROTOCOL_FEATURES => { + Ok(MasterReq::GET_PROTOCOL_FEATURES) => { self.check_request_size(&hdr, size, 0)?; let features = self.backend.get_protocol_features()?; let msg = VhostUserU64::new(features.bits()); @@ -426,21 +426,21 @@ impl SlaveReqHandler { self.protocol_features = features; self.update_reply_ack_flag(); } - MasterReq::SET_PROTOCOL_FEATURES => { + Ok(MasterReq::SET_PROTOCOL_FEATURES) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_protocol_features(msg.value); self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_QUEUE_NUM => { + Ok(MasterReq::GET_QUEUE_NUM) => { self.check_proto_feature(VhostUserProtocolFeatures::MQ)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_queue_num()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } - MasterReq::SET_VRING_ENABLE => { + Ok(MasterReq::SET_VRING_ENABLE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; let enable = match msg.num { @@ -452,24 +452,24 @@ impl SlaveReqHandler { let res = self.backend.set_vring_enable(msg.index, enable); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_CONFIG => { + Ok(MasterReq::GET_CONFIG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } - MasterReq::SET_CONFIG => { + Ok(MasterReq::SET_CONFIG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_config(size, &buf); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_SLAVE_REQ_FD => { + Ok(MasterReq::SET_SLAVE_REQ_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_slave_req_fd(files); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_INFLIGHT_FD => { + Ok(MasterReq::GET_INFLIGHT_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; @@ -478,21 +478,21 @@ impl SlaveReqHandler { self.main_sock .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; } - MasterReq::SET_INFLIGHT_FD => { + Ok(MasterReq::SET_INFLIGHT_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let file = take_single_file(files).ok_or(Error::IncorrectFds)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_inflight_fd(&msg, file); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_MAX_MEM_SLOTS => { + Ok(MasterReq::GET_MAX_MEM_SLOTS) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_max_mem_slots()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } - MasterReq::ADD_MEM_REG => { + Ok(MasterReq::ADD_MEM_REG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let mut files = files.ok_or(Error::InvalidParam)?; if files.len() != 1 { @@ -503,7 +503,7 @@ impl SlaveReqHandler { let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); self.send_ack_message(&hdr, res)?; } - MasterReq::REM_MEM_REG => { + Ok(MasterReq::REM_MEM_REG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let msg = @@ -683,15 +683,17 @@ impl SlaveReqHandler { files: &Option>, ) -> Result<()> { match hdr.get_code() { - MasterReq::SET_MEM_TABLE - | MasterReq::SET_VRING_CALL - | MasterReq::SET_VRING_KICK - | MasterReq::SET_VRING_ERR - | MasterReq::SET_LOG_BASE - | MasterReq::SET_LOG_FD - | MasterReq::SET_SLAVE_REQ_FD - | MasterReq::SET_INFLIGHT_FD - | MasterReq::ADD_MEM_REG => Ok(()), + Ok( + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG, + ) => Ok(()), _ if files.is_some() => Err(Error::InvalidMessage), _ => Ok(()), } @@ -737,7 +739,7 @@ impl SlaveReqHandler { } self.check_state()?; Ok(VhostUserMsgHeader::new( - req.get_code(), + req.get_code()?, VhostUserHeaderFlag::REPLY.bits(), (mem::size_of::() + payload_size) as u32, ))