From 0152e88b424b87c601effd0fb6a423b5015fcc03 Mon Sep 17 00:00:00 2001 From: Alyssa Ross Date: Sun, 7 Aug 2022 08:27:15 +0000 Subject: [PATCH] vhost_user: fix UB on invalid master request Since VhostUserMsgHeader implements ByteValued, it is supposed to be safe to construct from any correctly-sized arbitrary byte array. But that means we can do this: let bytes = b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00"; let header = VhostUserMsgHeader::::from_slice(bytes).unwrap(); header.get_code() constructing an invalid MasterReq, using only functions that are marked as safe. Constructing an invalid enum value is undefined behavior in Rust, so this API is unsound. This wasn't considered by the safety comment in VhostUserMsgHeader::get_code, which only considered the safety of requests that were valid enum variants. If the vhost-user frontend process sends a message that the backend doesn't recognise, that's exactly what will happen, so the UB can be triggered from an external process (but a trusted one). To fix this, we need to check whether the value is valid _before_ converting it. Req::is_valid is changed to be a non-instance method, so it can be called before constructing the Req. VhostUserMsgHeader::get_code is changed to return a Result, to accomodate the case where the request number is not a valid value for R. Signed-off-by: Alyssa Ross --- crates/vhost/CHANGELOG.md | 1 + crates/vhost/src/vhost_user/master.rs | 8 +-- .../src/vhost_user/master_req_handler.rs | 14 ++-- crates/vhost/src/vhost_user/message.rs | 61 ++++++++-------- .../vhost/src/vhost_user/slave_req_handler.rs | 70 ++++++++++--------- 5 files changed, 80 insertions(+), 74 deletions(-) diff --git a/crates/vhost/CHANGELOG.md b/crates/vhost/CHANGELOG.md index b53087d..2c636ce 100644 --- a/crates/vhost/CHANGELOG.md +++ b/crates/vhost/CHANGELOG.md @@ -6,6 +6,7 @@ ### Changed ### Fixed +- [[#135]](https://github.com/rust-vmm/vhost/pull/135) vhost_user: fix UB on invalid master request ### Deprecated diff --git a/crates/vhost/src/vhost_user/master.rs b/crates/vhost/src/vhost_user/master.rs index 4ca821d..8170718 100644 --- a/crates/vhost/src/vhost_user/master.rs +++ b/crates/vhost/src/vhost_user/master.rs @@ -797,13 +797,13 @@ mod tests { master.reset_owner().unwrap(); let (hdr, rfds) = slave.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); let (hdr, rfds) = slave.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::RESET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); @@ -831,7 +831,7 @@ mod tests { master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert_eq!(hdr.get_size(), 0); assert_eq!(hdr.get_version(), 0x1); assert!(rfds.is_none()); @@ -866,7 +866,7 @@ mod tests { master.set_owner().unwrap(); let (hdr, rfds) = peer.recv_header().unwrap(); - assert_eq!(hdr.get_code(), MasterReq::SET_OWNER); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER); assert!(rfds.is_none()); assert!(master.get_protocol_features().is_err()); diff --git a/crates/vhost/src/vhost_user/master_req_handler.rs b/crates/vhost/src/vhost_user/master_req_handler.rs index 9873b5b..3c6a489 100644 --- a/crates/vhost/src/vhost_user/master_req_handler.rs +++ b/crates/vhost/src/vhost_user/master_req_handler.rs @@ -216,32 +216,32 @@ impl MasterReqHandler { }; let res = match hdr.get_code() { - SlaveReq::CONFIG_CHANGE_MSG => { + Ok(SlaveReq::CONFIG_CHANGE_MSG) => { self.check_msg_size(&hdr, size, 0)?; self.backend .handle_config_change() .map_err(Error::ReqHandlerError) } - SlaveReq::FS_MAP => { + Ok(SlaveReq::FS_MAP) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; // check_attached_files() has validated files self.backend .fs_slave_map(&msg, &files.unwrap()[0]) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_UNMAP => { + Ok(SlaveReq::FS_UNMAP) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend .fs_slave_unmap(&msg) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_SYNC => { + Ok(SlaveReq::FS_SYNC) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; self.backend .fs_slave_sync(&msg) .map_err(Error::ReqHandlerError) } - SlaveReq::FS_IO => { + Ok(SlaveReq::FS_IO) => { let msg = self.extract_msg_body::(&hdr, size, &buf)?; // check_attached_files() has validated files self.backend @@ -285,7 +285,7 @@ impl MasterReqHandler { files: &Option>, ) -> Result<()> { match hdr.get_code() { - SlaveReq::FS_MAP | SlaveReq::FS_IO => { + Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => { // Expect a single file is passed. match files { Some(files) if files.len() == 1 => Ok(()), @@ -320,7 +320,7 @@ impl MasterReqHandler { } self.check_state()?; Ok(VhostUserMsgHeader::new( - req.get_code(), + req.get_code()?, VhostUserHeaderFlag::REPLY.bits(), mem::size_of::() as u32, )) diff --git a/crates/vhost/src/vhost_user/message.rs b/crates/vhost/src/vhost_user/message.rs index 35e1e30..b2882bc 100644 --- a/crates/vhost/src/vhost_user/message.rs +++ b/crates/vhost/src/vhost_user/message.rs @@ -14,6 +14,7 @@ use std::marker::PhantomData; use vm_memory::ByteValued; +use super::{Error, Result}; use crate::VringConfigData; /// The vhost-user specification uses a field of u32 to store message length. @@ -45,7 +46,7 @@ pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64; pub(super) trait Req: Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into { - fn is_valid(&self) -> bool; + fn is_valid(value: u32) -> bool; } /// Type of requests sending from masters to slaves. @@ -150,8 +151,8 @@ impl From for u32 { } impl Req for MasterReq { - fn is_valid(&self) -> bool { - (*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD) + fn is_valid(value: u32) -> bool { + (value > MasterReq::NOOP as u32) && (value < MasterReq::MAX_CMD as u32) } } @@ -190,8 +191,8 @@ impl From for u32 { } impl Req for SlaveReq { - fn is_valid(&self) -> bool { - (*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD) + fn is_valid(value: u32) -> bool { + (value > SlaveReq::NOOP as u32) && (value < SlaveReq::MAX_CMD as u32) } } @@ -270,9 +271,13 @@ impl VhostUserMsgHeader { } /// Get message type. - pub fn get_code(&self) -> R { - // It's safe because R is marked as repr(u32). - unsafe { std::mem::transmute_copy::(&{ self.request }) } + pub fn get_code(&self) -> Result { + if R::is_valid(self.request) { + // It's safe because R is marked as repr(u32), and the value is valid. + Ok(unsafe { std::mem::transmute_copy::(&{ self.request }) }) + } else { + Err(Error::InvalidMessage) + } } /// Set message type. @@ -321,7 +326,11 @@ impl VhostUserMsgHeader { /// Check whether it's the reply message for the request `req`. pub fn is_reply_for(&self, req: &VhostUserMsgHeader) -> bool { - self.is_reply() && !req.is_reply() && self.get_code() == req.get_code() + if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) { + self.is_reply() && !req.is_reply() && code1 == code2 + } else { + false + } } /// Get message size. @@ -351,7 +360,7 @@ unsafe impl ByteValued for VhostUserMsgHeader {} impl VhostUserMsgValidator for VhostUserMsgHeader { #[allow(clippy::if_same_then_else)] fn is_valid(&self) -> bool { - if !self.get_code().is_valid() { + if self.get_code().is_err() { return false; } else if self.size as usize > MAX_MSG_SIZE { return false; @@ -991,38 +1000,32 @@ mod tests { #[test] fn check_master_request_code() { - let code = MasterReq::NOOP; - assert!(!code.is_valid()); - let code = MasterReq::MAX_CMD; - assert!(!code.is_valid()); - assert!(code > MasterReq::NOOP); + assert!(!MasterReq::is_valid(MasterReq::NOOP as _)); + assert!(!MasterReq::is_valid(MasterReq::MAX_CMD as _)); + assert!(MasterReq::MAX_CMD > MasterReq::NOOP); let code = MasterReq::GET_FEATURES; - assert!(code.is_valid()); + assert!(MasterReq::is_valid(code as _)); assert_eq!(code, code.clone()); - let code: MasterReq = unsafe { std::mem::transmute::(10000u32) }; - assert!(!code.is_valid()); + assert!(!MasterReq::is_valid(10000)); } #[test] fn check_slave_request_code() { - let code = SlaveReq::NOOP; - assert!(!code.is_valid()); - let code = SlaveReq::MAX_CMD; - assert!(!code.is_valid()); - assert!(code > SlaveReq::NOOP); + assert!(!SlaveReq::is_valid(SlaveReq::NOOP as _)); + assert!(!SlaveReq::is_valid(SlaveReq::MAX_CMD as _)); + assert!(SlaveReq::MAX_CMD > SlaveReq::NOOP); let code = SlaveReq::CONFIG_CHANGE_MSG; - assert!(code.is_valid()); + assert!(SlaveReq::is_valid(code as _)); assert_eq!(code, code.clone()); - let code: SlaveReq = unsafe { std::mem::transmute::(10000u32) }; - assert!(!code.is_valid()); + assert!(!SlaveReq::is_valid(10000)); } #[test] fn msg_header_ops() { let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100); - assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES); + assert_eq!(hdr.get_code().unwrap(), MasterReq::GET_FEATURES); hdr.set_code(MasterReq::SET_FEATURES); - assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES); + assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_FEATURES); assert_eq!(hdr.get_version(), 0x1); @@ -1066,7 +1069,7 @@ mod tests { // Test Debug, Clone, PartiaEq trait assert_eq!(hdr, hdr.clone()); - assert_eq!(hdr.clone().get_code(), hdr.get_code()); + assert_eq!(hdr.clone().get_code().unwrap(), hdr.get_code().unwrap()); assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr)); } diff --git a/crates/vhost/src/vhost_user/slave_req_handler.rs b/crates/vhost/src/vhost_user/slave_req_handler.rs index ffde25c..db6b078 100644 --- a/crates/vhost/src/vhost_user/slave_req_handler.rs +++ b/crates/vhost/src/vhost_user/slave_req_handler.rs @@ -340,17 +340,17 @@ impl SlaveReqHandler { }; match hdr.get_code() { - MasterReq::SET_OWNER => { + Ok(MasterReq::SET_OWNER) => { self.check_request_size(&hdr, size, 0)?; let res = self.backend.set_owner(); self.send_ack_message(&hdr, res)?; } - MasterReq::RESET_OWNER => { + Ok(MasterReq::RESET_OWNER) => { self.check_request_size(&hdr, size, 0)?; let res = self.backend.reset_owner(); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_FEATURES => { + Ok(MasterReq::GET_FEATURES) => { self.check_request_size(&hdr, size, 0)?; let features = self.backend.get_features()?; let msg = VhostUserU64::new(features); @@ -358,23 +358,23 @@ impl SlaveReqHandler { self.virtio_features = features; self.update_reply_ack_flag(); } - MasterReq::SET_FEATURES => { + Ok(MasterReq::SET_FEATURES) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_features(msg.value); self.acked_virtio_features = msg.value; self.update_reply_ack_flag(); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_MEM_TABLE => { + Ok(MasterReq::SET_MEM_TABLE) => { let res = self.set_mem_table(&hdr, size, &buf, files); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_NUM => { + Ok(MasterReq::SET_VRING_NUM) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_vring_num(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_ADDR => { + Ok(MasterReq::SET_VRING_ADDR) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { Some(val) => val, @@ -390,35 +390,35 @@ impl SlaveReqHandler { ); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_BASE => { + Ok(MasterReq::SET_VRING_BASE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_vring_base(msg.index, msg.num); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_VRING_BASE => { + Ok(MasterReq::GET_VRING_BASE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let reply = self.backend.get_vring_base(msg.index)?; self.send_reply_message(&hdr, &reply)?; } - MasterReq::SET_VRING_CALL => { + Ok(MasterReq::SET_VRING_CALL) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_call(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_KICK => { + Ok(MasterReq::SET_VRING_KICK) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_kick(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_VRING_ERR => { + Ok(MasterReq::SET_VRING_ERR) => { self.check_request_size(&hdr, size, mem::size_of::())?; let (index, file) = self.handle_vring_fd_request(&buf, files)?; let res = self.backend.set_vring_err(index, file); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_PROTOCOL_FEATURES => { + Ok(MasterReq::GET_PROTOCOL_FEATURES) => { self.check_request_size(&hdr, size, 0)?; let features = self.backend.get_protocol_features()?; let msg = VhostUserU64::new(features.bits()); @@ -426,21 +426,21 @@ impl SlaveReqHandler { self.protocol_features = features; self.update_reply_ack_flag(); } - MasterReq::SET_PROTOCOL_FEATURES => { + Ok(MasterReq::SET_PROTOCOL_FEATURES) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_protocol_features(msg.value); self.acked_protocol_features = msg.value; self.update_reply_ack_flag(); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_QUEUE_NUM => { + Ok(MasterReq::GET_QUEUE_NUM) => { self.check_proto_feature(VhostUserProtocolFeatures::MQ)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_queue_num()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } - MasterReq::SET_VRING_ENABLE => { + Ok(MasterReq::SET_VRING_ENABLE) => { let msg = self.extract_request_body::(&hdr, size, &buf)?; self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; let enable = match msg.num { @@ -452,24 +452,24 @@ impl SlaveReqHandler { let res = self.backend.set_vring_enable(msg.index, enable); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_CONFIG => { + Ok(MasterReq::GET_CONFIG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; self.get_config(&hdr, &buf)?; } - MasterReq::SET_CONFIG => { + Ok(MasterReq::SET_CONFIG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_config(size, &buf); self.send_ack_message(&hdr, res)?; } - MasterReq::SET_SLAVE_REQ_FD => { + Ok(MasterReq::SET_SLAVE_REQ_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; self.check_request_size(&hdr, size, hdr.get_size() as usize)?; let res = self.set_slave_req_fd(files); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_INFLIGHT_FD => { + Ok(MasterReq::GET_INFLIGHT_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; @@ -478,21 +478,21 @@ impl SlaveReqHandler { self.main_sock .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; } - MasterReq::SET_INFLIGHT_FD => { + Ok(MasterReq::SET_INFLIGHT_FD) => { self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; let file = take_single_file(files).ok_or(Error::IncorrectFds)?; let msg = self.extract_request_body::(&hdr, size, &buf)?; let res = self.backend.set_inflight_fd(&msg, file); self.send_ack_message(&hdr, res)?; } - MasterReq::GET_MAX_MEM_SLOTS => { + Ok(MasterReq::GET_MAX_MEM_SLOTS) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; self.check_request_size(&hdr, size, 0)?; let num = self.backend.get_max_mem_slots()?; let msg = VhostUserU64::new(num); self.send_reply_message(&hdr, &msg)?; } - MasterReq::ADD_MEM_REG => { + Ok(MasterReq::ADD_MEM_REG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let mut files = files.ok_or(Error::InvalidParam)?; if files.len() != 1 { @@ -503,7 +503,7 @@ impl SlaveReqHandler { let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); self.send_ack_message(&hdr, res)?; } - MasterReq::REM_MEM_REG => { + Ok(MasterReq::REM_MEM_REG) => { self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; let msg = @@ -683,15 +683,17 @@ impl SlaveReqHandler { files: &Option>, ) -> Result<()> { match hdr.get_code() { - MasterReq::SET_MEM_TABLE - | MasterReq::SET_VRING_CALL - | MasterReq::SET_VRING_KICK - | MasterReq::SET_VRING_ERR - | MasterReq::SET_LOG_BASE - | MasterReq::SET_LOG_FD - | MasterReq::SET_SLAVE_REQ_FD - | MasterReq::SET_INFLIGHT_FD - | MasterReq::ADD_MEM_REG => Ok(()), + Ok( + MasterReq::SET_MEM_TABLE + | MasterReq::SET_VRING_CALL + | MasterReq::SET_VRING_KICK + | MasterReq::SET_VRING_ERR + | MasterReq::SET_LOG_BASE + | MasterReq::SET_LOG_FD + | MasterReq::SET_SLAVE_REQ_FD + | MasterReq::SET_INFLIGHT_FD + | MasterReq::ADD_MEM_REG, + ) => Ok(()), _ if files.is_some() => Err(Error::InvalidMessage), _ => Ok(()), } @@ -737,7 +739,7 @@ impl SlaveReqHandler { } self.check_state()?; Ok(VhostUserMsgHeader::new( - req.get_code(), + req.get_code()?, VhostUserHeaderFlag::REPLY.bits(), (mem::size_of::() + payload_size) as u32, ))