misc: Check that get_slice() returned a big enough slice

This should be guaranteed by GuestMemory and GuestMemoryRegion, but
those traits are currently safe, so add checks to guard against
incorrect implementations of them.

Signed-off-by: Demi Marie Obenour <demiobenour@gmail.com>
This commit is contained in:
Demi Marie Obenour 2025-06-27 22:10:37 -04:00 committed by Rob Bradford
parent 969a3b57a3
commit 2be304b392
5 changed files with 26 additions and 23 deletions

View file

@ -440,8 +440,9 @@ impl Request {
let origin_ptr = mem
.get_slice(data_addr, data_len)
.map_err(ExecuteError::GetHostAddress)?
.ptr_guard();
.map_err(ExecuteError::GetHostAddress)?;
assert!(origin_ptr.len() >= data_len);
let origin_ptr = origin_ptr.ptr_guard();
// Verify the buffer alignment.
// In case it's not properly aligned, an intermediate buffer is

View file

@ -429,22 +429,20 @@ impl PvmemcontrolBusDevice {
/// [`range_base`, `range_base` + `range_len`) is present in the guest
fn operate_on_memory_range<F>(&self, addr: u64, length: u64, f: F) -> result::Result<(), Error>
where
F: FnOnce(*mut libc::c_void, libc::size_t) -> libc::c_int,
F: FnOnce(*mut libc::c_void, usize) -> libc::c_int,
{
let memory = self.mem.memory();
let range_base = GuestAddress(addr);
let range_len = usize::try_from(length).map_err(|_| Error::InvalidRequest)?;
// assume guest memory is not interleaved with vmm memory on the host.
if !memory.check_range(range_base, range_len) {
let Ok(slice) = memory.get_slice(range_base, range_len) else {
return Err(Error::GuestMemory(GuestMemoryError::InvalidGuestAddress(
range_base,
)));
}
let hva = memory
.get_host_address(range_base)
.map_err(Error::GuestMemory)?;
let res = f(hva as *mut libc::c_void, range_len as libc::size_t);
};
assert!(slice.len() >= range_len);
let res = f(slice.ptr_guard_mut().as_ptr() as _, slice.len());
if res != 0 {
return Err(Error::LibcFail(io::Error::last_os_error()));
}

View file

@ -67,8 +67,9 @@ impl TxVirtio {
let buf = desc_chain
.memory()
.get_slice(desc_addr, desc.len() as usize)
.map_err(NetQueuePairError::GuestMemory)?
.ptr_guard_mut();
.map_err(NetQueuePairError::GuestMemory)?;
assert!(buf.len() >= desc.len() as usize);
let buf = buf.ptr_guard_mut();
let iovec = libc::iovec {
iov_base: buf.as_ptr() as *mut libc::c_void,
iov_len: desc.len() as libc::size_t,
@ -208,8 +209,9 @@ impl RxVirtio {
let buf = desc_chain
.memory()
.get_slice(desc_addr, desc.len() as usize)
.map_err(NetQueuePairError::GuestMemory)?
.ptr_guard_mut();
.map_err(NetQueuePairError::GuestMemory)?;
assert!(buf.len() >= desc.len() as usize);
let buf = buf.ptr_guard_mut();
let iovec = libc::iovec {
iov_base: buf.as_ptr() as *mut libc::c_void,
iov_len: desc.len() as libc::size_t,

View file

@ -174,12 +174,15 @@ impl BalloonEpollHandler {
range_len: usize,
advice: libc::c_int,
) -> result::Result<(), Error> {
let hva = memory
.get_host_address(range_base)
let slice = memory
.get_slice(range_base, range_len)
.map_err(Error::GuestMemory)?;
assert!(slice.len() >= range_len);
let res =
// SAFETY: Need unsafe to do syscall madvise
unsafe { libc::madvise(hva as *mut libc::c_void, range_len as libc::size_t, advice) };
// SAFETY: FFI call with valid arguments, guaranteed by VolatileSlice
unsafe {
libc::madvise(slice.ptr_guard_mut().as_ptr() as *mut libc::c_void,
range_len as libc::size_t, advice) };
if res != 0 {
return Err(Error::MadviseFail(io::Error::last_os_error()));
}

View file

@ -19,7 +19,7 @@ use vhost::{VhostBackend, VringConfigData};
use virtio_queue::desc::RawDescriptor;
use virtio_queue::{Queue, QueueT};
use vm_device::dma_mapping::ExternalDmaMapping;
use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemory, GuestMemoryAtomic};
use vm_memory::{GuestAddress, GuestAddressSpace, GuestMemoryAtomic};
use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable};
use vm_virtio::{AccessPlatform, Translatable};
use vmm_sys_util::eventfd::EventFd;
@ -27,7 +27,7 @@ use vmm_sys_util::eventfd::EventFd;
use crate::{
ActivateError, ActivateResult, DEVICE_ACKNOWLEDGE, DEVICE_DRIVER, DEVICE_DRIVER_OK,
DEVICE_FEATURES_OK, GuestMemoryMmap, VIRTIO_F_IOMMU_PLATFORM, VirtioCommon, VirtioDevice,
VirtioInterrupt, VirtioInterruptType,
VirtioInterrupt, VirtioInterruptType, get_host_address_range,
};
#[derive(Error, Debug)]
@ -548,11 +548,10 @@ impl<M: GuestAddressSpace> VdpaDmaMapping<M> {
impl<M: GuestAddressSpace + Sync + Send> ExternalDmaMapping for VdpaDmaMapping<M> {
fn map(&self, iova: u64, gpa: u64, size: u64) -> result::Result<(), io::Error> {
let usize_size = size.try_into().unwrap();
let mem = self.memory.memory();
let guest_addr = GuestAddress(gpa);
let user_addr = if mem.check_range(guest_addr, size as usize) {
mem.get_host_address(guest_addr).unwrap() as *const u8
} else {
let Some(user_addr) = get_host_address_range(&*mem, guest_addr, usize_size) else {
return Err(io::Error::other(format!(
"failed to convert guest address 0x{gpa:x} into \
host user virtual address"
@ -563,7 +562,7 @@ impl<M: GuestAddressSpace + Sync + Send> ExternalDmaMapping for VdpaDmaMapping<M
"DMA map iova 0x{:x}, gpa 0x{:x}, size 0x{:x}, host_addr 0x{:x}",
iova, gpa, size, user_addr as u64
);
// SAFETY: check_range() and get_host_address() guarantee that
// SAFETY: get_host_address_range() guarantees that
// user_addr points to `size` bytes of memory.
unsafe {
self.device