feat: add cloud-hypervisor support to balloond and dbus-proxy

- BalloonBackend trait abstracting over hypervisor-specific balloon control
- CrosvmBackend wrapping existing crosvm control socket protocol
- CloudHypervisorBackend using raw HTTP/1.1 over persistent Unix socket
  (GET /api/v1/vm.balloon-statistics, PUT /api/v1/vm.resize)
- Watcher recognizes both crosvm-control.socket and
  cloud-hypervisor-control.socket for auto-discovery
- dbus-proxy: CONNECT protocol support for cloud-hypervisor vsock,
  generic stream handling, --cid/--vsock-socket CLI args
- NixOS module: enable dbus-proxy for all VMs, vary args by hypervisor
This commit is contained in:
Davíð Steinn Geirsson 2026-03-22 11:19:57 +00:00
parent 4f97ccb28c
commit 6d7b606d02
10 changed files with 589 additions and 69 deletions

View file

@ -670,7 +670,7 @@ in
) (lib.attrValues cfg.nixosVms)
++ [ (lib.nameValuePair "vmsilo-balloond" mkVmsiloBalloondService) ]
++
# D-Bus proxy services (crosvm VMs only — kernel vsock required)
# D-Bus proxy services
lib.concatMap (
vm:
let
@ -679,7 +679,7 @@ in
lib.optional vm.dbus.tray "tray" ++ lib.optional vm.dbus.notifications "notifications"
);
in
lib.optional (dbusEnabled && vm.hypervisor == "crosvm") (
lib.optional dbusEnabled (
lib.nameValuePair "vmsilo-${vm.name}-dbus-proxy" {
description = "D-Bus proxy for VM ${vm.name}";
wantedBy = [ "vmsilo-${vm.name}-vm.service" ];
@ -702,7 +702,12 @@ in
esac
exec ${dbusProxyHost} \
--vm-name ${lib.escapeShellArg vm.name} \
--cid ${toString vm.id} \
${
if vm.hypervisor == "crosvm" then
"--cid ${toString vm.id}"
else
"--vsock-socket /run/vmsilo/${vm.name}/vsock.socket"
} \
--icon-theme "$icon_theme" \
--color ${lib.escapeShellArg vmColor} \
--protocols ${lib.escapeShellArg protocols} \

View file

@ -0,0 +1,27 @@
use anyhow::Result;
use serde::Deserialize;
/// Balloon statistics reported by the guest via virtio-balloon.
#[derive(Deserialize, Debug, Clone)]
pub struct BalloonStats {
pub free_memory: Option<u64>,
pub available_memory: Option<u64>,
pub disk_caches: Option<u64>,
pub total_memory: Option<u64>,
}
/// Stats result from a balloon query.
#[derive(Debug)]
pub struct StatsResult {
pub stats: BalloonStats,
pub balloon_actual: u64,
}
/// Abstraction over hypervisor-specific balloon control.
pub trait BalloonBackend: Send + Sync {
/// Query balloon stats from the VM. Blocking — call from spawn_blocking.
fn balloon_stats(&self) -> Result<StatsResult>;
/// Set absolute balloon size in bytes. Blocking — call from spawn_blocking.
fn balloon_adjust(&self, num_bytes: u64) -> Result<()>;
}

View file

@ -0,0 +1,296 @@
use std::io::{Read, Write};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
use std::sync::Mutex;
use std::time::Duration;
use anyhow::{anyhow, Context, Result};
use serde::Deserialize;
use crate::backend::{BalloonBackend, BalloonStats, StatsResult};
/// Cloud-hypervisor balloon statistics JSON response.
#[derive(Deserialize, Debug)]
struct ChBalloonStatistics {
free_memory: Option<u64>,
available_memory: Option<u64>,
disk_caches: Option<u64>,
total_memory: Option<u64>,
balloon_actual_bytes: u64,
}
/// Read timeout for the Unix socket connection.
const READ_TIMEOUT: Duration = Duration::from_secs(5);
/// Find the position of `\r\n\r\n` in a byte slice.
fn find_header_end(data: &[u8]) -> Option<usize> {
data.windows(4)
.position(|w| w == b"\r\n\r\n")
}
/// Parse an HTTP status line like "HTTP/1.1 200 OK\r\n" and return the status code.
fn parse_status_line(line: &str) -> Result<u16> {
let parts: Vec<&str> = line.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(anyhow!("malformed status line: {}", line));
}
parts[1]
.parse::<u16>()
.with_context(|| format!("invalid status code in: {}", line))
}
/// Read HTTP headers from a byte slice (everything before `\r\n\r\n`).
/// Returns the Content-Length value if present.
fn read_headers(header_bytes: &[u8]) -> Result<Option<usize>> {
let header_str = std::str::from_utf8(header_bytes)
.context("headers are not valid UTF-8")?;
let mut content_length = None;
for line in header_str.split("\r\n") {
let lower = line.to_ascii_lowercase();
if lower.starts_with("content-length:") {
let val = line["content-length:".len()..].trim();
content_length = Some(
val.parse::<usize>()
.with_context(|| format!("invalid Content-Length: {}", val))?,
);
}
}
Ok(content_length)
}
/// Parse a complete HTTP response into a `StatsResult`.
fn parse_stats_response(data: &[u8]) -> Result<StatsResult> {
let header_end = find_header_end(data)
.ok_or_else(|| anyhow!("no header terminator found"))?;
let header_bytes = &data[..header_end];
let body_start = header_end + 4;
// Parse status line
let header_str = std::str::from_utf8(header_bytes)
.context("headers are not valid UTF-8")?;
let status_line = header_str
.split("\r\n")
.next()
.ok_or_else(|| anyhow!("empty headers"))?;
let status = parse_status_line(status_line)?;
if status < 200 || status >= 300 {
return Err(anyhow!("HTTP {}", status));
}
let body = &data[body_start..];
let stats: ChBalloonStatistics =
serde_json::from_slice(body).context("parsing balloon statistics JSON")?;
Ok(StatsResult {
stats: BalloonStats {
free_memory: stats.free_memory,
available_memory: stats.available_memory,
disk_caches: stats.disk_caches,
total_memory: stats.total_memory,
},
balloon_actual: stats.balloon_actual_bytes,
})
}
/// Format a GET request for balloon statistics.
fn format_stats_request() -> &'static str {
"GET /api/v1/vm.balloon-statistics HTTP/1.1\r\nHost: localhost\r\n\r\n"
}
/// Format a PUT request to resize the balloon.
fn format_resize_request(desired_balloon: u64) -> String {
let body = format!(r#"{{"desired_balloon":{}}}"#, desired_balloon);
format!(
"PUT /api/v1/vm.resize HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body,
)
}
/// Parse a resize HTTP response — succeeds if status is 200 or 204.
fn parse_resize_response(data: &[u8]) -> Result<()> {
// Find at least the status line
let header_str = std::str::from_utf8(data)
.context("response is not valid UTF-8")?;
let status_line = header_str
.split("\r\n")
.next()
.ok_or_else(|| anyhow!("empty response"))?;
let status = parse_status_line(status_line)?;
if status == 200 || status == 204 {
Ok(())
} else {
Err(anyhow!("resize failed: HTTP {}", status))
}
}
/// Balloon backend for cloud-hypervisor — talks to the CH HTTP API over a Unix socket.
pub struct CloudHypervisorBackend {
api_socket_path: PathBuf,
conn: Mutex<Option<UnixStream>>,
}
impl CloudHypervisorBackend {
pub fn new(api_socket_path: PathBuf) -> Self {
Self {
api_socket_path,
conn: Mutex::new(None),
}
}
/// Get or create a persistent connection, with a read timeout.
fn get_conn(&self) -> Result<std::sync::MutexGuard<'_, Option<UnixStream>>> {
let mut guard = self.conn.lock().map_err(|e| anyhow!("mutex poisoned: {}", e))?;
if guard.is_none() {
let stream = UnixStream::connect(&self.api_socket_path)
.with_context(|| format!("connecting to {:?}", self.api_socket_path))?;
stream
.set_read_timeout(Some(READ_TIMEOUT))
.context("set_read_timeout")?;
*guard = Some(stream);
}
Ok(guard)
}
/// Send a request and read the full HTTP response. Uses Content-Length to
/// determine when the response is complete on a persistent connection.
fn try_request(&self, request: &str) -> Result<Vec<u8>> {
let mut guard = self.get_conn()?;
let stream = guard.as_mut().ok_or_else(|| anyhow!("no connection"))?;
stream
.write_all(request.as_bytes())
.context("writing request")?;
let mut buf = Vec::with_capacity(4096);
let mut tmp = [0u8; 4096];
// Read until we have the full headers
loop {
let n = stream.read(&mut tmp).context("reading response")?;
if n == 0 {
return Err(anyhow!("connection closed"));
}
buf.extend_from_slice(&tmp[..n]);
if find_header_end(&buf).is_some() {
break;
}
}
let header_end = find_header_end(&buf).unwrap();
let header_bytes = &buf[..header_end];
let content_length = read_headers(header_bytes)?;
let body_start = header_end + 4;
if let Some(cl) = content_length {
let total_needed = body_start + cl;
while buf.len() < total_needed {
let n = stream.read(&mut tmp).context("reading response body")?;
if n == 0 {
return Err(anyhow!("connection closed during body read"));
}
buf.extend_from_slice(&tmp[..n]);
}
// Trim to exact response size (in case of pipelining)
buf.truncate(total_needed);
}
// If no Content-Length (e.g. 204 No Content), what we have is enough.
Ok(buf)
}
/// Send a request, retrying once on failure by reconnecting.
fn request(&self, request: &str) -> Result<Vec<u8>> {
match self.try_request(request) {
Ok(data) => Ok(data),
Err(first_err) => {
tracing::debug!(error = %first_err, "request failed, reconnecting");
// Drop the old connection
if let Ok(mut guard) = self.conn.lock() {
*guard = None;
}
// Retry once
self.try_request(request)
.with_context(|| format!("retry after: {}", first_err))
}
}
}
}
impl BalloonBackend for CloudHypervisorBackend {
fn balloon_stats(&self) -> Result<StatsResult> {
let request = format_stats_request();
let response = self.request(request)?;
parse_stats_response(&response)
}
fn balloon_adjust(&self, num_bytes: u64) -> Result<()> {
let request = format_resize_request(num_bytes);
let response = self.request(&request)?;
parse_resize_response(&response)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_stats_response_full() {
let body = r#"{"free_memory":1073741824,"available_memory":2147483648,"disk_caches":536870912,"total_memory":4294967296,"balloon_actual_bytes":524288000,"balloon_target_bytes":524288000}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let result = parse_stats_response(response.as_bytes()).unwrap();
assert_eq!(result.stats.free_memory, Some(1073741824));
assert_eq!(result.stats.available_memory, Some(2147483648));
assert_eq!(result.stats.disk_caches, Some(536870912));
assert_eq!(result.stats.total_memory, Some(4294967296));
assert_eq!(result.balloon_actual, 524288000);
}
#[test]
fn parse_stats_response_nullable_fields() {
let body = r#"{"balloon_actual_bytes":0}"#;
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
let result = parse_stats_response(response.as_bytes()).unwrap();
assert_eq!(result.stats.free_memory, None);
assert_eq!(result.balloon_actual, 0);
}
#[test]
fn parse_stats_response_404() {
let response = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
assert!(parse_stats_response(&response[..]).is_err());
}
#[test]
fn format_resize_request_test() {
let req = format_resize_request(1048576);
assert!(req.starts_with("PUT /api/v1/vm.resize HTTP/1.1\r\n"));
assert!(req.contains("Content-Type: application/json\r\n"));
assert!(req.contains(r#""desired_balloon":1048576"#));
}
#[test]
fn parse_resize_response_204() {
let response = b"HTTP/1.1 204 No Content\r\n\r\n";
assert!(parse_resize_response(&response[..]).is_ok());
}
#[test]
fn parse_resize_response_404() {
let response = b"HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\n\r\n";
assert!(parse_resize_response(&response[..]).is_err());
}
}

View file

@ -1,5 +1,5 @@
use std::os::fd::{AsRawFd, OwnedFd};
use std::path::Path;
use std::path::{Path, PathBuf};
use std::time::Duration;
use anyhow::{anyhow, Context, Result};
@ -8,6 +8,8 @@ use nix::sys::socket::{
};
use serde::{Deserialize, Serialize};
use crate::backend::{BalloonBackend, BalloonStats, StatsResult};
/// Request sent to crosvm control socket.
#[derive(Serialize, Debug)]
pub enum VmRequest {
@ -35,25 +37,9 @@ pub enum VmResponse {
},
}
/// Balloon statistics reported by the guest via virtio-balloon.
#[derive(Deserialize, Debug, Clone)]
pub struct BalloonStats {
pub free_memory: Option<u64>,
pub available_memory: Option<u64>,
pub disk_caches: Option<u64>,
pub total_memory: Option<u64>,
}
/// Socket receive timeout.
const RECV_TIMEOUT: Duration = Duration::from_secs(1);
/// Stats result from crosvm balloon query.
#[derive(Debug)]
pub struct StatsResult {
pub stats: BalloonStats,
pub balloon_actual: u64,
}
/// Query balloon stats from a VM's crosvm control socket.
/// Blocking call -- use with tokio::task::spawn_blocking.
pub fn balloon_stats(socket_path: &Path) -> Result<StatsResult> {
@ -132,6 +118,27 @@ fn recv_response(fd: &OwnedFd) -> Result<VmResponse> {
Ok(response)
}
/// Balloon backend for crosvm — talks to crosvm's control socket.
pub struct CrosvmBackend {
socket_path: PathBuf,
}
impl CrosvmBackend {
pub fn new(socket_path: PathBuf) -> Self {
Self { socket_path }
}
}
impl BalloonBackend for CrosvmBackend {
fn balloon_stats(&self) -> Result<StatsResult> {
balloon_stats(&self.socket_path)
}
fn balloon_adjust(&self, num_bytes: u64) -> Result<()> {
balloon_adjust(&self.socket_path, num_bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -1,4 +1,6 @@
pub mod args;
pub mod backend;
pub mod cloud_hypervisor;
pub mod crosvm;
pub mod host;
pub mod policy;

View file

@ -1,12 +1,14 @@
use std::collections::BTreeMap;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Instant;
use anyhow::Result;
use clap::Parser;
use tokio::signal;
use tokio::signal::unix::{signal as unix_signal, SignalKind};
use vmsilo_balloond::crosvm;
use vmsilo_balloond::backend::BalloonBackend;
use vmsilo_balloond::cloud_hypervisor::CloudHypervisorBackend;
use vmsilo_balloond::crosvm::CrosvmBackend;
use vmsilo_balloond::host::read_host_memory;
use vmsilo_balloond::policy::{BalloonPolicy, GuestStats};
use vmsilo_balloond::psi;
@ -16,7 +18,7 @@ use vmsilo_balloond::{init_logging, Args};
struct VmState {
name: String,
socket_path: PathBuf,
backend: Arc<dyn BalloonBackend>,
policy: BalloonPolicy,
stall: StallDetector,
consecutive_failures: u32,
@ -96,10 +98,25 @@ async fn async_main(args: Args) -> Result<(), Box<dyn std::error::Error>> {
event = rx.recv() => {
match event {
Some(WatchEvent::VmAdded { name, socket_path }) => {
tracing::info!(vm = %name, "VM discovered");
tracing::info!(vm = %name, path = ?socket_path, "VM discovered");
let backend: Arc<dyn BalloonBackend> = match socket_path
.file_name()
.and_then(|f| f.to_str())
{
Some("crosvm-control.socket") => {
Arc::new(CrosvmBackend::new(socket_path))
}
Some("cloud-hypervisor-control.socket") => {
Arc::new(CloudHypervisorBackend::new(socket_path))
}
_ => {
tracing::warn!(vm = %name, path = ?socket_path, "unknown socket type");
continue;
}
};
vms.insert(name.clone(), VmState {
name,
socket_path,
backend,
policy: BalloonPolicy::new(critical_host_available, args.guest_available_bias, args.critical_guest_available),
stall: StallDetector::new(),
consecutive_failures: 0,
@ -165,12 +182,11 @@ async fn poll_all_vms(
}
async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result<i64> {
let socket_path = vm.socket_path.clone();
let result = tokio::task::spawn_blocking(move || crosvm::balloon_stats(&socket_path)).await??;
let backend = vm.backend.clone();
let result = tokio::task::spawn_blocking(move || backend.balloon_stats()).await??;
let now = Instant::now();
// Check for stall from previous inflate
if vm.stall.is_stalled(result.balloon_actual, now) {
let new_size = result.balloon_actual.saturating_sub(STALL_DEFLATE_BYTES);
tracing::warn!(
@ -180,27 +196,19 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result<i64> {
cooldown_secs = STALL_COOLDOWN.as_secs(),
"balloon stall detected, deflating"
);
let path = vm.socket_path.clone();
tokio::task::spawn_blocking(move || crosvm::balloon_adjust(&path, new_size)).await??;
let backend = vm.backend.clone();
tokio::task::spawn_blocking(move || backend.balloon_adjust(new_size)).await??;
vm.stall.clear();
vm.stall.enter_cooldown(now);
return Ok(-(STALL_DEFLATE_BYTES as i64));
}
// Required fields must be present
let free = result
.stats
.free_memory
let free = result.stats.free_memory
.ok_or_else(|| anyhow::anyhow!("free_memory not available"))?;
let cached = result
.stats
.disk_caches
let cached = result.stats.disk_caches
.ok_or_else(|| anyhow::anyhow!("disk_caches not available"))?;
let total = result
.stats
.total_memory
let total = result.stats.total_memory
.ok_or_else(|| anyhow::anyhow!("total_memory not available"))?;
let available = result.stats.available_memory;
let guest_stats = GuestStats {
@ -213,7 +221,6 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result<i64> {
let delta = vm.policy.compute_balloon_delta(&guest_stats, host_available);
// Suppress inflation during stall cooldown (deflation still allowed)
if delta > 0 && vm.stall.is_in_cooldown(now) {
return Ok(0);
}
@ -228,8 +235,8 @@ async fn poll_single_vm(vm: &mut VmState, host_available: u64) -> Result<i64> {
guest_available_mb = available.unwrap_or(0) / (1024 * 1024),
"adjusting balloon"
);
let path = vm.socket_path.clone();
tokio::task::spawn_blocking(move || crosvm::balloon_adjust(&path, new_size)).await??;
let backend = vm.backend.clone();
tokio::task::spawn_blocking(move || backend.balloon_adjust(new_size)).await??;
if delta > 0 {
vm.stall.record_inflate(result.balloon_actual, now);

View file

@ -4,7 +4,9 @@ use anyhow::{Context, Result};
use notify::{EventKind, RecommendedWatcher, Watcher as NotifyWatcher};
use tokio::sync::mpsc;
const SOCKET_FILENAME: &str = "crosvm-control.socket";
const CROSVM_SOCKET: &str = "crosvm-control.socket";
const CLOUD_HYPERVISOR_SOCKET: &str = "cloud-hypervisor-control.socket";
const SOCKET_FILENAMES: &[&str] = &[CROSVM_SOCKET, CLOUD_HYPERVISOR_SOCKET];
#[derive(Debug, Clone)]
pub enum WatchEvent {
@ -30,13 +32,15 @@ impl Watcher {
if let Ok(entries) = std::fs::read_dir(watch_dir) {
for entry in entries.flatten() {
if entry.file_type().map(|t| t.is_dir()).unwrap_or(false) {
let socket_path = entry.path().join(SOCKET_FILENAME);
if socket_path.exists() {
if let Some(name) = socket_vm_name(&socket_path) {
let _ = tx.send(WatchEvent::VmAdded {
name,
socket_path,
});
for &socket_name in SOCKET_FILENAMES {
let socket_path = entry.path().join(socket_name);
if socket_path.exists() {
if let Some(name) = socket_vm_name(&socket_path) {
let _ = tx.send(WatchEvent::VmAdded {
name,
socket_path,
});
}
}
}
}
@ -73,7 +77,7 @@ impl Watcher {
/// Returns None for paths that don't match this structure.
fn socket_vm_name(path: &Path) -> Option<String> {
let filename = path.file_name()?.to_str()?;
if filename != SOCKET_FILENAME {
if !SOCKET_FILENAMES.contains(&filename) {
return None;
}
// VM name is the name of the immediate parent directory
@ -123,4 +127,16 @@ mod tests {
let path = Path::new("/run/vmsilo/banking/other.socket");
assert_eq!(socket_vm_name(path), None);
}
#[test]
fn cloud_hypervisor_socket_recognized() {
let path = Path::new("/run/vmsilo/banking/cloud-hypervisor-control.socket");
assert_eq!(socket_vm_name(path), Some("banking".to_string()));
}
#[test]
fn unknown_socket_not_recognized() {
let path = Path::new("/run/vmsilo/banking/unknown.socket");
assert_eq!(socket_vm_name(path), None);
}
}

View file

@ -11,6 +11,7 @@ use vmsilo_dbus_proxy::args::{init_logging, LogLevel};
use vmsilo_dbus_proxy::host::item::{SyntheticSniItem, TrayItemState};
use vmsilo_dbus_proxy::host::menu::SyntheticMenu;
use vmsilo_dbus_proxy::host::notifications::HostNotifications;
use vmsilo_dbus_proxy::host::vsock::VsockTarget;
use vmsilo_dbus_proxy::protocol::*;
use vmsilo_dbus_proxy::sanitize;
@ -22,9 +23,14 @@ struct Args {
#[arg(long)]
vm_name: String,
/// VM CID for vsock connection
#[arg(long)]
cid: u32,
/// VM CID for kernel vsock connection (crosvm). Mutually exclusive with --vsock-socket.
#[arg(long, conflicts_with = "vsock_socket")]
cid: Option<u32>,
/// Path to Unix vsock socket for CONNECT protocol (cloud-hypervisor).
/// Mutually exclusive with --cid.
#[arg(long, conflicts_with = "cid")]
vsock_socket: Option<std::path::PathBuf>,
/// Icon theme name to send to guest (e.g. "breeze", "breeze-dark")
#[arg(long, default_value = "breeze")]
@ -42,6 +48,18 @@ struct Args {
log_level: LogLevel,
}
impl Args {
fn vsock_target(&self) -> anyhow::Result<VsockTarget> {
match (&self.cid, &self.vsock_socket) {
(Some(cid), None) => Ok(VsockTarget::Kernel { cid: *cid }),
(None, Some(path)) => Ok(VsockTarget::Socket { path: path.clone() }),
_ => Err(anyhow::anyhow!(
"exactly one of --cid or --vsock-socket must be provided"
)),
}
}
}
struct TrayItem {
state: Arc<RwLock<TrayItemState>>,
/// Connection that owns this item's D-Bus names and serves the interfaces.
@ -53,9 +71,10 @@ async fn main() -> Result<()> {
let args = Args::parse();
init_logging(args.log_level);
let target = args.vsock_target()?;
info!(
vm = args.vm_name.as_str(),
cid = args.cid,
target = ?target,
protocols = args.protocols.as_str(),
"Starting vmsilo-dbus-proxy-host"
);
@ -83,19 +102,42 @@ async fn run_connection(args: &Args) -> Result<()> {
let tray_enabled = protocols.contains("tray");
let notifications_enabled = protocols.contains("notifications");
let stream =
tokio_vsock::VsockStream::connect(VsockAddr::new(args.cid, TRAY_VSOCK_PORT)).await?;
info!("Connected to guest vsock");
let (reader, writer) = tokio::io::split(stream);
let mut reader = tokio::io::BufReader::new(reader);
let writer = Arc::new(tokio::sync::Mutex::new(tokio::io::BufWriter::new(writer)));
let tint_color = args.color.as_ref().map(|c| {
vmsilo_dbus_proxy::tint::parse_hex_color(c)
.unwrap_or_else(|e| panic!("invalid --color '{}': {}", c, e))
});
let target = args.vsock_target()?;
match target {
VsockTarget::Kernel { cid } => {
let stream =
tokio_vsock::VsockStream::connect(VsockAddr::new(cid, TRAY_VSOCK_PORT)).await?;
info!("Connected to guest via kernel vsock");
run_with_stream(stream, args, tray_enabled, notifications_enabled, tint_color).await
}
VsockTarget::Socket { ref path } => {
let stream =
vmsilo_dbus_proxy::host::vsock::connect_hybrid(path, TRAY_VSOCK_PORT).await?;
info!("Connected to guest via CONNECT protocol");
run_with_stream(stream, args, tray_enabled, notifications_enabled, tint_color).await
}
}
}
async fn run_with_stream<S>(
stream: S,
args: &Args,
tray_enabled: bool,
notifications_enabled: bool,
tint_color: Option<[u8; 3]>,
) -> Result<()>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
{
let (reader, writer) = tokio::io::split(stream);
let mut reader = tokio::io::BufReader::new(reader);
let writer = Arc::new(tokio::sync::Mutex::new(tokio::io::BufWriter::new(writer)));
// Send Init with icon theme before entering the message loop
{
let mut w = writer.lock().await;
@ -177,9 +219,11 @@ async fn run_connection(args: &Args) -> Result<()> {
}
#[allow(clippy::too_many_arguments)]
async fn run_event_loop(
reader: &mut tokio::io::BufReader<tokio::io::ReadHalf<tokio_vsock::VsockStream>>,
writer: &Arc<tokio::sync::Mutex<tokio::io::BufWriter<tokio::io::WriteHalf<tokio_vsock::VsockStream>>>>,
async fn run_event_loop<S>(
reader: &mut tokio::io::BufReader<tokio::io::ReadHalf<S>>,
writer: &Arc<
tokio::sync::Mutex<tokio::io::BufWriter<tokio::io::WriteHalf<S>>>,
>,
event_tx: &mpsc::UnboundedSender<HostToGuest>,
event_rx: &mut mpsc::UnboundedReceiver<HostToGuest>,
items: &mut HashMap<String, TrayItem>,
@ -189,7 +233,10 @@ async fn run_event_loop(
args: &Args,
tray_enabled: bool,
tint_color: Option<[u8; 3]>,
) -> Result<()> {
) -> Result<()>
where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
use futures_util::StreamExt;
loop {

View file

@ -1,3 +1,4 @@
pub mod item;
pub mod menu;
pub mod notifications;
pub mod vsock;

View file

@ -0,0 +1,112 @@
use std::path::PathBuf;
use anyhow::{anyhow, Context, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
/// Vsock connection target — determines how to reach the guest.
#[derive(Debug, Clone)]
pub enum VsockTarget {
/// Direct kernel vsock (crosvm).
Kernel { cid: u32 },
/// Unix socket with CONNECT handshake (cloud-hypervisor / firecracker).
Socket { path: PathBuf },
}
/// Connect to a guest's vsock port using the CONNECT protocol over a Unix socket.
///
/// Sends `CONNECT <port>\n`, expects `OK <assigned_port>\n`, then returns the
/// stream for bidirectional communication.
pub async fn connect_hybrid(socket_path: &std::path::Path, port: u32) -> Result<UnixStream> {
let mut stream = UnixStream::connect(socket_path)
.await
.with_context(|| format!("connecting to {}", socket_path.display()))?;
// Send CONNECT handshake
let connect_msg = format!("CONNECT {port}\n");
stream
.write_all(connect_msg.as_bytes())
.await
.context("sending CONNECT")?;
// Read response line byte-by-byte to avoid BufReader buffering issues.
// BufReader would borrow the stream and might buffer bytes past the
// response line, losing data when dropped.
let mut buf = Vec::with_capacity(64);
loop {
let mut byte = [0u8; 1];
stream
.read_exact(&mut byte)
.await
.context("reading CONNECT response")?;
buf.push(byte[0]);
if byte[0] == b'\n' {
break;
}
if buf.len() > 256 {
return Err(anyhow!("CONNECT response too long"));
}
}
let response = String::from_utf8_lossy(&buf);
let trimmed = response.trim();
if !trimmed.starts_with("OK ") {
return Err(anyhow!(
"CONNECT handshake failed: expected 'OK <port>', got '{trimmed}'"
));
}
Ok(stream)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn connect_handshake_success() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("test.sock");
let listener = tokio::net::UnixListener::bind(&sock_path).unwrap();
let client = tokio::spawn({
let path = sock_path.clone();
async move { connect_hybrid(&path, 5001).await }
});
let (mut server, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 128];
let n = server.read(&mut buf).await.unwrap();
let request = std::str::from_utf8(&buf[..n]).unwrap();
assert_eq!(request, "CONNECT 5001\n");
server.write_all(b"OK 1073741824\n").await.unwrap();
let stream = client.await.unwrap().unwrap();
drop(stream);
}
#[tokio::test]
async fn connect_handshake_failure() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("test.sock");
let listener = tokio::net::UnixListener::bind(&sock_path).unwrap();
let client = tokio::spawn({
let path = sock_path.clone();
async move { connect_hybrid(&path, 5001).await }
});
let (mut server, _) = listener.accept().await.unwrap();
let mut buf = vec![0u8; 128];
let _n = server.read(&mut buf).await.unwrap();
server.write_all(b"ERR Connection refused\n").await.unwrap();
let result = client.await.unwrap();
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("CONNECT handshake failed"));
}
}