From 6d7b606d027542df5228bc38ea93d1047c327c0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=AD=C3=B0=20Steinn=20Geirsson?= Date: Sun, 22 Mar 2026 11:19:57 +0000 Subject: [PATCH] feat: add cloud-hypervisor support to balloond and dbus-proxy - BalloonBackend trait abstracting over hypervisor-specific balloon control - CrosvmBackend wrapping existing crosvm control socket protocol - CloudHypervisorBackend using raw HTTP/1.1 over persistent Unix socket (GET /api/v1/vm.balloon-statistics, PUT /api/v1/vm.resize) - Watcher recognizes both crosvm-control.socket and cloud-hypervisor-control.socket for auto-discovery - dbus-proxy: CONNECT protocol support for cloud-hypervisor vsock, generic stream handling, --cid/--vsock-socket CLI args - NixOS module: enable dbus-proxy for all VMs, vary args by hypervisor --- modules/services.nix | 11 +- vmsilo-balloond/src/backend.rs | 27 +++ vmsilo-balloond/src/cloud_hypervisor.rs | 296 ++++++++++++++++++++++++ vmsilo-balloond/src/crosvm.rs | 41 ++-- vmsilo-balloond/src/lib.rs | 2 + vmsilo-balloond/src/main.rs | 55 +++-- vmsilo-balloond/src/watcher.rs | 34 ++- vmsilo-dbus-proxy/src/host/main.rs | 79 +++++-- vmsilo-dbus-proxy/src/host/mod.rs | 1 + vmsilo-dbus-proxy/src/host/vsock.rs | 112 +++++++++ 10 files changed, 589 insertions(+), 69 deletions(-) create mode 100644 vmsilo-balloond/src/backend.rs create mode 100644 vmsilo-balloond/src/cloud_hypervisor.rs create mode 100644 vmsilo-dbus-proxy/src/host/vsock.rs diff --git a/modules/services.nix b/modules/services.nix index 7a34e6c..1039a2e 100644 --- a/modules/services.nix +++ b/modules/services.nix @@ -670,7 +670,7 @@ in ) (lib.attrValues cfg.nixosVms) ++ [ (lib.nameValuePair "vmsilo-balloond" mkVmsiloBalloondService) ] ++ - # D-Bus proxy services (crosvm VMs only — kernel vsock required) + # D-Bus proxy services lib.concatMap ( vm: let @@ -679,7 +679,7 @@ in lib.optional vm.dbus.tray "tray" ++ lib.optional vm.dbus.notifications "notifications" ); in - lib.optional (dbusEnabled && vm.hypervisor == "crosvm") ( + lib.optional dbusEnabled ( lib.nameValuePair "vmsilo-${vm.name}-dbus-proxy" { description = "D-Bus proxy for VM ${vm.name}"; wantedBy = [ "vmsilo-${vm.name}-vm.service" ]; @@ -702,7 +702,12 @@ in esac exec ${dbusProxyHost} \ --vm-name ${lib.escapeShellArg vm.name} \ - --cid ${toString vm.id} \ + ${ + if vm.hypervisor == "crosvm" then + "--cid ${toString vm.id}" + else + "--vsock-socket /run/vmsilo/${vm.name}/vsock.socket" + } \ --icon-theme "$icon_theme" \ --color ${lib.escapeShellArg vmColor} \ --protocols ${lib.escapeShellArg protocols} \ diff --git a/vmsilo-balloond/src/backend.rs b/vmsilo-balloond/src/backend.rs new file mode 100644 index 0000000..65804df --- /dev/null +++ b/vmsilo-balloond/src/backend.rs @@ -0,0 +1,27 @@ +use anyhow::Result; +use serde::Deserialize; + +/// Balloon statistics reported by the guest via virtio-balloon. +#[derive(Deserialize, Debug, Clone)] +pub struct BalloonStats { + pub free_memory: Option, + pub available_memory: Option, + pub disk_caches: Option, + pub total_memory: Option, +} + +/// Stats result from a balloon query. +#[derive(Debug)] +pub struct StatsResult { + pub stats: BalloonStats, + pub balloon_actual: u64, +} + +/// Abstraction over hypervisor-specific balloon control. +pub trait BalloonBackend: Send + Sync { + /// Query balloon stats from the VM. Blocking — call from spawn_blocking. + fn balloon_stats(&self) -> Result; + + /// Set absolute balloon size in bytes. Blocking — call from spawn_blocking. + fn balloon_adjust(&self, num_bytes: u64) -> Result<()>; +} diff --git a/vmsilo-balloond/src/cloud_hypervisor.rs b/vmsilo-balloond/src/cloud_hypervisor.rs new file mode 100644 index 0000000..7336273 --- /dev/null +++ b/vmsilo-balloond/src/cloud_hypervisor.rs @@ -0,0 +1,296 @@ +use std::io::{Read, Write}; +use std::os::unix::net::UnixStream; +use std::path::PathBuf; +use std::sync::Mutex; +use std::time::Duration; + +use anyhow::{anyhow, Context, Result}; +use serde::Deserialize; + +use crate::backend::{BalloonBackend, BalloonStats, StatsResult}; + +/// Cloud-hypervisor balloon statistics JSON response. +#[derive(Deserialize, Debug)] +struct ChBalloonStatistics { + free_memory: Option, + available_memory: Option, + disk_caches: Option, + total_memory: Option, + balloon_actual_bytes: u64, +} + +/// Read timeout for the Unix socket connection. +const READ_TIMEOUT: Duration = Duration::from_secs(5); + +/// Find the position of `\r\n\r\n` in a byte slice. +fn find_header_end(data: &[u8]) -> Option { + data.windows(4) + .position(|w| w == b"\r\n\r\n") +} + +/// Parse an HTTP status line like "HTTP/1.1 200 OK\r\n" and return the status code. +fn parse_status_line(line: &str) -> Result { + let parts: Vec<&str> = line.splitn(3, ' ').collect(); + if parts.len() < 2 { + return Err(anyhow!("malformed status line: {}", line)); + } + parts[1] + .parse::() + .with_context(|| format!("invalid status code in: {}", line)) +} + +/// Read HTTP headers from a byte slice (everything before `\r\n\r\n`). +/// Returns the Content-Length value if present. +fn read_headers(header_bytes: &[u8]) -> Result> { + let header_str = std::str::from_utf8(header_bytes) + .context("headers are not valid UTF-8")?; + let mut content_length = None; + for line in header_str.split("\r\n") { + let lower = line.to_ascii_lowercase(); + if lower.starts_with("content-length:") { + let val = line["content-length:".len()..].trim(); + content_length = Some( + val.parse::() + .with_context(|| format!("invalid Content-Length: {}", val))?, + ); + } + } + Ok(content_length) +} + +/// Parse a complete HTTP response into a `StatsResult`. +fn parse_stats_response(data: &[u8]) -> Result { + let header_end = find_header_end(data) + .ok_or_else(|| anyhow!("no header terminator found"))?; + + let header_bytes = &data[..header_end]; + let body_start = header_end + 4; + + // Parse status line + let header_str = std::str::from_utf8(header_bytes) + .context("headers are not valid UTF-8")?; + let status_line = header_str + .split("\r\n") + .next() + .ok_or_else(|| anyhow!("empty headers"))?; + let status = parse_status_line(status_line)?; + + if status < 200 || status >= 300 { + return Err(anyhow!("HTTP {}", status)); + } + + let body = &data[body_start..]; + let stats: ChBalloonStatistics = + serde_json::from_slice(body).context("parsing balloon statistics JSON")?; + + Ok(StatsResult { + stats: BalloonStats { + free_memory: stats.free_memory, + available_memory: stats.available_memory, + disk_caches: stats.disk_caches, + total_memory: stats.total_memory, + }, + balloon_actual: stats.balloon_actual_bytes, + }) +} + +/// Format a GET request for balloon statistics. +fn format_stats_request() -> &'static str { + "GET /api/v1/vm.balloon-statistics HTTP/1.1\r\nHost: localhost\r\n\r\n" +} + +/// Format a PUT request to resize the balloon. +fn format_resize_request(desired_balloon: u64) -> String { + let body = format!(r#"{{"desired_balloon":{}}}"#, desired_balloon); + format!( + "PUT /api/v1/vm.resize HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body, + ) +} + +/// Parse a resize HTTP response — succeeds if status is 200 or 204. +fn parse_resize_response(data: &[u8]) -> Result<()> { + // Find at least the status line + let header_str = std::str::from_utf8(data) + .context("response is not valid UTF-8")?; + let status_line = header_str + .split("\r\n") + .next() + .ok_or_else(|| anyhow!("empty response"))?; + let status = parse_status_line(status_line)?; + + if status == 200 || status == 204 { + Ok(()) + } else { + Err(anyhow!("resize failed: HTTP {}", status)) + } +} + +/// Balloon backend for cloud-hypervisor — talks to the CH HTTP API over a Unix socket. +pub struct CloudHypervisorBackend { + api_socket_path: PathBuf, + conn: Mutex>, +} + +impl CloudHypervisorBackend { + pub fn new(api_socket_path: PathBuf) -> Self { + Self { + api_socket_path, + conn: Mutex::new(None), + } + } + + /// Get or create a persistent connection, with a read timeout. + fn get_conn(&self) -> Result>> { + let mut guard = self.conn.lock().map_err(|e| anyhow!("mutex poisoned: {}", e))?; + if guard.is_none() { + let stream = UnixStream::connect(&self.api_socket_path) + .with_context(|| format!("connecting to {:?}", self.api_socket_path))?; + stream + .set_read_timeout(Some(READ_TIMEOUT)) + .context("set_read_timeout")?; + *guard = Some(stream); + } + Ok(guard) + } + + /// Send a request and read the full HTTP response. Uses Content-Length to + /// determine when the response is complete on a persistent connection. + fn try_request(&self, request: &str) -> Result> { + let mut guard = self.get_conn()?; + let stream = guard.as_mut().ok_or_else(|| anyhow!("no connection"))?; + + stream + .write_all(request.as_bytes()) + .context("writing request")?; + + let mut buf = Vec::with_capacity(4096); + let mut tmp = [0u8; 4096]; + + // Read until we have the full headers + loop { + let n = stream.read(&mut tmp).context("reading response")?; + if n == 0 { + return Err(anyhow!("connection closed")); + } + buf.extend_from_slice(&tmp[..n]); + + if find_header_end(&buf).is_some() { + break; + } + } + + let header_end = find_header_end(&buf).unwrap(); + let header_bytes = &buf[..header_end]; + let content_length = read_headers(header_bytes)?; + let body_start = header_end + 4; + + if let Some(cl) = content_length { + let total_needed = body_start + cl; + while buf.len() < total_needed { + let n = stream.read(&mut tmp).context("reading response body")?; + if n == 0 { + return Err(anyhow!("connection closed during body read")); + } + buf.extend_from_slice(&tmp[..n]); + } + // Trim to exact response size (in case of pipelining) + buf.truncate(total_needed); + } + // If no Content-Length (e.g. 204 No Content), what we have is enough. + + Ok(buf) + } + + /// Send a request, retrying once on failure by reconnecting. + fn request(&self, request: &str) -> Result> { + match self.try_request(request) { + Ok(data) => Ok(data), + Err(first_err) => { + tracing::debug!(error = %first_err, "request failed, reconnecting"); + // Drop the old connection + if let Ok(mut guard) = self.conn.lock() { + *guard = None; + } + // Retry once + self.try_request(request) + .with_context(|| format!("retry after: {}", first_err)) + } + } + } +} + +impl BalloonBackend for CloudHypervisorBackend { + fn balloon_stats(&self) -> Result { + let request = format_stats_request(); + let response = self.request(request)?; + parse_stats_response(&response) + } + + fn balloon_adjust(&self, num_bytes: u64) -> Result<()> { + let request = format_resize_request(num_bytes); + let response = self.request(&request)?; + parse_resize_response(&response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_stats_response_full() { + let body = r#"{"free_memory":1073741824,"available_memory":2147483648,"disk_caches":536870912,"total_memory":4294967296,"balloon_actual_bytes":524288000,"balloon_target_bytes":524288000}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let result = parse_stats_response(response.as_bytes()).unwrap(); + assert_eq!(result.stats.free_memory, Some(1073741824)); + assert_eq!(result.stats.available_memory, Some(2147483648)); + assert_eq!(result.stats.disk_caches, Some(536870912)); + assert_eq!(result.stats.total_memory, Some(4294967296)); + assert_eq!(result.balloon_actual, 524288000); + } + + #[test] + fn parse_stats_response_nullable_fields() { + let body = r#"{"balloon_actual_bytes":0}"#; + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + let result = parse_stats_response(response.as_bytes()).unwrap(); + assert_eq!(result.stats.free_memory, None); + assert_eq!(result.balloon_actual, 0); + } + + #[test] + fn parse_stats_response_404() { + let response = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + assert!(parse_stats_response(&response[..]).is_err()); + } + + #[test] + fn format_resize_request_test() { + let req = format_resize_request(1048576); + assert!(req.starts_with("PUT /api/v1/vm.resize HTTP/1.1\r\n")); + assert!(req.contains("Content-Type: application/json\r\n")); + assert!(req.contains(r#""desired_balloon":1048576"#)); + } + + #[test] + fn parse_resize_response_204() { + let response = b"HTTP/1.1 204 No Content\r\n\r\n"; + assert!(parse_resize_response(&response[..]).is_ok()); + } + + #[test] + fn parse_resize_response_404() { + let response = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n"; + assert!(parse_resize_response(&response[..]).is_err()); + } +} diff --git a/vmsilo-balloond/src/crosvm.rs b/vmsilo-balloond/src/crosvm.rs index 8d51940..11a0bf7 100644 --- a/vmsilo-balloond/src/crosvm.rs +++ b/vmsilo-balloond/src/crosvm.rs @@ -1,5 +1,5 @@ use std::os::fd::{AsRawFd, OwnedFd}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::time::Duration; use anyhow::{anyhow, Context, Result}; @@ -8,6 +8,8 @@ use nix::sys::socket::{ }; use serde::{Deserialize, Serialize}; +use crate::backend::{BalloonBackend, BalloonStats, StatsResult}; + /// Request sent to crosvm control socket. #[derive(Serialize, Debug)] pub enum VmRequest { @@ -35,25 +37,9 @@ pub enum VmResponse { }, } -/// Balloon statistics reported by the guest via virtio-balloon. -#[derive(Deserialize, Debug, Clone)] -pub struct BalloonStats { - pub free_memory: Option, - pub available_memory: Option, - pub disk_caches: Option, - pub total_memory: Option, -} - /// Socket receive timeout. const RECV_TIMEOUT: Duration = Duration::from_secs(1); -/// Stats result from crosvm balloon query. -#[derive(Debug)] -pub struct StatsResult { - pub stats: BalloonStats, - pub balloon_actual: u64, -} - /// Query balloon stats from a VM's crosvm control socket. /// Blocking call -- use with tokio::task::spawn_blocking. pub fn balloon_stats(socket_path: &Path) -> Result { @@ -132,6 +118,27 @@ fn recv_response(fd: &OwnedFd) -> Result { Ok(response) } +/// Balloon backend for crosvm — talks to crosvm's control socket. +pub struct CrosvmBackend { + socket_path: PathBuf, +} + +impl CrosvmBackend { + pub fn new(socket_path: PathBuf) -> Self { + Self { socket_path } + } +} + +impl BalloonBackend for CrosvmBackend { + fn balloon_stats(&self) -> Result { + balloon_stats(&self.socket_path) + } + + fn balloon_adjust(&self, num_bytes: u64) -> Result<()> { + balloon_adjust(&self.socket_path, num_bytes) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/vmsilo-balloond/src/lib.rs b/vmsilo-balloond/src/lib.rs index 99d4436..cc289c3 100644 --- a/vmsilo-balloond/src/lib.rs +++ b/vmsilo-balloond/src/lib.rs @@ -1,4 +1,6 @@ pub mod args; +pub mod backend; +pub mod cloud_hypervisor; pub mod crosvm; pub mod host; pub mod policy; diff --git a/vmsilo-balloond/src/main.rs b/vmsilo-balloond/src/main.rs index 9e49fac..c4ecaa6 100644 --- a/vmsilo-balloond/src/main.rs +++ b/vmsilo-balloond/src/main.rs @@ -1,12 +1,14 @@ use std::collections::BTreeMap; -use std::path::PathBuf; +use std::sync::Arc; use std::time::Instant; use anyhow::Result; use clap::Parser; use tokio::signal; use tokio::signal::unix::{signal as unix_signal, SignalKind}; -use vmsilo_balloond::crosvm; +use vmsilo_balloond::backend::BalloonBackend; +use vmsilo_balloond::cloud_hypervisor::CloudHypervisorBackend; +use vmsilo_balloond::crosvm::CrosvmBackend; use vmsilo_balloond::host::read_host_memory; use vmsilo_balloond::policy::{BalloonPolicy, GuestStats}; use vmsilo_balloond::psi; @@ -16,7 +18,7 @@ use vmsilo_balloond::{init_logging, Args}; struct VmState { name: String, - socket_path: PathBuf, + backend: Arc, policy: BalloonPolicy, stall: StallDetector, consecutive_failures: u32, @@ -96,10 +98,25 @@ async fn async_main(args: Args) -> Result<(), Box> { event = rx.recv() => { match event { Some(WatchEvent::VmAdded { name, socket_path }) => { - tracing::info!(vm = %name, "VM discovered"); + tracing::info!(vm = %name, path = ?socket_path, "VM discovered"); + let backend: Arc = match socket_path + .file_name() + .and_then(|f| f.to_str()) + { + Some("crosvm-control.socket") => { + Arc::new(CrosvmBackend::new(socket_path)) + } + Some("cloud-hypervisor-control.socket") => { + Arc::new(CloudHypervisorBackend::new(socket_path)) + } + _ => { + tracing::warn!(vm = %name, path = ?socket_path, "unknown socket type"); + continue; + } + }; vms.insert(name.clone(), VmState { name, - socket_path, + backend, policy: BalloonPolicy::new(critical_host_available, args.guest_available_bias, args.critical_guest_available), stall: StallDetector::new(), consecutive_failures: 0, @@ -165,12 +182,11 @@ async fn poll_all_vms( } async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result { - let socket_path = vm.socket_path.clone(); - let result = tokio::task::spawn_blocking(move || crosvm::balloon_stats(&socket_path)).await??; + let backend = vm.backend.clone(); + let result = tokio::task::spawn_blocking(move || backend.balloon_stats()).await??; let now = Instant::now(); - // Check for stall from previous inflate if vm.stall.is_stalled(result.balloon_actual, now) { let new_size = result.balloon_actual.saturating_sub(STALL_DEFLATE_BYTES); tracing::warn!( @@ -180,27 +196,19 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result { cooldown_secs = STALL_COOLDOWN.as_secs(), "balloon stall detected, deflating" ); - let path = vm.socket_path.clone(); - tokio::task::spawn_blocking(move || crosvm::balloon_adjust(&path, new_size)).await??; + let backend = vm.backend.clone(); + tokio::task::spawn_blocking(move || backend.balloon_adjust(new_size)).await??; vm.stall.clear(); vm.stall.enter_cooldown(now); return Ok(-(STALL_DEFLATE_BYTES as i64)); } - // Required fields must be present - let free = result - .stats - .free_memory + let free = result.stats.free_memory .ok_or_else(|| anyhow::anyhow!("free_memory not available"))?; - let cached = result - .stats - .disk_caches + let cached = result.stats.disk_caches .ok_or_else(|| anyhow::anyhow!("disk_caches not available"))?; - let total = result - .stats - .total_memory + let total = result.stats.total_memory .ok_or_else(|| anyhow::anyhow!("total_memory not available"))?; - let available = result.stats.available_memory; let guest_stats = GuestStats { @@ -213,7 +221,6 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result { let delta = vm.policy.compute_balloon_delta(&guest_stats, host_available); - // Suppress inflation during stall cooldown (deflation still allowed) if delta > 0 && vm.stall.is_in_cooldown(now) { return Ok(0); } @@ -228,8 +235,8 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result { guest_available_mb = available.unwrap_or(0) / (1024 * 1024), "adjusting balloon" ); - let path = vm.socket_path.clone(); - tokio::task::spawn_blocking(move || crosvm::balloon_adjust(&path, new_size)).await??; + let backend = vm.backend.clone(); + tokio::task::spawn_blocking(move || backend.balloon_adjust(new_size)).await??; if delta > 0 { vm.stall.record_inflate(result.balloon_actual, now); diff --git a/vmsilo-balloond/src/watcher.rs b/vmsilo-balloond/src/watcher.rs index 549aa6a..d2dbf40 100644 --- a/vmsilo-balloond/src/watcher.rs +++ b/vmsilo-balloond/src/watcher.rs @@ -4,7 +4,9 @@ use anyhow::{Context, Result}; use notify::{EventKind, RecommendedWatcher, Watcher as NotifyWatcher}; use tokio::sync::mpsc; -const SOCKET_FILENAME: &str = "crosvm-control.socket"; +const CROSVM_SOCKET: &str = "crosvm-control.socket"; +const CLOUD_HYPERVISOR_SOCKET: &str = "cloud-hypervisor-control.socket"; +const SOCKET_FILENAMES: &[&str] = &[CROSVM_SOCKET, CLOUD_HYPERVISOR_SOCKET]; #[derive(Debug, Clone)] pub enum WatchEvent { @@ -30,13 +32,15 @@ impl Watcher { if let Ok(entries) = std::fs::read_dir(watch_dir) { for entry in entries.flatten() { if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) { - let socket_path = entry.path().join(SOCKET_FILENAME); - if socket_path.exists() { - if let Some(name) = socket_vm_name(&socket_path) { - let _ = tx.send(WatchEvent::VmAdded { - name, - socket_path, - }); + for &socket_name in SOCKET_FILENAMES { + let socket_path = entry.path().join(socket_name); + if socket_path.exists() { + if let Some(name) = socket_vm_name(&socket_path) { + let _ = tx.send(WatchEvent::VmAdded { + name, + socket_path, + }); + } } } } @@ -73,7 +77,7 @@ impl Watcher { /// Returns None for paths that don't match this structure. fn socket_vm_name(path: &Path) -> Option { let filename = path.file_name()?.to_str()?; - if filename != SOCKET_FILENAME { + if !SOCKET_FILENAMES.contains(&filename) { return None; } // VM name is the name of the immediate parent directory @@ -123,4 +127,16 @@ mod tests { let path = Path::new("/run/vmsilo/banking/other.socket"); assert_eq!(socket_vm_name(path), None); } + + #[test] + fn cloud_hypervisor_socket_recognized() { + let path = Path::new("/run/vmsilo/banking/cloud-hypervisor-control.socket"); + assert_eq!(socket_vm_name(path), Some("banking".to_string())); + } + + #[test] + fn unknown_socket_not_recognized() { + let path = Path::new("/run/vmsilo/banking/unknown.socket"); + assert_eq!(socket_vm_name(path), None); + } } diff --git a/vmsilo-dbus-proxy/src/host/main.rs b/vmsilo-dbus-proxy/src/host/main.rs index 7bd4690..9d8ac87 100644 --- a/vmsilo-dbus-proxy/src/host/main.rs +++ b/vmsilo-dbus-proxy/src/host/main.rs @@ -11,6 +11,7 @@ use vmsilo_dbus_proxy::args::{init_logging, LogLevel}; use vmsilo_dbus_proxy::host::item::{SyntheticSniItem, TrayItemState}; use vmsilo_dbus_proxy::host::menu::SyntheticMenu; use vmsilo_dbus_proxy::host::notifications::HostNotifications; +use vmsilo_dbus_proxy::host::vsock::VsockTarget; use vmsilo_dbus_proxy::protocol::*; use vmsilo_dbus_proxy::sanitize; @@ -22,9 +23,14 @@ struct Args { #[arg(long)] vm_name: String, - /// VM CID for vsock connection - #[arg(long)] - cid: u32, + /// VM CID for kernel vsock connection (crosvm). Mutually exclusive with --vsock-socket. + #[arg(long, conflicts_with = "vsock_socket")] + cid: Option, + + /// Path to Unix vsock socket for CONNECT protocol (cloud-hypervisor). + /// Mutually exclusive with --cid. + #[arg(long, conflicts_with = "cid")] + vsock_socket: Option, /// Icon theme name to send to guest (e.g. "breeze", "breeze-dark") #[arg(long, default_value = "breeze")] @@ -42,6 +48,18 @@ struct Args { log_level: LogLevel, } +impl Args { + fn vsock_target(&self) -> anyhow::Result { + match (&self.cid, &self.vsock_socket) { + (Some(cid), None) => Ok(VsockTarget::Kernel { cid: *cid }), + (None, Some(path)) => Ok(VsockTarget::Socket { path: path.clone() }), + _ => Err(anyhow::anyhow!( + "exactly one of --cid or --vsock-socket must be provided" + )), + } + } +} + struct TrayItem { state: Arc>, /// Connection that owns this item's D-Bus names and serves the interfaces. @@ -53,9 +71,10 @@ async fn main() -> Result<()> { let args = Args::parse(); init_logging(args.log_level); + let target = args.vsock_target()?; info!( vm = args.vm_name.as_str(), - cid = args.cid, + target = ?target, protocols = args.protocols.as_str(), "Starting vmsilo-dbus-proxy-host" ); @@ -83,19 +102,42 @@ async fn run_connection(args: &Args) -> Result<()> { let tray_enabled = protocols.contains("tray"); let notifications_enabled = protocols.contains("notifications"); - let stream = - tokio_vsock::VsockStream::connect(VsockAddr::new(args.cid, TRAY_VSOCK_PORT)).await?; - info!("Connected to guest vsock"); - - let (reader, writer) = tokio::io::split(stream); - let mut reader = tokio::io::BufReader::new(reader); - let writer = Arc::new(tokio::sync::Mutex::new(tokio::io::BufWriter::new(writer))); - let tint_color = args.color.as_ref().map(|c| { vmsilo_dbus_proxy::tint::parse_hex_color(c) .unwrap_or_else(|e| panic!("invalid --color '{}': {}", c, e)) }); + let target = args.vsock_target()?; + match target { + VsockTarget::Kernel { cid } => { + let stream = + tokio_vsock::VsockStream::connect(VsockAddr::new(cid, TRAY_VSOCK_PORT)).await?; + info!("Connected to guest via kernel vsock"); + run_with_stream(stream, args, tray_enabled, notifications_enabled, tint_color).await + } + VsockTarget::Socket { ref path } => { + let stream = + vmsilo_dbus_proxy::host::vsock::connect_hybrid(path, TRAY_VSOCK_PORT).await?; + info!("Connected to guest via CONNECT protocol"); + run_with_stream(stream, args, tray_enabled, notifications_enabled, tint_color).await + } + } +} + +async fn run_with_stream( + stream: S, + args: &Args, + tray_enabled: bool, + notifications_enabled: bool, + tint_color: Option<[u8; 3]>, +) -> Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + let (reader, writer) = tokio::io::split(stream); + let mut reader = tokio::io::BufReader::new(reader); + let writer = Arc::new(tokio::sync::Mutex::new(tokio::io::BufWriter::new(writer))); + // Send Init with icon theme before entering the message loop { let mut w = writer.lock().await; @@ -177,9 +219,11 @@ async fn run_connection(args: &Args) -> Result<()> { } #[allow(clippy::too_many_arguments)] -async fn run_event_loop( - reader: &mut tokio::io::BufReader>, - writer: &Arc>>>, +async fn run_event_loop( + reader: &mut tokio::io::BufReader>, + writer: &Arc< + tokio::sync::Mutex>>, + >, event_tx: &mpsc::UnboundedSender, event_rx: &mut mpsc::UnboundedReceiver, items: &mut HashMap, @@ -189,7 +233,10 @@ async fn run_event_loop( args: &Args, tray_enabled: bool, tint_color: Option<[u8; 3]>, -) -> Result<()> { +) -> Result<()> +where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, +{ use futures_util::StreamExt; loop { diff --git a/vmsilo-dbus-proxy/src/host/mod.rs b/vmsilo-dbus-proxy/src/host/mod.rs index d3e7d3d..031606b 100644 --- a/vmsilo-dbus-proxy/src/host/mod.rs +++ b/vmsilo-dbus-proxy/src/host/mod.rs @@ -1,3 +1,4 @@ pub mod item; pub mod menu; pub mod notifications; +pub mod vsock; diff --git a/vmsilo-dbus-proxy/src/host/vsock.rs b/vmsilo-dbus-proxy/src/host/vsock.rs new file mode 100644 index 0000000..b269f35 --- /dev/null +++ b/vmsilo-dbus-proxy/src/host/vsock.rs @@ -0,0 +1,112 @@ +use std::path::PathBuf; + +use anyhow::{anyhow, Context, Result}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::UnixStream; + +/// Vsock connection target — determines how to reach the guest. +#[derive(Debug, Clone)] +pub enum VsockTarget { + /// Direct kernel vsock (crosvm). + Kernel { cid: u32 }, + /// Unix socket with CONNECT handshake (cloud-hypervisor / firecracker). + Socket { path: PathBuf }, +} + +/// Connect to a guest's vsock port using the CONNECT protocol over a Unix socket. +/// +/// Sends `CONNECT \n`, expects `OK \n`, then returns the +/// stream for bidirectional communication. +pub async fn connect_hybrid(socket_path: &std::path::Path, port: u32) -> Result { + let mut stream = UnixStream::connect(socket_path) + .await + .with_context(|| format!("connecting to {}", socket_path.display()))?; + + // Send CONNECT handshake + let connect_msg = format!("CONNECT {port}\n"); + stream + .write_all(connect_msg.as_bytes()) + .await + .context("sending CONNECT")?; + + // Read response line byte-by-byte to avoid BufReader buffering issues. + // BufReader would borrow the stream and might buffer bytes past the + // response line, losing data when dropped. + let mut buf = Vec::with_capacity(64); + loop { + let mut byte = [0u8; 1]; + stream + .read_exact(&mut byte) + .await + .context("reading CONNECT response")?; + buf.push(byte[0]); + if byte[0] == b'\n' { + break; + } + if buf.len() > 256 { + return Err(anyhow!("CONNECT response too long")); + } + } + + let response = String::from_utf8_lossy(&buf); + let trimmed = response.trim(); + if !trimmed.starts_with("OK ") { + return Err(anyhow!( + "CONNECT handshake failed: expected 'OK ', got '{trimmed}'" + )); + } + + Ok(stream) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn connect_handshake_success() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("test.sock"); + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let client = tokio::spawn({ + let path = sock_path.clone(); + async move { connect_hybrid(&path, 5001).await } + }); + + let (mut server, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 128]; + let n = server.read(&mut buf).await.unwrap(); + let request = std::str::from_utf8(&buf[..n]).unwrap(); + assert_eq!(request, "CONNECT 5001\n"); + + server.write_all(b"OK 1073741824\n").await.unwrap(); + + let stream = client.await.unwrap().unwrap(); + drop(stream); + } + + #[tokio::test] + async fn connect_handshake_failure() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("test.sock"); + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let client = tokio::spawn({ + let path = sock_path.clone(); + async move { connect_hybrid(&path, 5001).await } + }); + + let (mut server, _) = listener.accept().await.unwrap(); + let mut buf = vec![0u8; 128]; + let _n = server.read(&mut buf).await.unwrap(); + server.write_all(b"ERR Connection refused\n").await.unwrap(); + + let result = client.await.unwrap(); + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("CONNECT handshake failed")); + } +}