feat(vm-switch): add process isolation with namespace sandbox and seccomp

Replace the thread-based vhost-user backend architecture with a
fork-based process model where each VM gets its own child process.
This enables strong isolation between VMs handling untrusted network
traffic, with multiple layers of defense in depth.

Process model:
- Main process watches config directory and orchestrates child lifecycle
- One child process forked per VM, running as vhost-user net backend
- Children communicate via SOCK_SEQPACKET control channel with SCM_RIGHTS
- Automatic child restart on crash/disconnect, with peer notification
- Ping/pong heartbeat monitoring for worker health (1s interval, 100ms timeout)
- SIGCHLD handling integrated into tokio event loop

Inter-process packet forwarding:
- Lock-free SPSC ring buffers in shared memory (memfd + mmap)
- 64-slot rings (~598KB each) with atomic head/tail, no locks in datapath
- Eventfd signaling for empty-to-non-empty transitions
- Main orchestrates buffer exchange: GetBuffer -> BufferReady -> PutBuffer
- Zero-copy path: producers write directly into consumer's shared memory

Namespace sandbox (applied before tokio, single-threaded):
- User namespace: unprivileged outside, UID 0 inside
- PID namespace: main is PID 1, children invisible to host
- Mount namespace: minimal tmpfs root with /config, /dev, /proc, /tmp
- IPC namespace: isolated System V IPC
- Network namespace: empty, communication only via inherited FDs
- Controllable via --no-sandbox flag

Seccomp BPF filtering (two-tier whitelist):
- Main filter: allows fork, socket creation, inotify, openat
- Child filter: strict subset - no fork, no socket, no file open
- Child filter applied after vhost setup, before event loop
- Modes: kill (default), trap (SIGSYS debug), log, disabled

Also adds vm-switch service dependencies to VM units in the NixOS
module so VMs wait for their network switch before starting.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Davíð Steinn Geirsson 2026-02-09 20:19:26 +00:00
parent 6722c0fbb4
commit 6941d2fe4c
29 changed files with 6275 additions and 2041 deletions

View file

@ -138,6 +138,14 @@ Provides L2 switching for VM-to-VM networks:
- **Purpose:** Handles vhost-user protocol for VM network interfaces
- **Systemd:** One service per vmNetwork (`vm-switch-<netname>.service`)
**CLI flags:**
```
-d, --config-dir <PATH> Config/MAC file directory (default: /run/vm-switch)
--log-level <LEVEL> error, warn, info, debug, trace (default: warn)
--no-sandbox Disable namespace sandboxing
--seccomp-mode <MODE> kill (default), trap, log, disabled
```
**Testing locally:**
```bash
# Build and run manually
@ -149,6 +157,67 @@ mkdir -p /tmp/test-switch/router
echo "52:00:00:00:00:01" > /tmp/test-switch/router/router.mac
```
**Process model:** Main process forks one child per VM. Children are vhost-user net backends that handle virtio TX/RX for their VM. Main orchestrates lifecycle, config watching, and buffer exchange between children. Children exit when the vhost-user client (crosvm) disconnects; main automatically restarts them so crosvm can reconnect.
**Startup sequence:**
1. Parse args, apply namespace sandbox (single-threaded, before tokio)
2. Apply main seccomp filter
3. Start tokio runtime, create ConfigWatcher + BackendManager
4. Start tokio runtime, enter async event loop (SIGCHLD via tokio select branch)
**Key source files:**
- `src/main.rs` - Entry point, sandbox/seccomp setup, async event loop
- `src/manager.rs` - BackendManager: fork children, buffer exchange, crash cleanup
- `src/child/process.rs` - Child entry point: control channel, vhost daemon, child seccomp
- `src/child/forwarder.rs` - PacketForwarder: L2 routing via ring buffers
- `src/child/vhost.rs` - ChildVhostBackend: virtio TX/RX callbacks
- `src/child/poll.rs` - Event polling for control channel + ingress buffers
- `src/control.rs` - Main-child IPC over Unix seqpacket sockets + SCM_RIGHTS
- `src/ring.rs` - Lock-free SPSC ring buffer in shared memory (memfd)
- `src/sandbox.rs` - Namespace isolation (user, PID, mount, IPC, network)
- `src/seccomp.rs` - BPF syscall filters (main and child whitelists)
- `src/frame.rs` - Ethernet frame parsing, MAC validation
- `src/main.rs` - SIGCHLD handling via tokio select branch
**Control protocol** (main <-> child IPC via `SOCK_SEQPACKET` + `SCM_RIGHTS`):
| Direction | Message | FDs | Purpose |
|-----------|---------|-----|---------|
| Main -> Child | `GetBuffer { peer_name, peer_mac }` | - | Ask child to create ingress buffer for a peer |
| Child -> Main | `BufferReady { peer_name }` | memfd, eventfd | Ingress buffer created, here are the FDs |
| Main -> Child | `PutBuffer { peer_name, peer_mac, broadcast }` | memfd, eventfd | Give child a peer's buffer as egress target |
| Main -> Child | `RemovePeer { peer_name }` | - | Clean up buffers for disconnected/crashed peer |
| Main -> Child | `Ping` | - | Heartbeat request (sent every 1s) |
| Child -> Main | `Ready` | - | Child initialized and ready |
| Child -> Main | `Pong` | - | Heartbeat response (must arrive within 100ms) |
Messages serialized with `postcard`. FDs passed via ancillary data.
**Buffer exchange flow:**
1. Main sends `GetBuffer` to Child1 ("create ingress buffer for Child2")
2. Child1 creates SPSC ring buffer (memfd + eventfd), becomes Consumer, replies `BufferReady`
3. Main forwards those FDs to Child2 via `PutBuffer` -- Child2 becomes Producer
4. Packets now flow: Child2 writes to Producer -> shared memfd -> Child1 reads from Consumer
**SPSC ring buffer** (`ring.rs`): Lock-free single-producer/single-consumer queue backed by `memfd_create()` + `mmap(MAP_SHARED)`. 64 slots, ~598KB total. Head/tail use atomic operations (no locks in datapath). Eventfd signals empty-to-non-empty transitions.
**Sandbox** (applied before tokio, requires single-threaded):
1. **User namespace** - Maps real UID to 0 inside, enables unprivileged namespace creation
2. **PID namespace** - Fork into new PID ns; main becomes PID 1
3. **Mount namespace** - Minimal tmpfs root with `/config` (bind-mount of config dir), `/dev` (null, zero, urandom), `/proc`, `/tmp`. Pivot root, unmount old.
4. **IPC namespace** - Isolates System V IPC
5. **Network namespace** - Empty (no interfaces). Communication only via inherited FDs.
**Seccomp filtering** (BPF syscall whitelist):
- `--seccomp-mode=kill` (default): Terminate on blocked syscall
- `--seccomp-mode=trap`: Send SIGSYS (debug with strace)
- `--seccomp-mode=log`: Log violations but allow
- `--seccomp-mode=disabled`: Skip filtering
Two filter tiers (child is a strict subset of main):
- **Main**: Allows fork, socket creation, inotify, openat (config watching + child management)
- **Child**: No fork, no socket creation, no file open. Applied after vhost setup completes. Allows clone3 for vhost-user threads.
### Dependencies
- Custom crosvm fork: `git.dsg.is/davidlowsec/crosvm.git`

View file

@ -499,3 +499,53 @@ The host provides:
- Polkit rules for the configured user to manage VM services without sudo
- CLI tools: `vm-run`, `vm-start`, `vm-stop`, `vm-start-debug`, `vm-shell`
- Desktop integration with .desktop files for guest applications
### vm-switch
The `vm-switch` daemon (`vm-switch/` Rust crate) provides L2 switching for VM-to-VM networks. One instance runs per `vmNetwork`, managed by systemd (`vm-switch-<netname>.service`).
**Process model:** The main process watches a config directory for MAC files and forks one child process per VM. Each child is a vhost-user net backend serving a single VM's network interface.
```
Main Process
(config watch, orchestration)
/ | \
fork / fork | fork \
v v v
Child: router Child: banking Child: shopping
(vhost-user) (vhost-user) (vhost-user)
| | |
[unix socket] [unix socket] [unix socket]
| | |
crosvm crosvm crosvm
(router VM) (banking VM) (shopping VM)
```
**Packet forwarding** uses lock-free SPSC ring buffers in shared memory (`memfd_create` + `mmap`). When a VM transmits a frame, its child process validates the source MAC address and routes the frame to the correct destination:
- Unicast: pushed into the destination child's ingress ring buffer
- Broadcast/multicast: pushed into all peers' ingress buffers
Ring buffers use atomic head/tail pointers (no locks in the datapath) with eventfd signaling for empty-to-non-empty transitions.
**Buffer exchange protocol:** The main process orchestrates buffer setup between children via a control channel (`SOCK_SEQPACKET` + `SCM_RIGHTS` for passing memfd/eventfd file descriptors):
1. Main tells Child A: "create an ingress buffer for Child B" (`GetBuffer`)
2. Child A creates the ring buffer and returns the FDs (`BufferReady`)
3. Main forwards those FDs to Child B as an egress target (`PutBuffer`)
4. Child B can now write frames directly into Child A's memory -- no copies through the main process
**Sandboxing:** The daemon runs in a multi-layer sandbox applied at startup (before any async runtime or threads):
| Layer | Mechanism | Effect |
|-------|-----------|--------|
| User namespace | `CLONE_NEWUSER` | Unprivileged outside, appears as UID 0 inside |
| PID namespace | `CLONE_NEWPID` | Main is PID 1; children invisible to host |
| Mount namespace | `CLONE_NEWNS` + pivot_root | Minimal tmpfs root: `/config`, `/dev` (null/zero/urandom), `/proc`, `/tmp` |
| IPC namespace | `CLONE_NEWIPC` | Isolated System V IPC |
| Network namespace | `CLONE_NEWNET` | No interfaces; communication only via inherited FDs |
| Seccomp (main) | BPF whitelist | Allows fork, socket creation, inotify for config watching |
| Seccomp (child) | Tighter BPF whitelist | No fork, no socket creation, no file open; applied after vhost setup |
Seccomp modes: `--seccomp-mode=kill` (default), `trap` (SIGSYS for debugging), `log`, `disabled`.
Disable sandboxing for debugging with `--no-sandbox` and `--seccomp-mode=disabled`.

View file

@ -1097,7 +1097,10 @@ in
vm:
lib.nameValuePair "qubes-lite-${vm.name}-vm" {
description = "qubes-lite VM: ${vm.name}";
after = [ "network.target" ];
after =
[ "network.target" ]
++ map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork);
requires = map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork);
serviceConfig = {
Type = "simple";
ExecStart = "${mkVmScript vm}";

267
vm-switch/Cargo.lock generated
View file

@ -61,6 +61,12 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "anyhow"
version = "1.0.101"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
[[package]]
name = "arc-swap"
version = "1.8.1"
@ -70,6 +76,15 @@ dependencies = [
"rustversion",
]
[[package]]
name = "atomic-polyfill"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
dependencies = [
"critical-section",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -88,6 +103,12 @@ version = "3.19.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.11.1"
@ -100,6 +121,12 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "cfg_aliases"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
[[package]]
name = "clap"
version = "4.5.57"
@ -140,12 +167,39 @@ version = "0.7.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32"
[[package]]
name = "cobs"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
dependencies = [
"thiserror",
]
[[package]]
name = "colorchoice"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "critical-section"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
[[package]]
name = "embedded-io"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
[[package]]
name = "embedded-io"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
[[package]]
name = "errno"
version = "0.3.14"
@ -191,6 +245,42 @@ dependencies = [
"libc",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-core",
"futures-task",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "getrandom"
version = "0.3.4"
@ -203,6 +293,29 @@ dependencies = [
"wasip2",
]
[[package]]
name = "hash32"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
dependencies = [
"byteorder",
]
[[package]]
name = "heapless"
version = "0.7.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
dependencies = [
"atomic-polyfill",
"hash32",
"rustc_version",
"serde",
"spin",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.5.0"
@ -345,6 +458,18 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "nix"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags 2.10.0",
"cfg-if",
"cfg_aliases",
"libc",
]
[[package]]
name = "notify"
version = "7.0.0"
@ -436,6 +561,25 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "postcard"
version = "1.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6764c3b5dd454e283a30e6dfe78e9b31096d9e32036b5d1eaac7a6119ccb9a24"
dependencies = [
"cobs",
"embedded-io 0.4.0",
"embedded-io 0.6.1",
"heapless",
"serde",
]
[[package]]
name = "ppv-lite86"
version = "0.2.21"
@ -533,6 +677,15 @@ version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c"
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver",
]
[[package]]
name = "rustix"
version = "1.1.3"
@ -561,12 +714,98 @@ dependencies = [
"winapi-util",
]
[[package]]
name = "scc"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc"
dependencies = [
"sdd",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdd"
version = "3.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca"
[[package]]
name = "seccompiler"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "345a3e4dddf721a478089d4697b83c6c0a8f5bf16086f6c13397e4534eb6e2e5"
dependencies = [
"libc",
]
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serial_test"
version = "3.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555"
dependencies = [
"futures-executor",
"futures-util",
"log",
"once_cell",
"parking_lot",
"scc",
"serial_test_derive",
]
[[package]]
name = "serial_test_derive"
version = "3.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
@ -586,6 +825,12 @@ dependencies = [
"libc",
]
[[package]]
name = "slab"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
[[package]]
name = "smallvec"
version = "1.15.1"
@ -602,6 +847,21 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
dependencies = [
"lock_api",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "strsim"
version = "0.11.1"
@ -841,9 +1101,16 @@ dependencies = [
name = "vm-switch"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"libc",
"nix",
"notify",
"notify-debouncer-full",
"postcard",
"seccompiler",
"serde",
"serial_test",
"tempfile",
"thiserror",
"tokio",

View file

@ -27,9 +27,20 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Error handling
thiserror = "2"
anyhow = "1"
# CLI
clap = { version = "4", features = ["derive"] }
# Sandboxing
nix = { version = "0.29", features = ["sched", "mount", "user", "signal", "process", "poll", "fs"] }
seccompiler = "0.4"
libc = "0.2"
# Serialization (for control channel)
serde = { version = "1", features = ["derive"] }
postcard = { version = "1", features = ["alloc"] }
[dev-dependencies]
tempfile = "3"
serial_test = "3"

View file

@ -1,9 +1,46 @@
//! Command-line argument parsing.
use clap::{Parser, ValueEnum};
use std::fmt;
use std::path::PathBuf;
use std::sync::Mutex;
use crate::seccomp::SeccompMode;
use clap::{Parser, ValueEnum};
use tracing_subscriber::fmt::format::Writer;
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields};
use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::EnvFilter;
/// Process name for log prefixes ("main" or "worker-$vmname").
static PROCESS_NAME: Mutex<String> = Mutex::new(String::new());
/// Set the process name for log prefixes.
pub fn set_process_name(name: impl Into<String>) {
*PROCESS_NAME.lock().unwrap() = name.into();
}
/// Custom log formatter that outputs: LEVEL process-name: message fields
struct PrefixedFormatter;
impl<S, N> FormatEvent<S, N> for PrefixedFormatter
where
S: tracing::Subscriber + for<'a> LookupSpan<'a>,
N: for<'a> FormatFields<'a> + 'static,
{
fn format_event(
&self,
ctx: &FmtContext<'_, S, N>,
mut writer: Writer<'_>,
event: &tracing::Event<'_>,
) -> fmt::Result {
let level = *event.metadata().level();
let name = PROCESS_NAME.lock().unwrap();
write!(writer, "{level} {name}: ")?;
ctx.field_format().format_fields(writer.by_ref(), event)?;
writeln!(writer)
}
}
/// Log level for the application.
#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
pub enum LogLevel {
@ -39,16 +76,26 @@ pub struct Args {
/// Log level (error, warn, info, debug, trace).
#[arg(long, value_enum, default_value_t = LogLevel::Warn)]
pub log_level: LogLevel,
/// Disable namespace sandboxing (for debugging).
#[arg(long, default_value_t = false)]
pub no_sandbox: bool,
/// Seccomp filter mode (kill, trap, log, disabled).
#[arg(long, value_enum, default_value_t = SeccompMode::Kill)]
pub seccomp_mode: SeccompMode,
}
/// Initialize logging based on log level.
pub fn init_logging(level: LogLevel) {
set_process_name("main");
let filter = EnvFilter::try_from_default_env()
.unwrap_or_else(|_| EnvFilter::new(format!("vm_switch={}", level.as_str())));
let _ = tracing_subscriber::fmt()
.with_env_filter(filter)
.with_target(false)
.event_format(PrefixedFormatter)
.try_init();
}
@ -106,4 +153,40 @@ mod tests {
init_logging(LogLevel::Debug);
init_logging(LogLevel::Trace);
}
#[test]
fn parse_no_sandbox_flag() {
let args = Args::try_parse_from(["vm-switch", "--no-sandbox"]).unwrap();
assert!(args.no_sandbox);
}
#[test]
fn parse_no_sandbox_default_false() {
let args = Args::try_parse_from(["vm-switch"]).unwrap();
assert!(!args.no_sandbox);
}
#[test]
fn parse_seccomp_mode_default() {
let args = Args::try_parse_from(["vm-switch"]).unwrap();
assert_eq!(args.seccomp_mode, SeccompMode::Kill);
}
#[test]
fn parse_seccomp_mode_trap() {
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "trap"]).unwrap();
assert_eq!(args.seccomp_mode, SeccompMode::Trap);
}
#[test]
fn parse_seccomp_mode_log() {
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "log"]).unwrap();
assert_eq!(args.seccomp_mode, SeccompMode::Log);
}
#[test]
fn parse_seccomp_mode_disabled() {
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "disabled"]).unwrap();
assert_eq!(args.seccomp_mode, SeccompMode::Disabled);
}
}

View file

@ -1,921 +0,0 @@
//! Vhost-user network backend implementation.
use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use tracing::{debug, trace, warn};
use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT};
use virtio_bindings::virtio_net::{
VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4,
VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4,
VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS,
};
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use virtio_queue::QueueT;
use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
use vmm_sys_util::epoll::EventSet;
use crate::config::VmRole;
use crate::frame::EthernetFrame;
use crate::mac::Mac;
use crate::switch::{ConnectionId, ForwardDecision, Switch};
/// Registry mapping connection IDs to their RX vrings and associated memory.
/// Shared between all backends for frame routing.
pub type VringRegistry = Arc<RwLock<HashMap<ConnectionId, (VringRwLock, GuestMemoryAtomic<GuestMemoryMmap>)>>>;
/// RX queue index.
pub const RX_QUEUE: u16 = 0;
/// TX queue index.
pub const TX_QUEUE: u16 = 1;
/// Number of queues (RX + TX).
pub const NUM_QUEUES: usize = 2;
/// Maximum queue size (must be power of 2, 32768 is typical max).
pub const MAX_QUEUE_SIZE: usize = 32768;
/// Size of virtio-net header.
pub const VIRTIO_NET_HDR_SIZE: usize = 12;
/// Result of processing a frame from the TX queue.
#[derive(Debug, Clone)]
pub struct ProcessedFrame {
/// Raw frame data (Ethernet frame, no virtio header).
pub data: Vec<u8>,
/// Forwarding decision from the switch.
pub decision: ForwardDecision,
}
/// Virtio net features we support.
const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM)
| (1 << VIRTIO_NET_F_GUEST_CSUM)
| (1 << VIRTIO_NET_F_GUEST_TSO4)
| (1 << VIRTIO_NET_F_GUEST_TSO6)
| (1 << VIRTIO_NET_F_GUEST_UFO)
| (1 << VIRTIO_NET_F_HOST_TSO4)
| (1 << VIRTIO_NET_F_HOST_TSO6)
| (1 << VIRTIO_NET_F_HOST_UFO)
| (1 << VIRTIO_NET_F_MAC)
| (1 << VIRTIO_NET_F_STATUS);
/// Network backend for a single VM.
pub struct NetBackend {
/// VM name for logging.
name: String,
/// VM's role (router or client).
role: VmRole,
/// VM's MAC address.
mac: Mac,
/// Connection ID in the switch (set after registration).
connection_id: Mutex<Option<ConnectionId>>,
/// Shared switch for forwarding.
switch: Arc<RwLock<Switch>>,
/// Shared registry of all backends' RX vrings for frame routing.
vring_registry: VringRegistry,
/// Guest memory.
mem: Mutex<Option<GuestMemoryAtomic<GuestMemoryMmap>>>,
/// Whether EVENT_IDX is enabled.
event_idx: Mutex<bool>,
/// Acked features.
acked_features: Mutex<u64>,
}
impl std::fmt::Debug for NetBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NetBackend")
.field("name", &self.name)
.field("role", &self.role)
.field("mac", &self.mac)
.field("connection_id", &self.connection_id)
.finish_non_exhaustive()
}
}
impl NetBackend {
/// Create a new network backend.
pub fn new(
name: String,
role: VmRole,
mac: Mac,
switch: Arc<RwLock<Switch>>,
vring_registry: VringRegistry,
) -> Self {
Self {
name,
role,
mac,
connection_id: Mutex::new(None),
switch,
vring_registry,
mem: Mutex::new(None),
event_idx: Mutex::new(false),
acked_features: Mutex::new(0),
}
}
/// Get the VM name.
pub fn name(&self) -> &str {
&self.name
}
/// Get the VM role.
pub fn role(&self) -> VmRole {
self.role
}
/// Get the VM MAC.
pub fn mac(&self) -> Mac {
self.mac
}
/// Get the connection ID (if registered).
pub fn connection_id(&self) -> Option<ConnectionId> {
*self.connection_id.lock().unwrap()
}
/// Register this backend with the switch.
pub fn register(&self) -> Option<ConnectionId> {
let mut switch = self.switch.write().unwrap();
let id = switch.register(self.name.clone(), self.role, self.mac)?;
*self.connection_id.lock().unwrap() = Some(id);
Some(id)
}
/// Unregister this backend from the switch.
pub fn unregister(&self) {
let mut id_guard = self.connection_id.lock().unwrap();
if let Some(id) = id_guard.take() {
let mut switch = self.switch.write().unwrap();
switch.unregister(id);
}
}
/// Clear state for connection reset (called between vhost-user sessions).
pub fn clear_state(&self) {
// Clear guest memory
*self.mem.lock().unwrap() = None;
// Reset event_idx
*self.event_idx.lock().unwrap() = false;
// Reset acked features
*self.acked_features.lock().unwrap() = 0;
}
/// Process a frame and determine forwarding.
///
/// Takes raw Ethernet frame bytes (no virtio header).
/// Returns None if not registered or frame is invalid.
pub fn process_frame(&self, frame_data: &[u8]) -> Option<ProcessedFrame> {
let conn_id = self.connection_id()?;
let frame = EthernetFrame::parse(frame_data)?;
let switch = self.switch.read().unwrap();
let decision = switch.forward(conn_id, frame.source_mac(), frame.dest_mac());
Some(ProcessedFrame {
data: frame_data.to_vec(),
decision,
})
}
/// Strip the virtio-net header from frame data.
/// Returns the Ethernet frame without the header.
pub fn strip_virtio_header(data: &[u8]) -> Option<&[u8]> {
if data.len() < VIRTIO_NET_HDR_SIZE {
return None;
}
Some(&data[VIRTIO_NET_HDR_SIZE..])
}
/// Prepend a virtio-net header to frame data.
/// Returns the complete buffer for RX injection.
pub fn prepend_virtio_header(frame: &[u8]) -> Vec<u8> {
let mut result = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()];
// Header is all zeros (basic operation, no offloading)
result[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame);
result
}
/// Process frames from the TX queue.
///
/// Reads frames, strips virtio header, gets forwarding decisions.
/// Returns frames with their forwarding decisions.
pub fn process_tx_queue(&self, vring: &VringRwLock) -> Vec<ProcessedFrame> {
use vm_memory::GuestMemoryLoadGuard;
let mut results = Vec::new();
// Need connection ID and memory to process
if self.connection_id().is_none() {
trace!(vm = %self.name, "process_tx_queue: no connection_id");
return results;
}
let mem_guard = self.mem.lock().unwrap();
let mem: GuestMemoryLoadGuard<GuestMemoryMmap> = match mem_guard.as_ref() {
Some(m) => m.memory(),
None => {
warn!(vm = %self.name, "process_tx_queue: no guest memory set!");
return results;
}
};
// Collect all frames and head indices
let mut frames: Vec<(u16, Vec<u8>)> = Vec::new();
{
let mut vring_state = vring.get_mut();
let queue = vring_state.get_queue_mut();
trace!(
vm = %self.name,
queue_ready = queue.ready(),
queue_size = queue.size(),
next_avail = queue.next_avail(),
next_used = queue.next_used(),
"process_tx_queue: checking queue"
);
while let Some(desc_chain) = queue.pop_descriptor_chain(mem.clone()) {
let head_index = desc_chain.head_index();
let mut raw_data = Vec::new();
// Read all descriptors in the chain
for desc in desc_chain {
let addr = desc.addr();
let len = desc.len() as usize;
let mut buf = vec![0u8; len];
if let Err(e) = mem.read_slice(&mut buf, addr) {
warn!(
vm = %self.name,
head_index,
addr = ?addr,
len,
error = %e,
"process_tx_queue: failed to read descriptor"
);
break;
}
raw_data.extend_from_slice(&buf);
}
trace!(
vm = %self.name,
head_index,
raw_len = raw_data.len(),
"process_tx_queue: popped descriptor"
);
frames.push((head_index, raw_data));
}
}
if frames.is_empty() {
trace!(vm = %self.name, "process_tx_queue: no frames in queue");
}
// Process frames and mark descriptors as used
for (head_index, raw_data) in &frames {
if let Err(e) = vring.add_used(*head_index, 0) {
warn!(
vm = %self.name,
head_index,
error = ?e,
"process_tx_queue: add_used failed"
);
} else {
trace!(
vm = %self.name,
head_index,
"process_tx_queue: add_used ok"
);
}
// Strip header and process frame
if let Some(frame_data) = Self::strip_virtio_header(raw_data) {
if let Some(processed) = self.process_frame(frame_data) {
results.push(processed);
}
}
}
// Re-enable notifications so guest will kick us when it adds more buffers.
// This is critical for EVENT_IDX: after draining the queue, we must tell
// the guest to notify us of new buffers, otherwise it will suppress kicks.
match vring.enable_notification() {
Ok(has_more) => {
if has_more {
trace!(vm = %self.name, "process_tx_queue: enable_notification returned has_more=true");
}
}
Err(e) => {
warn!(vm = %self.name, error = ?e, "process_tx_queue: enable_notification failed");
}
}
// Signal guest that we've processed the queue
if !frames.is_empty() {
// Check if the call eventfd is set
{
let vring_state = vring.get_ref();
let has_call = vring_state.get_call().is_some();
if !has_call {
warn!(
vm = %self.name,
num_frames = frames.len(),
"process_tx_queue: no call eventfd set, cannot notify guest!"
);
} else {
trace!(
vm = %self.name,
num_frames = frames.len(),
"process_tx_queue: call eventfd is set"
);
}
}
match vring.signal_used_queue() {
Ok(()) => trace!(
vm = %self.name,
num_frames = frames.len(),
"process_tx_queue: signal_used_queue ok"
),
Err(e) => warn!(
vm = %self.name,
num_frames = frames.len(),
error = %e,
"process_tx_queue: signal_used_queue FAILED"
),
}
}
// Log final queue state
{
let vring_state = vring.get_ref();
let queue = vring_state.get_queue();
debug!(
vm = %self.name,
frames_processed = results.len(),
next_avail = queue.next_avail(),
next_used = queue.next_used(),
"process_tx_queue complete"
);
}
results
}
/// Inject a frame into the RX queue.
///
/// Prepends virtio header and writes using scatter-gather across descriptors.
/// Returns true if successful, false if queue is full or insufficient space.
///
/// Note: This is a static method that takes the destination VM's memory mapping.
/// This is important when injecting frames into a different VM's RX queue -
/// we must use that VM's memory mapping, not our own.
pub fn inject_rx_frame(vring: &VringRwLock, mem: &GuestMemoryAtomic<GuestMemoryMmap>, frame: &[u8]) -> bool {
use vm_memory::GuestMemoryLoadGuard;
let mem: GuestMemoryLoadGuard<GuestMemoryMmap> = mem.memory();
let data_to_write = Self::prepend_virtio_header(frame);
let total_len = data_to_write.len();
let head_index;
let written;
{
let mut vring_state = vring.get_mut();
let queue = vring_state.get_queue_mut();
trace!(
queue_ready = queue.ready(),
queue_size = queue.size(),
frame_len = frame.len(),
total_len,
"inject_rx_frame: checking queue"
);
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
Some(chain) => chain,
None => {
trace!("inject_rx_frame: no descriptor available (queue full)");
return false;
}
};
head_index = desc_chain.head_index();
// First pass: collect writable descriptors and calculate available space
let mut writable_descs = Vec::new();
for desc in desc_chain {
if desc.is_write_only() {
writable_descs.push((desc.addr(), desc.len() as usize));
}
}
let available_space: usize = writable_descs.iter().map(|(_, len)| *len).sum();
trace!(
head_index,
num_writable_descs = writable_descs.len(),
available_space,
total_len,
"inject_rx_frame: descriptor chain"
);
if available_space < total_len {
// Insufficient space - don't write partial data
warn!(
head_index,
available_space,
total_len,
"inject_rx_frame: insufficient space in descriptors"
);
written = 0;
} else {
// Second pass: scatter-gather write across all descriptors
let mut bytes_written = 0;
for (addr, len) in writable_descs {
let remaining = total_len - bytes_written;
if remaining == 0 {
break;
}
let to_write = std::cmp::min(remaining, len);
if let Err(e) = mem.write_slice(&data_to_write[bytes_written..bytes_written + to_write], addr) {
warn!(
head_index,
addr = ?addr,
to_write,
error = %e,
"inject_rx_frame: write_slice failed"
);
break;
}
bytes_written += to_write;
}
written = bytes_written;
}
}
if let Err(e) = vring.add_used(head_index, written as u32) {
warn!(
head_index,
written,
error = ?e,
"inject_rx_frame: add_used failed"
);
}
// Re-enable notifications so guest knows to provide more RX buffers
if let Err(e) = vring.enable_notification() {
warn!(
head_index,
error = ?e,
"inject_rx_frame: enable_notification failed"
);
}
if let Err(e) = vring.signal_used_queue() {
warn!(
head_index,
error = %e,
"inject_rx_frame: signal_used_queue failed"
);
}
let success = written >= total_len;
trace!(
head_index,
written,
total_len,
success,
"inject_rx_frame complete"
);
success
}
}
impl VhostUserBackend for NetBackend {
type Bitmap = ();
type Vring = VringRwLock;
fn num_queues(&self) -> usize {
NUM_QUEUES
}
fn max_queue_size(&self) -> usize {
MAX_QUEUE_SIZE
}
fn features(&self) -> u64 {
let features = VIRTIO_NET_FEATURES
| (1 << VIRTIO_F_VERSION_1)
| (1 << VIRTIO_RING_F_EVENT_IDX)
| VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
trace!(vm = %self.name, features = format!("{:#x}", features), "features requested");
features
}
fn protocol_features(&self) -> VhostUserProtocolFeatures {
let proto = VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ;
trace!(vm = %self.name, protocol_features = ?proto, "protocol_features requested");
proto
}
fn set_event_idx(&self, enabled: bool) {
debug!(vm = %self.name, enabled, "set_event_idx");
*self.event_idx.lock().unwrap() = enabled;
}
fn update_memory(
&self,
mem: GuestMemoryAtomic<GuestMemoryMmap>,
) -> std::io::Result<()> {
debug!(vm = %self.name, "update_memory called");
*self.mem.lock().unwrap() = Some(mem);
Ok(())
}
fn handle_event(
&self,
device_event: u16,
_evset: EventSet,
vrings: &[VringRwLock],
_thread_id: usize,
) -> std::io::Result<()> {
// Note: read_kick() is already called by VringEpollHandler before invoking this method.
// We do not call it again to avoid blocking on a drained eventfd.
trace!(
vm = %self.name,
device_event,
queue_name = if device_event == RX_QUEUE { "RX" } else if device_event == TX_QUEUE { "TX" } else { "?" },
"handle_event called"
);
// Validate event index
if (device_event as usize) >= vrings.len() {
debug!(device_event, vrings_len = vrings.len(), "ignoring out-of-range device event");
return Ok(());
}
// Get our connection ID
let conn_id = match self.connection_id() {
Some(id) => id,
None => {
debug!(vm = %self.name, "handle_event: no connection_id, ignoring");
return Ok(());
}
};
// Register our RX vring in the shared registry (if not already done)
// This allows other backends to inject frames into our RX queue
if vrings.len() > RX_QUEUE as usize {
let mut registry = self.vring_registry.write().unwrap();
if !registry.contains_key(&conn_id) {
// Clone the memory for registry storage
let mem_guard = self.mem.lock().unwrap();
if let Some(mem) = mem_guard.clone() {
debug!(
vm = %self.name,
conn_id = ?conn_id,
"registering RX vring in registry"
);
registry.insert(conn_id, (vrings[RX_QUEUE as usize].clone(), mem));
} else {
warn!(
vm = %self.name,
conn_id = ?conn_id,
"cannot register RX vring: no guest memory set"
);
}
}
}
// Only process TX queue kicks
if device_event != TX_QUEUE {
trace!(vm = %self.name, device_event, "ignoring non-TX queue event");
return Ok(());
}
// Process frames from our TX queue
let tx_vring = &vrings[TX_QUEUE as usize];
let processed_frames = self.process_tx_queue(tx_vring);
if processed_frames.is_empty() {
trace!(vm = %self.name, "handle_event: no frames processed from TX queue");
return Ok(());
}
// Route each frame to its destination(s)
let registry = self.vring_registry.read().unwrap();
let mut routed = 0;
let mut dropped = 0;
trace!(
vm = %self.name,
num_frames = processed_frames.len(),
registry_size = registry.len(),
"routing frames"
);
for processed in processed_frames {
match &processed.decision {
ForwardDecision::Unicast(dest_id) => {
if let Some((rx_vring, dest_mem)) = registry.get(dest_id) {
if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) {
routed += 1;
} else {
dropped += 1;
debug!(
src = %self.name,
dest_id = ?dest_id,
"RX queue full, dropping frame"
);
}
} else {
dropped += 1;
debug!(
src = %self.name,
dest_id = ?dest_id,
"destination not in registry, dropping frame"
);
}
}
ForwardDecision::Multicast(dest_ids) => {
for dest_id in dest_ids {
if let Some((rx_vring, dest_mem)) = registry.get(dest_id) {
if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) {
routed += 1;
} else {
dropped += 1;
debug!(
src = %self.name,
dest_id = ?dest_id,
"RX queue full, dropping frame"
);
}
} else {
dropped += 1;
debug!(
src = %self.name,
dest_id = ?dest_id,
"destination not in registry (multicast), dropping frame"
);
}
}
}
ForwardDecision::Drop(reason) => {
dropped += 1;
debug!(src = %self.name, ?reason, "Dropping frame per switch decision");
}
}
}
trace!(vm = %self.name, routed, dropped, "handle_event complete");
Ok(())
}
fn acked_features(&self, features: u64) {
debug!(vm = %self.name, features = format!("{:#x}", features), "acked_features");
*self.acked_features.lock().unwrap() = features;
}
fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
// Virtio net config: MAC (6 bytes) + status (2) + max_virtqueue_pairs (2)
let mut config = [0u8; 10];
config[0..6].copy_from_slice(&self.mac.bytes());
config[6] = 1; // VIRTIO_NET_S_LINK_UP
config[8] = 1; // max_virtqueue_pairs = 1
let config = config;
let offset = offset as usize;
let size = size as usize;
if offset < config.len() {
let end = std::cmp::min(offset + size, config.len());
let mut result = config[offset..end].to_vec();
// Pad with zeros if requested more than available
result.resize(size, 0);
result
} else {
vec![0u8; size]
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::switch::Switch;
fn make_switch() -> Arc<RwLock<Switch>> {
Arc::new(RwLock::new(Switch::new()))
}
fn make_vring_registry() -> VringRegistry {
Arc::new(RwLock::new(HashMap::new()))
}
fn make_backend() -> NetBackend {
NetBackend::new(
"test".to_string(),
VmRole::Client,
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
make_switch(),
make_vring_registry(),
)
}
#[test]
fn num_queues_returns_two() {
let backend = make_backend();
assert_eq!(backend.num_queues(), 2);
}
#[test]
fn max_queue_size_returns_max() {
let backend = make_backend();
assert_eq!(backend.max_queue_size(), MAX_QUEUE_SIZE);
}
#[test]
fn register_assigns_connection_id() {
let backend = make_backend();
assert!(backend.connection_id().is_none());
let id = backend.register();
assert!(id.is_some());
assert_eq!(backend.connection_id(), id);
}
#[test]
fn unregister_clears_connection_id() {
let backend = make_backend();
backend.register();
assert!(backend.connection_id().is_some());
backend.unregister();
assert!(backend.connection_id().is_none());
}
#[test]
fn duplicate_router_returns_none() {
let switch = make_switch();
let registry = make_vring_registry();
let router1 = NetBackend::new(
"router1".to_string(),
VmRole::Router,
Mac::from_bytes([1, 0, 0, 0, 0, 1]),
Arc::clone(&switch),
Arc::clone(&registry),
);
let router2 = NetBackend::new(
"router2".to_string(),
VmRole::Router,
Mac::from_bytes([2, 0, 0, 0, 0, 2]),
Arc::clone(&switch),
Arc::clone(&registry),
);
assert!(router1.register().is_some());
assert!(router2.register().is_none());
}
#[test]
fn get_config_returns_mac_at_offset_zero() {
let switch = make_switch();
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry());
let config = backend.get_config(0, 6);
assert_eq!(config, vec![0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
}
#[test]
fn get_config_partial_read() {
let switch = make_switch();
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry());
// Read bytes 2-4 of MAC
let config = backend.get_config(2, 3);
assert_eq!(config, vec![0xcc, 0xdd, 0xee]);
}
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
let mut frame = vec![0u8; 14];
frame[0..6].copy_from_slice(&dest);
frame[6..12].copy_from_slice(&src);
frame
}
#[test]
fn process_frame_returns_none_when_unregistered() {
let backend = make_backend();
// Don't register
let frame = make_frame([0xff; 6], [1, 2, 3, 4, 5, 6]);
assert!(backend.process_frame(&frame).is_none());
}
#[test]
fn process_frame_returns_decision_when_registered() {
let switch = make_switch();
let registry = make_vring_registry();
// Register router first (clients need a router to forward to)
let router = NetBackend::new(
"router".to_string(),
VmRole::Router,
Mac::from_bytes([0xaa, 0, 0, 0, 0, 1]),
Arc::clone(&switch),
Arc::clone(&registry),
);
router.register();
// Register client
let client = NetBackend::new(
"client".to_string(),
VmRole::Client,
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
Arc::clone(&switch),
Arc::clone(&registry),
);
client.register();
// Client sends to router
let frame = make_frame([0xaa, 0, 0, 0, 0, 1], [1, 2, 3, 4, 5, 6]);
let result = client.process_frame(&frame);
assert!(result.is_some());
let processed = result.unwrap();
assert_eq!(processed.data, frame);
assert!(matches!(processed.decision, ForwardDecision::Unicast(_)));
}
#[test]
fn process_frame_returns_none_for_invalid_frame() {
let switch = make_switch();
let backend = NetBackend::new(
"test".to_string(),
VmRole::Client,
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
switch,
make_vring_registry(),
);
backend.register();
// Frame too small (< 14 bytes)
let frame = vec![0u8; 10];
assert!(backend.process_frame(&frame).is_none());
}
#[test]
fn strip_virtio_header_returns_frame() {
// 12-byte header + 14-byte Ethernet frame
let mut data = vec![0u8; 26];
data[12..18].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dest MAC
let frame = NetBackend::strip_virtio_header(&data).unwrap();
assert_eq!(frame.len(), 14);
assert_eq!(&frame[0..6], &[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
}
#[test]
fn strip_virtio_header_returns_none_if_too_small() {
// Less than 12 bytes
let data = vec![0u8; 10];
assert!(NetBackend::strip_virtio_header(&data).is_none());
}
#[test]
fn strip_virtio_header_returns_empty_slice_if_exact() {
// Exactly 12 bytes (header only, no frame)
let data = vec![0u8; 12];
let frame = NetBackend::strip_virtio_header(&data).unwrap();
assert!(frame.is_empty());
}
#[test]
fn prepend_virtio_header_adds_12_bytes() {
let frame = vec![0xaa, 0xbb, 0xcc, 0xdd];
let result = NetBackend::prepend_virtio_header(&frame);
assert_eq!(result.len(), 12 + 4);
assert_eq!(&result[0..12], &[0u8; 12]); // Header is zeros
assert_eq!(&result[12..], &[0xaa, 0xbb, 0xcc, 0xdd]);
}
#[test]
fn prepend_virtio_header_empty_frame() {
let frame: Vec<u8> = vec![];
let result = NetBackend::prepend_virtio_header(&frame);
assert_eq!(result.len(), 12);
assert_eq!(&result[..], &[0u8; 12]);
}
}

View file

@ -0,0 +1,370 @@
//! Packet forwarding logic for child processes.
use std::collections::HashMap;
use std::os::fd::{AsRawFd, RawFd};
use tracing::{debug, info, trace};
use crate::frame::validate_source_mac;
use crate::mac::Mac;
use crate::ring::{Consumer, Producer};
/// Ingress buffer from a peer (they produce, we consume).
struct IngressBuffer {
peer_mac: [u8; 6],
consumer: Consumer,
}
/// Egress buffer to a peer (we produce, they consume).
struct EgressBuffer {
peer_mac: [u8; 6],
producer: Producer,
/// If true, this buffer accepts broadcast/multicast traffic.
broadcast: bool,
}
/// Manages packet forwarding for a child process.
pub struct PacketForwarder {
/// This child's MAC address.
our_mac: Mac,
/// Ingress buffers FROM peers (they produce, we consume). Keyed by peer name.
ingress: HashMap<String, IngressBuffer>,
/// Egress buffers TO peers (we produce, they consume). Keyed by peer name.
egress: HashMap<String, EgressBuffer>,
}
impl PacketForwarder {
/// Create a new packet forwarder.
pub fn new(our_mac: Mac) -> Self {
Self {
our_mac,
ingress: HashMap::new(),
egress: HashMap::new(),
}
}
/// Add an ingress buffer from a peer (we consume packets they produce).
pub fn add_ingress(&mut self, peer_name: String, peer_mac: [u8; 6], consumer: Consumer) {
info!(peer = %peer_name, "added ingress buffer from peer");
self.ingress.insert(
peer_name,
IngressBuffer { peer_mac, consumer },
);
}
/// Add an egress buffer to a peer (we produce packets they consume).
pub fn add_egress(
&mut self,
peer_name: String,
peer_mac: [u8; 6],
producer: Producer,
broadcast: bool,
) {
info!(peer = %peer_name, broadcast, "added egress buffer to peer");
self.egress.insert(
peer_name,
EgressBuffer {
peer_mac,
producer,
broadcast,
},
);
}
/// Remove all buffers for a peer.
pub fn remove_peer(&mut self, peer_name: &str) {
if self.ingress.remove(peer_name).is_some() {
info!(peer = %peer_name, "removed ingress buffer");
}
if self.egress.remove(peer_name).is_some() {
info!(peer = %peer_name, "removed egress buffer");
}
}
/// Get eventfds for all ingress consumers (for polling).
pub fn ingress_eventfds(&self) -> Vec<(RawFd, [u8; 6])> {
self.ingress
.values()
.map(|buf| (buf.consumer.eventfd().as_raw_fd(), buf.peer_mac))
.collect()
}
/// Forward a TX frame to peers based on destination MAC.
///
/// - Broadcast/multicast: sent to all egress buffers with broadcast=true
/// - Unicast: sent to the egress buffer matching the destination MAC
pub fn forward_tx(&self, frame: &[u8]) -> bool {
// Validate source MAC
if !validate_source_mac(frame, self.our_mac) {
debug!(
reason = "source MAC rejected",
our_mac = %self.our_mac,
frame_src = %Mac::from_bytes(frame[6..12].try_into().unwrap()),
size = frame.len(),
"TX: dropped"
);
return false;
}
let dest_mac: [u8; 6] = frame[0..6].try_into().unwrap();
let dest = Mac::from_bytes(dest_mac);
// Broadcast/multicast: send to all egress buffers with broadcast=true
if dest.is_broadcast() || dest.is_multicast() {
let mut sent = false;
for (peer_name, egress) in &self.egress {
if egress.broadcast {
if egress.producer.push(frame) {
trace!(
to = %peer_name,
mac = %Mac::from_bytes(egress.peer_mac),
size = frame.len(),
"TX: broadcast"
);
sent = true;
} else {
debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: broadcast dropped");
}
}
}
if !sent {
debug!(reason = "no broadcast peers", size = frame.len(), "TX: dropped");
}
return sent;
}
// Unicast: find egress buffer by destination MAC
for (peer_name, egress) in &self.egress {
if egress.peer_mac == dest_mac {
if egress.producer.push(frame) {
trace!(
to = %peer_name,
mac = %Mac::from_bytes(egress.peer_mac),
size = frame.len(),
"TX: pushed to egress buffer"
);
return true;
} else {
debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: dropped");
return false;
}
}
}
// Unknown destination
debug!(
reason = "unknown destination",
dest_mac = %dest,
size = frame.len(),
"TX: dropped"
);
false
}
/// Poll all ingress buffers and return received frames.
///
/// Validates source MAC matches the expected peer for each buffer.
pub fn poll_ingress(&self) -> Vec<Vec<u8>> {
let mut frames = Vec::new();
for (peer_name, ingress) in &self.ingress {
// Drain the eventfd
ingress.consumer.drain_eventfd();
// Pop all available frames
while let Some(frame) = ingress.consumer.pop() {
// Validate source MAC matches expected peer
if !validate_source_mac(&frame, Mac::from_bytes(ingress.peer_mac)) {
debug!(
reason = "source MAC mismatch",
from = %peer_name,
expected = %Mac::from_bytes(ingress.peer_mac),
actual = %Mac::from_bytes(frame[6..12].try_into().unwrap_or([0; 6])),
size = frame.len(),
"RX: dropped"
);
continue;
}
trace!(
from = %peer_name,
mac = %Mac::from_bytes(ingress.peer_mac),
size = frame.len(),
"RX: read from ingress buffer"
);
frames.push(frame);
}
}
frames
}
/// Get number of configured ingress peers.
pub fn ingress_count(&self) -> usize {
self.ingress.len()
}
/// Get number of configured egress peers.
pub fn egress_count(&self) -> usize {
self.egress.len()
}
/// Get our MAC address.
pub fn our_mac(&self) -> Mac {
self.our_mac
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::os::fd::{FromRawFd, OwnedFd};
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
let mut frame = vec![0u8; 14];
frame[0..6].copy_from_slice(&dest);
frame[6..12].copy_from_slice(&src);
frame
}
#[test]
fn forward_tx_validates_source_mac() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let forwarder = PacketForwarder::new(our_mac);
// Frame with wrong source MAC - should be dropped
let frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]);
assert!(!forwarder.forward_tx(&frame));
}
#[test]
fn forward_tx_drops_when_no_egress() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let forwarder = PacketForwarder::new(our_mac);
// Frame with correct source MAC but no egress peers
let frame = make_frame([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], our_mac.bytes());
assert!(!forwarder.forward_tx(&frame));
}
#[test]
fn forward_tx_unicast_to_matching_peer() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let mut forwarder = PacketForwarder::new(our_mac);
// Add egress to peer with specific MAC
let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
let consumer = Consumer::new().expect("consumer");
let producer = Producer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.eventfd().as_raw_fd())) },
)
.expect("producer");
forwarder.add_egress("router".to_string(), peer_mac, producer, false);
// Unicast frame to that MAC should succeed
let frame = make_frame(peer_mac, our_mac.bytes());
assert!(forwarder.forward_tx(&frame));
// Consumer should receive the frame
consumer.drain_eventfd();
let received = consumer.pop().expect("should have frame");
assert_eq!(received, frame);
}
#[test]
fn forward_tx_broadcast_to_broadcast_peers() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let mut forwarder = PacketForwarder::new(our_mac);
// Add broadcast-enabled egress
let consumer1 = Consumer::new().expect("consumer1");
let producer1 = Producer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.eventfd().as_raw_fd())) },
)
.expect("producer1");
forwarder.add_egress("router".to_string(), [0x11; 6], producer1, true);
// Add non-broadcast egress
let consumer2 = Consumer::new().expect("consumer2");
let producer2 = Producer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) },
)
.expect("producer2");
forwarder.add_egress("client_a".to_string(), [0x22; 6], producer2, false);
// Broadcast frame
let frame = make_frame([0xff; 6], our_mac.bytes());
assert!(forwarder.forward_tx(&frame));
// Only broadcast-enabled peer should receive
consumer1.drain_eventfd();
assert!(consumer1.pop().is_some());
consumer2.drain_eventfd();
assert!(consumer2.pop().is_none());
}
#[test]
fn poll_ingress_validates_source_mac() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let mut forwarder = PacketForwarder::new(our_mac);
// Create producer/consumer pair - producer simulates peer sending to us
let producer = Producer::new().expect("producer");
let consumer = Consumer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) },
)
.expect("consumer");
let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
forwarder.add_ingress("router".to_string(), peer_mac, consumer);
// Push frame with correct source MAC
let good_frame = make_frame(our_mac.bytes(), peer_mac);
producer.push(&good_frame);
// Push frame with wrong source MAC
let bad_frame = make_frame(our_mac.bytes(), [0x99; 6]);
producer.push(&bad_frame);
// Only good frame should be returned
let frames = forwarder.poll_ingress();
assert_eq!(frames.len(), 1);
assert_eq!(frames[0], good_frame);
}
#[test]
fn remove_peer_cleans_up_buffers() {
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
let mut forwarder = PacketForwarder::new(our_mac);
// Add ingress
let producer = Producer::new().expect("producer");
let consumer = Consumer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) },
)
.expect("consumer");
forwarder.add_ingress("router".to_string(), [0x11; 6], consumer);
// Add egress
let consumer2 = Consumer::new().expect("consumer2");
let producer2 = Producer::from_fds(
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) },
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) },
)
.expect("producer2");
forwarder.add_egress("router".to_string(), [0x11; 6], producer2, true);
assert_eq!(forwarder.ingress_count(), 1);
assert_eq!(forwarder.egress_count(), 1);
forwarder.remove_peer("router");
assert_eq!(forwarder.ingress_count(), 0);
assert_eq!(forwarder.egress_count(), 0);
}
}

View file

@ -0,0 +1,14 @@
//! Child process entry point for VM backends.
//!
//! Each VM runs in its own forked child process, communicating with
//! the main process via a control channel.
pub mod forwarder;
pub mod poll;
pub mod process;
pub mod vhost;
pub use forwarder::PacketForwarder;
pub use poll::{poll_events, PollResult};
pub use process::run_child_process;
pub use vhost::ChildVhostBackend;

115
vm-switch/src/child/poll.rs Normal file
View file

@ -0,0 +1,115 @@
//! Event polling for child processes.
use std::os::fd::{BorrowedFd, RawFd};
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
use nix::Error as NixError;
/// Result of polling for events.
#[derive(Debug)]
pub enum PollResult {
/// Control channel has data.
Control,
/// One or more ingress buffers have data.
Ingress(Vec<[u8; 6]>),
/// Second FD slot has an event (POLLIN/POLLHUP/POLLERR).
/// Used for daemon exit pipe detection.
TxKick,
/// Timeout expired with no events.
Timeout,
/// Error occurred.
Error(NixError),
}
/// Poll for events on control channel, TX kick, and ingress eventfds.
///
/// # Safety
/// All raw file descriptors passed in must be valid and open for the duration of this call.
pub fn poll_events(
control_fd: RawFd,
tx_kick_fd: Option<RawFd>,
ingress_fds: &[(RawFd, [u8; 6])],
timeout_ms: i32,
) -> PollResult {
// SAFETY: caller guarantees fds are valid for the duration of this call
let control_borrowed = unsafe { BorrowedFd::borrow_raw(control_fd) };
let mut pollfds: Vec<PollFd> = Vec::with_capacity(2 + ingress_fds.len());
// Control channel is first
pollfds.push(PollFd::new(control_borrowed, PollFlags::POLLIN));
// TX kick eventfd is second (if present)
let tx_index = if let Some(fd) = tx_kick_fd {
let tx_borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
pollfds.push(PollFd::new(tx_borrowed, PollFlags::POLLIN));
Some(1)
} else {
None
};
// Ingress eventfds follow
let ingress_start = pollfds.len();
for (fd, _mac) in ingress_fds {
let ingress_borrowed = unsafe { BorrowedFd::borrow_raw(*fd) };
pollfds.push(PollFd::new(ingress_borrowed, PollFlags::POLLIN));
}
let timeout = if timeout_ms < 0 {
PollTimeout::NONE
} else if timeout_ms > u16::MAX as i32 {
PollTimeout::MAX
} else {
PollTimeout::from(timeout_ms as u16)
};
match poll(&mut pollfds, timeout) {
Ok(0) => PollResult::Timeout,
Ok(_) => {
// Check control channel first (priority)
if let Some(revents) = pollfds[0].revents() {
if revents.contains(PollFlags::POLLIN)
|| revents.contains(PollFlags::POLLHUP)
|| revents.contains(PollFlags::POLLERR)
{
return PollResult::Control;
}
}
// Check TX kick / daemon exit pipe
if let Some(idx) = tx_index {
if let Some(revents) = pollfds[idx].revents() {
if revents.contains(PollFlags::POLLIN)
|| revents.contains(PollFlags::POLLHUP)
|| revents.contains(PollFlags::POLLERR)
{
return PollResult::TxKick;
}
}
}
// Check ingress eventfds
let mut ready_macs = Vec::new();
for (i, (_, mac)) in ingress_fds.iter().enumerate() {
if let Some(revents) = pollfds[ingress_start + i].revents() {
if revents.contains(PollFlags::POLLIN) {
ready_macs.push(*mac);
}
}
}
if ready_macs.is_empty() {
PollResult::Timeout
} else {
PollResult::Ingress(ready_macs)
}
}
Err(e) => {
if e == NixError::EINTR {
PollResult::Timeout
} else {
PollResult::Error(e)
}
}
}
}

View file

@ -0,0 +1,239 @@
//! Child process main loop.
use std::os::fd::{AsRawFd, OwnedFd, RawFd};
use std::os::unix::net::UnixListener;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::thread;
use nix::unistd::pipe;
use tracing::{debug, error, info, warn};
use vhost_user_backend::VhostUserDaemon;
use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
use crate::control::{ChildToMain, ControlChannel, ControlError, MainToChild};
use crate::mac::Mac;
use crate::ring::{Consumer, Producer};
use crate::seccomp::{apply_child_seccomp, SeccompMode};
use super::forwarder::PacketForwarder;
use super::poll::{poll_events, PollResult};
use super::vhost::ChildVhostBackend;
/// Run the child process.
///
/// This is the entry point after fork(). Does not return.
pub fn run_child_process(
vm_name: &str,
mac: Mac,
control_fd: OwnedFd,
socket_path: &Path,
seccomp_mode: SeccompMode,
) -> ! {
// Set process name for log prefix before any logging
crate::args::set_process_name(format!("worker-{}", vm_name));
info!(vm = %vm_name, mac = %mac, socket = ?socket_path, "child starting");
// Reconstruct control channel from owned fd
let control = ControlChannel::from_fd(control_fd);
// Send Ready to main
let msg = ChildToMain::Ready;
if let Err(e) = control.send(&msg) {
error!(vm = %vm_name, error = %e, "failed to send Ready");
std::process::exit(1)
}
debug!("control: worker-{} -> main Ready", vm_name);
// Create packet forwarder
let forwarder = Arc::new(Mutex::new(PacketForwarder::new(mac)));
// Create vhost backend
let backend = ChildVhostBackend::new(vm_name.to_string(), mac);
// Set TX callback
let fwd = Arc::clone(&forwarder);
backend.set_tx_callback(Box::new(move |frame| {
fwd.lock().unwrap().forward_tx(frame);
}));
// Create vhost socket
if socket_path.exists() {
let _ = std::fs::remove_file(socket_path);
}
let listener = match UnixListener::bind(socket_path) {
Ok(l) => l,
Err(e) => {
error!(vm = %vm_name, error = %e, "failed to bind socket");
std::process::exit(1)
}
};
let _ = listener.set_nonblocking(true);
// Start vhost daemon thread
let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new());
let mut daemon = match VhostUserDaemon::new(vm_name.to_string(), backend.clone(), mem) {
Ok(d) => d,
Err(e) => {
error!(vm = %vm_name, error = %e, "failed to create daemon");
std::process::exit(1)
}
};
// Create pipe to detect daemon thread exit. The write end is moved into
// the daemon thread; when the thread exits for any reason, the write end
// is dropped, causing POLLHUP on the read end.
let (pipe_rd, pipe_wr) = match pipe() {
Ok((rd, wr)) => (rd, wr),
Err(e) => {
error!(vm = %vm_name, error = %e, "failed to create pipe");
std::process::exit(1)
}
};
let vhost_listener = vhost::vhost_user::Listener::from(listener);
let name = vm_name.to_string();
thread::spawn(move || {
let _pipe_wr = pipe_wr; // dropped on thread exit → POLLHUP on read end
let mut l = vhost_listener;
if let Err(e) = daemon.start(&mut l) {
warn!(vm = %name, error = %e, "daemon start failed");
return;
}
if let Err(e) = daemon.wait() {
debug!(vm = %name, error = %e, "daemon wait returned error");
}
});
// Apply seccomp filter now that setup is complete
// (socket created, thread spawned, signals configured)
if let Err(e) = apply_child_seccomp(seccomp_mode) {
error!(vm = %vm_name, error = %e, "failed to apply seccomp");
std::process::exit(1);
}
if seccomp_mode != SeccompMode::Disabled {
debug!(vm = %vm_name, mode = ?seccomp_mode, "seccomp filter applied");
}
// Main event loop
let daemon_exit_fd = pipe_rd.as_raw_fd();
match event_loop(vm_name, control, forwarder, backend, daemon_exit_fd) {
Ok(()) => {
info!(vm = %vm_name, "exiting normally");
std::process::exit(0)
}
Err(e) => {
error!(vm = %vm_name, error = %e, "exiting with error");
std::process::exit(1)
}
}
}
fn event_loop(
vm_name: &str,
control: ControlChannel,
forwarder: Arc<Mutex<PacketForwarder>>,
backend: Arc<ChildVhostBackend>,
daemon_exit_fd: RawFd,
) -> Result<(), ControlError> {
let control_fd = control.as_raw_fd();
loop {
let ingress_fds = forwarder.lock().unwrap().ingress_eventfds();
match poll_events(control_fd, Some(daemon_exit_fd), &ingress_fds, 100) {
PollResult::TxKick => {
// Daemon thread exited (pipe write end closed → POLLHUP)
info!(vm = %vm_name, "vhost daemon exited, shutting down");
return Ok(());
}
PollResult::Control => {
let (msg, fds) = match control.recv_with_fds_typed() {
Ok(r) => r,
Err(ControlError::Closed) => {
debug!(vm = %vm_name, "control closed");
return Ok(());
}
Err(e) => return Err(e),
};
match msg {
MainToChild::GetBuffer { peer_name, peer_mac } => {
debug!(
"control: main -> worker-{} GetBuffer({}, {})",
vm_name, peer_name, Mac::from_bytes(peer_mac)
);
// Create ingress buffer (we are Consumer)
match Consumer::new() {
Ok(consumer) => {
let response_fds = [
consumer.memfd().as_raw_fd(),
consumer.eventfd().as_raw_fd(),
];
let response = ChildToMain::BufferReady {
peer_name: peer_name.clone(),
};
if let Err(e) = control.send_with_fds_typed(&response, &response_fds) {
warn!(vm = %vm_name, error = %e, "failed to send BufferReady");
} else {
debug!(
"control: worker-{} -> main BufferReady({})",
vm_name, peer_name
);
forwarder.lock().unwrap().add_ingress(peer_name, peer_mac, consumer);
}
}
Err(e) => warn!(vm = %vm_name, error = %e, "failed to create ingress buffer"),
}
}
MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => {
debug!(
"control: main -> worker-{} PutBuffer({}, {}, broadcast={})",
vm_name, peer_name, Mac::from_bytes(peer_mac), broadcast
);
if fds.len() == 2 {
let mut fds = fds.into_iter();
match Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()) {
Ok(producer) => {
forwarder.lock().unwrap().add_egress(
peer_name,
peer_mac,
producer,
broadcast,
);
}
Err(e) => warn!(vm = %vm_name, error = %e, "failed to map egress buffer"),
}
} else {
warn!(vm = %vm_name, "PutBuffer with wrong number of FDs: {}", fds.len());
}
}
MainToChild::RemovePeer { peer_name } => {
debug!("control: main -> worker-{} RemovePeer({})", vm_name, peer_name);
forwarder.lock().unwrap().remove_peer(&peer_name);
}
MainToChild::Ping => {
control.send(&ChildToMain::Pong)?;
}
}
}
PollResult::Ingress(_) | PollResult::Timeout => {
let frames = forwarder.lock().unwrap().poll_ingress();
for frame in frames {
if !backend.inject_rx_frame(&frame) {
debug!(vm = %vm_name, "RX inject failed (queue full)");
}
}
}
PollResult::Error(e) => {
warn!(vm = %vm_name, error = ?e, "poll error");
}
}
}
}

View file

@ -0,0 +1,283 @@
//! Vhost-user backend for child processes.
use std::sync::{Arc, Mutex};
use tracing::{debug, warn};
use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT};
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
use virtio_bindings::virtio_net::{
VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4,
VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4,
VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS,
};
use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use virtio_queue::QueueT;
use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
use vmm_sys_util::epoll::EventSet;
use crate::mac::Mac;
/// RX queue index.
pub const RX_QUEUE: u16 = 0;
/// TX queue index.
pub const TX_QUEUE: u16 = 1;
/// Number of queues.
pub const NUM_QUEUES: usize = 2;
/// Maximum queue size.
pub const MAX_QUEUE_SIZE: usize = 32768;
/// Virtio-net header size.
pub const VIRTIO_NET_HDR_SIZE: usize = 12;
/// Virtio net features.
const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM)
| (1 << VIRTIO_NET_F_GUEST_CSUM)
| (1 << VIRTIO_NET_F_GUEST_TSO4)
| (1 << VIRTIO_NET_F_GUEST_TSO6)
| (1 << VIRTIO_NET_F_GUEST_UFO)
| (1 << VIRTIO_NET_F_HOST_TSO4)
| (1 << VIRTIO_NET_F_HOST_TSO6)
| (1 << VIRTIO_NET_F_HOST_UFO)
| (1 << VIRTIO_NET_F_MAC)
| (1 << VIRTIO_NET_F_STATUS);
/// Callback type for TX frames.
pub type TxCallback = Box<dyn Fn(&[u8]) + Send>;
/// Child's vhost-user backend.
pub struct ChildVhostBackend {
name: String,
mac: Mac,
mem: Mutex<Option<GuestMemoryAtomic<GuestMemoryMmap>>>,
tx_callback: Mutex<Option<TxCallback>>,
rx_vring: Mutex<Option<VringRwLock>>,
}
impl std::fmt::Debug for ChildVhostBackend {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ChildVhostBackend")
.field("name", &self.name)
.field("mac", &self.mac)
.finish_non_exhaustive()
}
}
impl ChildVhostBackend {
/// Create a new backend.
pub fn new(name: String, mac: Mac) -> Arc<Self> {
Arc::new(Self {
name,
mac,
mem: Mutex::new(None),
tx_callback: Mutex::new(None),
rx_vring: Mutex::new(None),
})
}
/// Set the TX callback.
pub fn set_tx_callback(&self, callback: TxCallback) {
*self.tx_callback.lock().unwrap() = Some(callback);
}
/// Inject a frame into the RX queue for the guest.
///
/// Returns `true` if the frame was written to the virtio RX queue, `false` if
/// dropped. Frames are silently dropped before the guest driver has initialized
/// virtio queues (RX vring not yet available). This is expected during early
/// startup -- the guest isn't ready to receive traffic until the driver
/// negotiates features and sets up vrings.
pub fn inject_rx_frame(&self, frame: &[u8]) -> bool {
let vring = match self.rx_vring.lock().unwrap().as_ref() {
Some(v) => v.clone(),
None => return false,
};
let mem_guard = self.mem.lock().unwrap();
let mem = match mem_guard.as_ref() {
Some(m) => m.memory(),
None => return false,
};
// Prepend virtio header
let mut data = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()];
data[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame);
let head_index;
let written;
{
let mut vring_state = vring.get_mut();
let queue = vring_state.get_queue_mut();
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
Some(c) => c,
None => return false,
};
head_index = desc_chain.head_index();
let mut writable_descs = Vec::new();
for desc in desc_chain {
if desc.is_write_only() {
writable_descs.push((desc.addr(), desc.len() as usize));
}
}
let available: usize = writable_descs.iter().map(|(_, l)| *l).sum();
if available < data.len() {
written = 0;
} else {
let mut bytes_written = 0;
for (addr, len) in writable_descs {
let remaining = data.len() - bytes_written;
if remaining == 0 {
break;
}
let to_write = std::cmp::min(remaining, len);
if mem.write_slice(&data[bytes_written..bytes_written + to_write], addr).is_err() {
break;
}
bytes_written += to_write;
}
written = bytes_written;
}
}
let _ = vring.add_used(head_index, written as u32);
let _ = vring.enable_notification();
let _ = vring.signal_used_queue();
written >= data.len()
}
/// Process TX queue and call callback for each frame.
fn process_tx(&self, vring: &VringRwLock) {
let mem_guard = self.mem.lock().unwrap();
let mem = match mem_guard.as_ref() {
Some(m) => m.memory(),
None => return,
};
let callback = self.tx_callback.lock().unwrap();
let callback = match callback.as_ref() {
Some(c) => c,
None => return,
};
loop {
let head_index;
let raw_data;
{
let mut vring_state = vring.get_mut();
let queue = vring_state.get_queue_mut();
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
Some(c) => c,
None => break,
};
head_index = desc_chain.head_index();
let mut data = Vec::new();
for desc in desc_chain {
let addr = desc.addr();
let len = desc.len() as usize;
let mut buf = vec![0u8; len];
if let Err(e) = mem.read_slice(&mut buf, addr) {
warn!(vm = %self.name, error = %e, "failed to read descriptor");
break;
}
data.extend_from_slice(&buf);
}
raw_data = data;
}
if let Err(e) = vring.add_used(head_index, 0) {
warn!(vm = %self.name, error = ?e, "add_used failed");
}
// Strip virtio header and call callback
if raw_data.len() > VIRTIO_NET_HDR_SIZE {
callback(&raw_data[VIRTIO_NET_HDR_SIZE..]);
}
}
let _ = vring.enable_notification();
let _ = vring.signal_used_queue();
}
}
impl VhostUserBackend for ChildVhostBackend {
type Bitmap = ();
type Vring = VringRwLock;
fn num_queues(&self) -> usize {
NUM_QUEUES
}
fn max_queue_size(&self) -> usize {
MAX_QUEUE_SIZE
}
fn features(&self) -> u64 {
VIRTIO_NET_FEATURES
| (1 << VIRTIO_F_VERSION_1)
| (1 << VIRTIO_RING_F_EVENT_IDX)
| VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
}
fn protocol_features(&self) -> VhostUserProtocolFeatures {
VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ
}
fn set_event_idx(&self, _enabled: bool) {}
fn update_memory(&self, mem: GuestMemoryAtomic<GuestMemoryMmap>) -> std::io::Result<()> {
debug!(vm = %self.name, "update_memory");
*self.mem.lock().unwrap() = Some(mem);
Ok(())
}
fn handle_event(
&self,
device_event: u16,
_evset: EventSet,
vrings: &[VringRwLock],
_thread_id: usize,
) -> std::io::Result<()> {
// Store RX vring for injection
if vrings.len() > RX_QUEUE as usize {
let mut rx = self.rx_vring.lock().unwrap();
if rx.is_none() {
*rx = Some(vrings[RX_QUEUE as usize].clone());
debug!(vm = %self.name, "stored RX vring");
}
}
// Process TX queue
if device_event == TX_QUEUE && vrings.len() > TX_QUEUE as usize {
self.process_tx(&vrings[TX_QUEUE as usize]);
}
Ok(())
}
fn acked_features(&self, _features: u64) {}
fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
let mut config = [0u8; 10];
config[0..6].copy_from_slice(&self.mac.bytes());
config[6] = 1; // LINK_UP
config[8] = 1; // max_virtqueue_pairs
let offset = offset as usize;
let size = size as usize;
if offset < config.len() {
let end = std::cmp::min(offset + size, config.len());
let mut result = config[offset..end].to_vec();
result.resize(size, 0);
result
} else {
vec![0u8; size]
}
}
}

721
vm-switch/src/control.rs Normal file
View file

@ -0,0 +1,721 @@
//! Control channel for main↔child process communication.
//!
//! Messages are serialized with postcard and sent over Unix sockets.
//! File descriptors are passed via SCM_RIGHTS ancillary data.
use crate::mac::Mac;
use serde::{Deserialize, Serialize};
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
use thiserror::Error;
/// Messages sent from main process to child.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum MainToChild {
/// Request child create an ingress buffer for a peer.
/// Child will be Consumer, peer will be Producer.
GetBuffer {
/// Name of the peer VM.
peer_name: String,
/// MAC address of the peer VM.
peer_mac: [u8; 6],
},
/// Provide peer's ingress buffer for child to use as egress.
/// Child becomes Producer for this buffer.
/// FDs: [memfd, eventfd]
PutBuffer {
/// Name of the peer VM.
peer_name: String,
/// MAC address of the peer VM.
peer_mac: [u8; 6],
/// If true, buffer accepts broadcast/multicast traffic.
broadcast: bool,
},
/// Peer disconnected, clean up all buffers for this peer.
RemovePeer {
/// Name of the peer VM.
peer_name: String,
},
/// Heartbeat request.
Ping,
}
/// Messages sent from child process to main.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub enum ChildToMain {
/// Child is ready to receive commands.
Ready,
/// Response to GetBuffer - here's my ingress buffer for the requested peer.
/// FDs: [memfd, eventfd]
BufferReady {
/// Name of the peer this buffer is for.
peer_name: String,
},
/// Heartbeat response.
Pong,
}
impl MainToChild {
/// Get the peer name, if this message has one.
pub fn peer_name(&self) -> Option<&str> {
match self {
MainToChild::GetBuffer { peer_name, .. } => Some(peer_name),
MainToChild::PutBuffer { peer_name, .. } => Some(peer_name),
MainToChild::RemovePeer { peer_name } => Some(peer_name),
MainToChild::Ping => None,
}
}
/// Get the peer MAC address as a Mac type, if this message has one.
pub fn peer_mac(&self) -> Option<Mac> {
match self {
MainToChild::GetBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)),
MainToChild::PutBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)),
MainToChild::RemovePeer { .. } | MainToChild::Ping => None,
}
}
}
impl ChildToMain {
/// Get the peer name, if this message has one.
pub fn peer_name(&self) -> Option<&str> {
match self {
ChildToMain::Ready | ChildToMain::Pong => None,
ChildToMain::BufferReady { peer_name } => Some(peer_name),
}
}
}
/// Maximum message size in bytes.
pub const MAX_MESSAGE_SIZE: usize = 256;
/// Maximum number of file descriptors per message.
pub const MAX_FDS: usize = 4;
/// Errors that can occur with control channel operations.
#[derive(Debug, Error)]
pub enum ControlError {
#[error("failed to create socketpair: {0}")]
Socketpair(std::io::Error),
#[error("failed to serialize message: {0}")]
Serialize(postcard::Error),
#[error("failed to deserialize message: {0}")]
Deserialize(postcard::Error),
#[error("failed to send message: {0}")]
Send(std::io::Error),
#[error("failed to receive message: {0}")]
Recv(std::io::Error),
#[error("connection closed")]
Closed,
#[error("message too large: {0} bytes")]
MessageTooLarge(usize),
}
/// A control channel endpoint for main↔child communication.
pub struct ControlChannel {
socket: OwnedFd,
}
impl AsRawFd for ControlChannel {
fn as_raw_fd(&self) -> RawFd {
self.socket.as_raw_fd()
}
}
impl ControlChannel {
/// Create a pair of connected control channels.
/// Returns (main_end, child_end).
pub fn pair() -> Result<(Self, Self), ControlError> {
let mut fds = [0i32; 2];
let ret = unsafe {
libc::socketpair(
libc::AF_UNIX,
libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
0,
fds.as_mut_ptr(),
)
};
if ret < 0 {
return Err(ControlError::Socketpair(std::io::Error::last_os_error()));
}
let main_end = Self {
socket: unsafe { OwnedFd::from_raw_fd(fds[0]) },
};
let child_end = Self {
socket: unsafe { OwnedFd::from_raw_fd(fds[1]) },
};
Ok((main_end, child_end))
}
/// Create a ControlChannel from a raw file descriptor.
///
/// # Safety
/// The fd must be a valid, open socket file descriptor.
pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self {
socket: OwnedFd::from_raw_fd(fd),
}
}
/// Create a ControlChannel from an owned file descriptor.
pub fn from_fd(fd: OwnedFd) -> Self {
Self { socket: fd }
}
/// Consume the ControlChannel and return the underlying file descriptor.
pub fn into_fd(self) -> OwnedFd {
self.socket
}
/// Send a message without file descriptors.
pub fn send<M: Serialize>(&self, msg: &M) -> Result<(), ControlError> {
let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?;
if bytes.len() > MAX_MESSAGE_SIZE {
return Err(ControlError::MessageTooLarge(bytes.len()));
}
send_with_fds(self.socket.as_raw_fd(), &bytes, &[])
}
/// Send a message with file descriptors.
pub fn send_with_fds_typed<M: Serialize>(
&self,
msg: &M,
fds: &[RawFd],
) -> Result<(), ControlError> {
let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?;
if bytes.len() > MAX_MESSAGE_SIZE {
return Err(ControlError::MessageTooLarge(bytes.len()));
}
send_with_fds(self.socket.as_raw_fd(), &bytes, fds)
}
/// Receive a message without expecting file descriptors.
pub fn recv<M: for<'de> Deserialize<'de>>(&self) -> Result<M, ControlError> {
let mut buf = [0u8; MAX_MESSAGE_SIZE];
let (n, _fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?;
postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize)
}
/// Receive a message with file descriptors.
/// Returns (message, file_descriptors).
pub fn recv_with_fds_typed<M: for<'de> Deserialize<'de>>(
&self,
) -> Result<(M, Vec<OwnedFd>), ControlError> {
let mut buf = [0u8; MAX_MESSAGE_SIZE];
let (n, fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?;
let msg = postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize)?;
Ok((msg, fds))
}
}
/// Send a message with optional file descriptors via SCM_RIGHTS.
fn send_with_fds(socket_fd: RawFd, data: &[u8], fds: &[RawFd]) -> Result<(), ControlError> {
let mut iov = libc::iovec {
iov_base: data.as_ptr() as *mut libc::c_void,
iov_len: data.len(),
};
// Calculate control message buffer size
let cmsg_space = if fds.is_empty() {
0
} else {
unsafe { libc::CMSG_SPACE(std::mem::size_of_val(fds) as u32) as usize }
};
let mut cmsg_buf = vec![0u8; cmsg_space];
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_iov = &mut iov;
msg.msg_iovlen = 1;
if !fds.is_empty() {
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_space;
// Fill in the control message header
let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
unsafe {
(*cmsg).cmsg_level = libc::SOL_SOCKET;
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
(*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of_val(fds) as u32) as usize;
// Copy file descriptors into cmsg data
let cmsg_data = libc::CMSG_DATA(cmsg) as *mut RawFd;
for (i, &fd) in fds.iter().enumerate() {
*cmsg_data.add(i) = fd;
}
}
}
let ret = unsafe { libc::sendmsg(socket_fd, &msg, 0) };
if ret < 0 {
return Err(ControlError::Send(std::io::Error::last_os_error()));
}
Ok(())
}
/// Receive a message with optional file descriptors via SCM_RIGHTS.
/// Returns (bytes_read, file_descriptors).
fn recv_with_fds(
socket_fd: RawFd,
buf: &mut [u8],
) -> Result<(usize, Vec<OwnedFd>), ControlError> {
let mut iov = libc::iovec {
iov_base: buf.as_mut_ptr() as *mut libc::c_void,
iov_len: buf.len(),
};
// Buffer for control messages (enough for MAX_FDS file descriptors)
let cmsg_space =
unsafe { libc::CMSG_SPACE((MAX_FDS * std::mem::size_of::<RawFd>()) as u32) as usize };
let mut cmsg_buf = vec![0u8; cmsg_space];
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_iov = &mut iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_space;
let n = unsafe { libc::recvmsg(socket_fd, &mut msg, 0) };
if n < 0 {
return Err(ControlError::Recv(std::io::Error::last_os_error()));
}
if n == 0 {
return Err(ControlError::Closed);
}
// Extract file descriptors from control message
let mut fds = Vec::new();
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
while !cmsg.is_null() {
unsafe {
if (*cmsg).cmsg_level == libc::SOL_SOCKET && (*cmsg).cmsg_type == libc::SCM_RIGHTS {
let data_len = (*cmsg).cmsg_len - libc::CMSG_LEN(0) as usize;
let num_fds = data_len / std::mem::size_of::<RawFd>();
let fd_ptr = libc::CMSG_DATA(cmsg) as *const RawFd;
for i in 0..num_fds {
let fd = *fd_ptr.add(i);
fds.push(OwnedFd::from_raw_fd(fd));
}
}
cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
}
}
Ok((n as usize, fds))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn main_to_child_get_buffer_serializes() {
let msg = MainToChild::GetBuffer {
peer_name: "router".to_string(),
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
};
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn main_to_child_put_buffer_serializes() {
let msg = MainToChild::PutBuffer {
peer_name: "client_a".to_string(),
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
broadcast: true,
};
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn main_to_child_remove_peer_serializes() {
let msg = MainToChild::RemovePeer {
peer_name: "client_b".to_string(),
};
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn main_to_child_ping_serializes() {
let msg = MainToChild::Ping;
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn child_to_main_pong_serializes() {
let msg = ChildToMain::Pong;
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn child_to_main_ready_serializes() {
let msg = ChildToMain::Ready;
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn child_to_main_buffer_ready_serializes() {
let msg = ChildToMain::BufferReady {
peer_name: "router".to_string(),
};
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
assert_eq!(decoded, msg);
}
#[test]
fn control_channel_pair_creates_connected_sockets() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
assert!(main_end.as_raw_fd() >= 0);
assert!(child_end.as_raw_fd() >= 0);
assert_ne!(main_end.as_raw_fd(), child_end.as_raw_fd());
}
#[test]
fn send_with_fds_sends_data() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let data = b"hello";
send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send");
// Read on child end to verify
let mut buf = [0u8; 64];
let n = unsafe {
libc::recv(
child_end.as_raw_fd(),
buf.as_mut_ptr() as *mut libc::c_void,
buf.len(),
0,
)
};
assert!(n > 0);
assert_eq!(&buf[..n as usize], data);
}
#[test]
fn send_with_fds_passes_file_descriptors() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
// Create a test eventfd to send
let test_fd = unsafe { libc::eventfd(42, libc::EFD_CLOEXEC) };
assert!(test_fd >= 0);
send_with_fds(main_end.as_raw_fd(), b"fd", &[test_fd]).expect("should send");
// Close our copy - child should have its own
unsafe { libc::close(test_fd) };
// Receive on child end
let mut buf = [0u8; 64];
let mut cmsg_buf = [0u8; 64];
let mut iov = libc::iovec {
iov_base: buf.as_mut_ptr() as *mut libc::c_void,
iov_len: buf.len(),
};
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
msg.msg_iov = &mut iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
msg.msg_controllen = cmsg_buf.len();
let n = unsafe { libc::recvmsg(child_end.as_raw_fd(), &mut msg, 0) };
assert!(n > 0);
// Extract the file descriptor
let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
assert!(!cmsg.is_null());
let received_fd = unsafe { *(libc::CMSG_DATA(cmsg) as *const RawFd) };
assert!(received_fd >= 0);
// Verify we can read the eventfd value (42)
let mut val: u64 = 0;
let ret = unsafe {
libc::read(
received_fd,
&mut val as *mut u64 as *mut libc::c_void,
std::mem::size_of::<u64>(),
)
};
assert_eq!(ret, std::mem::size_of::<u64>() as isize);
assert_eq!(val, 42);
unsafe { libc::close(received_fd) };
}
#[test]
fn recv_with_fds_receives_data() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let data = b"world";
// Send from main
send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send");
// Receive on child
let mut buf = [0u8; 64];
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
assert_eq!(n, data.len());
assert_eq!(&buf[..n], data);
assert!(fds.is_empty());
}
#[test]
fn recv_with_fds_receives_file_descriptors() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
// Create test eventfds
let fd1 = unsafe { libc::eventfd(10, libc::EFD_CLOEXEC) };
let fd2 = unsafe { libc::eventfd(20, libc::EFD_CLOEXEC) };
assert!(fd1 >= 0);
assert!(fd2 >= 0);
send_with_fds(main_end.as_raw_fd(), b"fds", &[fd1, fd2]).expect("should send");
// Close our copies
unsafe {
libc::close(fd1);
libc::close(fd2);
}
// Receive
let mut buf = [0u8; 64];
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
assert_eq!(n, 3);
assert_eq!(&buf[..n], b"fds");
assert_eq!(fds.len(), 2);
// Verify eventfd values
let mut val: u64 = 0;
unsafe {
libc::read(
fds[0].as_raw_fd(),
&mut val as *mut u64 as *mut libc::c_void,
8,
);
}
assert_eq!(val, 10);
val = 0;
unsafe {
libc::read(
fds[1].as_raw_fd(),
&mut val as *mut u64 as *mut libc::c_void,
8,
);
}
assert_eq!(val, 20);
}
#[test]
fn control_channel_send_delivers_message() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let msg = MainToChild::RemovePeer {
peer_name: "client_a".to_string(),
};
main_end.send(&msg).expect("should send");
// Verify by receiving raw bytes
let mut buf = [0u8; MAX_MESSAGE_SIZE];
let (n, _) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
let decoded: MainToChild = postcard::from_bytes(&buf[..n]).expect("should decode");
assert_eq!(decoded, msg);
}
#[test]
fn control_channel_send_with_fds_delivers_message_and_fds() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
// Create test eventfd
let test_fd = unsafe { libc::eventfd(99, libc::EFD_CLOEXEC) };
assert!(test_fd >= 0);
let msg = ChildToMain::BufferReady {
peer_name: "router".to_string(),
};
main_end
.send_with_fds_typed(&msg, &[test_fd])
.expect("should send");
unsafe { libc::close(test_fd) };
// Receive
let mut buf = [0u8; MAX_MESSAGE_SIZE];
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
let decoded: ChildToMain = postcard::from_bytes(&buf[..n]).expect("should decode");
assert_eq!(decoded, msg);
assert_eq!(fds.len(), 1);
// Verify eventfd value
let mut val: u64 = 0;
unsafe {
libc::read(
fds[0].as_raw_fd(),
&mut val as *mut u64 as *mut libc::c_void,
8,
);
}
assert_eq!(val, 99);
}
#[test]
fn control_channel_recv_returns_message() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let msg = MainToChild::RemovePeer {
peer_name: "client_b".to_string(),
};
// Send from main using typed method
main_end.send(&msg).expect("should send");
// Receive on child using typed method
let received: MainToChild = child_end.recv().expect("should receive");
assert_eq!(received, msg);
}
#[test]
fn control_channel_recv_with_fds_returns_message_and_fds() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
// Create test eventfds
let fd1 = unsafe { libc::eventfd(111, libc::EFD_CLOEXEC) };
let fd2 = unsafe { libc::eventfd(222, libc::EFD_CLOEXEC) };
let msg = MainToChild::PutBuffer {
peer_name: "router".to_string(),
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
broadcast: true,
};
main_end.send_with_fds_typed(&msg, &[fd1, fd2]).expect("should send");
unsafe {
libc::close(fd1);
libc::close(fd2);
}
// Receive with typed method
let (received, fds): (MainToChild, _) = child_end.recv_with_fds_typed().expect("should receive");
assert_eq!(received, msg);
assert_eq!(fds.len(), 2);
// Verify eventfd values
let mut val: u64 = 0;
unsafe {
libc::read(fds[0].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8);
}
assert_eq!(val, 111);
val = 0;
unsafe {
libc::read(fds[1].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8);
}
assert_eq!(val, 222);
}
#[test]
fn control_channel_recv_detects_closed_connection() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
// Close the sender end
drop(main_end);
// Receive should return Closed error
let result: Result<MainToChild, _> = child_end.recv();
match result {
Err(ControlError::Closed) => (),
other => panic!("expected Closed error, got {:?}", other),
}
}
#[test]
fn main_to_child_peer_name_helper_returns_correct_name() {
let msg = MainToChild::PutBuffer {
peer_name: "router".to_string(),
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
broadcast: true,
};
assert_eq!(msg.peer_name(), Some("router"));
let msg2 = MainToChild::GetBuffer {
peer_name: "client_a".to_string(),
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
};
assert_eq!(msg2.peer_name(), Some("client_a"));
let msg3 = MainToChild::RemovePeer {
peer_name: "client_b".to_string(),
};
assert_eq!(msg3.peer_name(), Some("client_b"));
assert_eq!(MainToChild::Ping.peer_name(), None);
}
#[test]
fn main_to_child_peer_mac_helper_returns_correct_mac() {
let msg = MainToChild::PutBuffer {
peer_name: "router".to_string(),
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
broadcast: true,
};
assert_eq!(format!("{}", msg.peer_mac().unwrap()), "aa:bb:cc:dd:ee:ff");
let msg2 = MainToChild::GetBuffer {
peer_name: "client_a".to_string(),
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
};
assert_eq!(format!("{}", msg2.peer_mac().unwrap()), "11:22:33:44:55:66");
let msg3 = MainToChild::RemovePeer {
peer_name: "client_b".to_string(),
};
assert!(msg3.peer_mac().is_none());
assert!(MainToChild::Ping.peer_mac().is_none());
}
#[test]
fn child_to_main_peer_name_helper_returns_correct_name() {
let msg = ChildToMain::BufferReady {
peer_name: "router".to_string(),
};
assert_eq!(msg.peer_name(), Some("router"));
assert_eq!(ChildToMain::Ready.peer_name(), None);
assert_eq!(ChildToMain::Pong.peer_name(), None);
}
}

View file

@ -45,6 +45,28 @@ impl<'a> EthernetFrame<'a> {
}
}
/// Validate that a frame's source MAC matches the expected MAC.
/// Returns false if the frame is too short or MAC doesn't match.
pub fn validate_source_mac(frame: &[u8], expected: Mac) -> bool {
if frame.len() < MIN_FRAME_SIZE {
return false;
}
let mut src = [0u8; 6];
src.copy_from_slice(&frame[6..12]);
Mac::from_bytes(src) == expected
}
/// Extract destination MAC from a frame.
/// Returns None if frame is too short.
pub fn extract_dest_mac(frame: &[u8]) -> Option<Mac> {
if frame.len() < 6 {
return None;
}
let mut dest = [0u8; 6];
dest.copy_from_slice(&frame[0..6]);
Some(Mac::from_bytes(dest))
}
#[cfg(test)]
mod tests {
use super::*;
@ -99,4 +121,44 @@ mod tests {
Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66])
);
}
#[test]
fn validate_source_mac_accepts_matching() {
let mut frame = [0u8; 14];
frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
assert!(validate_source_mac(&frame, expected));
}
#[test]
fn validate_source_mac_rejects_mismatch() {
let mut frame = [0u8; 14];
frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let expected = Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]);
assert!(!validate_source_mac(&frame, expected));
}
#[test]
fn validate_source_mac_rejects_short_frame() {
let frame = [0u8; 10]; // Too short
let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
assert!(!validate_source_mac(&frame, expected));
}
#[test]
fn extract_dest_mac_returns_mac() {
let mut frame = [0u8; 14];
frame[0..6].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]);
let result = extract_dest_mac(&frame);
assert_eq!(result, Some(Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66])));
}
#[test]
fn extract_dest_mac_returns_none_for_short_frame() {
let frame = [0u8; 5];
assert!(extract_dest_mac(&frame).is_none());
}
}

View file

@ -1,14 +1,15 @@
//! vm-switch: Virtual L2 switch for VM-to-VM networking.
//!
//! This library provides a vhost-user network backend that enables
//! communication between VMs through a virtual L2 switch with MAC filtering.
//! communication between VMs through a virtual L2 switch with ring buffer
//! forwarding.
//!
//! # Architecture
//!
//! - **ConfigWatcher**: Monitors `/run/vm-switch/` for `.mac` files
//! - **BackendManager**: Coordinates multiple NetBackend instances
//! - **Switch**: L2 switching logic with MAC filtering
//! - **NetBackend**: vhost-user backend for a single VM
//! - **BackendManager**: Coordinates forked child processes for each VM
//! - **Child**: vhost-user backend running in forked process with sandboxing
//! - **RingBuffer**: Inter-process packet transfer via shared memory
//!
//! # Usage
//!
@ -18,23 +19,30 @@
//!
//! let args = Args::parse();
//! let watcher = ConfigWatcher::new(&args.config_dir, 64)?;
//! let manager = BackendManager::new(&args.config_dir);
//! let manager = BackendManager::new(&args.config_dir, SeccompMode::Kill);
//! ```
pub mod args;
pub mod backend;
pub mod child;
pub mod config;
pub mod control;
pub mod frame;
pub mod mac;
pub mod manager;
pub mod switch;
pub mod ring;
pub mod sandbox;
pub mod seccomp;
pub mod watcher;
// Re-export commonly used types
pub use args::{init_logging, Args, LogLevel};
pub use backend::NetBackend;
pub use config::{ConfigEvent, VmConfig, VmRole};
pub use control::{ChildToMain, ControlChannel, ControlError, MainToChild, MAX_FDS, MAX_MESSAGE_SIZE};
pub use mac::Mac;
pub use manager::BackendManager;
pub use switch::{ConnectionId, ForwardDecision, Switch};
pub use manager::{BackendManager, ChildMessage};
pub use ring::{Consumer, Producer, RingError, RING_BUFFER_SIZE, RING_SIZE, SLOT_DATA_SIZE};
pub use sandbox::{
apply_sandbox, enter_user_namespace, setup_filesystem_isolation, SandboxError, SandboxResult,
};
pub use seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompError, SeccompMode};
pub use watcher::ConfigWatcher;

View file

@ -1,30 +1,99 @@
use std::path::PathBuf;
use std::time::Duration;
use clap::Parser;
use tokio::signal;
use tokio::signal::unix::{signal as unix_signal, SignalKind};
use tokio::sync::broadcast;
use vm_switch::{Args, BackendManager, ConfigWatcher, init_logging};
use vm_switch::{apply_sandbox, apply_main_seccomp, Args, BackendManager, ConfigWatcher, SandboxResult, SeccompMode, init_logging};
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
init_logging(args.log_level);
tracing::info!(config_dir = ?args.config_dir, "Starting vm-switch");
// Ensure config directory exists
/// Apply sandbox before tokio runtime starts (must be single-threaded).
///
/// This function may fork. If we are the parent wrapper, it exits with
/// the child's exit code. The actual vm-switch logic runs in the child.
fn setup_sandbox(args: &Args) -> Result<PathBuf, Box<dyn std::error::Error>> {
// Ensure config directory exists before sandboxing
if !args.config_dir.exists() {
std::fs::create_dir_all(&args.config_dir)?;
tracing::info!(path = ?args.config_dir, "Created config directory");
eprintln!("Created config directory: {:?}", args.config_dir);
}
// Create watcher and manager
let mut watcher = ConfigWatcher::new(&args.config_dir, 64)?;
let mut manager = BackendManager::new(&args.config_dir);
// Apply sandbox unless disabled
let config_path = if args.no_sandbox {
eprintln!("Sandboxing disabled via --no-sandbox");
args.config_dir.clone()
} else {
match apply_sandbox(&args.config_dir) {
Ok(SandboxResult::Parent(exit_code)) => {
// We are the wrapper parent - propagate child's exit code
std::process::exit(exit_code);
}
Ok(SandboxResult::Sandboxed(path)) => {
eprintln!("Sandbox applied, config at {:?}", path);
path
}
Err(e) => {
eprintln!("Failed to apply sandbox: {}", e);
return Err(e.into());
}
}
};
// Apply seccomp filter
if args.seccomp_mode != SeccompMode::Disabled {
apply_main_seccomp(args.seccomp_mode)?;
eprintln!("Seccomp applied (mode: {:?})", args.seccomp_mode);
} else {
eprintln!("Seccomp disabled via --seccomp-mode=disabled");
}
Ok(config_path)
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
let args = Args::parse();
// Apply sandbox BEFORE tokio runtime starts (unshare requires single-threaded)
let config_dir = setup_sandbox(&args)?;
// Now start tokio runtime and run async main
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?
.block_on(async_main(args, config_dir))
}
async fn async_main(args: Args, config_dir: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
init_logging(args.log_level);
tracing::info!(config_dir = ?config_dir, "Starting vm-switch");
// Create watcher and manager using sandboxed path
let mut watcher = ConfigWatcher::new(&config_dir, 64)?;
let (mut manager, mut child_rx) = BackendManager::new(&config_dir, args.seccomp_mode);
// Create SIGCHLD signal stream for child process monitoring
let mut sigchld = unix_signal(SignalKind::child())
.map_err(|e| format!("failed to create SIGCHLD signal stream: {}", e))?;
// Create SIGTERM signal stream for graceful shutdown.
// This is critical: as PID 1 in a PID namespace, the kernel only delivers
// signals for which a handler is registered. Without this, SIGTERM is
// silently dropped and systemd has to SIGKILL after timeout.
let mut sigterm = unix_signal(SignalKind::terminate())
.map_err(|e| format!("failed to create SIGTERM signal stream: {}", e))?;
// Get receiver for events (includes initial scan)
let mut rx = watcher.take_receiver();
tracing::info!("Processing configuration events (Ctrl+C to stop)...");
// Heartbeat: ping workers every second, check for responses after 100ms
let mut ping_interval = tokio::time::interval(Duration::from_secs(1));
let ping_timeout = tokio::time::sleep(Duration::from_secs(86400));
tokio::pin!(ping_timeout);
// Process events until shutdown signal
loop {
tokio::select! {
@ -43,8 +112,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
}
}
Some(msg) = child_rx.recv() => {
manager.handle_child_message(msg);
}
_ = sigchld.recv() => {
manager.reap_children();
}
_ = ping_interval.tick() => {
manager.send_pings();
ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(100));
}
_ = &mut ping_timeout => {
manager.check_ping_timeouts();
ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_secs(86400));
}
_ = sigterm.recv() => {
tracing::info!("Received SIGTERM, shutting down");
break;
}
_ = signal::ctrl_c() => {
tracing::info!("Received shutdown signal");
tracing::info!("Received SIGINT, shutting down");
break;
}
}

File diff suppressed because it is too large Load diff

865
vm-switch/src/ring.rs Normal file
View file

@ -0,0 +1,865 @@
//! SPSC ring buffer for cross-process frame passing.
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicU64, Ordering};
use nix::libc;
use crate::frame::MAX_FRAME_SIZE;
/// Errors that can occur with ring buffer operations.
#[derive(Debug, thiserror::Error)]
pub enum RingError {
#[error("failed to create memfd: {0}")]
MemfdCreate(std::io::Error),
#[error("failed to set memfd size: {0}")]
Ftruncate(std::io::Error),
#[error("failed to mmap: {0}")]
Mmap(std::io::Error),
#[error("failed to create eventfd: {0}")]
EventfdCreate(std::io::Error),
}
/// Slot data size - accommodates jumbo frames with headroom.
pub const SLOT_DATA_SIZE: usize = MAX_FRAME_SIZE + 256;
/// Number of slots in the ring buffer.
pub const RING_SIZE: usize = 64;
/// Total size of the ring buffer in bytes.
pub const RING_BUFFER_SIZE: usize = std::mem::size_of::<RingHeader>()
+ RING_SIZE * std::mem::size_of::<Slot>();
/// Ring buffer header containing head and tail indices.
/// Padded to cache line size to prevent false sharing.
#[repr(C, align(64))]
pub struct RingHeader {
/// Next write position (only producer modifies).
head: AtomicU64,
/// Next read position (only consumer modifies).
tail: AtomicU64,
}
impl RingHeader {
/// Create a new header with head and tail at 0.
pub fn new() -> Self {
Self {
head: AtomicU64::new(0),
tail: AtomicU64::new(0),
}
}
/// Load the current head value.
pub fn load_head(&self, order: Ordering) -> u64 {
self.head.load(order)
}
/// Load the current tail value.
pub fn load_tail(&self, order: Ordering) -> u64 {
self.tail.load(order)
}
/// Store a new head value.
pub fn store_head(&self, val: u64, order: Ordering) {
self.head.store(val, order)
}
/// Store a new tail value.
pub fn store_tail(&self, val: u64, order: Ordering) {
self.tail.store(val, order)
}
}
impl Default for RingHeader {
fn default() -> Self {
Self::new()
}
}
/// A single slot in the ring buffer.
/// Contains frame length and data.
#[repr(C)]
pub struct Slot {
/// Frame length in bytes. 0 means slot is empty/unused.
len: u32,
/// Padding for alignment.
_padding: u32,
/// Frame data buffer.
data: [u8; SLOT_DATA_SIZE],
}
impl Slot {
/// Create a new empty slot.
pub fn new() -> Self {
Self {
len: 0,
_padding: 0,
data: [0; SLOT_DATA_SIZE],
}
}
/// Write frame data to this slot.
/// Returns false if frame is too large.
pub fn write(&mut self, frame: &[u8]) -> bool {
if frame.len() > SLOT_DATA_SIZE {
return false;
}
self.data[..frame.len()].copy_from_slice(frame);
self.len = frame.len() as u32;
true
}
/// Read frame data from this slot.
/// Returns None if slot is empty.
pub fn read(&self) -> Option<&[u8]> {
if self.len == 0 {
return None;
}
Some(&self.data[..self.len as usize])
}
/// Clear the slot.
pub fn clear(&mut self) {
self.len = 0;
}
/// Returns true if slot is empty.
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl Default for Slot {
fn default() -> Self {
Self::new()
}
}
/// The in-memory layout of a ring buffer.
/// This struct is mapped directly over shared memory.
#[repr(C)]
pub struct RingBuffer {
/// Header with head/tail atomics.
header: RingHeader,
/// Fixed-size array of slots.
slots: [Slot; RING_SIZE],
}
impl RingBuffer {
/// Returns the number of items currently in the buffer.
pub fn len(&self) -> usize {
let head = self.header.load_head(Ordering::Relaxed);
let tail = self.header.load_tail(Ordering::Relaxed);
((head + RING_SIZE as u64 - tail) % RING_SIZE as u64) as usize
}
/// Returns true if the buffer is empty.
pub fn is_empty(&self) -> bool {
self.header.load_head(Ordering::Relaxed) == self.header.load_tail(Ordering::Relaxed)
}
/// Returns true if the buffer is full.
pub fn is_full(&self) -> bool {
let head = self.header.load_head(Ordering::Relaxed);
let tail = self.header.load_tail(Ordering::Relaxed);
(head + 1) % RING_SIZE as u64 == tail
}
/// Get mutable reference to slot at index.
pub fn slot_mut(&mut self, index: usize) -> &mut Slot {
&mut self.slots[index]
}
/// Get reference to slot at index.
pub fn slot(&self, index: usize) -> &Slot {
&self.slots[index]
}
/// Get reference to the header.
pub fn header(&self) -> &RingHeader {
&self.header
}
}
/// Producer side of an SPSC ring buffer.
/// Creates and owns the underlying shared memory.
pub struct Producer {
/// Pointer to the mapped ring buffer.
ring: NonNull<RingBuffer>,
/// The memfd backing the ring buffer.
memfd: OwnedFd,
/// Eventfd for signaling consumer.
eventfd: OwnedFd,
}
// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access.
// The producer has exclusive write access in SPSC pattern.
unsafe impl Send for Producer {}
unsafe impl Sync for Producer {}
impl Producer {
/// Create a new producer with its own shared memory region.
pub fn new() -> Result<Self, RingError> {
// Create memfd for shared memory
let memfd = unsafe {
let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC);
if fd < 0 {
return Err(RingError::MemfdCreate(std::io::Error::last_os_error()));
}
OwnedFd::from_raw_fd(fd)
};
// Set size
let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) };
if ret < 0 {
return Err(RingError::Ftruncate(std::io::Error::last_os_error()));
}
// Map the memory
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
RING_BUFFER_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
memfd.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(RingError::Mmap(std::io::Error::last_os_error()));
}
// Initialize the ring buffer
let ring = ptr as *mut RingBuffer;
unsafe {
std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new());
for i in 0..RING_SIZE {
std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new());
}
}
// Create eventfd for signaling
let eventfd = unsafe {
let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK);
if fd < 0 {
libc::munmap(ptr, RING_BUFFER_SIZE);
return Err(RingError::EventfdCreate(std::io::Error::last_os_error()));
}
OwnedFd::from_raw_fd(fd)
};
Ok(Self {
ring: NonNull::new(ring).unwrap(),
memfd,
eventfd,
})
}
/// Get the memfd for sharing with consumer.
pub fn memfd(&self) -> &OwnedFd {
&self.memfd
}
/// Get the eventfd for sharing with consumer.
pub fn eventfd(&self) -> &OwnedFd {
&self.eventfd
}
/// Create a producer by mapping an existing consumer's shared memory.
/// Use this when you receive FDs from a remote consumer and want to produce into their buffer.
pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result<Self, RingError> {
// Map the memory
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
RING_BUFFER_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
memfd.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(RingError::Mmap(std::io::Error::last_os_error()));
}
let ring = ptr as *mut RingBuffer;
Ok(Self {
ring: NonNull::new(ring).unwrap(),
memfd,
eventfd,
})
}
/// Get reference to the ring buffer.
fn ring(&self) -> &RingBuffer {
// SAFETY: Pointer is valid for lifetime of Producer.
unsafe { self.ring.as_ref() }
}
/// Write a frame to the slot at the given index using the raw pointer.
///
/// SAFETY: Producer has exclusive write access to slots in the SPSC pattern.
/// Only the producer calls this, and only for the slot at `head`.
/// Uses raw pointer writes to avoid creating `&mut RingBuffer` which would
/// violate Rust's aliasing rules (since `&RingBuffer` references coexist).
unsafe fn write_slot(&self, index: usize, frame: &[u8]) -> bool {
if frame.len() > SLOT_DATA_SIZE {
return false;
}
let ring_ptr = self.ring.as_ptr();
let slot_ptr = std::ptr::addr_of_mut!((*ring_ptr).slots[index]);
let data_ptr = std::ptr::addr_of_mut!((*slot_ptr).data) as *mut u8;
std::ptr::copy_nonoverlapping(frame.as_ptr(), data_ptr, frame.len());
std::ptr::addr_of_mut!((*slot_ptr).len).write(frame.len() as u32);
true
}
/// Push a frame into the ring buffer.
/// Returns true if successful, false if buffer is full (frame dropped).
pub fn push(&self, frame: &[u8]) -> bool {
let ring = self.ring();
let head = ring.header().load_head(Ordering::Relaxed);
let tail = ring.header().load_tail(Ordering::Acquire);
// Check if full (one slot always empty to distinguish full from empty)
let next_head = (head + 1) % RING_SIZE as u64;
if next_head == tail {
return false;
}
// Write to slot via raw pointer (avoids creating &mut RingBuffer alias)
// SAFETY: Producer has exclusive write access to the slot at head
if !unsafe { self.write_slot(head as usize, frame) } {
return false;
}
// Memory fence to ensure data is visible before advancing head
std::sync::atomic::fence(Ordering::Release);
// Advance head
ring.header().store_head(next_head, Ordering::Relaxed);
// Signal consumer if buffer was empty
if head == tail {
let val: u64 = 1;
unsafe {
libc::write(
self.eventfd.as_raw_fd(),
&val as *const u64 as *const libc::c_void,
std::mem::size_of::<u64>(),
);
}
}
true
}
}
impl Drop for Producer {
fn drop(&mut self) {
// SAFETY: We own the mapping and it's the correct size.
unsafe {
libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE);
}
}
}
/// Consumer side of an SPSC ring buffer.
/// Can either create its own shared memory (via `new()`) or map memory from a producer (via `from_fds()`).
pub struct Consumer {
/// Pointer to the mapped ring buffer.
ring: NonNull<RingBuffer>,
/// The memfd backing the ring buffer.
memfd: OwnedFd,
/// Eventfd for receiving signals from producer.
eventfd: OwnedFd,
}
// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access.
unsafe impl Send for Consumer {}
impl Consumer {
/// Create a new consumer with its own shared memory region.
/// Use this when YOU will be the consumer and share FDs with a remote producer.
pub fn new() -> Result<Self, RingError> {
// Create memfd for shared memory
let memfd = unsafe {
let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC);
if fd < 0 {
return Err(RingError::MemfdCreate(std::io::Error::last_os_error()));
}
OwnedFd::from_raw_fd(fd)
};
// Set size
let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) };
if ret < 0 {
return Err(RingError::Ftruncate(std::io::Error::last_os_error()));
}
// Map the memory
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
RING_BUFFER_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
memfd.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(RingError::Mmap(std::io::Error::last_os_error()));
}
// Initialize the ring buffer
let ring = ptr as *mut RingBuffer;
unsafe {
std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new());
for i in 0..RING_SIZE {
std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new());
}
}
// Create eventfd for signaling
let eventfd = unsafe {
let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK);
if fd < 0 {
libc::munmap(ptr, RING_BUFFER_SIZE);
return Err(RingError::EventfdCreate(std::io::Error::last_os_error()));
}
OwnedFd::from_raw_fd(fd)
};
Ok(Self {
ring: NonNull::new(ring).unwrap(),
memfd,
eventfd,
})
}
/// Get reference to the ring buffer.
fn ring(&self) -> &RingBuffer {
// SAFETY: Pointer is valid for lifetime of Consumer.
unsafe { self.ring.as_ref() }
}
/// Create a consumer by mapping an existing producer's shared memory.
pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result<Self, RingError> {
// Map the memory
let ptr = unsafe {
libc::mmap(
std::ptr::null_mut(),
RING_BUFFER_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_SHARED,
memfd.as_raw_fd(),
0,
)
};
if ptr == libc::MAP_FAILED {
return Err(RingError::Mmap(std::io::Error::last_os_error()));
}
let ring = ptr as *mut RingBuffer;
Ok(Self {
ring: NonNull::new(ring).unwrap(),
memfd,
eventfd,
})
}
/// Pop a frame from the ring buffer.
/// Returns None if buffer is empty.
pub fn pop(&self) -> Option<Vec<u8>> {
let ring = self.ring();
let tail = ring.header().load_tail(Ordering::Relaxed);
let head = ring.header().load_head(Ordering::Acquire);
// Check if empty
if head == tail {
return None;
}
// Read from slot
let data = ring.slot(tail as usize).read()?.to_vec();
// Advance tail
let next_tail = (tail + 1) % RING_SIZE as u64;
ring.header().store_tail(next_tail, Ordering::Release);
Some(data)
}
/// Get the memfd for sharing with producer.
pub fn memfd(&self) -> &OwnedFd {
&self.memfd
}
/// Get the eventfd for polling.
pub fn eventfd(&self) -> &OwnedFd {
&self.eventfd
}
/// Drain the eventfd, returning the notification count.
/// Returns 0 if no notifications pending.
pub fn drain_eventfd(&self) -> u64 {
let mut val: u64 = 0;
let ret = unsafe {
libc::read(
self.eventfd.as_raw_fd(),
&mut val as *mut u64 as *mut libc::c_void,
std::mem::size_of::<u64>(),
)
};
if ret == std::mem::size_of::<u64>() as isize {
val
} else {
0
}
}
}
impl Drop for Consumer {
fn drop(&mut self) {
// SAFETY: We own the mapping and it's the correct size.
unsafe {
libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn slot_write_and_read() {
let mut slot = Slot::new();
let frame = [1u8, 2, 3, 4, 5];
assert!(slot.write(&frame));
let read_data = slot.read().expect("slot should have data");
assert_eq!(read_data, &frame);
}
#[test]
fn slot_clear_resets_to_empty() {
let mut slot = Slot::new();
slot.write(&[1, 2, 3]);
slot.clear();
assert!(slot.is_empty());
assert!(slot.read().is_none());
}
#[test]
fn slot_rejects_oversized_frame() {
let mut slot = Slot::new();
let oversized = vec![0u8; SLOT_DATA_SIZE + 1];
assert!(!slot.write(&oversized));
assert!(slot.is_empty());
}
#[test]
fn header_stores_and_loads_head() {
let header = RingHeader::new();
header.store_head(42, Ordering::Relaxed);
assert_eq!(header.load_head(Ordering::Relaxed), 42);
}
#[test]
fn header_stores_and_loads_tail() {
let header = RingHeader::new();
header.store_tail(99, Ordering::Relaxed);
assert_eq!(header.load_tail(Ordering::Relaxed), 99);
}
#[test]
fn ring_buffer_len_reflects_items() {
// We'll test via header manipulation since we don't have push/pop yet
let mut buffer = std::mem::MaybeUninit::<RingBuffer>::uninit();
// SAFETY: We're testing the layout, initializing header manually
let buffer = unsafe {
let ptr = buffer.as_mut_ptr();
std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new());
// Initialize a few slots
for i in 0..RING_SIZE {
std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new());
}
buffer.assume_init_mut()
};
// Initially empty
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
// Simulate adding 3 items by advancing head
buffer.header().store_head(3, Ordering::Relaxed);
assert_eq!(buffer.len(), 3);
assert!(!buffer.is_empty());
}
#[test]
fn ring_buffer_is_full_when_one_slot_remains() {
let mut buffer = std::mem::MaybeUninit::<RingBuffer>::uninit();
let buffer = unsafe {
let ptr = buffer.as_mut_ptr();
std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new());
for i in 0..RING_SIZE {
std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new());
}
buffer.assume_init_mut()
};
// Head at RING_SIZE-1, tail at 0 means buffer is full
// (we keep one slot empty to distinguish full from empty)
buffer.header().store_head((RING_SIZE - 1) as u64, Ordering::Relaxed);
buffer.header().store_tail(0, Ordering::Relaxed);
assert!(buffer.is_full());
assert_eq!(buffer.len(), RING_SIZE - 1);
}
#[test]
fn producer_new_creates_valid_buffer() {
let producer = Producer::new().expect("should create producer");
// memfd and eventfd should be valid file descriptors
assert!(producer.memfd().as_raw_fd() >= 0);
assert!(producer.eventfd().as_raw_fd() >= 0);
}
#[test]
fn consumer_maps_producer_memory() {
let producer = Producer::new().expect("should create producer");
// Duplicate FDs (simulating what would happen with SCM_RIGHTS)
let memfd_dup = unsafe {
let fd = libc::dup(producer.memfd().as_raw_fd());
assert!(fd >= 0);
OwnedFd::from_raw_fd(fd)
};
let eventfd_dup = unsafe {
let fd = libc::dup(producer.eventfd().as_raw_fd());
assert!(fd >= 0);
OwnedFd::from_raw_fd(fd)
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
assert!(consumer.eventfd().as_raw_fd() >= 0);
}
#[test]
fn producer_push_adds_frame() {
let producer = Producer::new().expect("should create producer");
let frame = [1u8, 2, 3, 4, 5];
let result = producer.push(&frame);
assert!(result);
// Verify head advanced
assert_eq!(producer.ring().header().load_head(Ordering::Relaxed), 1);
}
#[test]
fn consumer_pop_returns_pushed_frame() {
let producer = Producer::new().expect("should create producer");
let frame = [1u8, 2, 3, 4, 5];
producer.push(&frame);
// Create consumer with duplicated FDs
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
let popped = consumer.pop();
assert!(popped.is_some());
assert_eq!(popped.unwrap(), frame);
}
#[test]
fn ring_buffer_maintains_fifo_order() {
let producer = Producer::new().expect("should create producer");
producer.push(&[1]);
producer.push(&[2]);
producer.push(&[3]);
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
assert_eq!(consumer.pop().unwrap(), vec![1]);
assert_eq!(consumer.pop().unwrap(), vec![2]);
assert_eq!(consumer.pop().unwrap(), vec![3]);
assert!(consumer.pop().is_none());
}
#[test]
fn producer_drops_frame_when_full() {
let producer = Producer::new().expect("should create producer");
// Fill the buffer (RING_SIZE - 1 slots usable)
for i in 0..(RING_SIZE - 1) {
assert!(producer.push(&[i as u8]), "push {} should succeed", i);
}
// This should fail - buffer full
assert!(!producer.push(&[255]));
// Create consumer and verify we can pop all items
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
for i in 0..(RING_SIZE - 1) {
let data = consumer.pop().expect("should have data");
assert_eq!(data, vec![i as u8]);
}
assert!(consumer.pop().is_none());
}
#[test]
fn ring_buffer_wraps_around_correctly() {
let producer = Producer::new().expect("should create producer");
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
// Push and pop more items than RING_SIZE to test wraparound
for round in 0..3 {
for i in 0..(RING_SIZE - 1) {
let val = (round * RING_SIZE + i) as u8;
assert!(producer.push(&[val]));
}
for i in 0..(RING_SIZE - 1) {
let val = (round * RING_SIZE + i) as u8;
assert_eq!(consumer.pop().unwrap(), vec![val]);
}
}
}
#[test]
fn consumer_drain_eventfd_clears_notifications() {
let producer = Producer::new().expect("should create producer");
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
};
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
// Push triggers eventfd write
producer.push(&[1]);
producer.push(&[2]);
// Drain should return the accumulated count
let count = consumer.drain_eventfd();
assert!(count > 0);
// Second drain should return 0 (nothing new)
let count2 = consumer.drain_eventfd();
assert_eq!(count2, 0);
}
#[test]
fn consumer_new_creates_valid_buffer() {
let consumer = Consumer::new().expect("should create consumer");
// memfd and eventfd should be valid file descriptors
assert!(consumer.memfd().as_raw_fd() >= 0);
assert!(consumer.eventfd().as_raw_fd() >= 0);
}
#[test]
fn producer_from_fds_maps_consumer_memory() {
// Consumer creates buffer, producer connects to it
let consumer = Consumer::new().expect("should create consumer");
// Duplicate FDs (simulating what would happen with SCM_RIGHTS)
let memfd_dup = unsafe {
let fd = libc::dup(consumer.memfd().as_raw_fd());
assert!(fd >= 0);
OwnedFd::from_raw_fd(fd)
};
let eventfd_dup = unsafe {
let fd = libc::dup(consumer.eventfd().as_raw_fd());
assert!(fd >= 0);
OwnedFd::from_raw_fd(fd)
};
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
assert!(producer.eventfd().as_raw_fd() >= 0);
}
#[test]
fn consumer_created_buffer_receives_from_producer() {
// Consumer creates buffer, shares with remote producer
let consumer = Consumer::new().expect("should create consumer");
// Create producer from consumer's FDs
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd()))
};
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
// Producer pushes, consumer receives
let frame = [1u8, 2, 3, 4, 5];
assert!(producer.push(&frame));
let popped = consumer.pop();
assert!(popped.is_some());
assert_eq!(popped.unwrap(), frame);
}
#[test]
fn consumer_created_buffer_maintains_fifo_order() {
let consumer = Consumer::new().expect("should create consumer");
let memfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd()))
};
let eventfd_dup = unsafe {
OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd()))
};
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
producer.push(&[1]);
producer.push(&[2]);
producer.push(&[3]);
assert_eq!(consumer.pop().unwrap(), vec![1]);
assert_eq!(consumer.pop().unwrap(), vec![2]);
assert_eq!(consumer.pop().unwrap(), vec![3]);
assert!(consumer.pop().is_none());
}
}

564
vm-switch/src/sandbox.rs Normal file
View file

@ -0,0 +1,564 @@
//! Namespace sandboxing for process isolation.
use nix::mount::{mount, MsFlags};
use nix::sched::{unshare, CloneFlags};
use nix::sys::signal::{self, SaFlags, SigAction, SigHandler, SigSet, Signal};
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{chdir, fork, pivot_root, ForkResult, Gid, Pid, Uid};
use std::fs;
use std::path::Path;
use std::sync::atomic::{AtomicI32, Ordering};
use thiserror::Error;
/// Errors that can occur during sandbox setup.
#[derive(Debug, Error)]
pub enum SandboxError {
#[error("failed to unshare namespace: {0}")]
Unshare(#[source] nix::Error),
#[error("fork failed: {0}")]
Fork(#[source] nix::Error),
#[error("failed to write {path}: {source}")]
WriteFile {
path: String,
#[source]
source: std::io::Error,
},
#[error("failed to {operation} {target}: {source}")]
Mount {
operation: String,
target: String,
#[source]
source: nix::Error,
},
#[error("failed to create directory {path}: {source}")]
Mkdir {
path: String,
#[source]
source: std::io::Error,
},
#[error("failed to pivot_root: {0}")]
PivotRoot(#[source] nix::Error),
#[error("failed to change directory to {path}: {source}")]
Chdir {
path: String,
#[source]
source: nix::Error,
},
}
/// Result of applying sandbox.
#[derive(Debug)]
pub enum SandboxResult {
/// We are the wrapper parent. Contains child's exit code.
Parent(i32),
/// Sandbox applied successfully. Contains new config path.
Sandboxed(std::path::PathBuf),
}
/// Generate the content for /proc/self/uid_map.
///
/// Maps the given outside UID to UID 0 inside the namespace.
fn generate_uid_map(outside_uid: u32) -> String {
format!("0 {} 1\n", outside_uid)
}
/// Generate the content for /proc/self/gid_map.
///
/// Maps the given outside GID to GID 0 inside the namespace.
fn generate_gid_map(outside_gid: u32) -> String {
format!("0 {} 1\n", outside_gid)
}
/// Write uid_map and gid_map files for the current process.
///
/// Must be called after unshare(CLONE_NEWUSER).
/// Writes "deny" to setgroups before gid_map (required by kernel).
fn write_uid_gid_maps(outside_uid: Uid, outside_gid: Gid) -> Result<(), SandboxError> {
// Must deny setgroups before writing gid_map (kernel requirement)
fs::write("/proc/self/setgroups", "deny").map_err(|e| SandboxError::WriteFile {
path: "/proc/self/setgroups".to_string(),
source: e,
})?;
// Write uid_map
let uid_content = generate_uid_map(outside_uid.as_raw());
fs::write("/proc/self/uid_map", &uid_content).map_err(|e| SandboxError::WriteFile {
path: "/proc/self/uid_map".to_string(),
source: e,
})?;
// Write gid_map
let gid_content = generate_gid_map(outside_gid.as_raw());
fs::write("/proc/self/gid_map", &gid_content).map_err(|e| SandboxError::WriteFile {
path: "/proc/self/gid_map".to_string(),
source: e,
})?;
Ok(())
}
/// Global storing the child PID for the signal forwarding handler.
static CHILD_PID: AtomicI32 = AtomicI32::new(0);
/// Signal handler that forwards SIGTERM to the child process.
extern "C" fn forward_signal(_sig: libc::c_int) {
let pid = CHILD_PID.load(Ordering::Relaxed);
if pid > 0 {
unsafe { libc::kill(pid, libc::SIGTERM) };
}
}
/// Install a SIGTERM handler on the parent that forwards to the child.
fn install_forwarding_handler(child: Pid) {
CHILD_PID.store(child.as_raw(), Ordering::Relaxed);
let action = SigAction::new(
SigHandler::Handler(forward_signal),
SaFlags::empty(),
SigSet::empty(),
);
// Safety: forward_signal is async-signal-safe (only calls kill)
unsafe { signal::sigaction(Signal::SIGTERM, &action) }.ok();
}
/// Fork into new PID and mount namespaces.
///
/// This function:
/// 1. Calls `unshare(CLONE_NEWPID | CLONE_NEWNS)` to create new namespaces
/// 2. Makes all mounts private (prevents propagation)
/// 3. Forks - the child becomes PID 1 in the new PID namespace
/// 4. Parent waits for child and returns its exit status
/// 5. Child continues execution
///
/// Returns:
/// - `Ok(Some(exit_code))` in the parent (should propagate and exit)
/// - `Ok(None)` in the child (continue with sandbox setup)
/// - `Err` on failure
///
/// Must be called AFTER entering user namespace and BEFORE
/// starting any multi-threaded runtime.
pub fn fork_into_pid_namespace() -> Result<Option<i32>, SandboxError> {
// Create new PID namespace and mount namespace together
// The mount namespace is needed so we can mount procfs for the new PID namespace
unshare(CloneFlags::CLONE_NEWPID | CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?;
// Make all mounts private to prevent propagation
mount(
None::<&str>,
"/",
None::<&str>,
MsFlags::MS_REC | MsFlags::MS_PRIVATE,
None::<&str>,
)
.map_err(|e| SandboxError::Mount {
operation: "make private".to_string(),
target: "/".to_string(),
source: e,
})?;
// Fork - child becomes PID 1 in the new namespace
match unsafe { fork() }.map_err(SandboxError::Fork)? {
ForkResult::Parent { child } => {
// Install signal handler that forwards SIGTERM to the child.
// Without this, SIGTERM kills the parent (default action) without
// notifying the child, which as PID 1 in a namespace ignores
// unhandled signals.
install_forwarding_handler(child);
// Parent: wait for child and collect exit status
loop {
match waitpid(child, None) {
Ok(WaitStatus::Exited(_, code)) => {
return Ok(Some(code));
}
Ok(WaitStatus::Signaled(_, sig, _)) => {
// Child killed by signal, propagate as exit code
return Ok(Some(128 + sig as i32));
}
Ok(_) => continue, // Stopped/continued, keep waiting
Err(nix::Error::EINTR) => continue, // Interrupted, retry
Err(e) => {
return Err(SandboxError::Fork(e));
}
}
}
}
ForkResult::Child => {
// Child: we are now PID 1 in the new namespace
Ok(None)
}
}
}
/// Enter a new user namespace.
///
/// After this call:
/// - The process appears to run as UID 0 / GID 0 inside the namespace
/// - The process can create other namespaces (mount, IPC, network, PID)
/// - The process has no capabilities in the parent namespace
///
/// Must be called before any other namespace operations.
pub fn enter_user_namespace() -> Result<(), SandboxError> {
// Capture current UID/GID before unshare
let outside_uid = Uid::current();
let outside_gid = Gid::current();
// Create new user namespace
unshare(CloneFlags::CLONE_NEWUSER).map_err(SandboxError::Unshare)?;
// Set up UID/GID mappings
write_uid_gid_maps(outside_uid, outside_gid)?;
Ok(())
}
/// Enter a new mount namespace with private mounts.
///
/// After this call:
/// - Mount changes are isolated from the parent namespace
/// - All mounts are marked as private (no propagation)
///
/// Must be called after `enter_user_namespace()`.
pub fn enter_mount_namespace() -> Result<(), SandboxError> {
// Create new mount namespace
unshare(CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?;
// Make all mounts private to prevent propagation
mount(
None::<&str>,
"/",
None::<&str>,
MsFlags::MS_REC | MsFlags::MS_PRIVATE,
None::<&str>,
)
.map_err(|e| SandboxError::Mount {
operation: "make private".to_string(),
target: "/".to_string(),
source: e,
})?;
Ok(())
}
/// Create minimal /dev with essential devices.
///
/// Bind-mounts /dev/null, /dev/zero, and /dev/urandom from the host.
/// These are required for basic operation (logging, random numbers).
fn create_minimal_dev(new_root: &Path) -> Result<(), SandboxError> {
let dev_dir = new_root.join("dev");
// Device nodes to bind-mount from host
let devices = ["null", "zero", "urandom"];
for device in devices {
let target = dev_dir.join(device);
// Create empty file as mount point
fs::write(&target, b"").map_err(|e| SandboxError::WriteFile {
path: target.display().to_string(),
source: e,
})?;
// Bind-mount the device from host
let source = Path::new("/dev").join(device);
mount(
Some(&source),
&target,
None::<&str>,
MsFlags::MS_BIND,
None::<&str>,
)
.map_err(|e| SandboxError::Mount {
operation: "bind".to_string(),
target: target.display().to_string(),
source: e,
})?;
}
Ok(())
}
/// Create a minimal root filesystem on tmpfs.
///
/// Creates:
/// - tmpfs mounted at a temporary location
/// - /proc (mount point)
/// - /dev with null, zero, urandom
/// - /tmp (empty tmpfs)
/// - /config (bind-mount of config_dir)
///
/// Returns the path to the new root.
fn setup_minimal_root(config_dir: &Path) -> Result<std::path::PathBuf, SandboxError> {
// Create tmpfs for new root
let new_root = Path::new("/tmp/sandbox-root");
fs::create_dir_all(new_root).map_err(|e| SandboxError::Mkdir {
path: new_root.display().to_string(),
source: e,
})?;
// Mount tmpfs on new root
mount(
Some("tmpfs"),
new_root,
Some("tmpfs"),
MsFlags::MS_NOSUID | MsFlags::MS_NODEV,
Some("size=16M,mode=0755"),
)
.map_err(|e| SandboxError::Mount {
operation: "mount tmpfs".to_string(),
target: new_root.display().to_string(),
source: e,
})?;
// Create directory structure
let dirs = ["proc", "dev", "tmp", "config", "old_root"];
for dir in dirs {
let path = new_root.join(dir);
fs::create_dir_all(&path).map_err(|e| SandboxError::Mkdir {
path: path.display().to_string(),
source: e,
})?;
}
// Set up /dev
create_minimal_dev(new_root)?;
// Bind-mount config directory
let config_target = new_root.join("config");
mount(
Some(config_dir),
&config_target,
None::<&str>,
MsFlags::MS_BIND | MsFlags::MS_REC,
None::<&str>,
)
.map_err(|e| SandboxError::Mount {
operation: "bind config".to_string(),
target: config_target.display().to_string(),
source: e,
})?;
// Mount /proc BEFORE pivot_root (required for proc mount to work)
// This may fail if we don't own a PID namespace (EPERM), which is fine
// for standalone filesystem isolation without apply_sandbox()
let proc_target = new_root.join("proc");
match mount(
Some("proc"),
&proc_target,
Some("proc"),
MsFlags::MS_NOSUID | MsFlags::MS_NODEV | MsFlags::MS_NOEXEC,
None::<&str>,
) {
Ok(()) => {}
Err(nix::Error::EPERM) => {
// Can't mount procfs without owning a PID namespace - this is expected
// when setup_filesystem_isolation is called standalone
}
Err(e) => {
return Err(SandboxError::Mount {
operation: "mount".to_string(),
target: proc_target.display().to_string(),
source: e,
});
}
}
Ok(new_root.to_path_buf())
}
/// Pivot root to the new filesystem and unmount the old root.
///
/// After this call, the process sees only the new root filesystem.
/// The old root is unmounted and inaccessible.
fn pivot_to_new_root(new_root: &Path) -> Result<(), SandboxError> {
// Change to new root before pivot
chdir(new_root).map_err(|e| SandboxError::Chdir {
path: new_root.display().to_string(),
source: e,
})?;
// Pivot root: new_root becomes /, old root moves to old_root
pivot_root(".", "old_root").map_err(SandboxError::PivotRoot)?;
// Change to new root
chdir("/").map_err(|e| SandboxError::Chdir {
path: "/".to_string(),
source: e,
})?;
// After pivot_root, old root is now at /old_root in the new namespace
let old_root = Path::new("/old_root");
// Unmount old root
nix::mount::umount2(old_root, nix::mount::MntFlags::MNT_DETACH).map_err(|e| {
SandboxError::Mount {
operation: "umount".to_string(),
target: old_root.display().to_string(),
source: e,
}
})?;
// Remove old_root directory
let _ = fs::remove_dir(old_root);
Ok(())
}
/// Enter a new IPC namespace.
///
/// After this call:
/// - System V IPC objects (semaphores, message queues, shared memory) are isolated
/// - The process cannot access host IPC objects
///
/// Must be called after `enter_user_namespace()`.
pub fn enter_ipc_namespace() -> Result<(), SandboxError> {
unshare(CloneFlags::CLONE_NEWIPC).map_err(SandboxError::Unshare)?;
Ok(())
}
/// Enter a new network namespace.
///
/// After this call:
/// - The process has no network interfaces (not even loopback)
/// - Network communication is only possible via inherited file descriptors
/// - Unix sockets created before entering the namespace remain usable
///
/// Must be called after `enter_user_namespace()`.
pub fn enter_network_namespace() -> Result<(), SandboxError> {
unshare(CloneFlags::CLONE_NEWNET).map_err(SandboxError::Unshare)?;
Ok(())
}
/// Apply full sandbox isolation to the current process.
///
/// This function:
/// 1. Enters user namespace (appears as root inside, enables other namespaces)
/// 2. Forks into a new PID namespace (becomes PID 1)
/// 3. Sets up filesystem isolation (minimal root with config at /config)
/// 4. Enters IPC namespace (isolates System V IPC)
/// 5. Enters network namespace (no network interfaces)
///
/// IMPORTANT: This function forks. In the parent, it returns
/// `Ok(SandboxResult::Parent(exit_code))`. The parent should propagate
/// this exit code. In the child, it returns
/// `Ok(SandboxResult::Sandboxed(config_path))`.
///
/// After sandboxing:
/// - The process is PID 1 in its own PID namespace
/// - The process appears to run as UID 0
/// - Only /config, /dev (minimal), /proc, and /tmp are accessible
/// - System V IPC is isolated from host
/// - No network interfaces (communication via inherited FDs only)
///
/// Must be called BEFORE starting any multi-threaded runtime (tokio).
pub fn apply_sandbox(config_dir: &Path) -> Result<SandboxResult, SandboxError> {
// 1. Enter user namespace first (provides CAP_SYS_ADMIN for other namespaces)
enter_user_namespace()?;
// 2. Fork into PID namespace (requires CAP_SYS_ADMIN from user namespace)
if let Some(exit_code) = fork_into_pid_namespace()? {
return Ok(SandboxResult::Parent(exit_code));
}
// 3. Set up filesystem isolation (mount ns already created by fork_into_pid_namespace)
setup_filesystem_isolation(config_dir, false)?;
// 4. Enter IPC namespace
enter_ipc_namespace()?;
// 5. Enter network namespace
enter_network_namespace()?;
// Config dir is now at /config
Ok(SandboxResult::Sandboxed(std::path::PathBuf::from("/config")))
}
/// Set up filesystem isolation with a minimal root.
///
/// This function:
/// 1. Optionally creates a mount namespace (if `needs_mount_ns` is true)
/// 2. Builds a minimal root on tmpfs
/// 3. Mounts /proc (requires owning a PID namespace; skipped if not available)
/// 4. Pivots to the new root
///
/// After this call, only /config (mapped to config_dir), /dev (minimal),
/// /proc (if available), and /tmp are accessible.
///
/// # Arguments
///
/// * `config_dir` - Host path to bind-mount as /config
/// * `needs_mount_ns` - If true, creates a new mount namespace before setup.
/// Set to false when already in a mount namespace (e.g., after
/// `fork_into_pid_namespace()` which creates one). Set to true when calling
/// standalone after `enter_user_namespace()`.
pub fn setup_filesystem_isolation(config_dir: &Path, needs_mount_ns: bool) -> Result<(), SandboxError> {
if needs_mount_ns {
enter_mount_namespace()?;
}
// Create minimal root (also attempts to mount /proc)
let new_root = setup_minimal_root(config_dir)?;
// Pivot to new root
pivot_to_new_root(&new_root)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generate_uid_map_maps_outside_to_zero() {
let content = generate_uid_map(1000);
// Format: inside_uid outside_uid count
assert_eq!(content, "0 1000 1\n");
}
#[test]
fn generate_uid_map_handles_root() {
let content = generate_uid_map(0);
assert_eq!(content, "0 0 1\n");
}
#[test]
fn generate_gid_map_maps_outside_to_zero() {
let content = generate_gid_map(1000);
assert_eq!(content, "0 1000 1\n");
}
#[test]
fn generate_gid_map_handles_root() {
let content = generate_gid_map(0);
assert_eq!(content, "0 0 1\n");
}
#[test]
fn sandbox_error_displays_mount_error() {
let err = SandboxError::Mount {
operation: "mount".to_string(),
target: "/proc".to_string(),
source: nix::Error::EPERM,
};
let msg = err.to_string();
assert!(msg.contains("mount"));
assert!(msg.contains("/proc"));
}
#[test]
fn sandbox_error_displays_mkdir_error() {
let err = SandboxError::Mkdir {
path: "/newroot/proc".to_string(),
source: std::io::Error::from_raw_os_error(libc::EACCES),
};
let msg = err.to_string();
assert!(msg.contains("/newroot/proc"));
}
}

443
vm-switch/src/seccomp.rs Normal file
View file

@ -0,0 +1,443 @@
//! Seccomp-bpf filtering for syscall restriction.
use clap::ValueEnum;
use thiserror::Error;
/// Errors that can occur during seccomp filter setup.
#[derive(Debug, Error)]
pub enum SeccompError {
#[error("failed to compile filter: {0}")]
Compile(String),
#[error("failed to apply filter: {0}")]
Apply(#[source] std::io::Error),
#[error("seccomp is disabled")]
Disabled,
}
/// Syscalls allowed for child (worker) processes.
///
/// This is the base allowlist. The main process filter is built as a
/// superset of this list (see `MAIN_EXTRA_SYSCALLS`), which is important
/// because children inherit the main filter via fork() and seccomp
/// filters stack with AND semantics — both must allow a syscall.
///
/// This is a tight whitelist because the child's own filter is applied AFTER:
/// - Socket creation and binding (vhost-user socket ready)
/// - Thread spawning (vhost daemon thread running)
/// - Signal handler setup
///
/// Children only need syscalls for:
/// - Accepting connections on existing socket
/// - Reading/writing on existing FDs
/// - Ring buffer operations (memfd, mmap)
/// - Polling and synchronization
pub static CHILD_SYSCALLS: &[i64] = &[
// Basic I/O
libc::SYS_read,
libc::SYS_write,
libc::SYS_close,
libc::SYS_lseek,
libc::SYS_pread64,
libc::SYS_pwrite64,
libc::SYS_readv,
libc::SYS_writev,
// Memory management
libc::SYS_mmap,
libc::SYS_mprotect,
libc::SYS_munmap,
libc::SYS_brk,
libc::SYS_mremap,
libc::SYS_madvise,
// Ring buffer operations
libc::SYS_memfd_create,
libc::SYS_ftruncate,
// File operations (limited - no creation)
libc::SYS_fstat,
libc::SYS_newfstatat,
libc::SYS_fcntl,
libc::SYS_dup,
libc::SYS_dup2,
libc::SYS_dup3,
libc::SYS_unlink, // glibc may use this instead of unlinkat
libc::SYS_unlinkat,
// Process/thread control
libc::SYS_clone3, // glibc pthread_create uses clone3 (vhost-user spawns threads lazily)
libc::SYS_exit,
libc::SYS_exit_group,
libc::SYS_getpid,
libc::SYS_gettid,
libc::SYS_getuid,
libc::SYS_getgid,
libc::SYS_geteuid,
libc::SYS_getegid,
libc::SYS_sched_yield,
libc::SYS_sched_getaffinity,
libc::SYS_set_robust_list,
libc::SYS_rseq,
// Signal handling (handlers already installed)
libc::SYS_rt_sigaction,
libc::SYS_rt_sigprocmask,
libc::SYS_rt_sigreturn,
libc::SYS_sigaltstack,
// Polling
libc::SYS_epoll_create1,
libc::SYS_epoll_ctl,
libc::SYS_epoll_wait,
libc::SYS_epoll_pwait,
libc::SYS_poll,
libc::SYS_ppoll,
// Socket operations (on existing sockets only)
// NO: socket, bind, listen, connect, setsockopt, getsockopt, socketpair
libc::SYS_accept,
libc::SYS_accept4,
libc::SYS_sendto,
libc::SYS_recvfrom,
libc::SYS_sendmsg,
libc::SYS_recvmsg,
libc::SYS_shutdown,
libc::SYS_getsockname,
libc::SYS_getpeername,
// Time
libc::SYS_clock_gettime,
libc::SYS_clock_getres,
libc::SYS_nanosleep,
libc::SYS_gettimeofday,
// Thread synchronization (for existing threads)
libc::SYS_futex,
// Misc
libc::SYS_getrandom,
libc::SYS_prctl,
libc::SYS_arch_prctl,
libc::SYS_ioctl,
libc::SYS_pipe2,
libc::SYS_eventfd2,
];
/// Additional syscalls needed only by the main process.
///
/// The main filter is built from CHILD_SYSCALLS + these extras.
/// This ensures children always inherit a superset of what they need.
pub static MAIN_EXTRA_SYSCALLS: &[i64] = &[
// File operations (main does config watching, directory traversal)
libc::SYS_openat,
libc::SYS_getdents64,
libc::SYS_mkdirat,
libc::SYS_readlinkat,
libc::SYS_faccessat,
libc::SYS_faccessat2,
libc::SYS_statx,
// Process control (main forks children, manages lifecycle)
libc::SYS_fork,
libc::SYS_clone,
libc::SYS_clone3,
libc::SYS_wait4,
libc::SYS_kill,
libc::SYS_getppid,
// Polling (extra variants)
libc::SYS_select,
libc::SYS_pselect6,
// Socket operations (main creates sockets, children only accept)
libc::SYS_socket,
libc::SYS_bind,
libc::SYS_listen,
libc::SYS_connect,
libc::SYS_getsockopt,
libc::SYS_setsockopt,
libc::SYS_socketpair,
// inotify (for config file watching)
libc::SYS_inotify_init1,
libc::SYS_inotify_add_watch,
libc::SYS_inotify_rm_watch,
// Time (extra variants)
libc::SYS_clock_nanosleep,
// Misc (main-only)
libc::SYS_uname,
libc::SYS_timerfd_create,
libc::SYS_timerfd_settime,
libc::SYS_timerfd_gettime,
// Seccomp (for applying child filters after fork)
libc::SYS_seccomp,
];
/// Seccomp filter mode.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, ValueEnum)]
pub enum SeccompMode {
/// Kill process on blocked syscall (production default).
#[default]
Kill,
/// Send SIGSYS on blocked syscall (for debugging).
Trap,
/// Log blocked syscalls but allow them.
Log,
/// Disable seccomp filtering entirely.
Disabled,
}
use seccompiler::{BpfMap, SeccompAction, SeccompFilter, TargetArch};
use std::collections::BTreeMap;
/// Build a BPF filter for the given syscall whitelist.
///
/// Returns a BpfMap ready to be applied via `seccompiler::apply_filter_all_threads`.
pub fn build_filter(syscalls: &[i64], mode: SeccompMode) -> Result<BpfMap, SeccompError> {
if mode == SeccompMode::Disabled {
return Err(SeccompError::Disabled);
}
let default_action = match mode {
SeccompMode::Kill => SeccompAction::KillProcess,
SeccompMode::Trap => SeccompAction::Trap,
SeccompMode::Log => SeccompAction::Log,
SeccompMode::Disabled => unreachable!(),
};
// Build allow rules for each syscall.
// An empty rule vector means "match unconditionally on syscall number".
let mut rules: BTreeMap<i64, Vec<_>> = BTreeMap::new();
for &syscall in syscalls {
rules.insert(syscall, vec![]);
}
// Create filter: allow whitelisted syscalls, block others.
// SeccompFilter::new(rules, mismatch_action, match_action, arch)
let filter = SeccompFilter::new(
rules,
default_action, // mismatch_action: block non-whitelisted syscalls
SeccompAction::Allow, // match_action: allow whitelisted syscalls
TargetArch::x86_64,
)
.map_err(|e| SeccompError::Compile(e.to_string()))?;
// Compile to BPF
let bpf_prog = filter
.try_into()
.map_err(|e: seccompiler::BackendError| SeccompError::Compile(e.to_string()))?;
let mut map = BpfMap::new();
map.insert("main".to_string(), bpf_prog);
Ok(map)
}
/// Apply a compiled BPF filter to all threads in the current process.
///
/// Uses `prctl(PR_SET_NO_NEW_PRIVS)` and `seccomp(SECCOMP_SET_MODE_FILTER)`.
/// Once applied, the filter cannot be removed or made less restrictive.
pub fn apply_filter(bpf_map: &BpfMap) -> Result<(), SeccompError> {
let bpf_prog = bpf_map
.get("main")
.ok_or_else(|| SeccompError::Compile("no 'main' filter in map".to_string()))?;
seccompiler::apply_filter_all_threads(bpf_prog)
.map_err(|e| SeccompError::Apply(std::io::Error::other(e.to_string())))?;
Ok(())
}
/// Collect the full main syscall list (CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS).
fn main_syscalls() -> Vec<i64> {
let mut syscalls = Vec::with_capacity(CHILD_SYSCALLS.len() + MAIN_EXTRA_SYSCALLS.len());
syscalls.extend_from_slice(CHILD_SYSCALLS);
syscalls.extend_from_slice(MAIN_EXTRA_SYSCALLS);
syscalls
}
/// Apply seccomp filter for the main process.
///
/// Call this after namespace setup, before starting the tokio runtime.
/// The filter is built from CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS, ensuring
/// children always inherit a permissive-enough base filter.
pub fn apply_main_seccomp(mode: SeccompMode) -> Result<(), SeccompError> {
if mode == SeccompMode::Disabled {
return Ok(());
}
let syscalls = main_syscalls();
let filter = build_filter(&syscalls, mode)?;
apply_filter(&filter)
}
/// Apply seccomp filter for a child (worker) process.
///
/// Call this after socket creation and thread spawning, just before
/// entering the event loop. This allows the tightest possible filter.
pub fn apply_child_seccomp(mode: SeccompMode) -> Result<(), SeccompError> {
if mode == SeccompMode::Disabled {
return Ok(());
}
let filter = build_filter(CHILD_SYSCALLS, mode)?;
apply_filter(&filter)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn seccomp_error_displays_compile_error() {
let err = SeccompError::Compile("invalid rule".to_string());
assert!(err.to_string().contains("compile"));
assert!(err.to_string().contains("invalid rule"));
}
#[test]
fn seccomp_error_displays_apply_error() {
let err = SeccompError::Apply(std::io::Error::from_raw_os_error(libc::EPERM));
let msg = err.to_string();
assert!(msg.contains("apply"));
}
#[test]
fn seccomp_error_displays_disabled() {
let err = SeccompError::Disabled;
assert!(err.to_string().contains("disabled"));
}
#[test]
fn main_syscalls_is_not_empty() {
assert!(!main_syscalls().is_empty());
}
#[test]
fn main_syscalls_contains_essential_syscalls() {
let syscalls = main_syscalls();
assert!(syscalls.contains(&libc::SYS_read));
assert!(syscalls.contains(&libc::SYS_write));
assert!(syscalls.contains(&libc::SYS_close));
assert!(syscalls.contains(&libc::SYS_exit_group));
}
#[test]
fn main_syscalls_allows_fork() {
// Main process needs fork to spawn children
assert!(main_syscalls().contains(&libc::SYS_fork));
}
#[test]
fn main_syscalls_is_superset_of_child() {
let main = main_syscalls();
for &syscall in CHILD_SYSCALLS {
assert!(
main.contains(&syscall),
"CHILD_SYSCALLS contains {} which is missing from main filter",
syscall
);
}
}
#[test]
fn child_syscalls_is_not_empty() {
assert!(!CHILD_SYSCALLS.is_empty());
}
#[test]
fn child_syscalls_does_not_allow_fork() {
// Children don't spawn processes
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_fork));
}
#[test]
fn child_syscalls_allows_clone3_for_vhost_threads() {
// vhost-user library spawns threads lazily on client connect
// clone3 is used by glibc pthread_create; clone is NOT allowed
// (fork uses clone, so blocking it prevents fork in children)
assert!(CHILD_SYSCALLS.contains(&libc::SYS_clone3));
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_clone));
}
#[test]
fn child_syscalls_does_not_allow_socket_creation() {
// Socket created before seccomp
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_socket));
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_bind));
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_listen));
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_setsockopt));
}
#[test]
fn child_syscalls_allows_accept() {
// Need to accept vhost-user connections
assert!(CHILD_SYSCALLS.contains(&libc::SYS_accept4));
}
#[test]
fn child_syscalls_allows_ring_buffer_ops() {
// Children need memfd_create for ring buffers
assert!(CHILD_SYSCALLS.contains(&libc::SYS_memfd_create));
assert!(CHILD_SYSCALLS.contains(&libc::SYS_ftruncate));
}
#[test]
fn build_filter_creates_valid_bpf() {
let syscalls = main_syscalls();
let filter = build_filter(&syscalls, SeccompMode::Kill)
.expect("filter should compile");
assert_eq!(filter.len(), 1);
assert!(filter.contains_key("main"));
}
#[test]
fn build_filter_handles_all_modes() {
let syscalls = main_syscalls();
for mode in [SeccompMode::Kill, SeccompMode::Trap, SeccompMode::Log] {
let result = build_filter(&syscalls, mode);
assert!(result.is_ok(), "mode {:?} should compile", mode);
}
}
#[test]
fn build_filter_disabled_returns_error() {
let syscalls = main_syscalls();
let result = build_filter(&syscalls, SeccompMode::Disabled);
assert!(matches!(result, Err(SeccompError::Disabled)));
}
#[test]
fn apply_filter_signature_check() {
fn check_signature(_f: fn(&BpfMap) -> Result<(), SeccompError>) {}
check_signature(apply_filter);
}
#[test]
fn apply_main_seccomp_signature_check() {
fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {}
check(apply_main_seccomp);
}
#[test]
fn apply_child_seccomp_signature_check() {
fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {}
check(apply_child_seccomp);
}
#[test]
fn apply_main_seccomp_disabled_succeeds() {
// Disabled mode should be a no-op, not an error
assert!(apply_main_seccomp(SeccompMode::Disabled).is_ok());
}
#[test]
fn apply_child_seccomp_disabled_succeeds() {
assert!(apply_child_seccomp(SeccompMode::Disabled).is_ok());
}
}

View file

@ -1,445 +0,0 @@
//! L2 switch logic with MAC filtering.
use std::collections::HashMap;
use tracing::debug;
use crate::config::VmRole;
use crate::mac::Mac;
/// Unique identifier for a connected VM.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConnectionId(u64);
impl ConnectionId {
/// Create a new connection ID.
pub fn new(id: u64) -> Self {
Self(id)
}
}
/// Decision for how to forward a frame.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ForwardDecision {
/// Forward to a single destination.
Unicast(ConnectionId),
/// Forward to multiple destinations (broadcast/multicast).
Multicast(Vec<ConnectionId>),
/// Drop the frame.
Drop(DropReason),
}
/// Reason why a frame was dropped.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DropReason {
/// Source MAC doesn't match the sender's configured MAC.
SourceMacMismatch { expected: Mac, actual: Mac },
/// Client tried to send to a MAC other than router/broadcast/multicast.
ClientViolation { destination: Mac },
/// Router tried to send to an unknown MAC.
UnknownDestination { destination: Mac },
/// No router is connected.
NoRouter,
/// Sender connection ID is not registered.
UnknownSender,
}
/// Information about a connected VM.
#[derive(Debug, Clone)]
struct Connection {
/// VM name for logging.
name: String,
/// Role (router or client).
role: VmRole,
/// Configured MAC address.
mac: Mac,
}
/// L2 switch with MAC filtering.
///
/// Maintains a registry of connected VMs and applies filtering rules
/// to determine how frames should be forwarded.
pub struct Switch {
/// Connected VMs by connection ID.
connections: HashMap<ConnectionId, Connection>,
/// MAC address to connection ID mapping for fast lookup.
mac_to_conn: HashMap<Mac, ConnectionId>,
/// The router's connection ID (if connected).
router: Option<ConnectionId>,
/// Next connection ID to assign.
next_id: u64,
}
impl Switch {
/// Create a new empty switch.
pub fn new() -> Self {
Self {
connections: HashMap::new(),
mac_to_conn: HashMap::new(),
router: None,
next_id: 0,
}
}
/// Register a new VM connection.
///
/// Returns the assigned connection ID, or None if a router is already
/// connected and this is another router.
pub fn register(&mut self, name: String, role: VmRole, mac: Mac) -> Option<ConnectionId> {
// Reject second router
if role == VmRole::Router && self.router.is_some() {
return None;
}
// Reject duplicate MAC address
if self.mac_to_conn.contains_key(&mac) {
return None;
}
let id = ConnectionId::new(self.next_id);
self.next_id += 1;
self.connections.insert(id, Connection { name, role, mac });
self.mac_to_conn.insert(mac, id);
if role == VmRole::Router {
self.router = Some(id);
}
Some(id)
}
/// Unregister a VM connection.
pub fn unregister(&mut self, id: ConnectionId) {
if let Some(conn) = self.connections.remove(&id) {
debug!(name = %conn.name, role = ?conn.role, mac = %conn.mac, "unregistered connection");
self.mac_to_conn.remove(&conn.mac);
if conn.role == VmRole::Router {
self.router = None;
}
}
}
/// Determine how to forward a frame.
///
/// # Arguments
/// * `sender` - Connection ID of the sender
/// * `source_mac` - Source MAC from the frame
/// * `dest_mac` - Destination MAC from the frame
pub fn forward(&self, sender: ConnectionId, source_mac: Mac, dest_mac: Mac) -> ForwardDecision {
// Get sender info
let sender_conn = match self.connections.get(&sender) {
Some(c) => c,
None => return ForwardDecision::Drop(DropReason::UnknownSender),
};
// Validate source MAC
if source_mac != sender_conn.mac {
return ForwardDecision::Drop(DropReason::SourceMacMismatch {
expected: sender_conn.mac,
actual: source_mac,
});
}
match sender_conn.role {
VmRole::Client => self.forward_from_client(dest_mac),
VmRole::Router => self.forward_from_router(sender, dest_mac),
}
}
fn forward_from_client(&self, dest_mac: Mac) -> ForwardDecision {
// Get router
let router_id = match self.router {
Some(id) => id,
None => return ForwardDecision::Drop(DropReason::NoRouter),
};
let router_conn = self.connections.get(&router_id).unwrap();
// Clients can only send to router, broadcast, or multicast
if dest_mac == router_conn.mac || dest_mac.is_broadcast() || dest_mac.is_multicast() {
ForwardDecision::Unicast(router_id)
} else {
ForwardDecision::Drop(DropReason::ClientViolation { destination: dest_mac })
}
}
fn forward_from_router(&self, sender: ConnectionId, dest_mac: Mac) -> ForwardDecision {
// Broadcast or multicast goes to all clients
if dest_mac.is_broadcast() || dest_mac.is_multicast() {
let client_ids: Vec<ConnectionId> = self.connections
.iter()
.filter(|(id, conn)| **id != sender && conn.role == VmRole::Client)
.map(|(id, _)| *id)
.collect();
return ForwardDecision::Multicast(client_ids);
}
// Unicast to specific client
match self.mac_to_conn.get(&dest_mac) {
Some(id) if *id != sender => {
// Verify it's a client
if let Some(conn) = self.connections.get(id) {
if conn.role == VmRole::Client {
return ForwardDecision::Unicast(*id);
}
}
ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac })
}
_ => ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac }),
}
}
}
impl Default for Switch {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
// Helper to create test MACs
fn mac(s: &str) -> Mac {
Mac::parse(s).unwrap()
}
#[test]
fn register_client() {
let mut switch = Switch::new();
let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
assert!(id.is_some());
}
#[test]
fn register_router() {
let mut switch = Switch::new();
let id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66"));
assert!(id.is_some());
}
#[test]
fn register_second_router_fails() {
let mut switch = Switch::new();
let id1 = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66"));
let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff"));
assert!(id1.is_some());
assert!(id2.is_none(), "Second router should be rejected");
}
#[test]
fn register_duplicate_mac_fails() {
let mut switch = Switch::new();
let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
assert!(id1.is_some());
assert!(id2.is_none(), "Duplicate MAC should be rejected");
}
#[test]
fn register_multiple_clients() {
let mut switch = Switch::new();
let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01"));
let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02"));
assert!(id1.is_some());
assert!(id2.is_some());
assert_ne!(id1, id2);
}
#[test]
fn unregister_client() {
let mut switch = Switch::new();
let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
switch.unregister(id);
// Should be able to register another client with same MAC
let id2 = switch.register("banking2".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
assert!(id2.is_some());
}
#[test]
fn unregister_router_allows_new_router() {
let mut switch = Switch::new();
let id = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
switch.unregister(id);
// Should now be able to register a new router
let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff"));
assert!(id2.is_some());
}
#[test]
fn forward_rejects_source_mac_mismatch() {
let mut switch = Switch::new();
let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
// Client sends with wrong source MAC
let result = switch.forward(client_id, mac("00:00:00:00:00:01"), mac("11:22:33:44:55:66"));
match result {
ForwardDecision::Drop(DropReason::SourceMacMismatch { expected, actual }) => {
assert_eq!(expected, mac("aa:bb:cc:dd:ee:ff"));
assert_eq!(actual, mac("00:00:00:00:00:01"));
}
_ => panic!("Expected SourceMacMismatch, got {:?}", result),
}
}
#[test]
fn forward_client_to_router_unicast() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("11:22:33:44:55:66"));
assert_eq!(result, ForwardDecision::Unicast(router_id));
}
#[test]
fn forward_client_to_router_broadcast() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff"));
// Broadcast from client goes to router only
assert_eq!(result, ForwardDecision::Unicast(router_id));
}
#[test]
fn forward_client_to_router_multicast() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
// IPv4 multicast MAC
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("01:00:5e:00:00:01"));
// Multicast from client goes to router only
assert_eq!(result, ForwardDecision::Unicast(router_id));
}
#[test]
fn forward_client_violation_to_other_client() {
let mut switch = Switch::new();
let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
let _client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
// Client tries to send directly to another client
let result = switch.forward(client1_id, mac("aa:bb:cc:dd:ee:01"), mac("aa:bb:cc:dd:ee:02"));
match result {
ForwardDecision::Drop(DropReason::ClientViolation { destination }) => {
assert_eq!(destination, mac("aa:bb:cc:dd:ee:02"));
}
_ => panic!("Expected ClientViolation, got {:?}", result),
}
}
#[test]
fn forward_client_no_router() {
let mut switch = Switch::new();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff"));
assert_eq!(result, ForwardDecision::Drop(DropReason::NoRouter));
}
#[test]
fn forward_router_to_client_unicast() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("aa:bb:cc:dd:ee:ff"));
assert_eq!(result, ForwardDecision::Unicast(client_id));
}
#[test]
fn forward_router_broadcast_to_all_clients() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff"));
match result {
ForwardDecision::Multicast(ids) => {
assert_eq!(ids.len(), 2);
assert!(ids.contains(&client1_id));
assert!(ids.contains(&client2_id));
}
_ => panic!("Expected Multicast, got {:?}", result),
}
}
#[test]
fn forward_router_multicast_to_all_clients() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("01:00:5e:00:00:01"));
match result {
ForwardDecision::Multicast(ids) => {
assert_eq!(ids.len(), 2);
assert!(ids.contains(&client1_id));
assert!(ids.contains(&client2_id));
}
_ => panic!("Expected Multicast, got {:?}", result),
}
}
#[test]
fn forward_router_unknown_destination() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let _client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
// Router sends to unknown MAC
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("00:00:00:00:00:01"));
match result {
ForwardDecision::Drop(DropReason::UnknownDestination { destination }) => {
assert_eq!(destination, mac("00:00:00:00:00:01"));
}
_ => panic!("Expected UnknownDestination, got {:?}", result),
}
}
#[test]
fn forward_router_broadcast_no_clients() {
let mut switch = Switch::new();
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff"));
// Empty multicast is valid (no clients to send to)
match result {
ForwardDecision::Multicast(ids) => {
assert!(ids.is_empty());
}
_ => panic!("Expected empty Multicast, got {:?}", result),
}
}
}

View file

@ -0,0 +1,323 @@
//! Integration tests for ring buffer exchange between processes.
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{fork, ForkResult};
use serial_test::serial;
use std::os::fd::AsRawFd;
use vm_switch::control::{ChildToMain, ControlChannel, MainToChild};
use vm_switch::mac::Mac;
use vm_switch::ring::{Consumer, Producer};
#[test]
#[serial]
fn child_sends_ready_on_startup() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let temp_dir = tempfile::tempdir().expect("tempdir");
let socket_path = temp_dir.path().join("test.sock");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
drop(child_end);
// Wait for Ready
let (msg, fds): (ChildToMain, _) = main_end
.recv_with_fds_typed()
.expect("should receive");
match msg {
ChildToMain::Ready => {}
_ => panic!("expected Ready, got {:?}", msg),
}
assert!(fds.is_empty(), "Ready should have no FDs");
// Close main end to trigger child exit
drop(main_end);
// Wait for child
let status = waitpid(child, None).expect("waitpid failed");
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "child should exit cleanly");
}
other => panic!("unexpected status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
drop(main_end);
let control_fd = child_end.into_fd();
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
}
Err(e) => panic!("fork failed: {}", e),
}
}
#[test]
#[serial]
fn child_creates_ingress_buffer_on_request() {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
let temp_dir = tempfile::tempdir().expect("tempdir");
let socket_path = temp_dir.path().join("test.sock");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
drop(child_end);
// Wait for Ready
let (msg, _): (ChildToMain, _) = main_end
.recv_with_fds_typed()
.expect("should receive");
assert!(matches!(msg, ChildToMain::Ready));
// Request buffer for a peer
let peer_name = "router".to_string();
let peer_mac = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
let msg = MainToChild::GetBuffer {
peer_name: peer_name.clone(),
peer_mac,
};
main_end.send(&msg).expect("send GetBuffer");
// Wait for BufferReady
let (msg, fds): (ChildToMain, _) = main_end
.recv_with_fds_typed()
.expect("should receive");
match msg {
ChildToMain::BufferReady { peer_name: name } => {
assert_eq!(name, peer_name);
}
_ => panic!("expected BufferReady"),
}
assert_eq!(fds.len(), 2, "should receive memfd and eventfd");
// Create producer from received FDs and verify we can write
let mut fds = fds.into_iter();
let memfd = fds.next().unwrap();
let eventfd = fds.next().unwrap();
let producer = Producer::from_fds(memfd, eventfd)
.expect("should create producer");
// Push a frame
assert!(producer.push(&[1, 2, 3, 4, 5]));
// Close main end to trigger child exit
drop(main_end);
// Wait for child
let status = waitpid(child, None).expect("waitpid failed");
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "child should exit cleanly");
}
other => panic!("unexpected status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
drop(main_end);
let control_fd = child_end.into_fd();
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
}
Err(e) => panic!("fork failed: {}", e),
}
}
#[test]
#[serial]
fn new_protocol_buffer_exchange() {
// Test the new protocol where:
// 1. Child sends Ready
// 2. Main sends GetBuffer
// 3. Child creates consumer and sends BufferReady with FDs
// 4. Main forwards FDs to peer as PutBuffer
// 5. Peer creates producer and can write
// Simulate two "children" with control channels
let (main_a, child_a) = ControlChannel::pair().expect("pair A");
let (main_b, child_b) = ControlChannel::pair().expect("pair B");
let name_a = "client_a".to_string();
let name_b = "router".to_string();
let mac_a = [0xaa, 0, 0, 0, 0, 1];
let mac_b = [0xbb, 0, 0, 0, 0, 2];
// Step 1: Both children send Ready
child_a.send(&ChildToMain::Ready).expect("A ready");
child_b.send(&ChildToMain::Ready).expect("B ready");
// Main receives Ready from both
let (msg_a, _): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready");
let (msg_b, _): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready");
assert!(matches!(msg_a, ChildToMain::Ready));
assert!(matches!(msg_b, ChildToMain::Ready));
// Step 2: Main requests buffer from A for B (A will be consumer of data from B)
let get_buffer_msg = MainToChild::GetBuffer {
peer_name: name_b.clone(),
peer_mac: mac_b,
};
main_a.send(&get_buffer_msg).expect("send GetBuffer to A");
// Step 3: A receives GetBuffer
let (msg, _): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer");
match msg {
MainToChild::GetBuffer { peer_name, peer_mac } => {
assert_eq!(peer_name, name_b);
assert_eq!(peer_mac, mac_b);
}
_ => panic!("expected GetBuffer"),
}
// A creates consumer (ingress buffer) and sends BufferReady
let consumer_a = Consumer::new().expect("consumer A");
let buffer_ready = ChildToMain::BufferReady { peer_name: name_b.clone() };
child_a.send_with_fds_typed(&buffer_ready, &[
consumer_a.memfd().as_raw_fd(),
consumer_a.eventfd().as_raw_fd(),
]).expect("A send BufferReady");
// Main receives BufferReady with FDs
let (msg, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv BufferReady");
match msg {
ChildToMain::BufferReady { peer_name } => {
assert_eq!(peer_name, name_b);
}
_ => panic!("expected BufferReady"),
}
assert_eq!(fds_from_a.len(), 2);
// Step 4: Main sends PutBuffer to B with A's ingress buffer
// B will use this as egress to A
// broadcast=true because A is a client and B is router
let put_buffer_msg = MainToChild::PutBuffer {
peer_name: name_a.clone(),
peer_mac: mac_a,
broadcast: false, // A is client, not router
};
main_b.send_with_fds_typed(&put_buffer_msg, &[
fds_from_a[0].as_raw_fd(),
fds_from_a[1].as_raw_fd(),
]).expect("send PutBuffer to B");
// Step 5: B receives PutBuffer with FDs
let (msg, fds_for_b): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer");
match msg {
MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => {
assert_eq!(peer_name, name_a);
assert_eq!(peer_mac, mac_a);
assert!(!broadcast);
}
_ => panic!("expected PutBuffer"),
}
assert_eq!(fds_for_b.len(), 2);
// B creates producer (egress buffer) from received FDs
let mut fds_b = fds_for_b.into_iter();
let producer_b = Producer::from_fds(fds_b.next().unwrap(), fds_b.next().unwrap())
.expect("producer B");
// Step 6: B writes to producer, A can read from consumer
producer_b.push(&[10, 20, 30, 40, 50]);
let data = consumer_a.pop().expect("should pop");
assert_eq!(data, vec![10, 20, 30, 40, 50]);
}
#[test]
#[serial]
fn bidirectional_new_protocol_exchange() {
// Two VMs exchange buffers and can communicate both ways using new protocol
let (main_a, child_a) = ControlChannel::pair().expect("pair A");
let (main_b, child_b) = ControlChannel::pair().expect("pair B");
let name_a = "client_a".to_string();
let name_b = "router".to_string();
let mac_a = [0xaa, 0, 0, 0, 0, 1];
let mac_b = [0xbb, 0, 0, 0, 0, 2];
// Both send Ready
child_a.send(&ChildToMain::Ready).expect("A ready");
child_b.send(&ChildToMain::Ready).expect("B ready");
let _: (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready");
let _: (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready");
// Request ingress buffers from both sides
// A creates ingress for B
main_a.send(&MainToChild::GetBuffer {
peer_name: name_b.clone(),
peer_mac: mac_b,
}).expect("GetBuffer A->B");
// B creates ingress for A
main_b.send(&MainToChild::GetBuffer {
peer_name: name_a.clone(),
peer_mac: mac_a,
}).expect("GetBuffer B->A");
// A receives GetBuffer, creates consumer, sends BufferReady
let _: (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer");
let consumer_a = Consumer::new().expect("consumer A");
child_a.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_b.clone() }, &[
consumer_a.memfd().as_raw_fd(),
consumer_a.eventfd().as_raw_fd(),
]).expect("A BufferReady");
// B receives GetBuffer, creates consumer, sends BufferReady
let _: (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv GetBuffer");
let consumer_b = Consumer::new().expect("consumer B");
child_b.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_a.clone() }, &[
consumer_b.memfd().as_raw_fd(),
consumer_b.eventfd().as_raw_fd(),
]).expect("B BufferReady");
// Main receives BufferReady from both
let (_, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv A BufferReady");
let (_, fds_from_b): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("main recv B BufferReady");
// Cross-send: A's ingress becomes B's egress to A, and vice versa
// Send A's buffer to B (B is router, so broadcast=true when sending TO router)
main_b.send_with_fds_typed(&MainToChild::PutBuffer {
peer_name: name_a.clone(),
peer_mac: mac_a,
broadcast: false, // A is not router
}, &[
fds_from_a[0].as_raw_fd(),
fds_from_a[1].as_raw_fd(),
]).expect("PutBuffer A to B");
// Send B's buffer to A (A is client, broadcast=true because B is router)
main_a.send_with_fds_typed(&MainToChild::PutBuffer {
peer_name: name_b.clone(),
peer_mac: mac_b,
broadcast: true, // B is router
}, &[
fds_from_b[0].as_raw_fd(),
fds_from_b[1].as_raw_fd(),
]).expect("PutBuffer B to A");
// Each side receives PutBuffer and creates producer
let (_, fds_b_egress): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer");
let (_, fds_a_egress): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv PutBuffer");
let mut fds = fds_b_egress.into_iter();
let producer_b = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer B");
let mut fds = fds_a_egress.into_iter();
let producer_a = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer A");
// B sends to A
producer_b.push(&[1, 2, 3]);
assert_eq!(consumer_a.pop().unwrap(), vec![1, 2, 3]);
// A sends to B
producer_a.push(&[4, 5, 6]);
assert_eq!(consumer_b.pop().unwrap(), vec![4, 5, 6]);
}

View file

@ -0,0 +1,104 @@
//! Integration tests for crash detection and cleanup.
use nix::sys::signal::{kill, Signal};
use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus};
use nix::unistd::{fork, ForkResult};
use std::time::Duration;
use vm_switch::control::{ControlChannel, MainToChild};
/// Test that main detects child crash via waitpid.
#[test]
fn main_detects_child_crash() {
let (main_end, child_end) = ControlChannel::pair().expect("pair");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
drop(child_end);
std::thread::sleep(Duration::from_millis(50));
kill(child, Signal::SIGKILL).expect("kill");
let status = waitpid(child, None).expect("waitpid");
assert!(matches!(status, WaitStatus::Signaled(_, Signal::SIGKILL, _)));
drop(main_end);
}
Ok(ForkResult::Child) => {
drop(main_end);
let control = ControlChannel::from_fd(child_end.into_fd());
loop {
match control.recv::<MainToChild>() {
Err(_) => break,
_ => continue,
}
}
std::process::exit(0);
}
Err(e) => panic!("fork failed: {}", e),
}
}
/// Test that RemovePeer message is delivered correctly.
#[test]
fn remove_peer_message_delivery() {
let (main_end, child_end) = ControlChannel::pair().expect("pair");
let crashed_vm = "crashed_client".to_string();
let msg = MainToChild::RemovePeer { peer_name: crashed_vm.clone() };
main_end.send(&msg).expect("send");
let received: MainToChild = child_end.recv().expect("recv");
match received {
MainToChild::RemovePeer { peer_name } => {
assert_eq!(peer_name, crashed_vm);
}
_ => panic!("expected RemovePeer"),
}
}
/// Test that children exit cleanly when control channel closes.
#[test]
fn children_exit_on_control_close() {
let (main_end, child_end) = ControlChannel::pair().expect("pair");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
drop(child_end);
std::thread::sleep(Duration::from_millis(50));
drop(main_end);
let deadline = std::time::Instant::now() + Duration::from_secs(2);
loop {
match waitpid(child, Some(WaitPidFlag::WNOHANG)) {
Ok(WaitStatus::Exited(_, 0)) => return,
Ok(WaitStatus::Exited(_, code)) => {
panic!("Child exited with code {}", code);
}
Ok(_) => {
if std::time::Instant::now() > deadline {
kill(child, Signal::SIGKILL).ok();
panic!("Child did not exit within timeout");
}
std::thread::sleep(Duration::from_millis(50));
}
Err(e) => panic!("waitpid error: {}", e),
}
}
}
Ok(ForkResult::Child) => {
drop(main_end);
let control = ControlChannel::from_fd(child_end.into_fd());
loop {
match control.recv::<MainToChild>() {
Err(_) => std::process::exit(0),
_ => continue,
}
}
}
Err(e) => panic!("fork failed: {}", e),
}
}

View file

@ -0,0 +1,160 @@
//! Integration tests for fork-based child process lifecycle.
//!
//! Note: These tests must be run with `--test-threads=1` because
//! fork tests cannot run in parallel within the same process.
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{fork, ForkResult};
use serial_test::serial;
use std::time::Duration;
use vm_switch::control::ControlChannel;
use vm_switch::mac::Mac;
#[test]
#[serial]
fn child_exits_when_control_channel_closes() {
// Create control channel
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x01]);
let temp_dir = tempfile::tempdir().expect("tempdir");
let socket_path = temp_dir.path().join("test.sock");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
// Parent: drop child's end, keep main's end
drop(child_end);
// Give child time to start and send Ready
std::thread::sleep(Duration::from_millis(100));
// Drain any Ready message
let _ = main_end.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
// Close main's end - should cause child to exit
drop(main_end);
// Wait for child to exit
let status = waitpid(child, None).expect("waitpid failed");
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "child should exit with code 0");
}
other => panic!("unexpected wait status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
// Child: drop main's end (we don't need parent's socket)
drop(main_end);
// Run child entry point - this should exit when control closes
let control_fd = child_end.into_fd();
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
}
Err(e) => panic!("fork failed: {}", e),
}
}
#[test]
#[serial]
fn child_processes_messages_before_exit() {
use vm_switch::control::MainToChild;
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x02]);
let temp_dir = tempfile::tempdir().expect("tempdir");
let socket_path = temp_dir.path().join("test.sock");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
// Parent: drop child's end
drop(child_end);
// Wait for Ready from child
std::thread::sleep(Duration::from_millis(100));
let _ = main_end.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
// Send a RemovePeer message
let msg = MainToChild::RemovePeer {
peer_name: "some-peer".to_string(),
};
main_end.send(&msg).expect("should send");
// Close channel (drop main_end)
drop(main_end);
// Child should exit cleanly after processing message
let status = waitpid(child, None).expect("waitpid failed");
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "child should exit with code 0");
}
other => panic!("unexpected wait status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
drop(main_end);
let control_fd = child_end.into_fd();
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
}
Err(e) => panic!("fork failed: {}", e),
}
}
#[test]
#[serial]
fn multiple_children_shut_down_gracefully() {
// Fork 3 children
let mut children = Vec::new();
for i in 0..3u8 {
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
let vm_name = format!("test-vm-{}", i);
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, i]);
let temp_dir = tempfile::tempdir().expect("tempdir");
let socket_path = temp_dir.path().join("test.sock");
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
drop(child_end);
children.push((child, main_end, vm_name, temp_dir));
}
Ok(ForkResult::Child) => {
drop(main_end);
let control_fd = child_end.into_fd();
vm_switch::child::run_child_process(&vm_name, mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
}
Err(e) => panic!("fork failed: {}", e),
}
}
// Give children time to start and send Ready
std::thread::sleep(Duration::from_millis(100));
// Drain Ready messages
for (_, control, _, _) in &children {
let _ = control.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
}
// Collect PIDs and drop control channels to signal shutdown
let pids: Vec<_> = children
.into_iter()
.map(|(pid, control, _, _temp_dir)| {
drop(control);
pid
})
.collect();
// Wait for all children
for pid in pids {
let status = waitpid(pid, None).expect("waitpid failed");
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "child should exit with code 0");
}
other => panic!("unexpected wait status: {:?}", other),
}
}
}

View file

@ -0,0 +1,63 @@
//! Integration tests for packet forwarding.
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
use vm_switch::child::PacketForwarder;
use vm_switch::mac::Mac;
use vm_switch::ring::{Consumer, Producer};
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
let mut frame = vec![0u8; 64];
frame[0..6].copy_from_slice(&dest);
frame[6..12].copy_from_slice(&src);
frame
}
#[test]
fn forwarder_validates_source_mac() {
let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]);
let peer_mac = [2, 0, 0, 0, 0, 2];
let mut forwarder = PacketForwarder::new(our_mac);
// Set up an egress buffer for broadcast
let consumer = Consumer::new().expect("consumer");
let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd())) };
let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd())) };
let producer = Producer::from_fds(memfd, eventfd).expect("producer");
forwarder.add_egress("router".to_string(), peer_mac, producer, true);
// Correct source MAC - broadcast frame should be sent to egress with broadcast=true
let good_frame = make_frame([0xff; 6], our_mac.bytes());
assert!(forwarder.forward_tx(&good_frame));
// Wrong source MAC - should be dropped
let bad_frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]);
assert!(!forwarder.forward_tx(&bad_frame));
}
#[test]
fn forwarder_ingress_validates_peer_mac() {
let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]);
let peer_mac = [2, 0, 0, 0, 0, 2];
let spoofed = [3, 0, 0, 0, 0, 3];
let mut forwarder = PacketForwarder::new(our_mac);
// Set up ingress from peer
let producer = Producer::new().expect("producer");
let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) };
let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) };
let consumer = Consumer::from_fds(memfd, eventfd).expect("consumer");
forwarder.add_ingress("router".to_string(), peer_mac, consumer);
// Good frame from peer
let good = make_frame(our_mac.bytes(), peer_mac);
producer.push(&good);
// Spoofed frame (wrong source)
let bad = make_frame(our_mac.bytes(), spoofed);
producer.push(&bad);
let received = forwarder.poll_ingress();
assert_eq!(received.len(), 1);
assert_eq!(&received[0][6..12], &peer_mac);
}

View file

@ -0,0 +1,256 @@
//! Integration tests for full sandbox isolation.
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{fork, ForkResult, Uid};
use std::fs;
use std::path::Path;
use vm_switch::sandbox::{apply_sandbox, SandboxResult};
/// Helper to run test in forked child process.
fn run_in_fork<F: FnOnce() + std::panic::UnwindSafe>(test_fn: F) {
if Uid::current().is_root() {
eprintln!("Skipping test: already running as root");
return;
}
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
let status = waitpid(child, None).unwrap();
match status {
WaitStatus::Exited(_, 0) => {}
other => panic!("Child failed: {:?}", other),
}
}
Ok(ForkResult::Child) => {
let result = std::panic::catch_unwind(test_fn);
match &result {
Err(e) => {
if let Some(s) = e.downcast_ref::<&str>() {
eprintln!("Child panic: {}", s);
} else if let Some(s) = e.downcast_ref::<String>() {
eprintln!("Child panic: {}", s);
} else {
eprintln!("Child panic: unknown error");
}
}
Ok(()) => {}
}
std::process::exit(if result.is_ok() { 0 } else { 1 });
}
Err(e) => panic!("Fork failed: {}", e),
}
}
/// Helper to handle SandboxResult in tests.
/// In the inner wrapper parent, propagates exit code.
/// In the sandboxed child, returns the config path.
fn apply_and_unwrap(config_path: &Path) -> std::path::PathBuf {
match apply_sandbox(config_path).expect("apply_sandbox failed") {
SandboxResult::Parent(code) => {
// Inner wrapper parent - propagate child's exit code
std::process::exit(code);
}
SandboxResult::Sandboxed(path) => path,
}
}
#[test]
fn apply_sandbox_returns_config_path() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
let result = apply_and_unwrap(&config_path);
assert_eq!(result, Path::new("/config"));
});
}
#[test]
fn apply_sandbox_isolates_ipc_namespace() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
// Get parent IPC namespace before fork
let parent_ipc = fs::read_link("/proc/self/ns/ipc").unwrap();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// The fact that apply_sandbox succeeded means IPC namespace was entered
// We verify isolation indirectly through the other tests
});
// Parent's namespace should be unchanged
let parent_ipc_after = fs::read_link("/proc/self/ns/ipc").unwrap();
assert_eq!(parent_ipc, parent_ipc_after);
}
#[test]
fn apply_sandbox_isolates_network_namespace() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// In empty network namespace, /sys/class/net should be empty or not exist
// Since /sys is not mounted in our minimal root, network is effectively isolated
assert!(!Path::new("/sys").exists(), "/sys should not exist in sandbox");
});
}
#[test]
fn apply_sandbox_creates_complete_isolation() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
// Create marker file
fs::write(config_path.join("marker.txt"), "isolated").unwrap();
run_in_fork(move || {
let new_config = apply_and_unwrap(&config_path);
// Verify config path is correct
assert_eq!(new_config, Path::new("/config"));
// Verify filesystem isolation
assert!(Path::new("/config/marker.txt").exists());
assert!(!Path::new("/home").exists());
assert!(!Path::new("/usr").exists());
// Verify /dev works
assert!(Path::new("/dev/null").exists());
fs::write("/dev/null", "test").expect("write to /dev/null");
// Verify /tmp is writable
fs::write("/tmp/test.txt", "temp").expect("write to /tmp");
});
}
#[test]
fn forked_children_inherit_sandbox() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
fs::write(config_path.join("shared.txt"), "parent").unwrap();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// Fork a child from within the sandbox
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
let status = waitpid(child, None).unwrap();
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "Nested child should exit with 0");
}
other => panic!("Nested child unexpected status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
// Nested child inherits sandbox
let exists = Path::new("/config/shared.txt").exists();
let no_home = !Path::new("/home").exists();
std::process::exit(if exists && no_home { 0 } else { 1 });
}
Err(e) => panic!("Nested fork failed: {}", e),
}
});
}
#[test]
fn apply_sandbox_creates_pid_namespace() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// Read /proc/self/stat - should be PID 1 in the new namespace
let stat = fs::read_to_string("/proc/self/stat").expect("should read /proc/self/stat");
assert!(
stat.starts_with("1 "),
"Should be PID 1 in namespace, got: {}",
stat
);
});
}
#[test]
fn apply_sandbox_mounts_proc() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// /proc should be mounted and functional
assert!(Path::new("/proc").exists(), "/proc should exist");
assert!(
Path::new("/proc/self").exists(),
"/proc/self should exist"
);
// Should be able to read process info
let cmdline = fs::read_to_string("/proc/self/cmdline");
assert!(cmdline.is_ok(), "Should be able to read /proc/self/cmdline");
});
}
#[test]
fn sandbox_children_get_sequential_pids() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// Fork a child and verify it gets PID 2
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
let status = waitpid(child, None).unwrap();
match status {
WaitStatus::Exited(_, code) => {
assert_eq!(code, 0, "Child should be PID 2");
}
other => panic!("Child unexpected status: {:?}", other),
}
}
Ok(ForkResult::Child) => {
// Read our PID from /proc
let stat = fs::read_to_string("/proc/self/stat").unwrap();
let pid_str = stat.split_whitespace().next().unwrap();
let pid: i32 = pid_str.parse().unwrap();
// Should be PID 2 (parent is 1)
std::process::exit(if pid == 2 { 0 } else { 1 });
}
Err(e) => panic!("fork failed: {}", e),
}
});
}
#[test]
fn sandbox_proc_shows_only_sandboxed_processes() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
apply_and_unwrap(&config_path);
// Count processes visible in /proc
let mut proc_pids = 0;
for entry in fs::read_dir("/proc").unwrap() {
let entry = entry.unwrap();
let name = entry.file_name();
let name_str = name.to_string_lossy();
// Count numeric directories (PIDs)
if name_str.chars().all(|c| c.is_ascii_digit()) {
proc_pids += 1;
}
}
// Should only see PID 1 (ourselves)
assert_eq!(proc_pids, 1, "Should only see 1 process in /proc, saw {}", proc_pids);
});
}

View file

@ -0,0 +1,140 @@
//! Integration tests for mount namespace and filesystem isolation.
//!
//! These tests require user namespace support (for unprivileged mount ns).
use nix::unistd::Uid;
use std::fs;
use std::path::Path;
use vm_switch::sandbox::{enter_user_namespace, setup_filesystem_isolation};
/// Helper to run test in forked child process.
fn run_in_fork<F: FnOnce() + std::panic::UnwindSafe>(test_fn: F) {
if Uid::current().is_root() {
eprintln!("Skipping test: already running as root");
return;
}
match unsafe { nix::unistd::fork() } {
Ok(nix::unistd::ForkResult::Parent { child }) => {
let status = nix::sys::wait::waitpid(child, None).unwrap();
match status {
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
other => panic!("Child failed: {:?}", other),
}
}
Ok(nix::unistd::ForkResult::Child) => {
let result = std::panic::catch_unwind(test_fn);
match &result {
Err(e) => {
if let Some(s) = e.downcast_ref::<&str>() {
eprintln!("Child panic: {}", s);
} else if let Some(s) = e.downcast_ref::<String>() {
eprintln!("Child panic: {}", s);
} else {
eprintln!("Child panic: unknown error");
}
}
Ok(()) => {}
}
std::process::exit(if result.is_ok() { 0 } else { 1 });
}
Err(e) => panic!("Fork failed: {}", e),
}
}
#[test]
fn filesystem_isolation_creates_minimal_root() {
// Create a temp config dir that we'll bind-mount
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
// Create a marker file in config dir
fs::write(config_path.join("marker.txt"), "test").unwrap();
run_in_fork(move || {
enter_user_namespace().expect("enter_user_namespace failed");
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
// Verify /config exists and contains our marker
assert!(Path::new("/config/marker.txt").exists(), "/config/marker.txt should exist");
let content = fs::read_to_string("/config/marker.txt").unwrap();
assert_eq!(content, "test");
// Verify /proc mount point exists (may not be mounted without PID namespace)
assert!(Path::new("/proc").is_dir(), "/proc should be a directory");
// Verify /dev devices exist
assert!(Path::new("/dev/null").exists(), "/dev/null should exist");
assert!(Path::new("/dev/zero").exists(), "/dev/zero should exist");
assert!(Path::new("/dev/urandom").exists(), "/dev/urandom should exist");
// Verify /tmp exists
assert!(Path::new("/tmp").is_dir(), "/tmp should be a directory");
});
}
#[test]
fn filesystem_isolation_hides_host_filesystem() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
enter_user_namespace().expect("enter_user_namespace failed");
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
// These paths should NOT exist (host filesystem hidden)
assert!(!Path::new("/home").exists(), "/home should not exist");
assert!(!Path::new("/usr").exists(), "/usr should not exist");
assert!(!Path::new("/etc").exists(), "/etc should not exist");
assert!(!Path::new("/bin").exists(), "/bin should not exist");
assert!(!Path::new("/nix").exists(), "/nix should not exist");
});
}
#[test]
fn filesystem_isolation_config_dir_is_writable() {
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
enter_user_namespace().expect("enter_user_namespace failed");
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
// Should be able to create files in /config (for socket files)
let test_file = Path::new("/config/test-write.txt");
fs::write(test_file, "writable").expect("should be able to write to /config");
assert!(test_file.exists());
let content = fs::read_to_string(test_file).unwrap();
assert_eq!(content, "writable");
});
}
#[test]
fn dev_devices_are_functional() {
use std::io::Read;
let config_dir = tempfile::tempdir().unwrap();
let config_path = config_dir.path().to_path_buf();
run_in_fork(move || {
enter_user_namespace().expect("enter_user_namespace failed");
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
// /dev/null should accept writes and return nothing on read
fs::write("/dev/null", "discard this").expect("write to /dev/null");
let content = fs::read("/dev/null").unwrap();
assert!(content.is_empty(), "/dev/null should return empty on read");
// /dev/zero should return zeros (read limited amount)
let mut f = fs::File::open("/dev/zero").expect("open /dev/zero");
let mut buf = [0xffu8; 16]; // Initialize with non-zero to verify it changes
f.read_exact(&mut buf).expect("read from /dev/zero");
assert!(buf.iter().all(|&b| b == 0), "/dev/zero should return zeros");
// /dev/urandom should return random bytes (just check it's readable)
let mut f = fs::File::open("/dev/urandom").expect("open /dev/urandom");
let mut buf = [0u8; 16];
f.read_exact(&mut buf).expect("/dev/urandom should be readable");
});
}

View file

@ -0,0 +1,88 @@
//! Integration tests for user namespace isolation.
//!
//! These tests require the ability to create user namespaces,
//! which is available to unprivileged users on most Linux systems.
use nix::unistd::{getgid, getuid, Uid};
use std::fs;
use vm_switch::sandbox::enter_user_namespace;
#[test]
fn enter_user_namespace_maps_to_root() {
// Skip if we're already root (CI environments sometimes run as root)
if Uid::current().is_root() {
eprintln!("Skipping test: already running as root");
return;
}
// Fork to avoid affecting the test process
match unsafe { nix::unistd::fork() } {
Ok(nix::unistd::ForkResult::Parent { child }) => {
// Parent waits for child
let status = nix::sys::wait::waitpid(child, None).unwrap();
match status {
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
other => panic!("Child failed: {:?}", other),
}
}
Ok(nix::unistd::ForkResult::Child) => {
// Child enters user namespace and checks UID
let result = std::panic::catch_unwind(|| {
enter_user_namespace().expect("enter_user_namespace failed");
// After entering user namespace, we should appear as root
let uid = getuid();
let gid = getgid();
assert!(uid.is_root(), "Expected UID 0, got {}", uid);
assert_eq!(gid.as_raw(), 0, "Expected GID 0, got {}", gid);
// Verify we're in a different namespace
let ns_path = fs::read_link("/proc/self/ns/user").unwrap();
eprintln!("User namespace: {:?}", ns_path);
});
std::process::exit(if result.is_ok() { 0 } else { 1 });
}
Err(e) => panic!("Fork failed: {}", e),
}
}
#[test]
fn enter_user_namespace_is_isolated() {
// Skip if we're already root
if Uid::current().is_root() {
eprintln!("Skipping test: already running as root");
return;
}
// Get parent namespace inode before fork
let parent_ns = fs::read_link("/proc/self/ns/user").unwrap();
match unsafe { nix::unistd::fork() } {
Ok(nix::unistd::ForkResult::Parent { child }) => {
let status = nix::sys::wait::waitpid(child, None).unwrap();
match status {
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
other => panic!("Child failed: {:?}", other),
}
}
Ok(nix::unistd::ForkResult::Child) => {
let result = std::panic::catch_unwind(|| {
enter_user_namespace().expect("enter_user_namespace failed");
let child_ns = fs::read_link("/proc/self/ns/user").unwrap();
// Namespace should be different from parent
assert_ne!(
parent_ns, child_ns,
"User namespace should differ: parent={:?}, child={:?}",
parent_ns, child_ns
);
});
std::process::exit(if result.is_ok() { 0 } else { 1 });
}
Err(e) => panic!("Fork failed: {}", e),
}
}

View file

@ -0,0 +1,178 @@
//! Integration tests for seccomp filtering.
use nix::sys::signal::Signal;
use nix::sys::wait::{waitpid, WaitStatus};
use nix::unistd::{fork, ForkResult, Uid};
use std::process;
use vm_switch::seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompMode};
/// Run test in forked child, return wait status.
fn run_in_fork<F: FnOnce() -> i32 + std::panic::UnwindSafe>(test_fn: F) -> WaitStatus {
if Uid::current().is_root() {
eprintln!("Skipping: running as root");
return WaitStatus::Exited(nix::unistd::Pid::from_raw(0), 0);
}
match unsafe { fork() } {
Ok(ForkResult::Parent { child }) => {
waitpid(child, None).expect("waitpid failed")
}
Ok(ForkResult::Child) => {
let result = std::panic::catch_unwind(test_fn);
let code = match result {
Ok(code) => code,
Err(_) => 1,
};
process::exit(code);
}
Err(e) => panic!("Fork failed: {}", e),
}
}
#[test]
fn seccomp_kill_blocks_ptrace() {
let status = run_in_fork(|| {
apply_main_seccomp(SeccompMode::Kill).expect("apply failed");
// ptrace is not whitelisted
unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
0 // Should not reach here
});
match status {
WaitStatus::Signaled(_, sig, _) => {
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
}
WaitStatus::Exited(_, code) => {
assert_ne!(code, 0, "should have been killed");
}
_ => panic!("unexpected: {:?}", status),
}
}
#[test]
fn seccomp_trap_sends_sigsys() {
let status = run_in_fork(|| {
apply_main_seccomp(SeccompMode::Trap).expect("apply failed");
unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
0
});
match status {
WaitStatus::Signaled(_, Signal::SIGSYS, _) => {}
WaitStatus::Stopped(_, Signal::SIGSYS) => {}
_ => panic!("expected SIGSYS, got {:?}", status),
}
}
#[test]
fn seccomp_disabled_allows_all() {
let status = run_in_fork(|| {
apply_main_seccomp(SeccompMode::Disabled).expect("apply failed");
// Would be blocked, but disabled mode allows it
let _ = unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
0
});
match status {
WaitStatus::Exited(_, 0) => {}
_ => panic!("expected clean exit, got {:?}", status),
}
}
#[test]
fn seccomp_allows_whitelisted() {
let status = run_in_fork(|| {
apply_main_seccomp(SeccompMode::Kill).expect("apply failed");
// All whitelisted
let _ = unsafe { libc::getpid() };
let _ = unsafe { libc::getuid() };
let msg = b"ok\n";
let ret = unsafe { libc::write(2, msg.as_ptr() as *const _, msg.len()) };
if ret < 0 { return 1; }
0
});
match status {
WaitStatus::Exited(_, 0) => {}
_ => panic!("expected clean exit, got {:?}", status),
}
}
#[test]
fn child_filter_blocks_fork() {
let status = run_in_fork(|| {
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
// fork not in child whitelist
let ret = unsafe { libc::fork() };
if ret == 0 {
process::exit(99); // grandchild
}
0 // Should not reach
});
match status {
WaitStatus::Signaled(_, sig, _) => {
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
}
WaitStatus::Exited(_, code) => {
assert!(code != 0 && code != 99, "fork should be blocked");
}
_ => panic!("unexpected: {:?}", status),
}
}
#[test]
fn child_filter_blocks_socket() {
let status = run_in_fork(|| {
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
// socket() not in child whitelist
let fd = unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0) };
if fd >= 0 {
unsafe { libc::close(fd) };
return 1; // Should have been blocked
}
0
});
match status {
WaitStatus::Signaled(_, sig, _) => {
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
}
WaitStatus::Exited(_, code) => {
assert_ne!(code, 0, "socket should be blocked");
}
_ => panic!("unexpected: {:?}", status),
}
}
#[test]
fn child_filter_allows_memfd() {
let status = run_in_fork(|| {
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
// memfd_create is in child whitelist (for ring buffers)
let fd = unsafe {
libc::syscall(
libc::SYS_memfd_create,
b"test\0".as_ptr(),
0u32,
)
};
if fd < 0 {
return 1;
}
unsafe { libc::close(fd as i32) };
0
});
match status {
WaitStatus::Exited(_, 0) => {}
_ => panic!("expected clean exit, got {:?}", status),
}
}