vhost_user: fix UB on invalid master request

Since VhostUserMsgHeader implements ByteValued, it is supposed to be
safe to construct from any correctly-sized arbitrary byte array.
But that means we can do this:

	let bytes = b"\xFF\xFF\xFF\xFF\x00\x00\x00\x00\x00\x00\x00\x00";
	let header = VhostUserMsgHeader::<MasterReq>::from_slice(bytes).unwrap();
	header.get_code()

constructing an invalid MasterReq, using only functions that are
marked as safe.  Constructing an invalid enum value is undefined
behavior in Rust, so this API is unsound.  This wasn't considered by
the safety comment in VhostUserMsgHeader::get_code, which only
considered the safety of requests that were valid enum variants.

If the vhost-user frontend process sends a message that the backend
doesn't recognise, that's exactly what will happen, so the UB can be
triggered from an external process (but a trusted one).

To fix this, we need to check whether the value is valid _before_
converting it.  Req::is_valid is changed to be a non-instance method,
so it can be called before constructing the Req.
VhostUserMsgHeader::get_code is changed to return a Result, to
accomodate the case where the request number is not a valid value for
R.

Signed-off-by: Alyssa Ross <alyssa.ross@unikie.com>
This commit is contained in:
Alyssa Ross 2022-08-07 08:27:15 +00:00 committed by Stefano Garzarella
parent 7a874476e8
commit 0152e88b42
5 changed files with 80 additions and 74 deletions

View file

@ -6,6 +6,7 @@
### Changed
### Fixed
- [[#135]](https://github.com/rust-vmm/vhost/pull/135) vhost_user: fix UB on invalid master request
### Deprecated

View file

@ -797,13 +797,13 @@ mod tests {
master.reset_owner().unwrap();
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER);
assert_eq!(hdr.get_code().unwrap(), MasterReq::RESET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
@ -831,7 +831,7 @@ mod tests {
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
@ -866,7 +866,7 @@ mod tests {
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_OWNER);
assert!(rfds.is_none());
assert!(master.get_protocol_features().is_err());

View file

@ -216,32 +216,32 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
};
let res = match hdr.get_code() {
SlaveReq::CONFIG_CHANGE_MSG => {
Ok(SlaveReq::CONFIG_CHANGE_MSG) => {
self.check_msg_size(&hdr, size, 0)?;
self.backend
.handle_config_change()
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_MAP => {
Ok(SlaveReq::FS_MAP) => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
// check_attached_files() has validated files
self.backend
.fs_slave_map(&msg, &files.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
Ok(SlaveReq::FS_UNMAP) => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
.fs_slave_unmap(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_SYNC => {
Ok(SlaveReq::FS_SYNC) => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
.fs_slave_sync(&msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_IO => {
Ok(SlaveReq::FS_IO) => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
// check_attached_files() has validated files
self.backend
@ -285,7 +285,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
files: &Option<Vec<File>>,
) -> Result<()> {
match hdr.get_code() {
SlaveReq::FS_MAP | SlaveReq::FS_IO => {
Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => {
// Expect a single file is passed.
match files {
Some(files) if files.len() == 1 => Ok(()),
@ -320,7 +320,7 @@ impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code(),
req.get_code()?,
VhostUserHeaderFlag::REPLY.bits(),
mem::size_of::<T>() as u32,
))

View file

@ -14,6 +14,7 @@ use std::marker::PhantomData;
use vm_memory::ByteValued;
use super::{Error, Result};
use crate::VringConfigData;
/// The vhost-user specification uses a field of u32 to store message length.
@ -45,7 +46,7 @@ pub const VHOST_USER_MAX_VRINGS: u64 = 0x8000u64;
pub(super) trait Req:
Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Send + Sync + Into<u32>
{
fn is_valid(&self) -> bool;
fn is_valid(value: u32) -> bool;
}
/// Type of requests sending from masters to slaves.
@ -150,8 +151,8 @@ impl From<MasterReq> for u32 {
}
impl Req for MasterReq {
fn is_valid(&self) -> bool {
(*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD)
fn is_valid(value: u32) -> bool {
(value > MasterReq::NOOP as u32) && (value < MasterReq::MAX_CMD as u32)
}
}
@ -190,8 +191,8 @@ impl From<SlaveReq> for u32 {
}
impl Req for SlaveReq {
fn is_valid(&self) -> bool {
(*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD)
fn is_valid(value: u32) -> bool {
(value > SlaveReq::NOOP as u32) && (value < SlaveReq::MAX_CMD as u32)
}
}
@ -270,9 +271,13 @@ impl<R: Req> VhostUserMsgHeader<R> {
}
/// Get message type.
pub fn get_code(&self) -> R {
// It's safe because R is marked as repr(u32).
unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) }
pub fn get_code(&self) -> Result<R> {
if R::is_valid(self.request) {
// It's safe because R is marked as repr(u32), and the value is valid.
Ok(unsafe { std::mem::transmute_copy::<u32, R>(&{ self.request }) })
} else {
Err(Error::InvalidMessage)
}
}
/// Set message type.
@ -321,7 +326,11 @@ impl<R: Req> VhostUserMsgHeader<R> {
/// Check whether it's the reply message for the request `req`.
pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool {
self.is_reply() && !req.is_reply() && self.get_code() == req.get_code()
if let (Ok(code1), Ok(code2)) = (self.get_code(), req.get_code()) {
self.is_reply() && !req.is_reply() && code1 == code2
} else {
false
}
}
/// Get message size.
@ -351,7 +360,7 @@ unsafe impl<R: Req> ByteValued for VhostUserMsgHeader<R> {}
impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if !self.get_code().is_valid() {
if self.get_code().is_err() {
return false;
} else if self.size as usize > MAX_MSG_SIZE {
return false;
@ -991,38 +1000,32 @@ mod tests {
#[test]
fn check_master_request_code() {
let code = MasterReq::NOOP;
assert!(!code.is_valid());
let code = MasterReq::MAX_CMD;
assert!(!code.is_valid());
assert!(code > MasterReq::NOOP);
assert!(!MasterReq::is_valid(MasterReq::NOOP as _));
assert!(!MasterReq::is_valid(MasterReq::MAX_CMD as _));
assert!(MasterReq::MAX_CMD > MasterReq::NOOP);
let code = MasterReq::GET_FEATURES;
assert!(code.is_valid());
assert!(MasterReq::is_valid(code as _));
assert_eq!(code, code.clone());
let code: MasterReq = unsafe { std::mem::transmute::<u32, MasterReq>(10000u32) };
assert!(!code.is_valid());
assert!(!MasterReq::is_valid(10000));
}
#[test]
fn check_slave_request_code() {
let code = SlaveReq::NOOP;
assert!(!code.is_valid());
let code = SlaveReq::MAX_CMD;
assert!(!code.is_valid());
assert!(code > SlaveReq::NOOP);
assert!(!SlaveReq::is_valid(SlaveReq::NOOP as _));
assert!(!SlaveReq::is_valid(SlaveReq::MAX_CMD as _));
assert!(SlaveReq::MAX_CMD > SlaveReq::NOOP);
let code = SlaveReq::CONFIG_CHANGE_MSG;
assert!(code.is_valid());
assert!(SlaveReq::is_valid(code as _));
assert_eq!(code, code.clone());
let code: SlaveReq = unsafe { std::mem::transmute::<u32, SlaveReq>(10000u32) };
assert!(!code.is_valid());
assert!(!SlaveReq::is_valid(10000));
}
#[test]
fn msg_header_ops() {
let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100);
assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES);
assert_eq!(hdr.get_code().unwrap(), MasterReq::GET_FEATURES);
hdr.set_code(MasterReq::SET_FEATURES);
assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES);
assert_eq!(hdr.get_code().unwrap(), MasterReq::SET_FEATURES);
assert_eq!(hdr.get_version(), 0x1);
@ -1066,7 +1069,7 @@ mod tests {
// Test Debug, Clone, PartiaEq trait
assert_eq!(hdr, hdr.clone());
assert_eq!(hdr.clone().get_code(), hdr.get_code());
assert_eq!(hdr.clone().get_code().unwrap(), hdr.get_code().unwrap());
assert_eq!(format!("{:?}", hdr.clone()), format!("{:?}", hdr));
}

View file

@ -340,17 +340,17 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
};
match hdr.get_code() {
MasterReq::SET_OWNER => {
Ok(MasterReq::SET_OWNER) => {
self.check_request_size(&hdr, size, 0)?;
let res = self.backend.set_owner();
self.send_ack_message(&hdr, res)?;
}
MasterReq::RESET_OWNER => {
Ok(MasterReq::RESET_OWNER) => {
self.check_request_size(&hdr, size, 0)?;
let res = self.backend.reset_owner();
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_FEATURES => {
Ok(MasterReq::GET_FEATURES) => {
self.check_request_size(&hdr, size, 0)?;
let features = self.backend.get_features()?;
let msg = VhostUserU64::new(features);
@ -358,23 +358,23 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.virtio_features = features;
self.update_reply_ack_flag();
}
MasterReq::SET_FEATURES => {
Ok(MasterReq::SET_FEATURES) => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
let res = self.backend.set_features(msg.value);
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_MEM_TABLE => {
Ok(MasterReq::SET_MEM_TABLE) => {
let res = self.set_mem_table(&hdr, size, &buf, files);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_NUM => {
Ok(MasterReq::SET_VRING_NUM) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self.backend.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ADDR => {
Ok(MasterReq::SET_VRING_ADDR) => {
let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
Some(val) => val,
@ -390,35 +390,35 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_BASE => {
Ok(MasterReq::SET_VRING_BASE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self.backend.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_VRING_BASE => {
Ok(MasterReq::GET_VRING_BASE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let reply = self.backend.get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
MasterReq::SET_VRING_CALL => {
Ok(MasterReq::SET_VRING_CALL) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_call(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
Ok(MasterReq::SET_VRING_KICK) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_kick(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
Ok(MasterReq::SET_VRING_ERR) => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, file) = self.handle_vring_fd_request(&buf, files)?;
let res = self.backend.set_vring_err(index, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
Ok(MasterReq::GET_PROTOCOL_FEATURES) => {
self.check_request_size(&hdr, size, 0)?;
let features = self.backend.get_protocol_features()?;
let msg = VhostUserU64::new(features.bits());
@ -426,21 +426,21 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.protocol_features = features;
self.update_reply_ack_flag();
}
MasterReq::SET_PROTOCOL_FEATURES => {
Ok(MasterReq::SET_PROTOCOL_FEATURES) => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
let res = self.backend.set_protocol_features(msg.value);
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_QUEUE_NUM => {
Ok(MasterReq::GET_QUEUE_NUM) => {
self.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
self.check_request_size(&hdr, size, 0)?;
let num = self.backend.get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
MasterReq::SET_VRING_ENABLE => {
Ok(MasterReq::SET_VRING_ENABLE) => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
let enable = match msg.num {
@ -452,24 +452,24 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
let res = self.backend.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_CONFIG => {
Ok(MasterReq::GET_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.get_config(&hdr, &buf)?;
}
MasterReq::SET_CONFIG => {
Ok(MasterReq::SET_CONFIG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
let res = self.set_config(size, &buf);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_SLAVE_REQ_FD => {
Ok(MasterReq::SET_SLAVE_REQ_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
let res = self.set_slave_req_fd(files);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_INFLIGHT_FD => {
Ok(MasterReq::GET_INFLIGHT_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
@ -478,21 +478,21 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
self.main_sock
.send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
}
MasterReq::SET_INFLIGHT_FD => {
Ok(MasterReq::SET_INFLIGHT_FD) => {
self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
let res = self.backend.set_inflight_fd(&msg, file);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_MAX_MEM_SLOTS => {
Ok(MasterReq::GET_MAX_MEM_SLOTS) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
self.check_request_size(&hdr, size, 0)?;
let num = self.backend.get_max_mem_slots()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
MasterReq::ADD_MEM_REG => {
Ok(MasterReq::ADD_MEM_REG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
let mut files = files.ok_or(Error::InvalidParam)?;
if files.len() != 1 {
@ -503,7 +503,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
self.send_ack_message(&hdr, res)?;
}
MasterReq::REM_MEM_REG => {
Ok(MasterReq::REM_MEM_REG) => {
self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
let msg =
@ -683,15 +683,17 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
files: &Option<Vec<File>>,
) -> Result<()> {
match hdr.get_code() {
MasterReq::SET_MEM_TABLE
| MasterReq::SET_VRING_CALL
| MasterReq::SET_VRING_KICK
| MasterReq::SET_VRING_ERR
| MasterReq::SET_LOG_BASE
| MasterReq::SET_LOG_FD
| MasterReq::SET_SLAVE_REQ_FD
| MasterReq::SET_INFLIGHT_FD
| MasterReq::ADD_MEM_REG => Ok(()),
Ok(
MasterReq::SET_MEM_TABLE
| MasterReq::SET_VRING_CALL
| MasterReq::SET_VRING_KICK
| MasterReq::SET_VRING_ERR
| MasterReq::SET_LOG_BASE
| MasterReq::SET_LOG_FD
| MasterReq::SET_SLAVE_REQ_FD
| MasterReq::SET_INFLIGHT_FD
| MasterReq::ADD_MEM_REG,
) => Ok(()),
_ if files.is_some() => Err(Error::InvalidMessage),
_ => Ok(()),
}
@ -737,7 +739,7 @@ impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code(),
req.get_code()?,
VhostUserHeaderFlag::REPLY.bits(),
(mem::size_of::<T>() + payload_size) as u32,
))