diff --git a/Cargo.lock b/Cargo.lock index e6b02914c..95f2cf879 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -573,6 +573,7 @@ dependencies = [ "hypervisor", "libc", "log", + "num_enum", "pci", "serde", "thiserror", @@ -1366,6 +1367,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ + "proc-macro-crate", "proc-macro2", "quote", "syn 2.0.66", diff --git a/Cargo.toml b/Cargo.toml index aba775daf..9bb3bdb67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,7 @@ igvm = ["mshv", "vmm/igvm"] io_uring = ["vmm/io_uring"] kvm = ["vmm/kvm"] mshv = ["vmm/mshv"] +pvmemcontrol = ["vmm/pvmemcontrol"] sev_snp = ["igvm", "mshv", "vmm/sev_snp"] tdx = ["vmm/tdx"] tracing = ["tracer/tracing", "vmm/tracing"] diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 5a7be2615..ee16c9eeb 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -14,13 +14,18 @@ event_monitor = { path = "../event_monitor" } hypervisor = { path = "../hypervisor" } libc = "0.2.153" log = "0.4.22" +num_enum = "0.7.2" pci = { path = "../pci" } serde = { version = "1.0.197", features = ["derive"] } thiserror = "1.0.62" tpm = { path = "../tpm" } vm-allocator = { path = "../vm-allocator" } vm-device = { path = "../vm-device" } -vm-memory = "0.14.1" +vm-memory = { version = "0.14.1", features = [ + "backend-atomic", + "backend-bitmap", + "backend-mmap", +] } vm-migration = { path = "../vm-migration" } vmm-sys-util = "0.12.1" @@ -29,3 +34,4 @@ arch = { path = "../arch" } [features] default = [] +pvmemcontrol = [] diff --git a/devices/src/lib.rs b/devices/src/lib.rs index 260b1ecae..4a63fbf50 100644 --- a/devices/src/lib.rs +++ b/devices/src/lib.rs @@ -23,6 +23,8 @@ pub mod interrupt_controller; #[cfg(target_arch = "x86_64")] pub mod ioapic; pub mod legacy; +#[cfg(feature = "pvmemcontrol")] +pub mod pvmemcontrol; pub mod pvpanic; pub mod tpm; diff --git a/devices/src/pvmemcontrol.rs b/devices/src/pvmemcontrol.rs new file mode 100644 index 000000000..cc7b37fb2 --- /dev/null +++ b/devices/src/pvmemcontrol.rs @@ -0,0 +1,819 @@ +// Copyright © 2024 Google LLC +// +// SPDX-License-Identifier: Apache-2.0 +// + +use num_enum::TryFromPrimitive; +use pci::{ + BarReprogrammingParams, PciBarConfiguration, PciBarPrefetchable, PciBarRegionType, + PciClassCode, PciConfiguration, PciDevice, PciDeviceError, PciHeaderType, PciSubclass, +}; +use std::{ + collections::HashMap, + ffi::CString, + io, result, + sync::{Arc, Barrier, Mutex, RwLock}, +}; +use thiserror::Error; +use vm_allocator::{page_size::get_page_size, AddressAllocator, SystemAllocator}; +use vm_device::{BusDeviceSync, Resource}; +use vm_memory::{ + bitmap::AtomicBitmap, Address, ByteValued, Bytes, GuestAddress, GuestAddressSpace, GuestMemory, + GuestMemoryAtomic, GuestMemoryError, GuestMemoryMmap, Le32, Le64, +}; +use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; + +const PVMEMCONTROL_VENDOR_ID: u16 = 0x1ae0; +const PVMEMCONTROL_DEVICE_ID: u16 = 0x0087; + +const PVMEMCONTROL_SUBSYSTEM_VENDOR_ID: u16 = 0x1ae0; +const PVMEMCONTROL_SUBSYSTEM_ID: u16 = 0x011F; + +const MAJOR_VERSION: u64 = 1; +const MINOR_VERSION: u64 = 0; + +#[derive(Error, Debug)] +pub enum Error { + // device errors + #[error("Guest gave us bad memory addresses: {0}")] + GuestMemory(#[source] GuestMemoryError), + #[error("Guest sent us invalid request")] + InvalidRequest, + + #[error("Guest sent us invalid command: {0}")] + InvalidCommand(u32), + #[error("Guest sent us invalid connection: {0}")] + InvalidConnection(u32), + + // pvmemcontrol errors + #[error("Request contains invalid arguments: {0}")] + InvalidArgument(u64), + #[error("Unknown function code: {0}")] + UnknownFunctionCode(u64), + #[error("Libc call fail: {0}")] + LibcFail(#[source] std::io::Error), +} + +#[derive(Copy, Clone)] +enum PvmemcontrolSubclass { + Other = 0x80, +} + +impl PciSubclass for PvmemcontrolSubclass { + fn get_register_value(&self) -> u8 { + *self as u8 + } +} + +/// commands have 0 as the most significant byte +#[repr(u32)] +#[derive(PartialEq, Eq, Copy, Clone, TryFromPrimitive)] +enum PvmemcontrolTransportCommand { + Reset = 0x060f_e6d2, + Register = 0x0e35_9539, + Ready = 0x0ca8_d227, + Disconnect = 0x030f_5da0, + Ack = 0x03cf_5196, + Error = 0x01fb_a249, +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct PvmemcontrolTransportRegister { + buf_phys_addr: Le64, +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct PvmemcontrolTransportRegisterResponse { + command: Le32, + _padding: u32, +} + +#[repr(C)] +#[derive(Copy, Clone)] +union PvmemcontrolTransportUnion { + register: PvmemcontrolTransportRegister, + register_response: PvmemcontrolTransportRegisterResponse, + unit: (), +} + +#[repr(C)] +#[derive(Copy, Clone)] +struct PvmemcontrolTransport { + payload: PvmemcontrolTransportUnion, + command: PvmemcontrolTransportCommand, +} + +const PVMEMCONTROL_DEVICE_MMIO_SIZE: u64 = std::mem::size_of::() as u64; +const PVMEMCONTROL_DEVICE_MMIO_ALIGN: u64 = std::mem::align_of::() as u64; + +impl PvmemcontrolTransport { + fn ack() -> Self { + PvmemcontrolTransport { + payload: PvmemcontrolTransportUnion { unit: () }, + command: PvmemcontrolTransportCommand::Ack, + } + } + + fn error() -> Self { + PvmemcontrolTransport { + payload: PvmemcontrolTransportUnion { unit: () }, + command: PvmemcontrolTransportCommand::Error, + } + } + + fn register_response(command: u32) -> Self { + PvmemcontrolTransport { + payload: PvmemcontrolTransportUnion { + register_response: PvmemcontrolTransportRegisterResponse { + command: command.into(), + _padding: 0, + }, + }, + command: PvmemcontrolTransportCommand::Ack, + } + } + + unsafe fn as_register(self) -> PvmemcontrolTransportRegister { + self.payload.register + } +} + +// SAFETY: Contains no references and does not have compiler-inserted padding +unsafe impl ByteValued for PvmemcontrolTransportUnion {} +// SAFETY: Contains no references and does not have compiler-inserted padding +unsafe impl ByteValued for PvmemcontrolTransport {} + +#[repr(u64)] +#[derive(Copy, Clone, TryFromPrimitive, Debug)] +enum FunctionCode { + Info = 0, + Dontneed = 1, + Remove = 2, + Free = 3, + Pageout = 4, + Dontdump = 5, + SetVMAAnonName = 6, + Mlock = 7, + Munlock = 8, + MprotectNone = 9, + MprotectR = 10, + MprotectW = 11, + MprotectRW = 12, + Mergeable = 13, + Unmergeable = 14, +} + +#[repr(C)] +#[derive(Copy, Clone, Debug, Default)] +struct PvmemcontrolReq { + func_code: Le64, + addr: Le64, + length: Le64, + arg: Le64, +} + +// SAFETY: it only has data and has no implicit padding. +unsafe impl ByteValued for PvmemcontrolReq {} + +#[repr(C)] +#[derive(Copy, Clone, Default)] +struct PvmemcontrolResp { + ret_errno: Le32, + ret_code: Le32, + ret_value: Le64, + arg0: Le64, + arg1: Le64, +} + +impl std::fmt::Debug for PvmemcontrolResp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let PvmemcontrolResp { + ret_errno, + ret_code, + .. + } = self; + write!( + f, + "PvmemcontrolResp {{ ret_errno: {}, ret_code: {}, .. }}", + ret_errno.to_native(), + ret_code.to_native() + ) + } +} + +// SAFETY: it only has data and has no implicit padding. +unsafe impl ByteValued for PvmemcontrolResp {} + +/// The guest connections start at 0x8000_0000, which has a leading 1 in +/// the most significant byte, this ensures it does not conflict with +/// any of the transport commands +#[derive(Hash, Clone, Copy, PartialEq, Eq, Debug)] +pub struct GuestConnection { + command: u32, +} + +impl Default for GuestConnection { + fn default() -> Self { + GuestConnection::new(0x8000_0000) + } +} + +impl GuestConnection { + fn new(command: u32) -> Self { + Self { command } + } + + fn next(&self) -> Self { + let GuestConnection { command } = *self; + + if command == u32::MAX { + GuestConnection::default() + } else { + GuestConnection::new(command + 1) + } + } +} + +impl TryFrom for GuestConnection { + type Error = Error; + + fn try_from(value: u32) -> Result { + if (value & 0x8000_0000) != 0 { + Ok(GuestConnection::new(value)) + } else { + Err(Error::InvalidConnection(value)) + } + } +} + +struct PercpuInitState { + port_buf_map: HashMap, + next_conn: GuestConnection, +} + +impl PercpuInitState { + fn new() -> Self { + PercpuInitState { + port_buf_map: HashMap::new(), + next_conn: GuestConnection::default(), + } + } +} + +enum PvmemcontrolState { + PercpuInit(PercpuInitState), + Ready(HashMap), + Broken, +} + +pub struct PvmemcontrolDevice { + transport: PvmemcontrolTransport, + state: PvmemcontrolState, +} + +impl PvmemcontrolDevice { + fn new(transport: PvmemcontrolTransport, state: PvmemcontrolState) -> Self { + PvmemcontrolDevice { transport, state } + } +} + +impl PvmemcontrolDevice { + fn register_percpu_buf( + guest_memory: &GuestMemoryAtomic>, + mut state: PercpuInitState, + PvmemcontrolTransportRegister { buf_phys_addr }: PvmemcontrolTransportRegister, + ) -> Self { + // access to this address is checked + let buf_phys_addr = GuestAddress(buf_phys_addr.into()); + if !guest_memory.memory().check_range( + buf_phys_addr, + std::mem::size_of::().max(std::mem::size_of::()), + ) { + warn!("guest sent invalid phys addr {:#x}", buf_phys_addr.0); + return PvmemcontrolDevice::new( + PvmemcontrolTransport::error(), + PvmemcontrolState::Broken, + ); + } + + let conn = { + // find an available port+byte combination, and fail if full + let mut next_conn = state.next_conn; + while state.port_buf_map.contains_key(&next_conn) { + next_conn = next_conn.next(); + if next_conn == state.next_conn { + warn!("connections exhausted"); + return PvmemcontrolDevice::new( + PvmemcontrolTransport::error(), + PvmemcontrolState::Broken, + ); + } + } + next_conn + }; + state.next_conn = conn.next(); + state.port_buf_map.insert(conn, buf_phys_addr); + + // inform guest of the connection + let response = PvmemcontrolTransport::register_response(conn.command); + + PvmemcontrolDevice::new(response, PvmemcontrolState::PercpuInit(state)) + } + + fn reset() -> Self { + PvmemcontrolDevice::new( + PvmemcontrolTransport::ack(), + PvmemcontrolState::PercpuInit(PercpuInitState::new()), + ) + } + + fn error() -> Self { + PvmemcontrolDevice::new(PvmemcontrolTransport::error(), PvmemcontrolState::Broken) + } + + fn ready(PercpuInitState { port_buf_map, .. }: PercpuInitState) -> Self { + PvmemcontrolDevice::new( + PvmemcontrolTransport::ack(), + PvmemcontrolState::Ready(port_buf_map), + ) + } + + fn run_command( + &mut self, + guest_memory: &GuestMemoryAtomic>, + command: PvmemcontrolTransportCommand, + ) { + let state = std::mem::replace(&mut self.state, PvmemcontrolState::Broken); + + *self = match command { + PvmemcontrolTransportCommand::Reset => Self::reset(), + PvmemcontrolTransportCommand::Register => { + if let PvmemcontrolState::PercpuInit(state) = state { + // SAFETY: By device protocol. If driver is wrong the device + // can enter a Broken state, but the behavior is still sound. + Self::register_percpu_buf(guest_memory, state, unsafe { + self.transport.as_register() + }) + } else { + debug!("received register without reset"); + Self::error() + } + } + PvmemcontrolTransportCommand::Ready => { + if let PvmemcontrolState::PercpuInit(state) = state { + Self::ready(state) + } else { + debug!("received ready without reset"); + Self::error() + } + } + PvmemcontrolTransportCommand::Disconnect => Self::error(), + PvmemcontrolTransportCommand::Ack => { + debug!("received ack as command"); + Self::error() + } + PvmemcontrolTransportCommand::Error => { + debug!("received error as command"); + Self::error() + } + } + } + + /// read from the transport + fn read_transport(&self, offset: u64, data: &mut [u8]) { + self.transport + .as_slice() + .iter() + .skip(offset as usize) + .zip(data.iter_mut()) + .for_each(|(src, dest)| *dest = *src) + } + + /// can only write to transport payload + /// command is a special register that needs separate dispatching + fn write_transport(&mut self, offset: u64, data: &[u8]) { + self.transport + .payload + .as_mut_slice() + .iter_mut() + .skip(offset as usize) + .zip(data.iter()) + .for_each(|(dest, src)| *dest = *src) + } + + fn find_connection(&self, conn: GuestConnection) -> Option { + match &self.state { + PvmemcontrolState::Ready(map) => map.get(&conn).copied(), + _ => None, + } + } +} + +pub struct PvmemcontrolBusDevice { + mem: GuestMemoryAtomic>, + dev: RwLock, +} + +pub struct PvmemcontrolPciDevice { + id: String, + configuration: PciConfiguration, + bar_regions: Vec, +} + +impl PvmemcontrolBusDevice { + /// f is called with the host address of `range_base` and only when + /// [`range_base`, `range_base` + `range_len`) is present in the guest + fn operate_on_memory_range(&self, addr: u64, length: u64, f: F) -> result::Result<(), Error> + where + F: FnOnce(*mut libc::c_void, libc::size_t) -> libc::c_int, + { + let memory = self.mem.memory(); + let range_base = GuestAddress(addr); + let range_len = usize::try_from(length).map_err(|_| Error::InvalidRequest)?; + + // assume guest memory is not interleaved with vmm memory on the host. + if !memory.check_range(range_base, range_len) { + return Err(Error::GuestMemory(GuestMemoryError::InvalidGuestAddress( + range_base, + ))); + } + let hva = memory + .get_host_address(range_base) + .map_err(Error::GuestMemory)?; + let res = f(hva as *mut libc::c_void, range_len as libc::size_t); + if res != 0 { + return Err(Error::LibcFail(io::Error::last_os_error())); + } + Ok(()) + } + + fn madvise(&self, addr: u64, length: u64, advice: libc::c_int) -> result::Result<(), Error> { + // SAFETY: [`base`, `base` + `len`) is guest memory + self.operate_on_memory_range(addr, length, |base, len| unsafe { + libc::madvise(base, len, advice) + }) + } + + fn mlock(&self, addr: u64, length: u64, on_default: bool) -> result::Result<(), Error> { + // SAFETY: [`base`, `base` + `len`) is guest memory + self.operate_on_memory_range(addr, length, |base, len| unsafe { + libc::mlock2(base, len, if on_default { libc::MLOCK_ONFAULT } else { 0 }) + }) + } + + fn munlock(&self, addr: u64, length: u64) -> result::Result<(), Error> { + // SAFETY: [`base`, `base` + `len`) is guest memory + self.operate_on_memory_range(addr, length, |base, len| unsafe { + libc::munlock(base, len) + }) + } + + fn mprotect( + &self, + addr: u64, + length: u64, + protection: libc::c_int, + ) -> result::Result<(), Error> { + // SAFETY: [`base`, `base` + `len`) is guest memory + self.operate_on_memory_range(addr, length, |base, len| unsafe { + libc::mprotect(base, len, protection) + }) + } + + fn set_vma_anon_name(&self, addr: u64, length: u64, name: u64) -> result::Result<(), Error> { + let name = (name != 0).then(|| CString::new(format!("pvmemcontrol-{}", name)).unwrap()); + let name_ptr = if let Some(name) = &name { + name.as_ptr() + } else { + std::ptr::null() + }; + debug!("addr {:X} length {} name {:?}", addr, length, name); + + // SAFETY: [`base`, `base` + `len`) is guest memory + self.operate_on_memory_range(addr, length, |base, len| unsafe { + libc::prctl( + libc::PR_SET_VMA, + libc::PR_SET_VMA_ANON_NAME, + base, + len, + name_ptr, + ) + }) + } + + fn process_request( + &self, + func_code: FunctionCode, + addr: u64, + length: u64, + arg: u64, + ) -> Result { + let result = match func_code { + FunctionCode::Info => { + return Ok(PvmemcontrolResp { + ret_errno: 0.into(), + ret_code: 0.into(), + ret_value: get_page_size().into(), + arg0: MAJOR_VERSION.into(), + arg1: MINOR_VERSION.into(), + }) + } + FunctionCode::Dontneed => self.madvise(addr, length, libc::MADV_DONTNEED), + FunctionCode::Remove => self.madvise(addr, length, libc::MADV_REMOVE), + FunctionCode::Free => self.madvise(addr, length, libc::MADV_FREE), + FunctionCode::Pageout => self.madvise(addr, length, libc::MADV_PAGEOUT), + FunctionCode::Dontdump => self.madvise(addr, length, libc::MADV_DONTDUMP), + FunctionCode::SetVMAAnonName => self.set_vma_anon_name(addr, length, arg), + FunctionCode::Mlock => self.mlock(addr, length, false), + FunctionCode::Munlock => self.munlock(addr, length), + FunctionCode::MprotectNone => self.mprotect(addr, length, libc::PROT_NONE), + FunctionCode::MprotectR => self.mprotect(addr, length, libc::PROT_READ), + FunctionCode::MprotectW => self.mprotect(addr, length, libc::PROT_WRITE), + FunctionCode::MprotectRW => { + self.mprotect(addr, length, libc::PROT_READ | libc::PROT_WRITE) + } + FunctionCode::Mergeable => self.madvise(addr, length, libc::MADV_MERGEABLE), + FunctionCode::Unmergeable => self.madvise(addr, length, libc::MADV_UNMERGEABLE), + }; + result.map(|_| PvmemcontrolResp::default()) + } + + fn handle_request( + &self, + PvmemcontrolReq { + func_code, + addr, + length, + arg, + }: PvmemcontrolReq, + ) -> Result { + let (func_code, addr, length, arg) = ( + func_code.to_native(), + addr.to_native(), + length.to_native(), + arg.to_native(), + ); + + let resp_or_err = FunctionCode::try_from(func_code) + .map_err(|_| Error::UnknownFunctionCode(func_code)) + .and_then(|func_code| self.process_request(func_code, addr, length, arg)); + + let resp = match resp_or_err { + Ok(resp) => resp, + Err(e) => match e { + Error::InvalidArgument(arg) => PvmemcontrolResp { + ret_errno: (libc::EINVAL as u32).into(), + ret_code: (arg as u32).into(), + ..Default::default() + }, + Error::LibcFail(err) => PvmemcontrolResp { + ret_errno: (err.raw_os_error().unwrap_or(libc::EFAULT) as u32).into(), + ret_code: 0u32.into(), + ..Default::default() + }, + Error::UnknownFunctionCode(func_code) => PvmemcontrolResp { + ret_errno: (libc::EOPNOTSUPP as u32).into(), + ret_code: (func_code as u32).into(), + ..Default::default() + }, + Error::GuestMemory(err) => { + warn!("{}", err); + PvmemcontrolResp { + ret_errno: (libc::EINVAL as u32).into(), + ret_code: (func_code as u32).into(), + ..Default::default() + } + } + // device error, stop responding + other => return Err(other), + }, + }; + Ok(resp) + } + + fn handle_pvmemcontrol_request(&self, guest_addr: GuestAddress) { + let request: PvmemcontrolReq = if let Ok(x) = self.mem.memory().read_obj(guest_addr) { + x + } else { + warn!("cannot read from guest address {:#x}", guest_addr.0); + return; + }; + + let response: PvmemcontrolResp = match self.handle_request(request) { + Ok(x) => x, + Err(e) => { + warn!("cannot process request {:?} with error {}", request, e); + return; + } + }; + + if self.mem.memory().write_obj(response, guest_addr).is_err() { + warn!("cannot write to guest address {:#x}", guest_addr.0); + } + } + + fn handle_guest_write(&self, offset: u64, data: &[u8]) { + if offset as usize != std::mem::offset_of!(PvmemcontrolTransport, command) { + if data.len() != 4 && data.len() != 8 { + warn!("guest write is not 4 or 8 bytes long"); + return; + } + self.dev.write().unwrap().write_transport(offset, data); + return; + } + let data = if data.len() == 4 { + let mut d = [0u8; 4]; + d.iter_mut() + .zip(data.iter()) + .for_each(|(d, data)| *d = *data); + d + } else { + warn!("guest write with non u32 at command register"); + return; + }; + let data_cmd = u32::from_le_bytes(data); + let command = PvmemcontrolTransportCommand::try_from(data_cmd); + + match command { + Ok(command) => self.dev.write().unwrap().run_command(&self.mem, command), + Err(_) => { + GuestConnection::try_from(data_cmd) + .and_then(|conn| { + self.dev + .read() + .unwrap() + .find_connection(conn) + .ok_or(Error::InvalidConnection(conn.command)) + }) + .map(|gpa| self.handle_pvmemcontrol_request(gpa)) + .unwrap_or_else(|err| warn!("{:?}", err)); + } + } + } + + fn handle_guest_read(&self, offset: u64, data: &mut [u8]) { + self.dev.read().unwrap().read_transport(offset, data) + } +} + +impl PvmemcontrolDevice { + pub fn make_device( + id: String, + mem: GuestMemoryAtomic>, + ) -> (PvmemcontrolPciDevice, PvmemcontrolBusDevice) { + let dev = RwLock::new(PvmemcontrolDevice::error()); + let mut configuration = PciConfiguration::new( + PVMEMCONTROL_VENDOR_ID, + PVMEMCONTROL_DEVICE_ID, + 0x1, + PciClassCode::BaseSystemPeripheral, + &PvmemcontrolSubclass::Other, + None, + PciHeaderType::Device, + PVMEMCONTROL_SUBSYSTEM_VENDOR_ID, + PVMEMCONTROL_SUBSYSTEM_ID, + None, + None, + ); + let command: [u8; 2] = [0x03, 0x01]; // memory, io, SERR# + + configuration.write_config_register(1, 0, &command); + ( + PvmemcontrolPciDevice { + id, + configuration, + bar_regions: Vec::new(), + }, + PvmemcontrolBusDevice { mem, dev }, + ) + } +} + +impl PciDevice for PvmemcontrolPciDevice { + fn write_config_register( + &mut self, + reg_idx: usize, + offset: u64, + data: &[u8], + ) -> Option> { + self.configuration + .write_config_register(reg_idx, offset, data); + None + } + + fn read_config_register(&mut self, reg_idx: usize) -> u32 { + self.configuration.read_config_register(reg_idx) + } + + fn as_any(&mut self) -> &mut dyn std::any::Any { + self + } + + fn id(&self) -> Option { + Some(self.id.clone()) + } + + fn detect_bar_reprogramming( + &mut self, + reg_idx: usize, + data: &[u8], + ) -> Option { + self.configuration.detect_bar_reprogramming(reg_idx, data) + } + + fn allocate_bars( + &mut self, + _allocator: &Arc>, + mmio32_allocator: &mut AddressAllocator, + _mmio64_allocator: &mut AddressAllocator, + resources: Option>, + ) -> Result, PciDeviceError> { + let mut bars = Vec::new(); + let region_type = PciBarRegionType::Memory32BitRegion; + let bar_id = 0; + let region_size = PVMEMCONTROL_DEVICE_MMIO_SIZE; + let restoring = resources.is_some(); + let bar_addr = mmio32_allocator + .allocate(None, region_size, Some(PVMEMCONTROL_DEVICE_MMIO_ALIGN)) + .ok_or(PciDeviceError::IoAllocationFailed(region_size))?; + + let bar = PciBarConfiguration::default() + .set_index(bar_id as usize) + .set_address(bar_addr.raw_value()) + .set_size(region_size) + .set_region_type(region_type) + .set_prefetchable(PciBarPrefetchable::NotPrefetchable); + + if !restoring { + self.configuration + .add_pci_bar(&bar) + .map_err(|e| PciDeviceError::IoRegistrationFailed(bar_addr.raw_value(), e))?; + } + + bars.push(bar); + self.bar_regions.clone_from(&bars); + Ok(bars) + } + + fn free_bars( + &mut self, + _allocator: &mut SystemAllocator, + mmio32_allocator: &mut AddressAllocator, + _mmio64_allocator: &mut AddressAllocator, + ) -> Result<(), PciDeviceError> { + for bar in self.bar_regions.drain(..) { + mmio32_allocator.free(GuestAddress(bar.addr()), bar.size()) + } + Ok(()) + } + + fn move_bar(&mut self, old_base: u64, new_base: u64) -> result::Result<(), io::Error> { + for bar in self.bar_regions.iter_mut() { + if bar.addr() == old_base { + *bar = bar.set_address(new_base); + } + } + Ok(()) + } +} + +impl Pausable for PvmemcontrolPciDevice { + fn pause(&mut self) -> std::result::Result<(), MigratableError> { + Ok(()) + } + + fn resume(&mut self) -> std::result::Result<(), MigratableError> { + Ok(()) + } +} + +impl Snapshottable for PvmemcontrolPciDevice { + fn id(&self) -> String { + self.id.clone() + } + + fn snapshot(&mut self) -> std::result::Result { + let mut snapshot = Snapshot::new_from_state(&())?; + + // Snapshot PciConfiguration + snapshot.add_snapshot(self.configuration.id(), self.configuration.snapshot()?); + + Ok(snapshot) + } +} + +impl Transportable for PvmemcontrolPciDevice {} +impl Migratable for PvmemcontrolPciDevice {} + +impl BusDeviceSync for PvmemcontrolBusDevice { + fn read(&self, _base: u64, offset: u64, data: &mut [u8]) { + self.handle_guest_read(offset, data) + } + + fn write(&self, _base: u64, offset: u64, data: &[u8]) -> Option> { + self.handle_guest_write(offset, data); + None + } +} diff --git a/fuzz/Cargo.lock b/fuzz/Cargo.lock index ba28f5718..9cf731ca1 100644 --- a/fuzz/Cargo.lock +++ b/fuzz/Cargo.lock @@ -283,6 +283,7 @@ dependencies = [ "hypervisor", "libc", "log", + "num_enum", "pci", "serde", "thiserror", @@ -324,6 +325,12 @@ dependencies = [ "libc", ] +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "event_monitor" version = "0.1.0" @@ -408,6 +415,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hypervisor" version = "0.1.0" @@ -432,6 +445,16 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -544,6 +567,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ca88d725a0a943b096803bd34e73a4437208b6077654cc4ecb2947a5f91618d" +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + [[package]] name = "micro_http" version = "0.1.0" @@ -597,6 +626,27 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_enum" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02339744ee7253741199f897151b38e72257d13802d4ee837285cc2990a90845" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "681030a937600a36906c185595136d26abfebb4aa9c65701cefcaf8578bb982b" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -634,6 +684,15 @@ dependencies = [ "vmm-sys-util", ] +[[package]] +name = "proc-macro-crate" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d37c51ca738a55da99dc0c4a34860fd675453b8b36209178c2249bb13651284" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.86" @@ -824,6 +883,23 @@ dependencies = [ "syn", ] +[[package]] +name = "toml_datetime" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4badfd56924ae69bcc9039335b2e017639ce3f9b001c393c1b2d1ef846ce2cbf" + +[[package]] +name = "toml_edit" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" +dependencies = [ + "indexmap", + "toml_datetime", + "winnow", +] + [[package]] name = "tpm" version = "0.1.0" @@ -1249,6 +1325,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.5.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" +dependencies = [ + "memchr", +] + [[package]] name = "zerocopy" version = "0.7.35" diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 80d50b06d..418f07cf9 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -10,6 +10,7 @@ cargo-fuzz = true [features] igvm = [] +pvmemcontrol = [] [dependencies] block = { path = "../block" } diff --git a/fuzz/fuzz_targets/http_api.rs b/fuzz/fuzz_targets/http_api.rs index 727fbf967..720facba1 100644 --- a/fuzz/fuzz_targets/http_api.rs +++ b/fuzz/fuzz_targets/http_api.rs @@ -180,6 +180,8 @@ impl RequestHandler for StubApiRequestHandler { vdpa: None, vsock: None, pvpanic: false, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: None, iommu: false, #[cfg(target_arch = "x86_64")] sgx_epc: None, diff --git a/src/main.rs b/src/main.rs index 36e0da4c5..9568fbe87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -522,6 +522,16 @@ fn create_app(default_vcpus: String, default_memory: String, default_rng: String .num_args(1) .group("vm-config"), ); + #[cfg(feature = "pvmemcontrol")] + let app = app.arg( + Arg::new("pvmemcontrol") + .long("pvmemcontrol") + .help("Pvmemcontrol device") + .num_args(0) + .action(ArgAction::SetTrue) + .group("vm-config"), + ); + app.arg( Arg::new("version") .short('V') @@ -1054,6 +1064,8 @@ mod unit_tests { vdpa: None, vsock: None, pvpanic: false, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: None, iommu: false, #[cfg(target_arch = "x86_64")] sgx_epc: None, diff --git a/vmm/Cargo.toml b/vmm/Cargo.toml index 54f2fa7a7..f6f426eaf 100644 --- a/vmm/Cargo.toml +++ b/vmm/Cargo.toml @@ -19,6 +19,7 @@ kvm = [ "vm-device/kvm", ] mshv = ["hypervisor/mshv", "pci/mshv", "vfio-ioctls/mshv", "vm-device/mshv"] +pvmemcontrol = ["devices/pvmemcontrol"] sev_snp = ["arch/sev_snp", "hypervisor/sev_snp", "virtio-devices/sev_snp"] tdx = ["arch/tdx", "hypervisor/tdx"] tracing = ["tracer/tracing"] diff --git a/vmm/src/config.rs b/vmm/src/config.rs index c788cb83a..204020c3f 100644 --- a/vmm/src/config.rs +++ b/vmm/src/config.rs @@ -482,6 +482,8 @@ pub struct VmParams<'a> { pub user_devices: Option>, pub vdpa: Option>, pub vsock: Option<&'a str>, + #[cfg(feature = "pvmemcontrol")] + pub pvmemcontrol: bool, pub pvpanic: bool, #[cfg(target_arch = "x86_64")] pub sgx_epc: Option>, @@ -543,6 +545,8 @@ impl<'a> VmParams<'a> { .get_many::("vdpa") .map(|x| x.map(|y| y as &str).collect()); let vsock: Option<&str> = args.get_one::("vsock").map(|x| x as &str); + #[cfg(feature = "pvmemcontrol")] + let pvmemcontrol = args.get_flag("pvmemcontrol"); let pvpanic = args.get_flag("pvpanic"); #[cfg(target_arch = "x86_64")] let sgx_epc: Option> = args @@ -591,6 +595,8 @@ impl<'a> VmParams<'a> { user_devices, vdpa, vsock, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol, pvpanic, #[cfg(target_arch = "x86_64")] sgx_epc, @@ -2772,6 +2778,11 @@ impl VmConfig { balloon = Some(BalloonConfig::parse(balloon_params)?); } + #[cfg(feature = "pvmemcontrol")] + let pvmemcontrol: Option = vm_params + .pvmemcontrol + .then_some(PvmemcontrolConfig::default()); + let mut fs: Option> = None; if let Some(fs_list) = &vm_params.fs { let mut fs_config_list = Vec::new(); @@ -2930,6 +2941,8 @@ impl VmConfig { user_devices, vdpa, vsock, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol, pvpanic: vm_params.pvpanic, iommu: false, // updated in VmConfig::validate() #[cfg(target_arch = "x86_64")] @@ -3049,6 +3062,8 @@ impl Clone for VmConfig { net: self.net.clone(), rng: self.rng.clone(), balloon: self.balloon.clone(), + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: self.pvmemcontrol.clone(), fs: self.fs.clone(), pmem: self.pmem.clone(), serial: self.serial.clone(), @@ -3838,6 +3853,8 @@ mod tests { user_devices: None, vdpa: None, vsock: None, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: None, pvpanic: false, iommu: false, #[cfg(target_arch = "x86_64")] @@ -4047,6 +4064,8 @@ mod tests { user_devices: None, vdpa: None, vsock: None, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: None, pvpanic: false, iommu: false, #[cfg(target_arch = "x86_64")] diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 3b9b24e39..63061bbce 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -49,6 +49,8 @@ use devices::gic; use devices::ioapic; #[cfg(target_arch = "aarch64")] use devices::legacy::Pl011; +#[cfg(feature = "pvmemcontrol")] +use devices::pvmemcontrol::{PvmemcontrolBusDevice, PvmemcontrolPciDevice}; use devices::{ interrupt_controller, interrupt_controller::InterruptController, AcpiNotificationFlags, }; @@ -118,6 +120,8 @@ const DEBUGCON_DEVICE_NAME: &str = "__debug_console"; const GPIO_DEVICE_NAME: &str = "__gpio"; const RNG_DEVICE_NAME: &str = "__rng"; const IOMMU_DEVICE_NAME: &str = "__iommu"; +#[cfg(feature = "pvmemcontrol")] +const PVMEMCONTROL_DEVICE_NAME: &str = "__pvmemcontrol"; const BALLOON_DEVICE_NAME: &str = "__balloon"; const CONSOLE_DEVICE_NAME: &str = "__console"; const PVPANIC_DEVICE_NAME: &str = "__pvpanic"; @@ -195,6 +199,10 @@ pub enum DeviceManagerError { /// Cannot create virtio-balloon device CreateVirtioBalloon(io::Error), + /// Cannot create pvmemcontrol device + #[cfg(feature = "pvmemcontrol")] + CreatePvmemcontrol(io::Error), + /// Cannot create virtio-watchdog device CreateVirtioWatchdog(io::Error), @@ -886,6 +894,12 @@ pub struct DeviceManager { // GPIO device for AArch64 gpio_device: Option>>, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol_devices: Option<( + Arc, + Arc>, + )>, + // pvpanic device pvpanic_device: Option>>, @@ -1165,6 +1179,8 @@ impl DeviceManager { virtio_mem_devices: Vec::new(), #[cfg(target_arch = "aarch64")] gpio_device: None, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol_devices: None, pvpanic_device: None, force_iommu, io_uring_supported: None, @@ -1278,6 +1294,17 @@ impl DeviceManager { self.virtio_devices = virtio_devices; + // Add pvmemcontrol if required + #[cfg(feature = "pvmemcontrol")] + { + if self.config.lock().unwrap().pvmemcontrol.is_some() { + let (pvmemcontrol_bus_device, pvmemcontrol_pci_device) = + self.make_pvmemcontrol_device()?; + self.pvmemcontrol_devices = + Some((pvmemcontrol_bus_device, pvmemcontrol_pci_device)); + } + } + if self.config.clone().lock().unwrap().pvpanic { self.pvpanic_device = self.add_pvpanic_device()?; } @@ -3048,6 +3075,48 @@ impl DeviceManager { Ok(devices) } + #[cfg(feature = "pvmemcontrol")] + fn make_pvmemcontrol_device( + &mut self, + ) -> DeviceManagerResult<( + Arc, + Arc>, + )> { + let id = String::from(PVMEMCONTROL_DEVICE_NAME); + let pci_segment_id = 0x0_u16; + + let (pci_segment_id, pci_device_bdf, resources) = + self.pci_resources(&id, pci_segment_id)?; + + info!("Creating pvmemcontrol device: id = {}", id); + let (pvmemcontrol_pci_device, pvmemcontrol_bus_device) = + devices::pvmemcontrol::PvmemcontrolDevice::make_device( + id.clone(), + self.memory_manager.lock().unwrap().guest_memory(), + ); + + let pvmemcontrol_pci_device = Arc::new(Mutex::new(pvmemcontrol_pci_device)); + let pvmemcontrol_bus_device = Arc::new(pvmemcontrol_bus_device); + + let new_resources = self.add_pci_device( + pvmemcontrol_bus_device.clone(), + pvmemcontrol_pci_device.clone(), + pci_segment_id, + pci_device_bdf, + resources, + )?; + + let mut node = device_node!(id, pvmemcontrol_pci_device); + + node.resources = new_resources; + node.pci_bdf = Some(pci_device_bdf); + node.pci_device_handle = None; + + self.device_tree.lock().unwrap().insert(id, node); + + Ok((pvmemcontrol_bus_device, pvmemcontrol_pci_device)) + } + fn make_virtio_balloon_devices(&mut self) -> DeviceManagerResult> { let mut devices = Vec::new(); diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index 4f6cf40e5..973ed2056 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -2230,6 +2230,8 @@ mod unit_tests { user_devices: None, vdpa: None, vsock: None, + #[cfg(feature = "pvmemcontrol")] + pvmemcontrol: None, pvpanic: false, iommu: false, #[cfg(target_arch = "x86_64")] diff --git a/vmm/src/vm_config.rs b/vmm/src/vm_config.rs index 3a746f7fa..e5409496d 100644 --- a/vmm/src/vm_config.rs +++ b/vmm/src/vm_config.rs @@ -424,6 +424,10 @@ pub struct BalloonConfig { pub free_page_reporting: bool, } +#[cfg(feature = "pvmemcontrol")] +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize, Default)] +pub struct PvmemcontrolConfig {} + #[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] pub struct FsConfig { pub tag: String, @@ -775,6 +779,9 @@ pub struct VmConfig { pub user_devices: Option>, pub vdpa: Option>, pub vsock: Option, + #[cfg(feature = "pvmemcontrol")] + #[serde(default)] + pub pvmemcontrol: Option, #[serde(default)] pub pvpanic: bool, #[serde(default)]