From 6941d2fe4cc73a79db5bf2831246dee78ff42985 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=AD=C3=B0=20Steinn=20Geirsson?= Date: Mon, 9 Feb 2026 20:19:26 +0000 Subject: [PATCH] feat(vm-switch): add process isolation with namespace sandbox and seccomp Replace the thread-based vhost-user backend architecture with a fork-based process model where each VM gets its own child process. This enables strong isolation between VMs handling untrusted network traffic, with multiple layers of defense in depth. Process model: - Main process watches config directory and orchestrates child lifecycle - One child process forked per VM, running as vhost-user net backend - Children communicate via SOCK_SEQPACKET control channel with SCM_RIGHTS - Automatic child restart on crash/disconnect, with peer notification - Ping/pong heartbeat monitoring for worker health (1s interval, 100ms timeout) - SIGCHLD handling integrated into tokio event loop Inter-process packet forwarding: - Lock-free SPSC ring buffers in shared memory (memfd + mmap) - 64-slot rings (~598KB each) with atomic head/tail, no locks in datapath - Eventfd signaling for empty-to-non-empty transitions - Main orchestrates buffer exchange: GetBuffer -> BufferReady -> PutBuffer - Zero-copy path: producers write directly into consumer's shared memory Namespace sandbox (applied before tokio, single-threaded): - User namespace: unprivileged outside, UID 0 inside - PID namespace: main is PID 1, children invisible to host - Mount namespace: minimal tmpfs root with /config, /dev, /proc, /tmp - IPC namespace: isolated System V IPC - Network namespace: empty, communication only via inherited FDs - Controllable via --no-sandbox flag Seccomp BPF filtering (two-tier whitelist): - Main filter: allows fork, socket creation, inotify, openat - Child filter: strict subset - no fork, no socket, no file open - Child filter applied after vhost setup, before event loop - Modes: kill (default), trap (SIGSYS debug), log, disabled Also adds vm-switch service dependencies to VM units in the NixOS module so VMs wait for their network switch before starting. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 69 ++ README.md | 50 + modules/config.nix | 5 +- vm-switch/Cargo.lock | 267 ++++++ vm-switch/Cargo.toml | 11 + vm-switch/src/args.rs | 87 +- vm-switch/src/backend.rs | 921 ------------------- vm-switch/src/child/forwarder.rs | 370 ++++++++ vm-switch/src/child/mod.rs | 14 + vm-switch/src/child/poll.rs | 115 +++ vm-switch/src/child/process.rs | 239 +++++ vm-switch/src/child/vhost.rs | 283 ++++++ vm-switch/src/control.rs | 721 +++++++++++++++ vm-switch/src/frame.rs | 62 ++ vm-switch/src/lib.rs | 28 +- vm-switch/src/main.rs | 115 ++- vm-switch/src/manager.rs | 1330 ++++++++++++++------------- vm-switch/src/ring.rs | 865 +++++++++++++++++ vm-switch/src/sandbox.rs | 564 ++++++++++++ vm-switch/src/seccomp.rs | 443 +++++++++ vm-switch/src/switch.rs | 445 --------- vm-switch/tests/buffer_exchange.rs | 323 +++++++ vm-switch/tests/crash_handling.rs | 104 +++ vm-switch/tests/fork_lifecycle.rs | 160 ++++ vm-switch/tests/packet_flow.rs | 63 ++ vm-switch/tests/sandbox_full.rs | 256 ++++++ vm-switch/tests/sandbox_mount_ns.rs | 140 +++ vm-switch/tests/sandbox_user_ns.rs | 88 ++ vm-switch/tests/seccomp_filter.rs | 178 ++++ 29 files changed, 6275 insertions(+), 2041 deletions(-) delete mode 100644 vm-switch/src/backend.rs create mode 100644 vm-switch/src/child/forwarder.rs create mode 100644 vm-switch/src/child/mod.rs create mode 100644 vm-switch/src/child/poll.rs create mode 100644 vm-switch/src/child/process.rs create mode 100644 vm-switch/src/child/vhost.rs create mode 100644 vm-switch/src/control.rs create mode 100644 vm-switch/src/ring.rs create mode 100644 vm-switch/src/sandbox.rs create mode 100644 vm-switch/src/seccomp.rs delete mode 100644 vm-switch/src/switch.rs create mode 100644 vm-switch/tests/buffer_exchange.rs create mode 100644 vm-switch/tests/crash_handling.rs create mode 100644 vm-switch/tests/fork_lifecycle.rs create mode 100644 vm-switch/tests/packet_flow.rs create mode 100644 vm-switch/tests/sandbox_full.rs create mode 100644 vm-switch/tests/sandbox_mount_ns.rs create mode 100644 vm-switch/tests/sandbox_user_ns.rs create mode 100644 vm-switch/tests/seccomp_filter.rs diff --git a/CLAUDE.md b/CLAUDE.md index dacb727..0d43c0e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -138,6 +138,14 @@ Provides L2 switching for VM-to-VM networks: - **Purpose:** Handles vhost-user protocol for VM network interfaces - **Systemd:** One service per vmNetwork (`vm-switch-.service`) +**CLI flags:** +``` +-d, --config-dir Config/MAC file directory (default: /run/vm-switch) +--log-level error, warn, info, debug, trace (default: warn) +--no-sandbox Disable namespace sandboxing +--seccomp-mode kill (default), trap, log, disabled +``` + **Testing locally:** ```bash # Build and run manually @@ -149,6 +157,67 @@ mkdir -p /tmp/test-switch/router echo "52:00:00:00:00:01" > /tmp/test-switch/router/router.mac ``` +**Process model:** Main process forks one child per VM. Children are vhost-user net backends that handle virtio TX/RX for their VM. Main orchestrates lifecycle, config watching, and buffer exchange between children. Children exit when the vhost-user client (crosvm) disconnects; main automatically restarts them so crosvm can reconnect. + +**Startup sequence:** +1. Parse args, apply namespace sandbox (single-threaded, before tokio) +2. Apply main seccomp filter +3. Start tokio runtime, create ConfigWatcher + BackendManager +4. Start tokio runtime, enter async event loop (SIGCHLD via tokio select branch) + +**Key source files:** +- `src/main.rs` - Entry point, sandbox/seccomp setup, async event loop +- `src/manager.rs` - BackendManager: fork children, buffer exchange, crash cleanup +- `src/child/process.rs` - Child entry point: control channel, vhost daemon, child seccomp +- `src/child/forwarder.rs` - PacketForwarder: L2 routing via ring buffers +- `src/child/vhost.rs` - ChildVhostBackend: virtio TX/RX callbacks +- `src/child/poll.rs` - Event polling for control channel + ingress buffers +- `src/control.rs` - Main-child IPC over Unix seqpacket sockets + SCM_RIGHTS +- `src/ring.rs` - Lock-free SPSC ring buffer in shared memory (memfd) +- `src/sandbox.rs` - Namespace isolation (user, PID, mount, IPC, network) +- `src/seccomp.rs` - BPF syscall filters (main and child whitelists) +- `src/frame.rs` - Ethernet frame parsing, MAC validation +- `src/main.rs` - SIGCHLD handling via tokio select branch + +**Control protocol** (main <-> child IPC via `SOCK_SEQPACKET` + `SCM_RIGHTS`): + +| Direction | Message | FDs | Purpose | +|-----------|---------|-----|---------| +| Main -> Child | `GetBuffer { peer_name, peer_mac }` | - | Ask child to create ingress buffer for a peer | +| Child -> Main | `BufferReady { peer_name }` | memfd, eventfd | Ingress buffer created, here are the FDs | +| Main -> Child | `PutBuffer { peer_name, peer_mac, broadcast }` | memfd, eventfd | Give child a peer's buffer as egress target | +| Main -> Child | `RemovePeer { peer_name }` | - | Clean up buffers for disconnected/crashed peer | +| Main -> Child | `Ping` | - | Heartbeat request (sent every 1s) | +| Child -> Main | `Ready` | - | Child initialized and ready | +| Child -> Main | `Pong` | - | Heartbeat response (must arrive within 100ms) | + +Messages serialized with `postcard`. FDs passed via ancillary data. + +**Buffer exchange flow:** +1. Main sends `GetBuffer` to Child1 ("create ingress buffer for Child2") +2. Child1 creates SPSC ring buffer (memfd + eventfd), becomes Consumer, replies `BufferReady` +3. Main forwards those FDs to Child2 via `PutBuffer` -- Child2 becomes Producer +4. Packets now flow: Child2 writes to Producer -> shared memfd -> Child1 reads from Consumer + +**SPSC ring buffer** (`ring.rs`): Lock-free single-producer/single-consumer queue backed by `memfd_create()` + `mmap(MAP_SHARED)`. 64 slots, ~598KB total. Head/tail use atomic operations (no locks in datapath). Eventfd signals empty-to-non-empty transitions. + +**Sandbox** (applied before tokio, requires single-threaded): +1. **User namespace** - Maps real UID to 0 inside, enables unprivileged namespace creation +2. **PID namespace** - Fork into new PID ns; main becomes PID 1 +3. **Mount namespace** - Minimal tmpfs root with `/config` (bind-mount of config dir), `/dev` (null, zero, urandom), `/proc`, `/tmp`. Pivot root, unmount old. +4. **IPC namespace** - Isolates System V IPC +5. **Network namespace** - Empty (no interfaces). Communication only via inherited FDs. + +**Seccomp filtering** (BPF syscall whitelist): +- `--seccomp-mode=kill` (default): Terminate on blocked syscall +- `--seccomp-mode=trap`: Send SIGSYS (debug with strace) +- `--seccomp-mode=log`: Log violations but allow +- `--seccomp-mode=disabled`: Skip filtering + +Two filter tiers (child is a strict subset of main): +- **Main**: Allows fork, socket creation, inotify, openat (config watching + child management) +- **Child**: No fork, no socket creation, no file open. Applied after vhost setup completes. Allows clone3 for vhost-user threads. + ### Dependencies - Custom crosvm fork: `git.dsg.is/davidlowsec/crosvm.git` diff --git a/README.md b/README.md index 2c79686..216fc71 100644 --- a/README.md +++ b/README.md @@ -499,3 +499,53 @@ The host provides: - Polkit rules for the configured user to manage VM services without sudo - CLI tools: `vm-run`, `vm-start`, `vm-stop`, `vm-start-debug`, `vm-shell` - Desktop integration with .desktop files for guest applications + +### vm-switch + +The `vm-switch` daemon (`vm-switch/` Rust crate) provides L2 switching for VM-to-VM networks. One instance runs per `vmNetwork`, managed by systemd (`vm-switch-.service`). + +**Process model:** The main process watches a config directory for MAC files and forks one child process per VM. Each child is a vhost-user net backend serving a single VM's network interface. + +``` + Main Process + (config watch, orchestration) + / | \ + fork / fork | fork \ + v v v + Child: router Child: banking Child: shopping + (vhost-user) (vhost-user) (vhost-user) + | | | + [unix socket] [unix socket] [unix socket] + | | | + crosvm crosvm crosvm + (router VM) (banking VM) (shopping VM) +``` + +**Packet forwarding** uses lock-free SPSC ring buffers in shared memory (`memfd_create` + `mmap`). When a VM transmits a frame, its child process validates the source MAC address and routes the frame to the correct destination: +- Unicast: pushed into the destination child's ingress ring buffer +- Broadcast/multicast: pushed into all peers' ingress buffers + +Ring buffers use atomic head/tail pointers (no locks in the datapath) with eventfd signaling for empty-to-non-empty transitions. + +**Buffer exchange protocol:** The main process orchestrates buffer setup between children via a control channel (`SOCK_SEQPACKET` + `SCM_RIGHTS` for passing memfd/eventfd file descriptors): + +1. Main tells Child A: "create an ingress buffer for Child B" (`GetBuffer`) +2. Child A creates the ring buffer and returns the FDs (`BufferReady`) +3. Main forwards those FDs to Child B as an egress target (`PutBuffer`) +4. Child B can now write frames directly into Child A's memory -- no copies through the main process + +**Sandboxing:** The daemon runs in a multi-layer sandbox applied at startup (before any async runtime or threads): + +| Layer | Mechanism | Effect | +|-------|-----------|--------| +| User namespace | `CLONE_NEWUSER` | Unprivileged outside, appears as UID 0 inside | +| PID namespace | `CLONE_NEWPID` | Main is PID 1; children invisible to host | +| Mount namespace | `CLONE_NEWNS` + pivot_root | Minimal tmpfs root: `/config`, `/dev` (null/zero/urandom), `/proc`, `/tmp` | +| IPC namespace | `CLONE_NEWIPC` | Isolated System V IPC | +| Network namespace | `CLONE_NEWNET` | No interfaces; communication only via inherited FDs | +| Seccomp (main) | BPF whitelist | Allows fork, socket creation, inotify for config watching | +| Seccomp (child) | Tighter BPF whitelist | No fork, no socket creation, no file open; applied after vhost setup | + +Seccomp modes: `--seccomp-mode=kill` (default), `trap` (SIGSYS for debugging), `log`, `disabled`. + +Disable sandboxing for debugging with `--no-sandbox` and `--seccomp-mode=disabled`. diff --git a/modules/config.nix b/modules/config.nix index 0234a82..d7710f9 100644 --- a/modules/config.nix +++ b/modules/config.nix @@ -1097,7 +1097,10 @@ in vm: lib.nameValuePair "qubes-lite-${vm.name}-vm" { description = "qubes-lite VM: ${vm.name}"; - after = [ "network.target" ]; + after = + [ "network.target" ] + ++ map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork); + requires = map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork); serviceConfig = { Type = "simple"; ExecStart = "${mkVmScript vm}"; diff --git a/vm-switch/Cargo.lock b/vm-switch/Cargo.lock index d4f1877..e3f6f17 100644 --- a/vm-switch/Cargo.lock +++ b/vm-switch/Cargo.lock @@ -61,6 +61,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "anyhow" +version = "1.0.101" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" + [[package]] name = "arc-swap" version = "1.8.1" @@ -70,6 +76,15 @@ dependencies = [ "rustversion", ] +[[package]] +name = "atomic-polyfill" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4" +dependencies = [ + "critical-section", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -88,6 +103,12 @@ version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.11.1" @@ -100,6 +121,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "clap" version = "4.5.57" @@ -140,12 +167,39 @@ version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32" +[[package]] +name = "cobs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1" +dependencies = [ + "thiserror", +] + [[package]] name = "colorchoice" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "embedded-io" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced" + +[[package]] +name = "embedded-io" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d" + [[package]] name = "errno" version = "0.3.14" @@ -191,6 +245,42 @@ dependencies = [ "libc", ] +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", + "slab", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -203,6 +293,29 @@ dependencies = [ "wasip2", ] +[[package]] +name = "hash32" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67" +dependencies = [ + "byteorder", +] + +[[package]] +name = "heapless" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f" +dependencies = [ + "atomic-polyfill", + "hash32", + "rustc_version", + "serde", + "spin", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.5.0" @@ -345,6 +458,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nix" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46" +dependencies = [ + "bitflags 2.10.0", + "cfg-if", + "cfg_aliases", + "libc", +] + [[package]] name = "notify" version = "7.0.0" @@ -436,6 +561,25 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "postcard" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6764c3b5dd454e283a30e6dfe78e9b31096d9e32036b5d1eaac7a6119ccb9a24" +dependencies = [ + "cobs", + "embedded-io 0.4.0", + "embedded-io 0.6.1", + "heapless", + "serde", +] + [[package]] name = "ppv-lite86" version = "0.2.21" @@ -533,6 +677,15 @@ version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" +[[package]] +name = "rustc_version" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "1.1.3" @@ -561,12 +714,98 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + +[[package]] +name = "seccompiler" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345a3e4dddf721a478089d4697b83c6c0a8f5bf16086f6c13397e4534eb6e2e5" +dependencies = [ + "libc", +] + +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serial_test" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -586,6 +825,12 @@ dependencies = [ "libc", ] +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + [[package]] name = "smallvec" version = "1.15.1" @@ -602,6 +847,21 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +dependencies = [ + "lock_api", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + [[package]] name = "strsim" version = "0.11.1" @@ -841,9 +1101,16 @@ dependencies = [ name = "vm-switch" version = "0.1.0" dependencies = [ + "anyhow", "clap", + "libc", + "nix", "notify", "notify-debouncer-full", + "postcard", + "seccompiler", + "serde", + "serial_test", "tempfile", "thiserror", "tokio", diff --git a/vm-switch/Cargo.toml b/vm-switch/Cargo.toml index cfcdb98..61f92c1 100644 --- a/vm-switch/Cargo.toml +++ b/vm-switch/Cargo.toml @@ -27,9 +27,20 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } # Error handling thiserror = "2" +anyhow = "1" # CLI clap = { version = "4", features = ["derive"] } +# Sandboxing +nix = { version = "0.29", features = ["sched", "mount", "user", "signal", "process", "poll", "fs"] } +seccompiler = "0.4" +libc = "0.2" + +# Serialization (for control channel) +serde = { version = "1", features = ["derive"] } +postcard = { version = "1", features = ["alloc"] } + [dev-dependencies] tempfile = "3" +serial_test = "3" diff --git a/vm-switch/src/args.rs b/vm-switch/src/args.rs index 133b7c6..0827ccb 100644 --- a/vm-switch/src/args.rs +++ b/vm-switch/src/args.rs @@ -1,9 +1,46 @@ //! Command-line argument parsing. -use clap::{Parser, ValueEnum}; +use std::fmt; use std::path::PathBuf; +use std::sync::Mutex; + +use crate::seccomp::SeccompMode; +use clap::{Parser, ValueEnum}; +use tracing_subscriber::fmt::format::Writer; +use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields}; +use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::EnvFilter; +/// Process name for log prefixes ("main" or "worker-$vmname"). +static PROCESS_NAME: Mutex = Mutex::new(String::new()); + +/// Set the process name for log prefixes. +pub fn set_process_name(name: impl Into) { + *PROCESS_NAME.lock().unwrap() = name.into(); +} + +/// Custom log formatter that outputs: LEVEL process-name: message fields +struct PrefixedFormatter; + +impl FormatEvent for PrefixedFormatter +where + S: tracing::Subscriber + for<'a> LookupSpan<'a>, + N: for<'a> FormatFields<'a> + 'static, +{ + fn format_event( + &self, + ctx: &FmtContext<'_, S, N>, + mut writer: Writer<'_>, + event: &tracing::Event<'_>, + ) -> fmt::Result { + let level = *event.metadata().level(); + let name = PROCESS_NAME.lock().unwrap(); + write!(writer, "{level} {name}: ")?; + ctx.field_format().format_fields(writer.by_ref(), event)?; + writeln!(writer) + } +} + /// Log level for the application. #[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)] pub enum LogLevel { @@ -39,16 +76,26 @@ pub struct Args { /// Log level (error, warn, info, debug, trace). #[arg(long, value_enum, default_value_t = LogLevel::Warn)] pub log_level: LogLevel, + + /// Disable namespace sandboxing (for debugging). + #[arg(long, default_value_t = false)] + pub no_sandbox: bool, + + /// Seccomp filter mode (kill, trap, log, disabled). + #[arg(long, value_enum, default_value_t = SeccompMode::Kill)] + pub seccomp_mode: SeccompMode, } /// Initialize logging based on log level. pub fn init_logging(level: LogLevel) { + set_process_name("main"); + let filter = EnvFilter::try_from_default_env() .unwrap_or_else(|_| EnvFilter::new(format!("vm_switch={}", level.as_str()))); let _ = tracing_subscriber::fmt() .with_env_filter(filter) - .with_target(false) + .event_format(PrefixedFormatter) .try_init(); } @@ -106,4 +153,40 @@ mod tests { init_logging(LogLevel::Debug); init_logging(LogLevel::Trace); } + + #[test] + fn parse_no_sandbox_flag() { + let args = Args::try_parse_from(["vm-switch", "--no-sandbox"]).unwrap(); + assert!(args.no_sandbox); + } + + #[test] + fn parse_no_sandbox_default_false() { + let args = Args::try_parse_from(["vm-switch"]).unwrap(); + assert!(!args.no_sandbox); + } + + #[test] + fn parse_seccomp_mode_default() { + let args = Args::try_parse_from(["vm-switch"]).unwrap(); + assert_eq!(args.seccomp_mode, SeccompMode::Kill); + } + + #[test] + fn parse_seccomp_mode_trap() { + let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "trap"]).unwrap(); + assert_eq!(args.seccomp_mode, SeccompMode::Trap); + } + + #[test] + fn parse_seccomp_mode_log() { + let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "log"]).unwrap(); + assert_eq!(args.seccomp_mode, SeccompMode::Log); + } + + #[test] + fn parse_seccomp_mode_disabled() { + let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "disabled"]).unwrap(); + assert_eq!(args.seccomp_mode, SeccompMode::Disabled); + } } diff --git a/vm-switch/src/backend.rs b/vm-switch/src/backend.rs deleted file mode 100644 index 9b73d01..0000000 --- a/vm-switch/src/backend.rs +++ /dev/null @@ -1,921 +0,0 @@ -//! Vhost-user network backend implementation. - -use std::collections::HashMap; -use std::sync::{Arc, Mutex, RwLock}; - -use tracing::{debug, trace, warn}; -use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; -use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT}; -use virtio_bindings::virtio_net::{ - VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4, - VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4, - VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS, -}; -use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1; -use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX; -use virtio_queue::QueueT; -use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; -use vmm_sys_util::epoll::EventSet; - -use crate::config::VmRole; -use crate::frame::EthernetFrame; -use crate::mac::Mac; -use crate::switch::{ConnectionId, ForwardDecision, Switch}; - -/// Registry mapping connection IDs to their RX vrings and associated memory. -/// Shared between all backends for frame routing. -pub type VringRegistry = Arc)>>>; - -/// RX queue index. -pub const RX_QUEUE: u16 = 0; -/// TX queue index. -pub const TX_QUEUE: u16 = 1; -/// Number of queues (RX + TX). -pub const NUM_QUEUES: usize = 2; -/// Maximum queue size (must be power of 2, 32768 is typical max). -pub const MAX_QUEUE_SIZE: usize = 32768; -/// Size of virtio-net header. -pub const VIRTIO_NET_HDR_SIZE: usize = 12; - -/// Result of processing a frame from the TX queue. -#[derive(Debug, Clone)] -pub struct ProcessedFrame { - /// Raw frame data (Ethernet frame, no virtio header). - pub data: Vec, - /// Forwarding decision from the switch. - pub decision: ForwardDecision, -} - -/// Virtio net features we support. -const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM) - | (1 << VIRTIO_NET_F_GUEST_CSUM) - | (1 << VIRTIO_NET_F_GUEST_TSO4) - | (1 << VIRTIO_NET_F_GUEST_TSO6) - | (1 << VIRTIO_NET_F_GUEST_UFO) - | (1 << VIRTIO_NET_F_HOST_TSO4) - | (1 << VIRTIO_NET_F_HOST_TSO6) - | (1 << VIRTIO_NET_F_HOST_UFO) - | (1 << VIRTIO_NET_F_MAC) - | (1 << VIRTIO_NET_F_STATUS); - -/// Network backend for a single VM. -pub struct NetBackend { - /// VM name for logging. - name: String, - /// VM's role (router or client). - role: VmRole, - /// VM's MAC address. - mac: Mac, - /// Connection ID in the switch (set after registration). - connection_id: Mutex>, - /// Shared switch for forwarding. - switch: Arc>, - /// Shared registry of all backends' RX vrings for frame routing. - vring_registry: VringRegistry, - /// Guest memory. - mem: Mutex>>, - /// Whether EVENT_IDX is enabled. - event_idx: Mutex, - /// Acked features. - acked_features: Mutex, -} - -impl std::fmt::Debug for NetBackend { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("NetBackend") - .field("name", &self.name) - .field("role", &self.role) - .field("mac", &self.mac) - .field("connection_id", &self.connection_id) - .finish_non_exhaustive() - } -} - -impl NetBackend { - /// Create a new network backend. - pub fn new( - name: String, - role: VmRole, - mac: Mac, - switch: Arc>, - vring_registry: VringRegistry, - ) -> Self { - Self { - name, - role, - mac, - connection_id: Mutex::new(None), - switch, - vring_registry, - mem: Mutex::new(None), - event_idx: Mutex::new(false), - acked_features: Mutex::new(0), - } - } - - /// Get the VM name. - pub fn name(&self) -> &str { - &self.name - } - - /// Get the VM role. - pub fn role(&self) -> VmRole { - self.role - } - - /// Get the VM MAC. - pub fn mac(&self) -> Mac { - self.mac - } - - /// Get the connection ID (if registered). - pub fn connection_id(&self) -> Option { - *self.connection_id.lock().unwrap() - } - - /// Register this backend with the switch. - pub fn register(&self) -> Option { - let mut switch = self.switch.write().unwrap(); - let id = switch.register(self.name.clone(), self.role, self.mac)?; - *self.connection_id.lock().unwrap() = Some(id); - Some(id) - } - - /// Unregister this backend from the switch. - pub fn unregister(&self) { - let mut id_guard = self.connection_id.lock().unwrap(); - if let Some(id) = id_guard.take() { - let mut switch = self.switch.write().unwrap(); - switch.unregister(id); - } - } - - /// Clear state for connection reset (called between vhost-user sessions). - pub fn clear_state(&self) { - // Clear guest memory - *self.mem.lock().unwrap() = None; - - // Reset event_idx - *self.event_idx.lock().unwrap() = false; - - // Reset acked features - *self.acked_features.lock().unwrap() = 0; - } - - /// Process a frame and determine forwarding. - /// - /// Takes raw Ethernet frame bytes (no virtio header). - /// Returns None if not registered or frame is invalid. - pub fn process_frame(&self, frame_data: &[u8]) -> Option { - let conn_id = self.connection_id()?; - - let frame = EthernetFrame::parse(frame_data)?; - - let switch = self.switch.read().unwrap(); - let decision = switch.forward(conn_id, frame.source_mac(), frame.dest_mac()); - - Some(ProcessedFrame { - data: frame_data.to_vec(), - decision, - }) - } - - /// Strip the virtio-net header from frame data. - /// Returns the Ethernet frame without the header. - pub fn strip_virtio_header(data: &[u8]) -> Option<&[u8]> { - if data.len() < VIRTIO_NET_HDR_SIZE { - return None; - } - Some(&data[VIRTIO_NET_HDR_SIZE..]) - } - - /// Prepend a virtio-net header to frame data. - /// Returns the complete buffer for RX injection. - pub fn prepend_virtio_header(frame: &[u8]) -> Vec { - let mut result = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()]; - // Header is all zeros (basic operation, no offloading) - result[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame); - result - } - - /// Process frames from the TX queue. - /// - /// Reads frames, strips virtio header, gets forwarding decisions. - /// Returns frames with their forwarding decisions. - pub fn process_tx_queue(&self, vring: &VringRwLock) -> Vec { - use vm_memory::GuestMemoryLoadGuard; - - let mut results = Vec::new(); - - // Need connection ID and memory to process - if self.connection_id().is_none() { - trace!(vm = %self.name, "process_tx_queue: no connection_id"); - return results; - } - - let mem_guard = self.mem.lock().unwrap(); - let mem: GuestMemoryLoadGuard = match mem_guard.as_ref() { - Some(m) => m.memory(), - None => { - warn!(vm = %self.name, "process_tx_queue: no guest memory set!"); - return results; - } - }; - - // Collect all frames and head indices - let mut frames: Vec<(u16, Vec)> = Vec::new(); - { - let mut vring_state = vring.get_mut(); - let queue = vring_state.get_queue_mut(); - - trace!( - vm = %self.name, - queue_ready = queue.ready(), - queue_size = queue.size(), - next_avail = queue.next_avail(), - next_used = queue.next_used(), - "process_tx_queue: checking queue" - ); - - while let Some(desc_chain) = queue.pop_descriptor_chain(mem.clone()) { - let head_index = desc_chain.head_index(); - let mut raw_data = Vec::new(); - - // Read all descriptors in the chain - for desc in desc_chain { - let addr = desc.addr(); - let len = desc.len() as usize; - - let mut buf = vec![0u8; len]; - if let Err(e) = mem.read_slice(&mut buf, addr) { - warn!( - vm = %self.name, - head_index, - addr = ?addr, - len, - error = %e, - "process_tx_queue: failed to read descriptor" - ); - break; - } - raw_data.extend_from_slice(&buf); - } - - trace!( - vm = %self.name, - head_index, - raw_len = raw_data.len(), - "process_tx_queue: popped descriptor" - ); - frames.push((head_index, raw_data)); - } - } - - if frames.is_empty() { - trace!(vm = %self.name, "process_tx_queue: no frames in queue"); - } - - // Process frames and mark descriptors as used - for (head_index, raw_data) in &frames { - if let Err(e) = vring.add_used(*head_index, 0) { - warn!( - vm = %self.name, - head_index, - error = ?e, - "process_tx_queue: add_used failed" - ); - } else { - trace!( - vm = %self.name, - head_index, - "process_tx_queue: add_used ok" - ); - } - - // Strip header and process frame - if let Some(frame_data) = Self::strip_virtio_header(raw_data) { - if let Some(processed) = self.process_frame(frame_data) { - results.push(processed); - } - } - } - - // Re-enable notifications so guest will kick us when it adds more buffers. - // This is critical for EVENT_IDX: after draining the queue, we must tell - // the guest to notify us of new buffers, otherwise it will suppress kicks. - match vring.enable_notification() { - Ok(has_more) => { - if has_more { - trace!(vm = %self.name, "process_tx_queue: enable_notification returned has_more=true"); - } - } - Err(e) => { - warn!(vm = %self.name, error = ?e, "process_tx_queue: enable_notification failed"); - } - } - - // Signal guest that we've processed the queue - if !frames.is_empty() { - // Check if the call eventfd is set - { - let vring_state = vring.get_ref(); - let has_call = vring_state.get_call().is_some(); - if !has_call { - warn!( - vm = %self.name, - num_frames = frames.len(), - "process_tx_queue: no call eventfd set, cannot notify guest!" - ); - } else { - trace!( - vm = %self.name, - num_frames = frames.len(), - "process_tx_queue: call eventfd is set" - ); - } - } - - match vring.signal_used_queue() { - Ok(()) => trace!( - vm = %self.name, - num_frames = frames.len(), - "process_tx_queue: signal_used_queue ok" - ), - Err(e) => warn!( - vm = %self.name, - num_frames = frames.len(), - error = %e, - "process_tx_queue: signal_used_queue FAILED" - ), - } - } - - // Log final queue state - { - let vring_state = vring.get_ref(); - let queue = vring_state.get_queue(); - debug!( - vm = %self.name, - frames_processed = results.len(), - next_avail = queue.next_avail(), - next_used = queue.next_used(), - "process_tx_queue complete" - ); - } - - results - } - - /// Inject a frame into the RX queue. - /// - /// Prepends virtio header and writes using scatter-gather across descriptors. - /// Returns true if successful, false if queue is full or insufficient space. - /// - /// Note: This is a static method that takes the destination VM's memory mapping. - /// This is important when injecting frames into a different VM's RX queue - - /// we must use that VM's memory mapping, not our own. - pub fn inject_rx_frame(vring: &VringRwLock, mem: &GuestMemoryAtomic, frame: &[u8]) -> bool { - use vm_memory::GuestMemoryLoadGuard; - - let mem: GuestMemoryLoadGuard = mem.memory(); - - let data_to_write = Self::prepend_virtio_header(frame); - let total_len = data_to_write.len(); - - let head_index; - let written; - { - let mut vring_state = vring.get_mut(); - let queue = vring_state.get_queue_mut(); - - trace!( - queue_ready = queue.ready(), - queue_size = queue.size(), - frame_len = frame.len(), - total_len, - "inject_rx_frame: checking queue" - ); - - let desc_chain = match queue.pop_descriptor_chain(mem.clone()) { - Some(chain) => chain, - None => { - trace!("inject_rx_frame: no descriptor available (queue full)"); - return false; - } - }; - - head_index = desc_chain.head_index(); - - // First pass: collect writable descriptors and calculate available space - let mut writable_descs = Vec::new(); - for desc in desc_chain { - if desc.is_write_only() { - writable_descs.push((desc.addr(), desc.len() as usize)); - } - } - - let available_space: usize = writable_descs.iter().map(|(_, len)| *len).sum(); - - trace!( - head_index, - num_writable_descs = writable_descs.len(), - available_space, - total_len, - "inject_rx_frame: descriptor chain" - ); - - if available_space < total_len { - // Insufficient space - don't write partial data - warn!( - head_index, - available_space, - total_len, - "inject_rx_frame: insufficient space in descriptors" - ); - written = 0; - } else { - // Second pass: scatter-gather write across all descriptors - let mut bytes_written = 0; - for (addr, len) in writable_descs { - let remaining = total_len - bytes_written; - if remaining == 0 { - break; - } - let to_write = std::cmp::min(remaining, len); - - if let Err(e) = mem.write_slice(&data_to_write[bytes_written..bytes_written + to_write], addr) { - warn!( - head_index, - addr = ?addr, - to_write, - error = %e, - "inject_rx_frame: write_slice failed" - ); - break; - } - bytes_written += to_write; - } - written = bytes_written; - } - } - - if let Err(e) = vring.add_used(head_index, written as u32) { - warn!( - head_index, - written, - error = ?e, - "inject_rx_frame: add_used failed" - ); - } - - // Re-enable notifications so guest knows to provide more RX buffers - if let Err(e) = vring.enable_notification() { - warn!( - head_index, - error = ?e, - "inject_rx_frame: enable_notification failed" - ); - } - - if let Err(e) = vring.signal_used_queue() { - warn!( - head_index, - error = %e, - "inject_rx_frame: signal_used_queue failed" - ); - } - - let success = written >= total_len; - trace!( - head_index, - written, - total_len, - success, - "inject_rx_frame complete" - ); - success - } -} - -impl VhostUserBackend for NetBackend { - type Bitmap = (); - type Vring = VringRwLock; - - fn num_queues(&self) -> usize { - NUM_QUEUES - } - - fn max_queue_size(&self) -> usize { - MAX_QUEUE_SIZE - } - - fn features(&self) -> u64 { - let features = VIRTIO_NET_FEATURES - | (1 << VIRTIO_F_VERSION_1) - | (1 << VIRTIO_RING_F_EVENT_IDX) - | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); - trace!(vm = %self.name, features = format!("{:#x}", features), "features requested"); - features - } - - fn protocol_features(&self) -> VhostUserProtocolFeatures { - let proto = VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ; - trace!(vm = %self.name, protocol_features = ?proto, "protocol_features requested"); - proto - } - - fn set_event_idx(&self, enabled: bool) { - debug!(vm = %self.name, enabled, "set_event_idx"); - *self.event_idx.lock().unwrap() = enabled; - } - - fn update_memory( - &self, - mem: GuestMemoryAtomic, - ) -> std::io::Result<()> { - debug!(vm = %self.name, "update_memory called"); - *self.mem.lock().unwrap() = Some(mem); - Ok(()) - } - - fn handle_event( - &self, - device_event: u16, - _evset: EventSet, - vrings: &[VringRwLock], - _thread_id: usize, - ) -> std::io::Result<()> { - // Note: read_kick() is already called by VringEpollHandler before invoking this method. - // We do not call it again to avoid blocking on a drained eventfd. - - trace!( - vm = %self.name, - device_event, - queue_name = if device_event == RX_QUEUE { "RX" } else if device_event == TX_QUEUE { "TX" } else { "?" }, - "handle_event called" - ); - - // Validate event index - if (device_event as usize) >= vrings.len() { - debug!(device_event, vrings_len = vrings.len(), "ignoring out-of-range device event"); - return Ok(()); - } - - // Get our connection ID - let conn_id = match self.connection_id() { - Some(id) => id, - None => { - debug!(vm = %self.name, "handle_event: no connection_id, ignoring"); - return Ok(()); - } - }; - - // Register our RX vring in the shared registry (if not already done) - // This allows other backends to inject frames into our RX queue - if vrings.len() > RX_QUEUE as usize { - let mut registry = self.vring_registry.write().unwrap(); - if !registry.contains_key(&conn_id) { - // Clone the memory for registry storage - let mem_guard = self.mem.lock().unwrap(); - if let Some(mem) = mem_guard.clone() { - debug!( - vm = %self.name, - conn_id = ?conn_id, - "registering RX vring in registry" - ); - registry.insert(conn_id, (vrings[RX_QUEUE as usize].clone(), mem)); - } else { - warn!( - vm = %self.name, - conn_id = ?conn_id, - "cannot register RX vring: no guest memory set" - ); - } - } - } - - // Only process TX queue kicks - if device_event != TX_QUEUE { - trace!(vm = %self.name, device_event, "ignoring non-TX queue event"); - return Ok(()); - } - - // Process frames from our TX queue - let tx_vring = &vrings[TX_QUEUE as usize]; - let processed_frames = self.process_tx_queue(tx_vring); - - if processed_frames.is_empty() { - trace!(vm = %self.name, "handle_event: no frames processed from TX queue"); - return Ok(()); - } - - // Route each frame to its destination(s) - let registry = self.vring_registry.read().unwrap(); - let mut routed = 0; - let mut dropped = 0; - - trace!( - vm = %self.name, - num_frames = processed_frames.len(), - registry_size = registry.len(), - "routing frames" - ); - - for processed in processed_frames { - match &processed.decision { - ForwardDecision::Unicast(dest_id) => { - if let Some((rx_vring, dest_mem)) = registry.get(dest_id) { - if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) { - routed += 1; - } else { - dropped += 1; - debug!( - src = %self.name, - dest_id = ?dest_id, - "RX queue full, dropping frame" - ); - } - } else { - dropped += 1; - debug!( - src = %self.name, - dest_id = ?dest_id, - "destination not in registry, dropping frame" - ); - } - } - ForwardDecision::Multicast(dest_ids) => { - for dest_id in dest_ids { - if let Some((rx_vring, dest_mem)) = registry.get(dest_id) { - if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) { - routed += 1; - } else { - dropped += 1; - debug!( - src = %self.name, - dest_id = ?dest_id, - "RX queue full, dropping frame" - ); - } - } else { - dropped += 1; - debug!( - src = %self.name, - dest_id = ?dest_id, - "destination not in registry (multicast), dropping frame" - ); - } - } - } - ForwardDecision::Drop(reason) => { - dropped += 1; - debug!(src = %self.name, ?reason, "Dropping frame per switch decision"); - } - } - } - - trace!(vm = %self.name, routed, dropped, "handle_event complete"); - - Ok(()) - } - - fn acked_features(&self, features: u64) { - debug!(vm = %self.name, features = format!("{:#x}", features), "acked_features"); - *self.acked_features.lock().unwrap() = features; - } - - fn get_config(&self, offset: u32, size: u32) -> Vec { - // Virtio net config: MAC (6 bytes) + status (2) + max_virtqueue_pairs (2) - let mut config = [0u8; 10]; - config[0..6].copy_from_slice(&self.mac.bytes()); - config[6] = 1; // VIRTIO_NET_S_LINK_UP - config[8] = 1; // max_virtqueue_pairs = 1 - let config = config; - - let offset = offset as usize; - let size = size as usize; - if offset < config.len() { - let end = std::cmp::min(offset + size, config.len()); - let mut result = config[offset..end].to_vec(); - // Pad with zeros if requested more than available - result.resize(size, 0); - result - } else { - vec![0u8; size] - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::switch::Switch; - - fn make_switch() -> Arc> { - Arc::new(RwLock::new(Switch::new())) - } - - fn make_vring_registry() -> VringRegistry { - Arc::new(RwLock::new(HashMap::new())) - } - - fn make_backend() -> NetBackend { - NetBackend::new( - "test".to_string(), - VmRole::Client, - Mac::from_bytes([1, 2, 3, 4, 5, 6]), - make_switch(), - make_vring_registry(), - ) - } - - #[test] - fn num_queues_returns_two() { - let backend = make_backend(); - assert_eq!(backend.num_queues(), 2); - } - - #[test] - fn max_queue_size_returns_max() { - let backend = make_backend(); - assert_eq!(backend.max_queue_size(), MAX_QUEUE_SIZE); - } - - #[test] - fn register_assigns_connection_id() { - let backend = make_backend(); - - assert!(backend.connection_id().is_none()); - let id = backend.register(); - assert!(id.is_some()); - assert_eq!(backend.connection_id(), id); - } - - #[test] - fn unregister_clears_connection_id() { - let backend = make_backend(); - - backend.register(); - assert!(backend.connection_id().is_some()); - - backend.unregister(); - assert!(backend.connection_id().is_none()); - } - - #[test] - fn duplicate_router_returns_none() { - let switch = make_switch(); - let registry = make_vring_registry(); - - let router1 = NetBackend::new( - "router1".to_string(), - VmRole::Router, - Mac::from_bytes([1, 0, 0, 0, 0, 1]), - Arc::clone(&switch), - Arc::clone(®istry), - ); - let router2 = NetBackend::new( - "router2".to_string(), - VmRole::Router, - Mac::from_bytes([2, 0, 0, 0, 0, 2]), - Arc::clone(&switch), - Arc::clone(®istry), - ); - - assert!(router1.register().is_some()); - assert!(router2.register().is_none()); - } - - #[test] - fn get_config_returns_mac_at_offset_zero() { - let switch = make_switch(); - let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); - let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry()); - - let config = backend.get_config(0, 6); - assert_eq!(config, vec![0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); - } - - #[test] - fn get_config_partial_read() { - let switch = make_switch(); - let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); - let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry()); - - // Read bytes 2-4 of MAC - let config = backend.get_config(2, 3); - assert_eq!(config, vec![0xcc, 0xdd, 0xee]); - } - - fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec { - let mut frame = vec![0u8; 14]; - frame[0..6].copy_from_slice(&dest); - frame[6..12].copy_from_slice(&src); - frame - } - - #[test] - fn process_frame_returns_none_when_unregistered() { - let backend = make_backend(); - // Don't register - let frame = make_frame([0xff; 6], [1, 2, 3, 4, 5, 6]); - assert!(backend.process_frame(&frame).is_none()); - } - - #[test] - fn process_frame_returns_decision_when_registered() { - let switch = make_switch(); - let registry = make_vring_registry(); - - // Register router first (clients need a router to forward to) - let router = NetBackend::new( - "router".to_string(), - VmRole::Router, - Mac::from_bytes([0xaa, 0, 0, 0, 0, 1]), - Arc::clone(&switch), - Arc::clone(®istry), - ); - router.register(); - - // Register client - let client = NetBackend::new( - "client".to_string(), - VmRole::Client, - Mac::from_bytes([1, 2, 3, 4, 5, 6]), - Arc::clone(&switch), - Arc::clone(®istry), - ); - client.register(); - - // Client sends to router - let frame = make_frame([0xaa, 0, 0, 0, 0, 1], [1, 2, 3, 4, 5, 6]); - let result = client.process_frame(&frame); - - assert!(result.is_some()); - let processed = result.unwrap(); - assert_eq!(processed.data, frame); - assert!(matches!(processed.decision, ForwardDecision::Unicast(_))); - } - - #[test] - fn process_frame_returns_none_for_invalid_frame() { - let switch = make_switch(); - let backend = NetBackend::new( - "test".to_string(), - VmRole::Client, - Mac::from_bytes([1, 2, 3, 4, 5, 6]), - switch, - make_vring_registry(), - ); - backend.register(); - - // Frame too small (< 14 bytes) - let frame = vec![0u8; 10]; - assert!(backend.process_frame(&frame).is_none()); - } - - #[test] - fn strip_virtio_header_returns_frame() { - // 12-byte header + 14-byte Ethernet frame - let mut data = vec![0u8; 26]; - data[12..18].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dest MAC - - let frame = NetBackend::strip_virtio_header(&data).unwrap(); - assert_eq!(frame.len(), 14); - assert_eq!(&frame[0..6], &[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); - } - - #[test] - fn strip_virtio_header_returns_none_if_too_small() { - // Less than 12 bytes - let data = vec![0u8; 10]; - assert!(NetBackend::strip_virtio_header(&data).is_none()); - } - - #[test] - fn strip_virtio_header_returns_empty_slice_if_exact() { - // Exactly 12 bytes (header only, no frame) - let data = vec![0u8; 12]; - let frame = NetBackend::strip_virtio_header(&data).unwrap(); - assert!(frame.is_empty()); - } - - #[test] - fn prepend_virtio_header_adds_12_bytes() { - let frame = vec![0xaa, 0xbb, 0xcc, 0xdd]; - let result = NetBackend::prepend_virtio_header(&frame); - - assert_eq!(result.len(), 12 + 4); - assert_eq!(&result[0..12], &[0u8; 12]); // Header is zeros - assert_eq!(&result[12..], &[0xaa, 0xbb, 0xcc, 0xdd]); - } - - #[test] - fn prepend_virtio_header_empty_frame() { - let frame: Vec = vec![]; - let result = NetBackend::prepend_virtio_header(&frame); - - assert_eq!(result.len(), 12); - assert_eq!(&result[..], &[0u8; 12]); - } -} diff --git a/vm-switch/src/child/forwarder.rs b/vm-switch/src/child/forwarder.rs new file mode 100644 index 0000000..028d9f5 --- /dev/null +++ b/vm-switch/src/child/forwarder.rs @@ -0,0 +1,370 @@ +//! Packet forwarding logic for child processes. + +use std::collections::HashMap; +use std::os::fd::{AsRawFd, RawFd}; + +use tracing::{debug, info, trace}; + +use crate::frame::validate_source_mac; +use crate::mac::Mac; +use crate::ring::{Consumer, Producer}; + +/// Ingress buffer from a peer (they produce, we consume). +struct IngressBuffer { + peer_mac: [u8; 6], + consumer: Consumer, +} + +/// Egress buffer to a peer (we produce, they consume). +struct EgressBuffer { + peer_mac: [u8; 6], + producer: Producer, + /// If true, this buffer accepts broadcast/multicast traffic. + broadcast: bool, +} + +/// Manages packet forwarding for a child process. +pub struct PacketForwarder { + /// This child's MAC address. + our_mac: Mac, + /// Ingress buffers FROM peers (they produce, we consume). Keyed by peer name. + ingress: HashMap, + /// Egress buffers TO peers (we produce, they consume). Keyed by peer name. + egress: HashMap, +} + +impl PacketForwarder { + /// Create a new packet forwarder. + pub fn new(our_mac: Mac) -> Self { + Self { + our_mac, + ingress: HashMap::new(), + egress: HashMap::new(), + } + } + + /// Add an ingress buffer from a peer (we consume packets they produce). + pub fn add_ingress(&mut self, peer_name: String, peer_mac: [u8; 6], consumer: Consumer) { + info!(peer = %peer_name, "added ingress buffer from peer"); + self.ingress.insert( + peer_name, + IngressBuffer { peer_mac, consumer }, + ); + } + + /// Add an egress buffer to a peer (we produce packets they consume). + pub fn add_egress( + &mut self, + peer_name: String, + peer_mac: [u8; 6], + producer: Producer, + broadcast: bool, + ) { + info!(peer = %peer_name, broadcast, "added egress buffer to peer"); + self.egress.insert( + peer_name, + EgressBuffer { + peer_mac, + producer, + broadcast, + }, + ); + } + + /// Remove all buffers for a peer. + pub fn remove_peer(&mut self, peer_name: &str) { + if self.ingress.remove(peer_name).is_some() { + info!(peer = %peer_name, "removed ingress buffer"); + } + if self.egress.remove(peer_name).is_some() { + info!(peer = %peer_name, "removed egress buffer"); + } + } + + /// Get eventfds for all ingress consumers (for polling). + pub fn ingress_eventfds(&self) -> Vec<(RawFd, [u8; 6])> { + self.ingress + .values() + .map(|buf| (buf.consumer.eventfd().as_raw_fd(), buf.peer_mac)) + .collect() + } + + /// Forward a TX frame to peers based on destination MAC. + /// + /// - Broadcast/multicast: sent to all egress buffers with broadcast=true + /// - Unicast: sent to the egress buffer matching the destination MAC + pub fn forward_tx(&self, frame: &[u8]) -> bool { + // Validate source MAC + if !validate_source_mac(frame, self.our_mac) { + debug!( + reason = "source MAC rejected", + our_mac = %self.our_mac, + frame_src = %Mac::from_bytes(frame[6..12].try_into().unwrap()), + size = frame.len(), + "TX: dropped" + ); + return false; + } + + let dest_mac: [u8; 6] = frame[0..6].try_into().unwrap(); + let dest = Mac::from_bytes(dest_mac); + + // Broadcast/multicast: send to all egress buffers with broadcast=true + if dest.is_broadcast() || dest.is_multicast() { + let mut sent = false; + for (peer_name, egress) in &self.egress { + if egress.broadcast { + if egress.producer.push(frame) { + trace!( + to = %peer_name, + mac = %Mac::from_bytes(egress.peer_mac), + size = frame.len(), + "TX: broadcast" + ); + sent = true; + } else { + debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: broadcast dropped"); + } + } + } + if !sent { + debug!(reason = "no broadcast peers", size = frame.len(), "TX: dropped"); + } + return sent; + } + + // Unicast: find egress buffer by destination MAC + for (peer_name, egress) in &self.egress { + if egress.peer_mac == dest_mac { + if egress.producer.push(frame) { + trace!( + to = %peer_name, + mac = %Mac::from_bytes(egress.peer_mac), + size = frame.len(), + "TX: pushed to egress buffer" + ); + return true; + } else { + debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: dropped"); + return false; + } + } + } + + // Unknown destination + debug!( + reason = "unknown destination", + dest_mac = %dest, + size = frame.len(), + "TX: dropped" + ); + false + } + + /// Poll all ingress buffers and return received frames. + /// + /// Validates source MAC matches the expected peer for each buffer. + pub fn poll_ingress(&self) -> Vec> { + let mut frames = Vec::new(); + + for (peer_name, ingress) in &self.ingress { + // Drain the eventfd + ingress.consumer.drain_eventfd(); + + // Pop all available frames + while let Some(frame) = ingress.consumer.pop() { + // Validate source MAC matches expected peer + if !validate_source_mac(&frame, Mac::from_bytes(ingress.peer_mac)) { + debug!( + reason = "source MAC mismatch", + from = %peer_name, + expected = %Mac::from_bytes(ingress.peer_mac), + actual = %Mac::from_bytes(frame[6..12].try_into().unwrap_or([0; 6])), + size = frame.len(), + "RX: dropped" + ); + continue; + } + trace!( + from = %peer_name, + mac = %Mac::from_bytes(ingress.peer_mac), + size = frame.len(), + "RX: read from ingress buffer" + ); + frames.push(frame); + } + } + + frames + } + + /// Get number of configured ingress peers. + pub fn ingress_count(&self) -> usize { + self.ingress.len() + } + + /// Get number of configured egress peers. + pub fn egress_count(&self) -> usize { + self.egress.len() + } + + /// Get our MAC address. + pub fn our_mac(&self) -> Mac { + self.our_mac + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::os::fd::{FromRawFd, OwnedFd}; + + fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec { + let mut frame = vec![0u8; 14]; + frame[0..6].copy_from_slice(&dest); + frame[6..12].copy_from_slice(&src); + frame + } + + #[test] + fn forward_tx_validates_source_mac() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let forwarder = PacketForwarder::new(our_mac); + + // Frame with wrong source MAC - should be dropped + let frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]); + assert!(!forwarder.forward_tx(&frame)); + } + + #[test] + fn forward_tx_drops_when_no_egress() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let forwarder = PacketForwarder::new(our_mac); + + // Frame with correct source MAC but no egress peers + let frame = make_frame([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], our_mac.bytes()); + assert!(!forwarder.forward_tx(&frame)); + } + + #[test] + fn forward_tx_unicast_to_matching_peer() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let mut forwarder = PacketForwarder::new(our_mac); + + // Add egress to peer with specific MAC + let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]; + let consumer = Consumer::new().expect("consumer"); + let producer = Producer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.eventfd().as_raw_fd())) }, + ) + .expect("producer"); + forwarder.add_egress("router".to_string(), peer_mac, producer, false); + + // Unicast frame to that MAC should succeed + let frame = make_frame(peer_mac, our_mac.bytes()); + assert!(forwarder.forward_tx(&frame)); + + // Consumer should receive the frame + consumer.drain_eventfd(); + let received = consumer.pop().expect("should have frame"); + assert_eq!(received, frame); + } + + #[test] + fn forward_tx_broadcast_to_broadcast_peers() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let mut forwarder = PacketForwarder::new(our_mac); + + // Add broadcast-enabled egress + let consumer1 = Consumer::new().expect("consumer1"); + let producer1 = Producer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.eventfd().as_raw_fd())) }, + ) + .expect("producer1"); + forwarder.add_egress("router".to_string(), [0x11; 6], producer1, true); + + // Add non-broadcast egress + let consumer2 = Consumer::new().expect("consumer2"); + let producer2 = Producer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) }, + ) + .expect("producer2"); + forwarder.add_egress("client_a".to_string(), [0x22; 6], producer2, false); + + // Broadcast frame + let frame = make_frame([0xff; 6], our_mac.bytes()); + assert!(forwarder.forward_tx(&frame)); + + // Only broadcast-enabled peer should receive + consumer1.drain_eventfd(); + assert!(consumer1.pop().is_some()); + + consumer2.drain_eventfd(); + assert!(consumer2.pop().is_none()); + } + + #[test] + fn poll_ingress_validates_source_mac() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let mut forwarder = PacketForwarder::new(our_mac); + + // Create producer/consumer pair - producer simulates peer sending to us + let producer = Producer::new().expect("producer"); + let consumer = Consumer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) }, + ) + .expect("consumer"); + + let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]; + forwarder.add_ingress("router".to_string(), peer_mac, consumer); + + // Push frame with correct source MAC + let good_frame = make_frame(our_mac.bytes(), peer_mac); + producer.push(&good_frame); + + // Push frame with wrong source MAC + let bad_frame = make_frame(our_mac.bytes(), [0x99; 6]); + producer.push(&bad_frame); + + // Only good frame should be returned + let frames = forwarder.poll_ingress(); + assert_eq!(frames.len(), 1); + assert_eq!(frames[0], good_frame); + } + + #[test] + fn remove_peer_cleans_up_buffers() { + let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); + let mut forwarder = PacketForwarder::new(our_mac); + + // Add ingress + let producer = Producer::new().expect("producer"); + let consumer = Consumer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) }, + ) + .expect("consumer"); + forwarder.add_ingress("router".to_string(), [0x11; 6], consumer); + + // Add egress + let consumer2 = Consumer::new().expect("consumer2"); + let producer2 = Producer::from_fds( + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) }, + unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) }, + ) + .expect("producer2"); + forwarder.add_egress("router".to_string(), [0x11; 6], producer2, true); + + assert_eq!(forwarder.ingress_count(), 1); + assert_eq!(forwarder.egress_count(), 1); + + forwarder.remove_peer("router"); + + assert_eq!(forwarder.ingress_count(), 0); + assert_eq!(forwarder.egress_count(), 0); + } +} diff --git a/vm-switch/src/child/mod.rs b/vm-switch/src/child/mod.rs new file mode 100644 index 0000000..7d7b7e3 --- /dev/null +++ b/vm-switch/src/child/mod.rs @@ -0,0 +1,14 @@ +//! Child process entry point for VM backends. +//! +//! Each VM runs in its own forked child process, communicating with +//! the main process via a control channel. + +pub mod forwarder; +pub mod poll; +pub mod process; +pub mod vhost; + +pub use forwarder::PacketForwarder; +pub use poll::{poll_events, PollResult}; +pub use process::run_child_process; +pub use vhost::ChildVhostBackend; diff --git a/vm-switch/src/child/poll.rs b/vm-switch/src/child/poll.rs new file mode 100644 index 0000000..c60c230 --- /dev/null +++ b/vm-switch/src/child/poll.rs @@ -0,0 +1,115 @@ +//! Event polling for child processes. + +use std::os::fd::{BorrowedFd, RawFd}; + +use nix::poll::{poll, PollFd, PollFlags, PollTimeout}; +use nix::Error as NixError; + +/// Result of polling for events. +#[derive(Debug)] +pub enum PollResult { + /// Control channel has data. + Control, + /// One or more ingress buffers have data. + Ingress(Vec<[u8; 6]>), + /// Second FD slot has an event (POLLIN/POLLHUP/POLLERR). + /// Used for daemon exit pipe detection. + TxKick, + /// Timeout expired with no events. + Timeout, + /// Error occurred. + Error(NixError), +} + +/// Poll for events on control channel, TX kick, and ingress eventfds. +/// +/// # Safety +/// All raw file descriptors passed in must be valid and open for the duration of this call. +pub fn poll_events( + control_fd: RawFd, + tx_kick_fd: Option, + ingress_fds: &[(RawFd, [u8; 6])], + timeout_ms: i32, +) -> PollResult { + // SAFETY: caller guarantees fds are valid for the duration of this call + let control_borrowed = unsafe { BorrowedFd::borrow_raw(control_fd) }; + + let mut pollfds: Vec = Vec::with_capacity(2 + ingress_fds.len()); + + // Control channel is first + pollfds.push(PollFd::new(control_borrowed, PollFlags::POLLIN)); + + // TX kick eventfd is second (if present) + let tx_index = if let Some(fd) = tx_kick_fd { + let tx_borrowed = unsafe { BorrowedFd::borrow_raw(fd) }; + pollfds.push(PollFd::new(tx_borrowed, PollFlags::POLLIN)); + Some(1) + } else { + None + }; + + // Ingress eventfds follow + let ingress_start = pollfds.len(); + for (fd, _mac) in ingress_fds { + let ingress_borrowed = unsafe { BorrowedFd::borrow_raw(*fd) }; + pollfds.push(PollFd::new(ingress_borrowed, PollFlags::POLLIN)); + } + + let timeout = if timeout_ms < 0 { + PollTimeout::NONE + } else if timeout_ms > u16::MAX as i32 { + PollTimeout::MAX + } else { + PollTimeout::from(timeout_ms as u16) + }; + + match poll(&mut pollfds, timeout) { + Ok(0) => PollResult::Timeout, + Ok(_) => { + // Check control channel first (priority) + if let Some(revents) = pollfds[0].revents() { + if revents.contains(PollFlags::POLLIN) + || revents.contains(PollFlags::POLLHUP) + || revents.contains(PollFlags::POLLERR) + { + return PollResult::Control; + } + } + + // Check TX kick / daemon exit pipe + if let Some(idx) = tx_index { + if let Some(revents) = pollfds[idx].revents() { + if revents.contains(PollFlags::POLLIN) + || revents.contains(PollFlags::POLLHUP) + || revents.contains(PollFlags::POLLERR) + { + return PollResult::TxKick; + } + } + } + + // Check ingress eventfds + let mut ready_macs = Vec::new(); + for (i, (_, mac)) in ingress_fds.iter().enumerate() { + if let Some(revents) = pollfds[ingress_start + i].revents() { + if revents.contains(PollFlags::POLLIN) { + ready_macs.push(*mac); + } + } + } + + if ready_macs.is_empty() { + PollResult::Timeout + } else { + PollResult::Ingress(ready_macs) + } + } + Err(e) => { + if e == NixError::EINTR { + PollResult::Timeout + } else { + PollResult::Error(e) + } + } + } +} diff --git a/vm-switch/src/child/process.rs b/vm-switch/src/child/process.rs new file mode 100644 index 0000000..35d46ce --- /dev/null +++ b/vm-switch/src/child/process.rs @@ -0,0 +1,239 @@ +//! Child process main loop. + +use std::os::fd::{AsRawFd, OwnedFd, RawFd}; +use std::os::unix::net::UnixListener; +use std::path::Path; +use std::sync::{Arc, Mutex}; +use std::thread; + +use nix::unistd::pipe; +use tracing::{debug, error, info, warn}; +use vhost_user_backend::VhostUserDaemon; +use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; + +use crate::control::{ChildToMain, ControlChannel, ControlError, MainToChild}; +use crate::mac::Mac; +use crate::ring::{Consumer, Producer}; +use crate::seccomp::{apply_child_seccomp, SeccompMode}; + +use super::forwarder::PacketForwarder; +use super::poll::{poll_events, PollResult}; +use super::vhost::ChildVhostBackend; + +/// Run the child process. +/// +/// This is the entry point after fork(). Does not return. +pub fn run_child_process( + vm_name: &str, + mac: Mac, + control_fd: OwnedFd, + socket_path: &Path, + seccomp_mode: SeccompMode, +) -> ! { + // Set process name for log prefix before any logging + crate::args::set_process_name(format!("worker-{}", vm_name)); + + info!(vm = %vm_name, mac = %mac, socket = ?socket_path, "child starting"); + + // Reconstruct control channel from owned fd + let control = ControlChannel::from_fd(control_fd); + + // Send Ready to main + let msg = ChildToMain::Ready; + if let Err(e) = control.send(&msg) { + error!(vm = %vm_name, error = %e, "failed to send Ready"); + std::process::exit(1) + } + debug!("control: worker-{} -> main Ready", vm_name); + + // Create packet forwarder + let forwarder = Arc::new(Mutex::new(PacketForwarder::new(mac))); + + // Create vhost backend + let backend = ChildVhostBackend::new(vm_name.to_string(), mac); + + // Set TX callback + let fwd = Arc::clone(&forwarder); + backend.set_tx_callback(Box::new(move |frame| { + fwd.lock().unwrap().forward_tx(frame); + })); + + // Create vhost socket + if socket_path.exists() { + let _ = std::fs::remove_file(socket_path); + } + let listener = match UnixListener::bind(socket_path) { + Ok(l) => l, + Err(e) => { + error!(vm = %vm_name, error = %e, "failed to bind socket"); + std::process::exit(1) + } + }; + let _ = listener.set_nonblocking(true); + + // Start vhost daemon thread + let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new()); + let mut daemon = match VhostUserDaemon::new(vm_name.to_string(), backend.clone(), mem) { + Ok(d) => d, + Err(e) => { + error!(vm = %vm_name, error = %e, "failed to create daemon"); + std::process::exit(1) + } + }; + + // Create pipe to detect daemon thread exit. The write end is moved into + // the daemon thread; when the thread exits for any reason, the write end + // is dropped, causing POLLHUP on the read end. + let (pipe_rd, pipe_wr) = match pipe() { + Ok((rd, wr)) => (rd, wr), + Err(e) => { + error!(vm = %vm_name, error = %e, "failed to create pipe"); + std::process::exit(1) + } + }; + + let vhost_listener = vhost::vhost_user::Listener::from(listener); + let name = vm_name.to_string(); + thread::spawn(move || { + let _pipe_wr = pipe_wr; // dropped on thread exit → POLLHUP on read end + let mut l = vhost_listener; + if let Err(e) = daemon.start(&mut l) { + warn!(vm = %name, error = %e, "daemon start failed"); + return; + } + if let Err(e) = daemon.wait() { + debug!(vm = %name, error = %e, "daemon wait returned error"); + } + }); + + // Apply seccomp filter now that setup is complete + // (socket created, thread spawned, signals configured) + if let Err(e) = apply_child_seccomp(seccomp_mode) { + error!(vm = %vm_name, error = %e, "failed to apply seccomp"); + std::process::exit(1); + } + if seccomp_mode != SeccompMode::Disabled { + debug!(vm = %vm_name, mode = ?seccomp_mode, "seccomp filter applied"); + } + + // Main event loop + let daemon_exit_fd = pipe_rd.as_raw_fd(); + match event_loop(vm_name, control, forwarder, backend, daemon_exit_fd) { + Ok(()) => { + info!(vm = %vm_name, "exiting normally"); + std::process::exit(0) + } + Err(e) => { + error!(vm = %vm_name, error = %e, "exiting with error"); + std::process::exit(1) + } + } +} + +fn event_loop( + vm_name: &str, + control: ControlChannel, + forwarder: Arc>, + backend: Arc, + daemon_exit_fd: RawFd, +) -> Result<(), ControlError> { + let control_fd = control.as_raw_fd(); + + loop { + let ingress_fds = forwarder.lock().unwrap().ingress_eventfds(); + + match poll_events(control_fd, Some(daemon_exit_fd), &ingress_fds, 100) { + PollResult::TxKick => { + // Daemon thread exited (pipe write end closed → POLLHUP) + info!(vm = %vm_name, "vhost daemon exited, shutting down"); + return Ok(()); + } + PollResult::Control => { + let (msg, fds) = match control.recv_with_fds_typed() { + Ok(r) => r, + Err(ControlError::Closed) => { + debug!(vm = %vm_name, "control closed"); + return Ok(()); + } + Err(e) => return Err(e), + }; + + match msg { + MainToChild::GetBuffer { peer_name, peer_mac } => { + debug!( + "control: main -> worker-{} GetBuffer({}, {})", + vm_name, peer_name, Mac::from_bytes(peer_mac) + ); + + // Create ingress buffer (we are Consumer) + match Consumer::new() { + Ok(consumer) => { + let response_fds = [ + consumer.memfd().as_raw_fd(), + consumer.eventfd().as_raw_fd(), + ]; + let response = ChildToMain::BufferReady { + peer_name: peer_name.clone(), + }; + if let Err(e) = control.send_with_fds_typed(&response, &response_fds) { + warn!(vm = %vm_name, error = %e, "failed to send BufferReady"); + } else { + debug!( + "control: worker-{} -> main BufferReady({})", + vm_name, peer_name + ); + forwarder.lock().unwrap().add_ingress(peer_name, peer_mac, consumer); + } + } + Err(e) => warn!(vm = %vm_name, error = %e, "failed to create ingress buffer"), + } + } + + MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => { + debug!( + "control: main -> worker-{} PutBuffer({}, {}, broadcast={})", + vm_name, peer_name, Mac::from_bytes(peer_mac), broadcast + ); + + if fds.len() == 2 { + let mut fds = fds.into_iter(); + match Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()) { + Ok(producer) => { + forwarder.lock().unwrap().add_egress( + peer_name, + peer_mac, + producer, + broadcast, + ); + } + Err(e) => warn!(vm = %vm_name, error = %e, "failed to map egress buffer"), + } + } else { + warn!(vm = %vm_name, "PutBuffer with wrong number of FDs: {}", fds.len()); + } + } + + MainToChild::RemovePeer { peer_name } => { + debug!("control: main -> worker-{} RemovePeer({})", vm_name, peer_name); + forwarder.lock().unwrap().remove_peer(&peer_name); + } + + MainToChild::Ping => { + control.send(&ChildToMain::Pong)?; + } + } + } + PollResult::Ingress(_) | PollResult::Timeout => { + let frames = forwarder.lock().unwrap().poll_ingress(); + for frame in frames { + if !backend.inject_rx_frame(&frame) { + debug!(vm = %vm_name, "RX inject failed (queue full)"); + } + } + } + PollResult::Error(e) => { + warn!(vm = %vm_name, error = ?e, "poll error"); + } + } + } +} diff --git a/vm-switch/src/child/vhost.rs b/vm-switch/src/child/vhost.rs new file mode 100644 index 0000000..86214ca --- /dev/null +++ b/vm-switch/src/child/vhost.rs @@ -0,0 +1,283 @@ +//! Vhost-user backend for child processes. + +use std::sync::{Arc, Mutex}; + +use tracing::{debug, warn}; +use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures}; +use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT}; +use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1; +use virtio_bindings::virtio_net::{ + VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4, + VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4, + VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS, +}; +use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX; +use virtio_queue::QueueT; +use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap}; +use vmm_sys_util::epoll::EventSet; + +use crate::mac::Mac; + +/// RX queue index. +pub const RX_QUEUE: u16 = 0; +/// TX queue index. +pub const TX_QUEUE: u16 = 1; +/// Number of queues. +pub const NUM_QUEUES: usize = 2; +/// Maximum queue size. +pub const MAX_QUEUE_SIZE: usize = 32768; +/// Virtio-net header size. +pub const VIRTIO_NET_HDR_SIZE: usize = 12; + +/// Virtio net features. +const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM) + | (1 << VIRTIO_NET_F_GUEST_CSUM) + | (1 << VIRTIO_NET_F_GUEST_TSO4) + | (1 << VIRTIO_NET_F_GUEST_TSO6) + | (1 << VIRTIO_NET_F_GUEST_UFO) + | (1 << VIRTIO_NET_F_HOST_TSO4) + | (1 << VIRTIO_NET_F_HOST_TSO6) + | (1 << VIRTIO_NET_F_HOST_UFO) + | (1 << VIRTIO_NET_F_MAC) + | (1 << VIRTIO_NET_F_STATUS); + +/// Callback type for TX frames. +pub type TxCallback = Box; + +/// Child's vhost-user backend. +pub struct ChildVhostBackend { + name: String, + mac: Mac, + mem: Mutex>>, + tx_callback: Mutex>, + rx_vring: Mutex>, +} + +impl std::fmt::Debug for ChildVhostBackend { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ChildVhostBackend") + .field("name", &self.name) + .field("mac", &self.mac) + .finish_non_exhaustive() + } +} + +impl ChildVhostBackend { + /// Create a new backend. + pub fn new(name: String, mac: Mac) -> Arc { + Arc::new(Self { + name, + mac, + mem: Mutex::new(None), + tx_callback: Mutex::new(None), + rx_vring: Mutex::new(None), + }) + } + + /// Set the TX callback. + pub fn set_tx_callback(&self, callback: TxCallback) { + *self.tx_callback.lock().unwrap() = Some(callback); + } + + /// Inject a frame into the RX queue for the guest. + /// + /// Returns `true` if the frame was written to the virtio RX queue, `false` if + /// dropped. Frames are silently dropped before the guest driver has initialized + /// virtio queues (RX vring not yet available). This is expected during early + /// startup -- the guest isn't ready to receive traffic until the driver + /// negotiates features and sets up vrings. + pub fn inject_rx_frame(&self, frame: &[u8]) -> bool { + let vring = match self.rx_vring.lock().unwrap().as_ref() { + Some(v) => v.clone(), + None => return false, + }; + + let mem_guard = self.mem.lock().unwrap(); + let mem = match mem_guard.as_ref() { + Some(m) => m.memory(), + None => return false, + }; + + // Prepend virtio header + let mut data = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()]; + data[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame); + + let head_index; + let written; + { + let mut vring_state = vring.get_mut(); + let queue = vring_state.get_queue_mut(); + + let desc_chain = match queue.pop_descriptor_chain(mem.clone()) { + Some(c) => c, + None => return false, + }; + + head_index = desc_chain.head_index(); + + let mut writable_descs = Vec::new(); + for desc in desc_chain { + if desc.is_write_only() { + writable_descs.push((desc.addr(), desc.len() as usize)); + } + } + + let available: usize = writable_descs.iter().map(|(_, l)| *l).sum(); + if available < data.len() { + written = 0; + } else { + let mut bytes_written = 0; + for (addr, len) in writable_descs { + let remaining = data.len() - bytes_written; + if remaining == 0 { + break; + } + let to_write = std::cmp::min(remaining, len); + if mem.write_slice(&data[bytes_written..bytes_written + to_write], addr).is_err() { + break; + } + bytes_written += to_write; + } + written = bytes_written; + } + } + + let _ = vring.add_used(head_index, written as u32); + let _ = vring.enable_notification(); + let _ = vring.signal_used_queue(); + + written >= data.len() + } + + /// Process TX queue and call callback for each frame. + fn process_tx(&self, vring: &VringRwLock) { + let mem_guard = self.mem.lock().unwrap(); + let mem = match mem_guard.as_ref() { + Some(m) => m.memory(), + None => return, + }; + + let callback = self.tx_callback.lock().unwrap(); + let callback = match callback.as_ref() { + Some(c) => c, + None => return, + }; + + loop { + let head_index; + let raw_data; + { + let mut vring_state = vring.get_mut(); + let queue = vring_state.get_queue_mut(); + + let desc_chain = match queue.pop_descriptor_chain(mem.clone()) { + Some(c) => c, + None => break, + }; + + head_index = desc_chain.head_index(); + let mut data = Vec::new(); + + for desc in desc_chain { + let addr = desc.addr(); + let len = desc.len() as usize; + let mut buf = vec![0u8; len]; + if let Err(e) = mem.read_slice(&mut buf, addr) { + warn!(vm = %self.name, error = %e, "failed to read descriptor"); + break; + } + data.extend_from_slice(&buf); + } + raw_data = data; + } + + if let Err(e) = vring.add_used(head_index, 0) { + warn!(vm = %self.name, error = ?e, "add_used failed"); + } + + // Strip virtio header and call callback + if raw_data.len() > VIRTIO_NET_HDR_SIZE { + callback(&raw_data[VIRTIO_NET_HDR_SIZE..]); + } + } + + let _ = vring.enable_notification(); + let _ = vring.signal_used_queue(); + } +} + +impl VhostUserBackend for ChildVhostBackend { + type Bitmap = (); + type Vring = VringRwLock; + + fn num_queues(&self) -> usize { + NUM_QUEUES + } + + fn max_queue_size(&self) -> usize { + MAX_QUEUE_SIZE + } + + fn features(&self) -> u64 { + VIRTIO_NET_FEATURES + | (1 << VIRTIO_F_VERSION_1) + | (1 << VIRTIO_RING_F_EVENT_IDX) + | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() + } + + fn protocol_features(&self) -> VhostUserProtocolFeatures { + VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ + } + + fn set_event_idx(&self, _enabled: bool) {} + + fn update_memory(&self, mem: GuestMemoryAtomic) -> std::io::Result<()> { + debug!(vm = %self.name, "update_memory"); + *self.mem.lock().unwrap() = Some(mem); + Ok(()) + } + + fn handle_event( + &self, + device_event: u16, + _evset: EventSet, + vrings: &[VringRwLock], + _thread_id: usize, + ) -> std::io::Result<()> { + // Store RX vring for injection + if vrings.len() > RX_QUEUE as usize { + let mut rx = self.rx_vring.lock().unwrap(); + if rx.is_none() { + *rx = Some(vrings[RX_QUEUE as usize].clone()); + debug!(vm = %self.name, "stored RX vring"); + } + } + + // Process TX queue + if device_event == TX_QUEUE && vrings.len() > TX_QUEUE as usize { + self.process_tx(&vrings[TX_QUEUE as usize]); + } + + Ok(()) + } + + fn acked_features(&self, _features: u64) {} + + fn get_config(&self, offset: u32, size: u32) -> Vec { + let mut config = [0u8; 10]; + config[0..6].copy_from_slice(&self.mac.bytes()); + config[6] = 1; // LINK_UP + config[8] = 1; // max_virtqueue_pairs + + let offset = offset as usize; + let size = size as usize; + if offset < config.len() { + let end = std::cmp::min(offset + size, config.len()); + let mut result = config[offset..end].to_vec(); + result.resize(size, 0); + result + } else { + vec![0u8; size] + } + } +} diff --git a/vm-switch/src/control.rs b/vm-switch/src/control.rs new file mode 100644 index 0000000..4b6643d --- /dev/null +++ b/vm-switch/src/control.rs @@ -0,0 +1,721 @@ +//! Control channel for main↔child process communication. +//! +//! Messages are serialized with postcard and sent over Unix sockets. +//! File descriptors are passed via SCM_RIGHTS ancillary data. + +use crate::mac::Mac; +use serde::{Deserialize, Serialize}; +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd}; +use thiserror::Error; + +/// Messages sent from main process to child. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum MainToChild { + /// Request child create an ingress buffer for a peer. + /// Child will be Consumer, peer will be Producer. + GetBuffer { + /// Name of the peer VM. + peer_name: String, + /// MAC address of the peer VM. + peer_mac: [u8; 6], + }, + /// Provide peer's ingress buffer for child to use as egress. + /// Child becomes Producer for this buffer. + /// FDs: [memfd, eventfd] + PutBuffer { + /// Name of the peer VM. + peer_name: String, + /// MAC address of the peer VM. + peer_mac: [u8; 6], + /// If true, buffer accepts broadcast/multicast traffic. + broadcast: bool, + }, + /// Peer disconnected, clean up all buffers for this peer. + RemovePeer { + /// Name of the peer VM. + peer_name: String, + }, + /// Heartbeat request. + Ping, +} + +/// Messages sent from child process to main. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum ChildToMain { + /// Child is ready to receive commands. + Ready, + /// Response to GetBuffer - here's my ingress buffer for the requested peer. + /// FDs: [memfd, eventfd] + BufferReady { + /// Name of the peer this buffer is for. + peer_name: String, + }, + /// Heartbeat response. + Pong, +} + +impl MainToChild { + /// Get the peer name, if this message has one. + pub fn peer_name(&self) -> Option<&str> { + match self { + MainToChild::GetBuffer { peer_name, .. } => Some(peer_name), + MainToChild::PutBuffer { peer_name, .. } => Some(peer_name), + MainToChild::RemovePeer { peer_name } => Some(peer_name), + MainToChild::Ping => None, + } + } + + /// Get the peer MAC address as a Mac type, if this message has one. + pub fn peer_mac(&self) -> Option { + match self { + MainToChild::GetBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)), + MainToChild::PutBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)), + MainToChild::RemovePeer { .. } | MainToChild::Ping => None, + } + } +} + +impl ChildToMain { + /// Get the peer name, if this message has one. + pub fn peer_name(&self) -> Option<&str> { + match self { + ChildToMain::Ready | ChildToMain::Pong => None, + ChildToMain::BufferReady { peer_name } => Some(peer_name), + } + } +} + +/// Maximum message size in bytes. +pub const MAX_MESSAGE_SIZE: usize = 256; + +/// Maximum number of file descriptors per message. +pub const MAX_FDS: usize = 4; + +/// Errors that can occur with control channel operations. +#[derive(Debug, Error)] +pub enum ControlError { + #[error("failed to create socketpair: {0}")] + Socketpair(std::io::Error), + #[error("failed to serialize message: {0}")] + Serialize(postcard::Error), + #[error("failed to deserialize message: {0}")] + Deserialize(postcard::Error), + #[error("failed to send message: {0}")] + Send(std::io::Error), + #[error("failed to receive message: {0}")] + Recv(std::io::Error), + #[error("connection closed")] + Closed, + #[error("message too large: {0} bytes")] + MessageTooLarge(usize), +} + +/// A control channel endpoint for main↔child communication. +pub struct ControlChannel { + socket: OwnedFd, +} + +impl AsRawFd for ControlChannel { + fn as_raw_fd(&self) -> RawFd { + self.socket.as_raw_fd() + } +} + +impl ControlChannel { + /// Create a pair of connected control channels. + /// Returns (main_end, child_end). + pub fn pair() -> Result<(Self, Self), ControlError> { + let mut fds = [0i32; 2]; + let ret = unsafe { + libc::socketpair( + libc::AF_UNIX, + libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC, + 0, + fds.as_mut_ptr(), + ) + }; + if ret < 0 { + return Err(ControlError::Socketpair(std::io::Error::last_os_error())); + } + + let main_end = Self { + socket: unsafe { OwnedFd::from_raw_fd(fds[0]) }, + }; + let child_end = Self { + socket: unsafe { OwnedFd::from_raw_fd(fds[1]) }, + }; + + Ok((main_end, child_end)) + } + + /// Create a ControlChannel from a raw file descriptor. + /// + /// # Safety + /// The fd must be a valid, open socket file descriptor. + pub unsafe fn from_raw_fd(fd: RawFd) -> Self { + Self { + socket: OwnedFd::from_raw_fd(fd), + } + } + + /// Create a ControlChannel from an owned file descriptor. + pub fn from_fd(fd: OwnedFd) -> Self { + Self { socket: fd } + } + + /// Consume the ControlChannel and return the underlying file descriptor. + pub fn into_fd(self) -> OwnedFd { + self.socket + } + + /// Send a message without file descriptors. + pub fn send(&self, msg: &M) -> Result<(), ControlError> { + let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?; + if bytes.len() > MAX_MESSAGE_SIZE { + return Err(ControlError::MessageTooLarge(bytes.len())); + } + send_with_fds(self.socket.as_raw_fd(), &bytes, &[]) + } + + /// Send a message with file descriptors. + pub fn send_with_fds_typed( + &self, + msg: &M, + fds: &[RawFd], + ) -> Result<(), ControlError> { + let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?; + if bytes.len() > MAX_MESSAGE_SIZE { + return Err(ControlError::MessageTooLarge(bytes.len())); + } + send_with_fds(self.socket.as_raw_fd(), &bytes, fds) + } + + /// Receive a message without expecting file descriptors. + pub fn recv Deserialize<'de>>(&self) -> Result { + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + let (n, _fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?; + postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize) + } + + /// Receive a message with file descriptors. + /// Returns (message, file_descriptors). + pub fn recv_with_fds_typed Deserialize<'de>>( + &self, + ) -> Result<(M, Vec), ControlError> { + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + let (n, fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?; + let msg = postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize)?; + Ok((msg, fds)) + } +} + +/// Send a message with optional file descriptors via SCM_RIGHTS. +fn send_with_fds(socket_fd: RawFd, data: &[u8], fds: &[RawFd]) -> Result<(), ControlError> { + let mut iov = libc::iovec { + iov_base: data.as_ptr() as *mut libc::c_void, + iov_len: data.len(), + }; + + // Calculate control message buffer size + let cmsg_space = if fds.is_empty() { + 0 + } else { + unsafe { libc::CMSG_SPACE(std::mem::size_of_val(fds) as u32) as usize } + }; + + let mut cmsg_buf = vec![0u8; cmsg_space]; + + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + msg.msg_iov = &mut iov; + msg.msg_iovlen = 1; + + if !fds.is_empty() { + msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = cmsg_space; + + // Fill in the control message header + let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) }; + unsafe { + (*cmsg).cmsg_level = libc::SOL_SOCKET; + (*cmsg).cmsg_type = libc::SCM_RIGHTS; + (*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of_val(fds) as u32) as usize; + + // Copy file descriptors into cmsg data + let cmsg_data = libc::CMSG_DATA(cmsg) as *mut RawFd; + for (i, &fd) in fds.iter().enumerate() { + *cmsg_data.add(i) = fd; + } + } + } + + let ret = unsafe { libc::sendmsg(socket_fd, &msg, 0) }; + if ret < 0 { + return Err(ControlError::Send(std::io::Error::last_os_error())); + } + + Ok(()) +} + +/// Receive a message with optional file descriptors via SCM_RIGHTS. +/// Returns (bytes_read, file_descriptors). +fn recv_with_fds( + socket_fd: RawFd, + buf: &mut [u8], +) -> Result<(usize, Vec), ControlError> { + let mut iov = libc::iovec { + iov_base: buf.as_mut_ptr() as *mut libc::c_void, + iov_len: buf.len(), + }; + + // Buffer for control messages (enough for MAX_FDS file descriptors) + let cmsg_space = + unsafe { libc::CMSG_SPACE((MAX_FDS * std::mem::size_of::()) as u32) as usize }; + let mut cmsg_buf = vec![0u8; cmsg_space]; + + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + msg.msg_iov = &mut iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = cmsg_space; + + let n = unsafe { libc::recvmsg(socket_fd, &mut msg, 0) }; + if n < 0 { + return Err(ControlError::Recv(std::io::Error::last_os_error())); + } + if n == 0 { + return Err(ControlError::Closed); + } + + // Extract file descriptors from control message + let mut fds = Vec::new(); + let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) }; + while !cmsg.is_null() { + unsafe { + if (*cmsg).cmsg_level == libc::SOL_SOCKET && (*cmsg).cmsg_type == libc::SCM_RIGHTS { + let data_len = (*cmsg).cmsg_len - libc::CMSG_LEN(0) as usize; + let num_fds = data_len / std::mem::size_of::(); + let fd_ptr = libc::CMSG_DATA(cmsg) as *const RawFd; + for i in 0..num_fds { + let fd = *fd_ptr.add(i); + fds.push(OwnedFd::from_raw_fd(fd)); + } + } + cmsg = libc::CMSG_NXTHDR(&msg, cmsg); + } + } + + Ok((n as usize, fds)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn main_to_child_get_buffer_serializes() { + let msg = MainToChild::GetBuffer { + peer_name: "router".to_string(), + peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], + }; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn main_to_child_put_buffer_serializes() { + let msg = MainToChild::PutBuffer { + peer_name: "client_a".to_string(), + peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], + broadcast: true, + }; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn main_to_child_remove_peer_serializes() { + let msg = MainToChild::RemovePeer { + peer_name: "client_b".to_string(), + }; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn main_to_child_ping_serializes() { + let msg = MainToChild::Ping; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn child_to_main_pong_serializes() { + let msg = ChildToMain::Pong; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn child_to_main_ready_serializes() { + let msg = ChildToMain::Ready; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn child_to_main_buffer_ready_serializes() { + let msg = ChildToMain::BufferReady { + peer_name: "router".to_string(), + }; + + let bytes = postcard::to_allocvec(&msg).expect("should serialize"); + let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize"); + + assert_eq!(decoded, msg); + } + + #[test] + fn control_channel_pair_creates_connected_sockets() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + assert!(main_end.as_raw_fd() >= 0); + assert!(child_end.as_raw_fd() >= 0); + assert_ne!(main_end.as_raw_fd(), child_end.as_raw_fd()); + } + + #[test] + fn send_with_fds_sends_data() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let data = b"hello"; + + send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send"); + + // Read on child end to verify + let mut buf = [0u8; 64]; + let n = unsafe { + libc::recv( + child_end.as_raw_fd(), + buf.as_mut_ptr() as *mut libc::c_void, + buf.len(), + 0, + ) + }; + assert!(n > 0); + assert_eq!(&buf[..n as usize], data); + } + + #[test] + fn send_with_fds_passes_file_descriptors() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + // Create a test eventfd to send + let test_fd = unsafe { libc::eventfd(42, libc::EFD_CLOEXEC) }; + assert!(test_fd >= 0); + + send_with_fds(main_end.as_raw_fd(), b"fd", &[test_fd]).expect("should send"); + + // Close our copy - child should have its own + unsafe { libc::close(test_fd) }; + + // Receive on child end + let mut buf = [0u8; 64]; + let mut cmsg_buf = [0u8; 64]; + let mut iov = libc::iovec { + iov_base: buf.as_mut_ptr() as *mut libc::c_void, + iov_len: buf.len(), + }; + let mut msg: libc::msghdr = unsafe { std::mem::zeroed() }; + msg.msg_iov = &mut iov; + msg.msg_iovlen = 1; + msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void; + msg.msg_controllen = cmsg_buf.len(); + + let n = unsafe { libc::recvmsg(child_end.as_raw_fd(), &mut msg, 0) }; + assert!(n > 0); + + // Extract the file descriptor + let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) }; + assert!(!cmsg.is_null()); + let received_fd = unsafe { *(libc::CMSG_DATA(cmsg) as *const RawFd) }; + assert!(received_fd >= 0); + + // Verify we can read the eventfd value (42) + let mut val: u64 = 0; + let ret = unsafe { + libc::read( + received_fd, + &mut val as *mut u64 as *mut libc::c_void, + std::mem::size_of::(), + ) + }; + assert_eq!(ret, std::mem::size_of::() as isize); + assert_eq!(val, 42); + + unsafe { libc::close(received_fd) }; + } + + #[test] + fn recv_with_fds_receives_data() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let data = b"world"; + + // Send from main + send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send"); + + // Receive on child + let mut buf = [0u8; 64]; + let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive"); + + assert_eq!(n, data.len()); + assert_eq!(&buf[..n], data); + assert!(fds.is_empty()); + } + + #[test] + fn recv_with_fds_receives_file_descriptors() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + // Create test eventfds + let fd1 = unsafe { libc::eventfd(10, libc::EFD_CLOEXEC) }; + let fd2 = unsafe { libc::eventfd(20, libc::EFD_CLOEXEC) }; + assert!(fd1 >= 0); + assert!(fd2 >= 0); + + send_with_fds(main_end.as_raw_fd(), b"fds", &[fd1, fd2]).expect("should send"); + + // Close our copies + unsafe { + libc::close(fd1); + libc::close(fd2); + } + + // Receive + let mut buf = [0u8; 64]; + let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive"); + + assert_eq!(n, 3); + assert_eq!(&buf[..n], b"fds"); + assert_eq!(fds.len(), 2); + + // Verify eventfd values + let mut val: u64 = 0; + unsafe { + libc::read( + fds[0].as_raw_fd(), + &mut val as *mut u64 as *mut libc::c_void, + 8, + ); + } + assert_eq!(val, 10); + + val = 0; + unsafe { + libc::read( + fds[1].as_raw_fd(), + &mut val as *mut u64 as *mut libc::c_void, + 8, + ); + } + assert_eq!(val, 20); + } + + #[test] + fn control_channel_send_delivers_message() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + let msg = MainToChild::RemovePeer { + peer_name: "client_a".to_string(), + }; + main_end.send(&msg).expect("should send"); + + // Verify by receiving raw bytes + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + let (n, _) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive"); + + let decoded: MainToChild = postcard::from_bytes(&buf[..n]).expect("should decode"); + assert_eq!(decoded, msg); + } + + #[test] + fn control_channel_send_with_fds_delivers_message_and_fds() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + // Create test eventfd + let test_fd = unsafe { libc::eventfd(99, libc::EFD_CLOEXEC) }; + assert!(test_fd >= 0); + + let msg = ChildToMain::BufferReady { + peer_name: "router".to_string(), + }; + main_end + .send_with_fds_typed(&msg, &[test_fd]) + .expect("should send"); + + unsafe { libc::close(test_fd) }; + + // Receive + let mut buf = [0u8; MAX_MESSAGE_SIZE]; + let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive"); + + let decoded: ChildToMain = postcard::from_bytes(&buf[..n]).expect("should decode"); + assert_eq!(decoded, msg); + assert_eq!(fds.len(), 1); + + // Verify eventfd value + let mut val: u64 = 0; + unsafe { + libc::read( + fds[0].as_raw_fd(), + &mut val as *mut u64 as *mut libc::c_void, + 8, + ); + } + assert_eq!(val, 99); + } + + #[test] + fn control_channel_recv_returns_message() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + let msg = MainToChild::RemovePeer { + peer_name: "client_b".to_string(), + }; + + // Send from main using typed method + main_end.send(&msg).expect("should send"); + + // Receive on child using typed method + let received: MainToChild = child_end.recv().expect("should receive"); + assert_eq!(received, msg); + } + + #[test] + fn control_channel_recv_with_fds_returns_message_and_fds() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + // Create test eventfds + let fd1 = unsafe { libc::eventfd(111, libc::EFD_CLOEXEC) }; + let fd2 = unsafe { libc::eventfd(222, libc::EFD_CLOEXEC) }; + + let msg = MainToChild::PutBuffer { + peer_name: "router".to_string(), + peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66], + broadcast: true, + }; + + main_end.send_with_fds_typed(&msg, &[fd1, fd2]).expect("should send"); + + unsafe { + libc::close(fd1); + libc::close(fd2); + } + + // Receive with typed method + let (received, fds): (MainToChild, _) = child_end.recv_with_fds_typed().expect("should receive"); + assert_eq!(received, msg); + assert_eq!(fds.len(), 2); + + // Verify eventfd values + let mut val: u64 = 0; + unsafe { + libc::read(fds[0].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8); + } + assert_eq!(val, 111); + + val = 0; + unsafe { + libc::read(fds[1].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8); + } + assert_eq!(val, 222); + } + + #[test] + fn control_channel_recv_detects_closed_connection() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + + // Close the sender end + drop(main_end); + + // Receive should return Closed error + let result: Result = child_end.recv(); + match result { + Err(ControlError::Closed) => (), + other => panic!("expected Closed error, got {:?}", other), + } + } + + #[test] + fn main_to_child_peer_name_helper_returns_correct_name() { + let msg = MainToChild::PutBuffer { + peer_name: "router".to_string(), + peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], + broadcast: true, + }; + assert_eq!(msg.peer_name(), Some("router")); + + let msg2 = MainToChild::GetBuffer { + peer_name: "client_a".to_string(), + peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66], + }; + assert_eq!(msg2.peer_name(), Some("client_a")); + + let msg3 = MainToChild::RemovePeer { + peer_name: "client_b".to_string(), + }; + assert_eq!(msg3.peer_name(), Some("client_b")); + + assert_eq!(MainToChild::Ping.peer_name(), None); + } + + #[test] + fn main_to_child_peer_mac_helper_returns_correct_mac() { + let msg = MainToChild::PutBuffer { + peer_name: "router".to_string(), + peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], + broadcast: true, + }; + assert_eq!(format!("{}", msg.peer_mac().unwrap()), "aa:bb:cc:dd:ee:ff"); + + let msg2 = MainToChild::GetBuffer { + peer_name: "client_a".to_string(), + peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66], + }; + assert_eq!(format!("{}", msg2.peer_mac().unwrap()), "11:22:33:44:55:66"); + + let msg3 = MainToChild::RemovePeer { + peer_name: "client_b".to_string(), + }; + assert!(msg3.peer_mac().is_none()); + + assert!(MainToChild::Ping.peer_mac().is_none()); + } + + #[test] + fn child_to_main_peer_name_helper_returns_correct_name() { + let msg = ChildToMain::BufferReady { + peer_name: "router".to_string(), + }; + assert_eq!(msg.peer_name(), Some("router")); + + assert_eq!(ChildToMain::Ready.peer_name(), None); + assert_eq!(ChildToMain::Pong.peer_name(), None); + } +} diff --git a/vm-switch/src/frame.rs b/vm-switch/src/frame.rs index fc7dc29..7cdbdba 100644 --- a/vm-switch/src/frame.rs +++ b/vm-switch/src/frame.rs @@ -45,6 +45,28 @@ impl<'a> EthernetFrame<'a> { } } +/// Validate that a frame's source MAC matches the expected MAC. +/// Returns false if the frame is too short or MAC doesn't match. +pub fn validate_source_mac(frame: &[u8], expected: Mac) -> bool { + if frame.len() < MIN_FRAME_SIZE { + return false; + } + let mut src = [0u8; 6]; + src.copy_from_slice(&frame[6..12]); + Mac::from_bytes(src) == expected +} + +/// Extract destination MAC from a frame. +/// Returns None if frame is too short. +pub fn extract_dest_mac(frame: &[u8]) -> Option { + if frame.len() < 6 { + return None; + } + let mut dest = [0u8; 6]; + dest.copy_from_slice(&frame[0..6]); + Some(Mac::from_bytes(dest)) +} + #[cfg(test)] mod tests { use super::*; @@ -99,4 +121,44 @@ mod tests { Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]) ); } + + #[test] + fn validate_source_mac_accepts_matching() { + let mut frame = [0u8; 14]; + frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + + let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + assert!(validate_source_mac(&frame, expected)); + } + + #[test] + fn validate_source_mac_rejects_mismatch() { + let mut frame = [0u8; 14]; + frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + + let expected = Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); + assert!(!validate_source_mac(&frame, expected)); + } + + #[test] + fn validate_source_mac_rejects_short_frame() { + let frame = [0u8; 10]; // Too short + let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + assert!(!validate_source_mac(&frame, expected)); + } + + #[test] + fn extract_dest_mac_returns_mac() { + let mut frame = [0u8; 14]; + frame[0..6].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]); + + let result = extract_dest_mac(&frame); + assert_eq!(result, Some(Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]))); + } + + #[test] + fn extract_dest_mac_returns_none_for_short_frame() { + let frame = [0u8; 5]; + assert!(extract_dest_mac(&frame).is_none()); + } } diff --git a/vm-switch/src/lib.rs b/vm-switch/src/lib.rs index f1686db..c29cdab 100644 --- a/vm-switch/src/lib.rs +++ b/vm-switch/src/lib.rs @@ -1,14 +1,15 @@ //! vm-switch: Virtual L2 switch for VM-to-VM networking. //! //! This library provides a vhost-user network backend that enables -//! communication between VMs through a virtual L2 switch with MAC filtering. +//! communication between VMs through a virtual L2 switch with ring buffer +//! forwarding. //! //! # Architecture //! //! - **ConfigWatcher**: Monitors `/run/vm-switch/` for `.mac` files -//! - **BackendManager**: Coordinates multiple NetBackend instances -//! - **Switch**: L2 switching logic with MAC filtering -//! - **NetBackend**: vhost-user backend for a single VM +//! - **BackendManager**: Coordinates forked child processes for each VM +//! - **Child**: vhost-user backend running in forked process with sandboxing +//! - **RingBuffer**: Inter-process packet transfer via shared memory //! //! # Usage //! @@ -18,23 +19,30 @@ //! //! let args = Args::parse(); //! let watcher = ConfigWatcher::new(&args.config_dir, 64)?; -//! let manager = BackendManager::new(&args.config_dir); +//! let manager = BackendManager::new(&args.config_dir, SeccompMode::Kill); //! ``` pub mod args; -pub mod backend; +pub mod child; pub mod config; +pub mod control; pub mod frame; pub mod mac; pub mod manager; -pub mod switch; +pub mod ring; +pub mod sandbox; +pub mod seccomp; pub mod watcher; // Re-export commonly used types pub use args::{init_logging, Args, LogLevel}; -pub use backend::NetBackend; pub use config::{ConfigEvent, VmConfig, VmRole}; +pub use control::{ChildToMain, ControlChannel, ControlError, MainToChild, MAX_FDS, MAX_MESSAGE_SIZE}; pub use mac::Mac; -pub use manager::BackendManager; -pub use switch::{ConnectionId, ForwardDecision, Switch}; +pub use manager::{BackendManager, ChildMessage}; +pub use ring::{Consumer, Producer, RingError, RING_BUFFER_SIZE, RING_SIZE, SLOT_DATA_SIZE}; +pub use sandbox::{ + apply_sandbox, enter_user_namespace, setup_filesystem_isolation, SandboxError, SandboxResult, +}; +pub use seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompError, SeccompMode}; pub use watcher::ConfigWatcher; diff --git a/vm-switch/src/main.rs b/vm-switch/src/main.rs index 30b6ed6..24a1771 100644 --- a/vm-switch/src/main.rs +++ b/vm-switch/src/main.rs @@ -1,30 +1,99 @@ +use std::path::PathBuf; + +use std::time::Duration; + use clap::Parser; use tokio::signal; +use tokio::signal::unix::{signal as unix_signal, SignalKind}; use tokio::sync::broadcast; -use vm_switch::{Args, BackendManager, ConfigWatcher, init_logging}; +use vm_switch::{apply_sandbox, apply_main_seccomp, Args, BackendManager, ConfigWatcher, SandboxResult, SeccompMode, init_logging}; -#[tokio::main] -async fn main() -> Result<(), Box> { - let args = Args::parse(); - init_logging(args.log_level); - - tracing::info!(config_dir = ?args.config_dir, "Starting vm-switch"); - - // Ensure config directory exists +/// Apply sandbox before tokio runtime starts (must be single-threaded). +/// +/// This function may fork. If we are the parent wrapper, it exits with +/// the child's exit code. The actual vm-switch logic runs in the child. +fn setup_sandbox(args: &Args) -> Result> { + // Ensure config directory exists before sandboxing if !args.config_dir.exists() { std::fs::create_dir_all(&args.config_dir)?; - tracing::info!(path = ?args.config_dir, "Created config directory"); + eprintln!("Created config directory: {:?}", args.config_dir); } - // Create watcher and manager - let mut watcher = ConfigWatcher::new(&args.config_dir, 64)?; - let mut manager = BackendManager::new(&args.config_dir); + // Apply sandbox unless disabled + let config_path = if args.no_sandbox { + eprintln!("Sandboxing disabled via --no-sandbox"); + args.config_dir.clone() + } else { + match apply_sandbox(&args.config_dir) { + Ok(SandboxResult::Parent(exit_code)) => { + // We are the wrapper parent - propagate child's exit code + std::process::exit(exit_code); + } + Ok(SandboxResult::Sandboxed(path)) => { + eprintln!("Sandbox applied, config at {:?}", path); + path + } + Err(e) => { + eprintln!("Failed to apply sandbox: {}", e); + return Err(e.into()); + } + } + }; + + // Apply seccomp filter + if args.seccomp_mode != SeccompMode::Disabled { + apply_main_seccomp(args.seccomp_mode)?; + eprintln!("Seccomp applied (mode: {:?})", args.seccomp_mode); + } else { + eprintln!("Seccomp disabled via --seccomp-mode=disabled"); + } + + Ok(config_path) +} + +fn main() -> Result<(), Box> { + let args = Args::parse(); + + // Apply sandbox BEFORE tokio runtime starts (unshare requires single-threaded) + let config_dir = setup_sandbox(&args)?; + + // Now start tokio runtime and run async main + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build()? + .block_on(async_main(args, config_dir)) +} + +async fn async_main(args: Args, config_dir: PathBuf) -> Result<(), Box> { + init_logging(args.log_level); + + tracing::info!(config_dir = ?config_dir, "Starting vm-switch"); + + // Create watcher and manager using sandboxed path + let mut watcher = ConfigWatcher::new(&config_dir, 64)?; + let (mut manager, mut child_rx) = BackendManager::new(&config_dir, args.seccomp_mode); + + // Create SIGCHLD signal stream for child process monitoring + let mut sigchld = unix_signal(SignalKind::child()) + .map_err(|e| format!("failed to create SIGCHLD signal stream: {}", e))?; + + // Create SIGTERM signal stream for graceful shutdown. + // This is critical: as PID 1 in a PID namespace, the kernel only delivers + // signals for which a handler is registered. Without this, SIGTERM is + // silently dropped and systemd has to SIGKILL after timeout. + let mut sigterm = unix_signal(SignalKind::terminate()) + .map_err(|e| format!("failed to create SIGTERM signal stream: {}", e))?; // Get receiver for events (includes initial scan) let mut rx = watcher.take_receiver(); tracing::info!("Processing configuration events (Ctrl+C to stop)..."); + // Heartbeat: ping workers every second, check for responses after 100ms + let mut ping_interval = tokio::time::interval(Duration::from_secs(1)); + let ping_timeout = tokio::time::sleep(Duration::from_secs(86400)); + tokio::pin!(ping_timeout); + // Process events until shutdown signal loop { tokio::select! { @@ -43,8 +112,26 @@ async fn main() -> Result<(), Box> { } } } + Some(msg) = child_rx.recv() => { + manager.handle_child_message(msg); + } + _ = sigchld.recv() => { + manager.reap_children(); + } + _ = ping_interval.tick() => { + manager.send_pings(); + ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(100)); + } + _ = &mut ping_timeout => { + manager.check_ping_timeouts(); + ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_secs(86400)); + } + _ = sigterm.recv() => { + tracing::info!("Received SIGTERM, shutting down"); + break; + } _ = signal::ctrl_c() => { - tracing::info!("Received shutdown signal"); + tracing::info!("Received SIGINT, shutting down"); break; } } diff --git a/vm-switch/src/manager.rs b/vm-switch/src/manager.rs index 7e44458..7a24b83 100644 --- a/vm-switch/src/manager.rs +++ b/vm-switch/src/manager.rs @@ -1,57 +1,136 @@ -//! Coordinates multiple NetBackend instances. +//! Backend manager - coordinates forked child processes for VMs. -use std::collections::HashMap; -use std::path::{Path, PathBuf}; -use std::sync::{Arc, RwLock}; -use std::thread::JoinHandle; +use std::collections::{HashMap, HashSet}; +use std::os::fd::{AsRawFd, OwnedFd, RawFd}; +use std::path::Path; +use std::time::Instant; -use tracing::{debug, error, info, warn}; -use vhost_user_backend::VringRwLock; -use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; +use anyhow::Context; +use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus}; +use nix::unistd::{dup, fork, ForkResult, Pid}; +use tokio::io::unix::AsyncFd; +use tokio::sync::mpsc; +use tracing::{debug, info, warn}; -use crate::backend::{NetBackend, ProcessedFrame, VringRegistry}; use crate::config::{ConfigEvent, VmConfig, VmRole}; -use crate::switch::{ConnectionId, ForwardDecision, Switch}; +use crate::control::{ChildToMain, ControlChannel, ControlError}; +use crate::seccomp::SeccompMode; -/// Handle to an active backend. -struct BackendHandle { - /// The backend instance. - backend: Arc, - /// Path to the socket file. - socket_path: PathBuf, - /// Handle to the daemon thread (if running). - daemon_thread: Option>, +/// Information about a peer's ingress buffer (received from a child via BufferReady). +#[derive(Debug)] +struct PeerBufferInfo { + /// Memfd for the ring buffer shared memory. + memfd: OwnedFd, + /// Eventfd for signaling. + eventfd: OwnedFd, } -/// Map of connection IDs to their RX vrings and associated memory. -pub type VringMap = HashMap)>; +impl PeerBufferInfo { + fn new(memfd: OwnedFd, eventfd: OwnedFd) -> Self { + Self { memfd, eventfd } + } + + /// Get raw FDs for passing via SCM_RIGHTS. + fn as_raw_fds(&self) -> [RawFd; 2] { + [self.memfd.as_raw_fd(), self.eventfd.as_raw_fd()] + } +} + +/// Message forwarded from a child reader task to the main loop. +pub struct ChildMessage { + /// Name of the VM that sent this message. + pub vm_name: String, + /// The message from the child. + pub msg: ChildToMain, + /// File descriptors accompanying the message. + pub fds: Vec, +} + +/// State tracking for a forked child process. +struct ChildState { + /// Process ID of the child. + pid: Pid, + /// Control channel for communication with child. + control: ControlChannel, + /// VM's role (router or client). + role: VmRole, + /// Child's MAC address. + mac: [u8; 6], + /// Whether the child has sent Ready. + is_ready: bool, + /// Pending GetBuffer requests awaiting BufferReady responses (peer_name -> peer_mac). + pending_buffers: HashMap, + /// Buffers received from this child, waiting to be sent to peers (peer_name -> buffer info). + ready_buffers: HashMap, + /// Names of peers who have received buffers from this child (for crash notification). + buffer_sent_to: HashSet, + /// Path to the vhost-user socket (needed for restart). + socket_path: std::path::PathBuf, + /// When the last Ping was sent (None if no pending ping). + ping_sent_at: Option, +} + +impl ChildState { + fn new(pid: Pid, control: ControlChannel, role: VmRole, mac: [u8; 6], socket_path: std::path::PathBuf) -> Self { + Self { + pid, + control, + role, + mac, + is_ready: false, + pending_buffers: HashMap::new(), + ready_buffers: HashMap::new(), + buffer_sent_to: HashSet::new(), + socket_path, + ping_sent_at: None, + } + } + + /// Check if child has exited without blocking. + /// Returns Some(exit_status) if exited, None if still running. + fn try_wait(&self) -> Option { + match waitpid(self.pid, Some(WaitPidFlag::WNOHANG)) { + Ok(WaitStatus::Exited(_, status)) => Some(status), + Ok(WaitStatus::Signaled(_, signal, _)) => Some(128 + signal as i32), + Ok(_) => None, // Still running or stopped + Err(_) => Some(-1), // Error checking, treat as exited + } + } + + /// Record that we sent this child's buffer to a peer. + fn mark_buffer_sent_to(&mut self, peer_name: &str) { + self.buffer_sent_to.insert(peer_name.to_string()); + } + + /// Get set of peer names who have this child's buffer. + fn peers_with_our_buffer(&self) -> &HashSet { + &self.buffer_sent_to + } +} /// Manages vhost-user backends for all VMs. pub struct BackendManager { - /// Configuration directory path (stored for future use). - _config_dir: PathBuf, - /// Shared switch instance. - switch: Arc>, - /// Shared vring registry for frame routing between backends. - vring_registry: VringRegistry, - /// Active backends by VM name. - backends: HashMap, + /// Forked child processes, keyed by VM name. + children: HashMap, + /// Channel for forwarding messages from child reader tasks. + child_msg_tx: mpsc::UnboundedSender, + /// Seccomp mode to pass to child processes. + seccomp_mode: SeccompMode, } impl BackendManager { /// Create a new backend manager. - pub fn new(config_dir: impl AsRef) -> Self { - Self { - _config_dir: config_dir.as_ref().to_path_buf(), - switch: Arc::new(RwLock::new(Switch::new())), - vring_registry: Arc::new(RwLock::new(HashMap::new())), - backends: HashMap::new(), - } - } - - /// Get reference to the switch. - pub fn switch(&self) -> &Arc> { - &self.switch + /// Returns the manager and a receiver for child messages. + pub fn new(_config_dir: impl AsRef, seccomp_mode: SeccompMode) -> (Self, mpsc::UnboundedReceiver) { + let (tx, rx) = mpsc::unbounded_channel(); + ( + Self { + children: HashMap::new(), + child_msg_tx: tx, + seccomp_mode, + }, + rx, + ) } /// Handle a configuration event. @@ -66,273 +145,502 @@ impl BackendManager { } } - /// Get backend by name. - pub fn get_backend(&self, name: &str) -> Option> { - self.backends.get(name).map(|h| Arc::clone(&h.backend)) - } - - /// Route a processed frame to its destination(s). - /// - /// Uses the forwarding decision to inject the frame into - /// the appropriate RX queue(s). - pub fn route_frame(&self, processed: &ProcessedFrame, rx_vrings: &VringMap) { - match &processed.decision { - ForwardDecision::Unicast(dest_id) => { - if let Some((vring, mem)) = rx_vrings.get(dest_id) { - if !NetBackend::inject_rx_frame(vring, mem, &processed.data) { - debug!(dest_id = ?dest_id, "RX queue full, dropping frame"); - } + /// Handle a message received from a child process. + pub fn handle_child_message(&mut self, msg: ChildMessage) { + match msg.msg { + ChildToMain::Ready => { + self.handle_ready(&msg.vm_name); + } + ChildToMain::BufferReady { peer_name } => { + if msg.fds.len() == 2 { + let mut fds = msg.fds.into_iter(); + let memfd = fds.next().unwrap(); + let eventfd = fds.next().unwrap(); + self.handle_buffer_ready(&msg.vm_name, &peer_name, memfd, eventfd); + } else { + warn!( + vm = %msg.vm_name, + peer = %peer_name, + fd_count = msg.fds.len(), + "BufferReady with wrong number of FDs" + ); } } - ForwardDecision::Multicast(dest_ids) => { - for dest_id in dest_ids { - if let Some((vring, mem)) = rx_vrings.get(dest_id) { - if !NetBackend::inject_rx_frame(vring, mem, &processed.data) { - debug!(dest_id = ?dest_id, "RX queue full, dropping frame"); - } - } + ChildToMain::Pong => { + if let Some(child) = self.children.get_mut(&msg.vm_name) { + child.ping_sent_at = None; } } - ForwardDecision::Drop(reason) => { - debug!(?reason, "Dropping frame"); - } } } - /// Get socket path for a VM. - pub fn get_socket_path(&self, name: &str) -> Option { - self.backends.get(name).map(|h| h.socket_path.clone()) - } - - /// Get all backend names. - pub fn backend_names(&self) -> Vec { - self.backends.keys().cloned().collect() - } - - /// Find the name of the current router backend, if any. + /// Find the name of the current router child, if any. fn find_router(&self) -> Option { - for (name, handle) in &self.backends { - if handle.backend.role() == VmRole::Router { - return Some(name.clone()); - } - } - None + self.children.iter() + .find(|(_, state)| state.role == VmRole::Router) + .map(|(name, _)| name.clone()) } - /// Start the vhost-user daemon for a backend. - /// - /// Creates a Unix socket and spawns a thread to run the daemon. - /// Returns Ok(()) if successful, Err if socket creation fails. - pub fn start_daemon(&mut self, vm_name: &str) -> std::io::Result<()> { - use std::os::unix::net::UnixListener; - - let handle = self.backends.get_mut(vm_name).ok_or_else(|| { - std::io::Error::new(std::io::ErrorKind::NotFound, "VM not found") - })?; - - // Remove old socket if exists - if handle.socket_path.exists() { - std::fs::remove_file(&handle.socket_path)?; - } - - // Create Unix listener to bind the socket path - let listener = UnixListener::bind(&handle.socket_path)?; - listener.set_nonblocking(true)?; - - info!(vm = %vm_name, socket = ?handle.socket_path, "Created vhost-user socket"); - - // Clone what we need for the thread - let backend = Arc::clone(&handle.backend); - let vm_name_owned = vm_name.to_string(); - let socket_path = handle.socket_path.clone(); - let vring_registry = Arc::clone(&self.vring_registry); - - // Spawn daemon thread - let thread = std::thread::spawn(move || { - Self::run_daemon(vm_name_owned, backend, listener, socket_path, vring_registry); - }); - - handle.daemon_thread = Some(thread); - - Ok(()) - } - - /// Run the vhost-user daemon (called from spawned thread). - /// - /// Accepts connections in a loop, allowing crosvm to reconnect after disconnecting. - /// The loop exits when the socket file is removed (by stop_daemon). - fn run_daemon( - vm_name: String, - backend: Arc, - initial_listener: std::os::unix::net::UnixListener, - socket_path: PathBuf, - vring_registry: VringRegistry, - ) { - use std::os::unix::net::UnixListener; - use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap}; - use vhost_user_backend::VhostUserDaemon; - - info!(vm = %vm_name, "Daemon thread started"); - - let mut current_listener = initial_listener; - - loop { - // Clear state from previous connection - backend.clear_state(); - if let Some(conn_id) = backend.connection_id() { - let mut registry = vring_registry.write().unwrap(); - registry.remove(&conn_id); - } - - // Create empty guest memory (will be populated by vhost protocol via SET_MEM_TABLE) - let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new()); - - // Create the daemon - let mut daemon = match VhostUserDaemon::new( - vm_name.clone(), - Arc::clone(&backend), - mem, - ) { - Ok(d) => d, - Err(e) => { - warn!(vm = %vm_name, error = %e, "Failed to create daemon"); - break; - } - }; - - // Convert UnixListener to vhost Listener - let mut vhost_listener = vhost::vhost_user::Listener::from(current_listener); - - // Start serving - this spawns an internal thread and returns immediately - if let Err(e) = daemon.start(&mut vhost_listener) { - warn!(vm = %vm_name, error = %e, "Daemon failed to start"); - break; - } - - info!(vm = %vm_name, "Client connected"); - - // Wait for the internal daemon thread to finish (blocks until client disconnects) - if let Err(e) = daemon.wait() { - error!(vm = %vm_name, error = %e, "Daemon stopped with error"); - } - - info!(vm = %vm_name, "Client disconnected"); - - // Check if we should accept another connection (socket file removed = shutdown) - if !socket_path.exists() { - break; - } - - // Recreate listener for next connection - let _ = std::fs::remove_file(&socket_path); - current_listener = match UnixListener::bind(&socket_path) { - Ok(l) => { - if let Err(e) = l.set_nonblocking(true) { - warn!(vm = %vm_name, error = %e, "Failed to set non-blocking"); - break; - } - l - } - Err(e) => { - warn!(vm = %vm_name, error = %e, "Failed to rebind socket"); - break; - } - }; - - info!(vm = %vm_name, "Ready for next connection"); - } - - info!(vm = %vm_name, "Daemon thread exiting"); - } - - /// Stop the daemon for a backend. - /// - /// Removes the socket file to trigger daemon shutdown, - /// then waits for the thread to exit with a timeout. - pub fn stop_daemon(&mut self, vm_name: &str) { - let handle = match self.backends.get_mut(vm_name) { - Some(h) => h, - None => return, - }; - - // Remove socket to trigger shutdown - if handle.socket_path.exists() { - if let Err(e) = std::fs::remove_file(&handle.socket_path) { - warn!(vm = %vm_name, error = %e, "Failed to remove socket"); - } - } - - // Wait for thread to exit with timeout - if let Some(thread) = handle.daemon_thread.take() { - // Try to join with a timeout by polling is_finished - let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2); - while std::time::Instant::now() < deadline { - if thread.is_finished() { - match thread.join() { - Ok(_) => info!(vm = %vm_name, "Daemon thread joined"), - Err(_) => warn!(vm = %vm_name, "Daemon thread panicked"), - } - return; - } - std::thread::sleep(std::time::Duration::from_millis(50)); - } - - // Timeout expired - thread is still running (crosvm likely connected) - // We can't forcibly kill the thread, so just detach it - warn!( - vm = %vm_name, - "Daemon thread did not exit within timeout (client may still be connected)" - ); - // Thread handle is dropped here, detaching the thread - } - } - - /// Check if a daemon is running for a backend. + /// Check if a child process is running for a VM. pub fn is_daemon_running(&self, vm_name: &str) -> bool { - self.backends - .get(vm_name) - .is_some_and(|h| h.daemon_thread.is_some()) + self.children.contains_key(vm_name) } /// Stop all daemons and clean up. pub fn shutdown_all(&mut self) { - let names: Vec = self.backends.keys().cloned().collect(); - for name in names { - self.stop_daemon(&name); - } - info!("All daemons stopped"); + self.shutdown_children(); + info!("All backends stopped"); } - /// Start daemons for all registered backends. - pub fn start_all_daemons(&mut self) { - let names: Vec = self.backends.keys().cloned().collect(); - for name in names { - if let Err(e) = self.start_daemon(&name) { - warn!(vm = %name, error = %e, "Failed to start daemon"); + /// Check for and handle any exited child processes. + /// Restarts exited children automatically. + /// Should be called when SIGCHLD is received. + pub fn reap_children(&mut self) { + let mut exited = Vec::new(); + + for (name, child) in &self.children { + if let Some(status) = child.try_wait() { + if status == 0 { + info!(vm = %name, "child process exited normally"); + } else { + warn!(vm = %name, status, "child process exited with error"); + } + let peers: Vec<_> = child.peers_with_our_buffer().iter().cloned().collect(); + let restart_info = ( + crate::mac::Mac::from_bytes(child.mac), + child.role, + child.socket_path.clone(), + ); + exited.push((name.clone(), peers, restart_info)); + } + } + + let mut to_restart = Vec::new(); + + for (name, peers, restart_info) in exited { + // Notify peers before removing state + self.notify_peers_of_crash(&name, &peers); + + // Clean stale buffer_sent_to references in peers + for peer_name in &peers { + if let Some(peer) = self.children.get_mut(peer_name) { + peer.buffer_sent_to.remove(&name); + } + } + + // Remove child state + if let Some(child) = self.children.remove(&name) { + info!( + vm = %name, + pid = %child.pid, + was_ready = child.is_ready, + "Cleaned up exited child" + ); + } + + to_restart.push((name, restart_info)); + } + + // Restart children + for (vm_name, (mac, role, socket_path)) in to_restart { + info!(vm = %vm_name, "Restarting child process"); + match self.fork_child(&vm_name, mac, role, &socket_path) { + Ok(child_state) => { + self.children.insert(vm_name, child_state); + } + Err(e) => { + warn!(vm = %vm_name, error = %e, "Failed to restart child"); + } } } } - /// Add a VM and create its backend. + /// Send Ping to all ready children. + pub fn send_pings(&mut self) { + use crate::control::MainToChild; + + let now = Instant::now(); + for (name, child) in &mut self.children { + if child.is_ready { + match child.control.send(&MainToChild::Ping) { + Ok(()) => { + child.ping_sent_at = Some(now); + } + Err(e) => { + warn!(vm = %name, error = %e, "failed to send Ping"); + } + } + } + } + } + + /// Check for workers that haven't responded to Ping within the timeout. + /// Kills unresponsive workers; SIGCHLD + reap_children handles restart. + pub fn check_ping_timeouts(&mut self) { + let timeout = std::time::Duration::from_millis(100); + for (name, child) in &self.children { + if let Some(sent_at) = child.ping_sent_at { + if sent_at.elapsed() > timeout { + warn!(vm = %name, "worker did not respond to ping, killing"); + let _ = nix::sys::signal::kill(child.pid, nix::sys::signal::Signal::SIGKILL); + } + } + } + } + + /// Notify peers that a child has crashed. + /// Sends RemovePeer to every peer who had the crashed child's buffer. + fn notify_peers_of_crash(&self, crashed_vm: &str, peers_to_notify: &[String]) { + use crate::control::MainToChild; + + if peers_to_notify.is_empty() { + debug!(crashed_vm = %crashed_vm, "No peers to notify of crash"); + return; + } + + debug!( + crashed_vm = %crashed_vm, + peer_count = peers_to_notify.len(), + "Notifying peers of child crash" + ); + + let msg = MainToChild::RemovePeer { + peer_name: crashed_vm.to_string(), + }; + + for peer_name in peers_to_notify { + if let Some(peer) = self.children.get(peer_name) { + match peer.control.send(&msg) { + Ok(()) => { + debug!( + "control: main -> worker-{} RemovePeer({})", + peer_name, crashed_vm + ); + } + Err(e) => { + warn!( + crashed = %crashed_vm, + peer = %peer_name, + error = %e, + "Failed to send RemovePeer" + ); + } + } + } + } + } + + /// Handle Ready message from a child. + /// + /// When a child sends Ready, we request buffers from all existing ready peers + /// and request buffers from the new child for all existing ready peers. + fn handle_ready(&mut self, vm_name: &str) { + debug!("control: worker-{} -> main Ready", vm_name); + + // Mark child as ready + if let Some(child) = self.children.get_mut(vm_name) { + child.is_ready = true; + } else { + warn!(vm = %vm_name, "Ready from unknown child"); + return; + } + + // Collect info about all other ready peers + let our_mac = self.children.get(vm_name).map(|c| c.mac).unwrap(); + let peers: Vec<(String, [u8; 6])> = self.children.iter() + .filter(|(name, state)| *name != vm_name && state.is_ready) + .map(|(name, state)| (name.clone(), state.mac)) + .collect(); + + // Request buffers from both sides for each peer pair + for (peer_name, peer_mac) in peers { + // Ask the new child to create an ingress buffer for the peer + self.send_get_buffer(vm_name, &peer_name, peer_mac); + // Ask the peer to create an ingress buffer for the new child + self.send_get_buffer(&peer_name, vm_name, our_mac); + } + } + + /// Send GetBuffer request to a child. + fn send_get_buffer(&mut self, to_vm: &str, peer_name: &str, peer_mac: [u8; 6]) { + use crate::control::MainToChild; + use crate::mac::Mac; + + let msg = MainToChild::GetBuffer { + peer_name: peer_name.to_string(), + peer_mac, + }; + + if let Some(child) = self.children.get_mut(to_vm) { + match child.control.send(&msg) { + Ok(()) => { + debug!( + "control: main -> worker-{} GetBuffer({}, {})", + to_vm, peer_name, Mac::from_bytes(peer_mac) + ); + // Track pending request + child.pending_buffers.insert(peer_name.to_string(), peer_mac); + } + Err(e) => { + warn!( + to = %to_vm, + peer = %peer_name, + error = %e, + "Failed to send GetBuffer" + ); + } + } + } + } + + /// Handle BufferReady message from a child. + /// + /// The child has created an ingress buffer for a peer. We store it and + /// forward it to the peer as a PutBuffer (the peer will use it as egress). + fn handle_buffer_ready( + &mut self, + vm_name: &str, + peer_name: &str, + memfd: OwnedFd, + eventfd: OwnedFd, + ) { + use crate::control::MainToChild; + use crate::mac::Mac; + + debug!("control: worker-{} -> main BufferReady({})", vm_name, peer_name); + + // Validate this BufferReady matches a pending GetBuffer request + { + let child = match self.children.get_mut(vm_name) { + Some(c) => c, + None => { + warn!(vm = %vm_name, "BufferReady from unknown child"); + return; + } + }; + + if child.pending_buffers.remove(peer_name).is_none() { + warn!( + vm = %vm_name, + peer = %peer_name, + "BufferReady for unexpected peer" + ); + return; + } + } + + // Store the buffer + { + let child = self.children.get_mut(vm_name).unwrap(); + child.ready_buffers.insert( + peer_name.to_string(), + PeerBufferInfo::new(memfd, eventfd), + ); + } + + // Get info needed to send PutBuffer to peer + let (our_mac, our_role, fds) = { + let child = self.children.get(vm_name).unwrap(); + let buf = child.ready_buffers.get(peer_name).unwrap(); + (child.mac, child.role, buf.as_raw_fds()) + }; + + // Peer uses this buffer as egress to us + // broadcast = true if WE are the router (peer sends to router) + let broadcast = our_role == VmRole::Router; + + let msg = MainToChild::PutBuffer { + peer_name: vm_name.to_string(), + peer_mac: our_mac, + broadcast, + }; + + if let Some(peer) = self.children.get(peer_name) { + match peer.control.send_with_fds_typed(&msg, &fds) { + Ok(()) => { + debug!( + "control: main -> worker-{} PutBuffer({}, {}, broadcast={})", + peer_name, vm_name, Mac::from_bytes(our_mac), broadcast + ); + } + Err(e) => { + warn!( + to = %peer_name, + from = %vm_name, + error = %e, + "Failed to send PutBuffer" + ); + return; + } + } + } else { + warn!(peer = %peer_name, "Peer not found for PutBuffer"); + return; + } + + // Record that we sent our buffer to the peer + if let Some(child) = self.children.get_mut(vm_name) { + child.mark_buffer_sent_to(peer_name); + } + } + + /// Gracefully shut down all child processes. + /// Closes control channels and waits for children to exit. + pub fn shutdown_children(&mut self) { + if self.children.is_empty() { + return; + } + + info!("shutting down {} child process(es)", self.children.len()); + + // Close all control channels - children will detect EOF and exit + let children: Vec<_> = self.children.drain().collect(); + + // Drop control channels to signal shutdown + let pids: Vec<_> = children + .into_iter() + .map(|(name, handle)| { + debug!(vm = %name, pid = %handle.pid, "closing control channel"); + drop(handle.control); + (name, handle.pid) + }) + .collect(); + + // Wait for children to exit (with timeout) + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5); + + for (name, pid) in &pids { + loop { + match waitpid(*pid, Some(WaitPidFlag::WNOHANG)) { + Ok(WaitStatus::Exited(_, status)) => { + debug!(vm = %name, status, "child exited"); + break; + } + Ok(WaitStatus::Signaled(_, signal, _)) => { + debug!(vm = %name, ?signal, "child killed by signal"); + break; + } + Ok(_) => { + // Still running + if std::time::Instant::now() > deadline { + warn!(vm = %name, pid = %pid, "child did not exit in time, sending SIGKILL"); + let _ = nix::sys::signal::kill(*pid, nix::sys::signal::Signal::SIGKILL); + let _ = waitpid(*pid, None); + break; + } + std::thread::sleep(std::time::Duration::from_millis(50)); + } + Err(_) => break, // Process doesn't exist + } + } + } + + info!("all child processes shut down"); + } + + /// Fork a new child process for the given VM. + /// Returns the ChildState on success. + fn fork_child(&mut self, vm_name: &str, mac: crate::mac::Mac, role: VmRole, socket_path: &Path) -> anyhow::Result { + // Create control channel before fork + let (main_end, child_end) = ControlChannel::pair() + .context("failed to create control channel")?; + + let socket_path_owned = socket_path.to_path_buf(); + let vm_name_owned = vm_name.to_string(); + + // Fork + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + // Parent: drop child's end, keep main's end + drop(child_end); + + info!(vm = %vm_name, pid = %child, "forked child process"); + + // Dup the FD for the reader task (main_end keeps original for sending) + let recv_fd = dup(main_end.as_raw_fd()) + .context("failed to dup control channel fd")?; + let recv_channel = unsafe { ControlChannel::from_raw_fd(recv_fd) }; + + // Wrap in AsyncFd and spawn reader task + let async_channel = AsyncFd::new(recv_channel) + .context("failed to create AsyncFd for control channel")?; + + let tx = self.child_msg_tx.clone(); + let vm_name_clone = vm_name.to_string(); + tokio::spawn(async move { + loop { + // Wait for socket to be readable + let mut guard = match async_channel.readable().await { + Ok(g) => g, + Err(e) => { + warn!(vm = %vm_name_clone, error = %e, "AsyncFd error"); + break; + } + }; + + // Socket is readable - do recv via inner ControlChannel + match async_channel.get_ref().recv_with_fds_typed() { + Ok((msg, fds)) => { + if tx + .send(ChildMessage { + vm_name: vm_name_clone.clone(), + msg, + fds, + }) + .is_err() + { + break; // Channel closed, shutdown + } + } + Err(ControlError::Closed) => break, + Err(e) => { + warn!(vm = %vm_name_clone, error = %e, "control recv error"); + } + } + guard.clear_ready(); + } + }); + + Ok(ChildState::new( + child, + main_end, + role, + mac.bytes(), + socket_path.to_path_buf(), + )) + } + Ok(ForkResult::Child) => { + // Child: drop main's end, run child entry point + drop(main_end); + + // Run child - this never returns + let control_fd = child_end.into_fd(); + crate::child::run_child_process(&vm_name_owned, mac, control_fd, &socket_path_owned, self.seccomp_mode); + } + Err(e) => { + anyhow::bail!("fork failed: {}", e) + } + } + } + + /// Add a VM and fork a child process. fn add_vm(&mut self, mac_file_path: &Path, config: VmConfig) { - // Handle existing backend with same name - if let Some(existing) = self.backends.get(&config.name) { - if existing.backend.mac() == config.mac { + // Check for existing child with same name + if let Some(existing) = self.children.get(&config.name) { + if existing.mac == config.mac.bytes() { debug!(vm = %config.name, "MAC unchanged, ignoring update"); return; } info!( vm = %config.name, - old_mac = %existing.backend.mac(), new_mac = %config.mac, - "MAC changed, recreating backend" + "MAC changed, recreating child" ); - // Clone name to avoid borrow conflict let name = config.name.clone(); self.remove_vm(&name); } - // Handle router replacement (different VM name, same role) + // Handle router replacement if config.role == VmRole::Router { if let Some(old_router) = self.find_router() { warn!( @@ -363,60 +671,54 @@ impl BackendManager { "Adding VM" ); - let backend = Arc::new(NetBackend::new( - config.name.clone(), - config.role, - config.mac, - Arc::clone(&self.switch), - Arc::clone(&self.vring_registry), - )); - - if backend.register().is_none() { - warn!(vm = %config.name, "Failed to register VM (duplicate router?)"); - return; - } - let vm_name = config.name.clone(); - self.backends.insert( - vm_name.clone(), - BackendHandle { - backend, - socket_path, - daemon_thread: None, - }, - ); - - // Start the daemon for this VM - if let Err(e) = self.start_daemon(&vm_name) { - warn!(vm = %vm_name, error = %e, "Failed to start daemon"); + let mac = config.mac; + let role = config.role; + match self.fork_child(&vm_name, mac, role, &socket_path) { + Ok(child_state) => { + self.children.insert(vm_name, child_state); + } + Err(e) => { + warn!(vm = %vm_name, error = %e, "failed to fork child"); + } } } /// Remove a VM and clean up. + /// Notifies peers, closes control channel, and waits for the child to exit. fn remove_vm(&mut self, vm_name: &str) { - // Stop daemon first - self.stop_daemon(vm_name); + if let Some(child) = self.children.remove(vm_name) { + info!(vm = %vm_name, pid = %child.pid, "Removing VM, terminating child"); - if let Some(handle) = self.backends.remove(vm_name) { - info!(vm = %vm_name, "Removing VM"); + // Notify peers that this VM is gone + let peers: Vec<_> = child.peers_with_our_buffer().iter().cloned().collect(); + self.notify_peers_of_crash(vm_name, &peers); - // Remove from vring registry - if let Some(conn_id) = handle.backend.connection_id() { - let mut registry = self.vring_registry.write().unwrap(); - registry.remove(&conn_id); - } + // Close control channel to signal shutdown, then wait for exit + let pid = child.pid; + drop(child); - handle.backend.unregister(); - - // Clean up socket file (may already be removed by stop_daemon) - if handle.socket_path.exists() { - if let Err(e) = std::fs::remove_file(&handle.socket_path) { - warn!( - vm = %vm_name, - path = ?handle.socket_path, - error = %e, - "Failed to remove socket" - ); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5); + loop { + match waitpid(pid, Some(WaitPidFlag::WNOHANG)) { + Ok(WaitStatus::Exited(_, status)) => { + debug!(vm = %vm_name, status, "removed child exited"); + break; + } + Ok(WaitStatus::Signaled(_, signal, _)) => { + debug!(vm = %vm_name, ?signal, "removed child killed by signal"); + break; + } + Ok(_) => { + if std::time::Instant::now() > deadline { + warn!(vm = %vm_name, pid = %pid, "removed child did not exit in time, sending SIGKILL"); + let _ = nix::sys::signal::kill(pid, nix::sys::signal::Signal::SIGKILL); + let _ = waitpid(pid, None); + break; + } + std::thread::sleep(std::time::Duration::from_millis(50)); + } + Err(_) => break, } } } @@ -426,354 +728,86 @@ impl BackendManager { #[cfg(test)] mod tests { use super::*; - use crate::config::VmRole; - use crate::mac::Mac; + use crate::control::ControlChannel; + use nix::unistd::Pid; use tempfile::TempDir; - fn setup_vm_dir(dir: &Path, name: &str, role: VmRole) -> PathBuf { - let vm_dir = dir.join(name); - std::fs::create_dir_all(&vm_dir).unwrap(); - - let filename = match role { - VmRole::Router => "router.mac", - VmRole::Client => "client.mac", - }; - vm_dir.join(filename) + #[test] + fn new_manager_is_empty() { + let dir = TempDir::new().unwrap(); + let (manager, _rx) = BackendManager::new(dir.path(), SeccompMode::Disabled); + assert!(manager.children.is_empty()); } #[test] - fn handle_vm_added_creates_backend() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); + fn child_state_tracks_buffer_sent_to() { + let (main_end, _child_end) = ControlChannel::pair().unwrap(); + let mut child = ChildState::new( + Pid::from_raw(1234), + main_end, + VmRole::Client, + [1, 2, 3, 4, 5, 6], + std::path::PathBuf::from("/tmp/test.sock"), + ); - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - let event = ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }; + assert!(!child.is_ready); + assert!(child.peers_with_our_buffer().is_empty()); - manager.handle_event(event); + child.mark_buffer_sent_to("router"); + assert_eq!(child.peers_with_our_buffer().len(), 1); - assert!(manager.get_backend("test-vm").is_some()); + // Duplicate is ignored + child.mark_buffer_sent_to("router"); + assert_eq!(child.peers_with_our_buffer().len(), 1); } #[test] - fn handle_vm_removed_cleans_up() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); + fn child_state_tracks_pending_buffers() { + let (main_end, _child_end) = ControlChannel::pair().unwrap(); + let mut child = ChildState::new( + Pid::from_raw(1234), + main_end, + VmRole::Client, + [1, 2, 3, 4, 5, 6], + std::path::PathBuf::from("/tmp/test.sock"), + ); - // Add VM first (this also starts the daemon and creates the socket) - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - let add_event = ConfigEvent::VmAdded { - path: mac_path.clone(), - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }; - manager.handle_event(add_event); + let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]; + assert!(!child.pending_buffers.contains_key("router")); - // Socket should be created automatically by start_daemon - let socket_path = manager.get_socket_path("test-vm").unwrap(); - assert!(socket_path.exists(), "Socket should be created by add_vm"); + child.pending_buffers.insert("router".to_string(), peer_mac); + assert!(child.pending_buffers.contains_key("router")); + assert_eq!(child.pending_buffers.get("router"), Some(&peer_mac)); - // Remove VM - let remove_event = ConfigEvent::VmRemoved { - path: mac_path, - vm_name: "test-vm".to_string(), - role: VmRole::Client, - }; - manager.handle_event(remove_event); - - assert!(manager.get_backend("test-vm").is_none()); - assert!(!socket_path.exists()); + // Remove when BufferReady is received + let removed_mac = child.pending_buffers.remove("router"); + assert_eq!(removed_mac, Some(peer_mac)); + assert!(!child.pending_buffers.contains_key("router")); } #[test] - fn backend_names_lists_all_vms() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); + fn child_state_stores_role() { + let (main_end, _child_end) = ControlChannel::pair().unwrap(); - // Add two VMs - let path1 = setup_vm_dir(dir.path(), "vm-a", VmRole::Client); - let path2 = setup_vm_dir(dir.path(), "vm-b", VmRole::Client); + let client = ChildState::new( + Pid::from_raw(1), + main_end, + VmRole::Client, + [1, 2, 3, 4, 5, 6], + std::path::PathBuf::from("/tmp/client.sock"), + ); + assert_eq!(client.role, VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: path1, - config: VmConfig::new("vm-a", VmRole::Client, Mac::from_bytes([1, 0, 0, 0, 0, 1])), - }); - manager.handle_event(ConfigEvent::VmAdded { - path: path2, - config: VmConfig::new("vm-b", VmRole::Client, Mac::from_bytes([2, 0, 0, 0, 0, 2])), - }); - - let mut names = manager.backend_names(); - names.sort(); - assert_eq!(names, vec!["vm-a", "vm-b"]); + let (main_end2, _child_end2) = ControlChannel::pair().unwrap(); + let router = ChildState::new( + Pid::from_raw(2), + main_end2, + VmRole::Router, + [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], + std::path::PathBuf::from("/tmp/router.sock"), + ); + assert_eq!(router.role, VmRole::Router); } - #[test] - fn router_replacement_succeeds() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let path1 = setup_vm_dir(dir.path(), "router1", VmRole::Router); - let path2 = setup_vm_dir(dir.path(), "router2", VmRole::Router); - - manager.handle_event(ConfigEvent::VmAdded { - path: path1.clone(), - config: VmConfig::new("router1", VmRole::Router, Mac::from_bytes([1, 0, 0, 0, 0, 1])), - }); - - // Verify router1 socket exists - let socket1 = manager.get_socket_path("router1").unwrap(); - assert!(socket1.exists(), "router1 socket should exist"); - - manager.handle_event(ConfigEvent::VmAdded { - path: path2, - config: VmConfig::new("router2", VmRole::Router, Mac::from_bytes([2, 0, 0, 0, 0, 2])), - }); - - // router2 should replace router1 - assert!(manager.get_backend("router1").is_none(), "router1 should be gone"); - assert!(manager.get_backend("router2").is_some(), "router2 should exist"); - - // router1 socket should be cleaned up - assert!(!socket1.exists(), "router1 socket should be removed"); - } - - #[test] - fn mac_change_recreates_backend() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - let old_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); - let new_mac = Mac::from_bytes([6, 5, 4, 3, 2, 1]); - - // Add VM with old MAC - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path.clone(), - config: VmConfig::new("test-vm", VmRole::Client, old_mac), - }); - - let socket_path = manager.get_socket_path("test-vm").unwrap(); - assert!(socket_path.exists(), "Socket should exist after add"); - - let backend = manager.get_backend("test-vm").unwrap(); - assert_eq!(backend.mac(), old_mac); - - // Update with new MAC - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, new_mac), - }); - - // Backend should have new MAC - let backend = manager.get_backend("test-vm").unwrap(); - assert_eq!(backend.mac(), new_mac); - - // Socket should still exist (recreated) - assert!(socket_path.exists(), "Socket should be recreated"); - } - - #[test] - fn mac_unchanged_does_nothing() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - let mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]); - - // Add VM - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path.clone(), - config: VmConfig::new("test-vm", VmRole::Client, mac), - }); - - let socket_path = manager.get_socket_path("test-vm").unwrap(); - let socket_mtime = std::fs::metadata(&socket_path).unwrap().modified().unwrap(); - - // Send same MAC again - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, mac), - }); - - // Socket should be unchanged (not recreated) - let new_mtime = std::fs::metadata(&socket_path).unwrap().modified().unwrap(); - assert_eq!(socket_mtime, new_mtime, "Socket should not be recreated for same MAC"); - - // Backend should still exist - assert!(manager.get_backend("test-vm").is_some()); - } - - #[test] - fn start_daemon_creates_socket() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }); - - // Start daemon - manager.start_daemon("test-vm").unwrap(); - - // Socket file should exist - let socket_path = manager.get_socket_path("test-vm").unwrap(); - assert!(socket_path.exists(), "Socket file should be created"); - } - - #[test] - fn start_daemon_unknown_vm_returns_error() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let result = manager.start_daemon("nonexistent"); - assert!(result.is_err()); - } - - #[test] - fn stop_daemon_removes_socket() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }); - - // Start daemon - manager.start_daemon("test-vm").unwrap(); - let socket_path = manager.get_socket_path("test-vm").unwrap(); - assert!(socket_path.exists(), "Socket should exist after start"); - - // Stop daemon - manager.stop_daemon("test-vm"); - - // Socket should be removed - assert!(!socket_path.exists(), "Socket should be removed after stop"); - } - - #[test] - fn is_daemon_running_true_after_add() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }); - - // Daemon is now started automatically when VM is added - assert!(manager.is_daemon_running("test-vm"), "Should be running after add"); - } - - #[test] - fn is_daemon_running_false_after_stop() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path, - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }); - - // Daemon is already running after add - assert!(manager.is_daemon_running("test-vm"), "Should be running after add"); - manager.stop_daemon("test-vm"); - - assert!(!manager.is_daemon_running("test-vm"), "Should not be running after stop"); - } - - #[test] - fn remove_vm_stops_daemon() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - let mac_path = setup_vm_dir(dir.path(), "test-vm", VmRole::Client); - manager.handle_event(ConfigEvent::VmAdded { - path: mac_path.clone(), - config: VmConfig::new("test-vm", VmRole::Client, Mac::from_bytes([1, 2, 3, 4, 5, 6])), - }); - - // Daemon is already running after add - let socket_path = manager.get_socket_path("test-vm").unwrap(); - assert!(socket_path.exists(), "Socket should exist after add"); - - // Remove the VM - manager.handle_event(ConfigEvent::VmRemoved { - path: mac_path, - vm_name: "test-vm".to_string(), - role: VmRole::Client, - }); - - // Socket should be cleaned up - assert!(!socket_path.exists(), "Socket should be removed after VM removal"); - } - - #[test] - fn shutdown_all_stops_all_daemons() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - // Add two VMs (daemons start automatically) - let path1 = setup_vm_dir(dir.path(), "vm-a", VmRole::Client); - let path2 = setup_vm_dir(dir.path(), "vm-b", VmRole::Client); - - manager.handle_event(ConfigEvent::VmAdded { - path: path1, - config: VmConfig::new("vm-a", VmRole::Client, Mac::from_bytes([1, 0, 0, 0, 0, 1])), - }); - manager.handle_event(ConfigEvent::VmAdded { - path: path2, - config: VmConfig::new("vm-b", VmRole::Client, Mac::from_bytes([2, 0, 0, 0, 0, 2])), - }); - - // Daemons are already running after add - let socket_a = manager.get_socket_path("vm-a").unwrap(); - let socket_b = manager.get_socket_path("vm-b").unwrap(); - assert!(socket_a.exists()); - assert!(socket_b.exists()); - - // Shutdown all - manager.shutdown_all(); - - // Both sockets should be removed - assert!(!socket_a.exists(), "Socket A should be removed"); - assert!(!socket_b.exists(), "Socket B should be removed"); - } - - #[test] - fn daemons_start_automatically_on_add() { - let dir = TempDir::new().unwrap(); - let mut manager = BackendManager::new(dir.path()); - - // Add two VMs - let path1 = setup_vm_dir(dir.path(), "vm-a", VmRole::Client); - let path2 = setup_vm_dir(dir.path(), "vm-b", VmRole::Client); - - manager.handle_event(ConfigEvent::VmAdded { - path: path1, - config: VmConfig::new("vm-a", VmRole::Client, Mac::from_bytes([1, 0, 0, 0, 0, 1])), - }); - manager.handle_event(ConfigEvent::VmAdded { - path: path2, - config: VmConfig::new("vm-b", VmRole::Client, Mac::from_bytes([2, 0, 0, 0, 0, 2])), - }); - - // Both should be running automatically after add - assert!(manager.is_daemon_running("vm-a")); - assert!(manager.is_daemon_running("vm-b")); - - // Both should be running - assert!(manager.is_daemon_running("vm-a"), "vm-a should be running"); - assert!(manager.is_daemon_running("vm-b"), "vm-b should be running"); - - // Cleanup - manager.shutdown_all(); - } + // Note: Fork-based tests are in tests/fork_lifecycle.rs } diff --git a/vm-switch/src/ring.rs b/vm-switch/src/ring.rs new file mode 100644 index 0000000..63cf6a7 --- /dev/null +++ b/vm-switch/src/ring.rs @@ -0,0 +1,865 @@ +//! SPSC ring buffer for cross-process frame passing. + +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use std::ptr::NonNull; +use std::sync::atomic::{AtomicU64, Ordering}; + +use nix::libc; + +use crate::frame::MAX_FRAME_SIZE; + +/// Errors that can occur with ring buffer operations. +#[derive(Debug, thiserror::Error)] +pub enum RingError { + #[error("failed to create memfd: {0}")] + MemfdCreate(std::io::Error), + #[error("failed to set memfd size: {0}")] + Ftruncate(std::io::Error), + #[error("failed to mmap: {0}")] + Mmap(std::io::Error), + #[error("failed to create eventfd: {0}")] + EventfdCreate(std::io::Error), +} + +/// Slot data size - accommodates jumbo frames with headroom. +pub const SLOT_DATA_SIZE: usize = MAX_FRAME_SIZE + 256; + +/// Number of slots in the ring buffer. +pub const RING_SIZE: usize = 64; + +/// Total size of the ring buffer in bytes. +pub const RING_BUFFER_SIZE: usize = std::mem::size_of::() + + RING_SIZE * std::mem::size_of::(); + +/// Ring buffer header containing head and tail indices. +/// Padded to cache line size to prevent false sharing. +#[repr(C, align(64))] +pub struct RingHeader { + /// Next write position (only producer modifies). + head: AtomicU64, + /// Next read position (only consumer modifies). + tail: AtomicU64, +} + +impl RingHeader { + /// Create a new header with head and tail at 0. + pub fn new() -> Self { + Self { + head: AtomicU64::new(0), + tail: AtomicU64::new(0), + } + } + + /// Load the current head value. + pub fn load_head(&self, order: Ordering) -> u64 { + self.head.load(order) + } + + /// Load the current tail value. + pub fn load_tail(&self, order: Ordering) -> u64 { + self.tail.load(order) + } + + /// Store a new head value. + pub fn store_head(&self, val: u64, order: Ordering) { + self.head.store(val, order) + } + + /// Store a new tail value. + pub fn store_tail(&self, val: u64, order: Ordering) { + self.tail.store(val, order) + } +} + +impl Default for RingHeader { + fn default() -> Self { + Self::new() + } +} + +/// A single slot in the ring buffer. +/// Contains frame length and data. +#[repr(C)] +pub struct Slot { + /// Frame length in bytes. 0 means slot is empty/unused. + len: u32, + /// Padding for alignment. + _padding: u32, + /// Frame data buffer. + data: [u8; SLOT_DATA_SIZE], +} + +impl Slot { + /// Create a new empty slot. + pub fn new() -> Self { + Self { + len: 0, + _padding: 0, + data: [0; SLOT_DATA_SIZE], + } + } + + /// Write frame data to this slot. + /// Returns false if frame is too large. + pub fn write(&mut self, frame: &[u8]) -> bool { + if frame.len() > SLOT_DATA_SIZE { + return false; + } + self.data[..frame.len()].copy_from_slice(frame); + self.len = frame.len() as u32; + true + } + + /// Read frame data from this slot. + /// Returns None if slot is empty. + pub fn read(&self) -> Option<&[u8]> { + if self.len == 0 { + return None; + } + Some(&self.data[..self.len as usize]) + } + + /// Clear the slot. + pub fn clear(&mut self) { + self.len = 0; + } + + /// Returns true if slot is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +impl Default for Slot { + fn default() -> Self { + Self::new() + } +} + +/// The in-memory layout of a ring buffer. +/// This struct is mapped directly over shared memory. +#[repr(C)] +pub struct RingBuffer { + /// Header with head/tail atomics. + header: RingHeader, + /// Fixed-size array of slots. + slots: [Slot; RING_SIZE], +} + +impl RingBuffer { + /// Returns the number of items currently in the buffer. + pub fn len(&self) -> usize { + let head = self.header.load_head(Ordering::Relaxed); + let tail = self.header.load_tail(Ordering::Relaxed); + ((head + RING_SIZE as u64 - tail) % RING_SIZE as u64) as usize + } + + /// Returns true if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.header.load_head(Ordering::Relaxed) == self.header.load_tail(Ordering::Relaxed) + } + + /// Returns true if the buffer is full. + pub fn is_full(&self) -> bool { + let head = self.header.load_head(Ordering::Relaxed); + let tail = self.header.load_tail(Ordering::Relaxed); + (head + 1) % RING_SIZE as u64 == tail + } + + /// Get mutable reference to slot at index. + pub fn slot_mut(&mut self, index: usize) -> &mut Slot { + &mut self.slots[index] + } + + /// Get reference to slot at index. + pub fn slot(&self, index: usize) -> &Slot { + &self.slots[index] + } + + /// Get reference to the header. + pub fn header(&self) -> &RingHeader { + &self.header + } +} + +/// Producer side of an SPSC ring buffer. +/// Creates and owns the underlying shared memory. +pub struct Producer { + /// Pointer to the mapped ring buffer. + ring: NonNull, + /// The memfd backing the ring buffer. + memfd: OwnedFd, + /// Eventfd for signaling consumer. + eventfd: OwnedFd, +} + +// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access. +// The producer has exclusive write access in SPSC pattern. +unsafe impl Send for Producer {} +unsafe impl Sync for Producer {} + +impl Producer { + /// Create a new producer with its own shared memory region. + pub fn new() -> Result { + // Create memfd for shared memory + let memfd = unsafe { + let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC); + if fd < 0 { + return Err(RingError::MemfdCreate(std::io::Error::last_os_error())); + } + OwnedFd::from_raw_fd(fd) + }; + + // Set size + let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) }; + if ret < 0 { + return Err(RingError::Ftruncate(std::io::Error::last_os_error())); + } + + // Map the memory + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + RING_BUFFER_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + memfd.as_raw_fd(), + 0, + ) + }; + if ptr == libc::MAP_FAILED { + return Err(RingError::Mmap(std::io::Error::last_os_error())); + } + + // Initialize the ring buffer + let ring = ptr as *mut RingBuffer; + unsafe { + std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new()); + for i in 0..RING_SIZE { + std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new()); + } + } + + // Create eventfd for signaling + let eventfd = unsafe { + let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK); + if fd < 0 { + libc::munmap(ptr, RING_BUFFER_SIZE); + return Err(RingError::EventfdCreate(std::io::Error::last_os_error())); + } + OwnedFd::from_raw_fd(fd) + }; + + Ok(Self { + ring: NonNull::new(ring).unwrap(), + memfd, + eventfd, + }) + } + + /// Get the memfd for sharing with consumer. + pub fn memfd(&self) -> &OwnedFd { + &self.memfd + } + + /// Get the eventfd for sharing with consumer. + pub fn eventfd(&self) -> &OwnedFd { + &self.eventfd + } + + /// Create a producer by mapping an existing consumer's shared memory. + /// Use this when you receive FDs from a remote consumer and want to produce into their buffer. + pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result { + // Map the memory + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + RING_BUFFER_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + memfd.as_raw_fd(), + 0, + ) + }; + if ptr == libc::MAP_FAILED { + return Err(RingError::Mmap(std::io::Error::last_os_error())); + } + + let ring = ptr as *mut RingBuffer; + + Ok(Self { + ring: NonNull::new(ring).unwrap(), + memfd, + eventfd, + }) + } + + /// Get reference to the ring buffer. + fn ring(&self) -> &RingBuffer { + // SAFETY: Pointer is valid for lifetime of Producer. + unsafe { self.ring.as_ref() } + } + + /// Write a frame to the slot at the given index using the raw pointer. + /// + /// SAFETY: Producer has exclusive write access to slots in the SPSC pattern. + /// Only the producer calls this, and only for the slot at `head`. + /// Uses raw pointer writes to avoid creating `&mut RingBuffer` which would + /// violate Rust's aliasing rules (since `&RingBuffer` references coexist). + unsafe fn write_slot(&self, index: usize, frame: &[u8]) -> bool { + if frame.len() > SLOT_DATA_SIZE { + return false; + } + let ring_ptr = self.ring.as_ptr(); + let slot_ptr = std::ptr::addr_of_mut!((*ring_ptr).slots[index]); + let data_ptr = std::ptr::addr_of_mut!((*slot_ptr).data) as *mut u8; + std::ptr::copy_nonoverlapping(frame.as_ptr(), data_ptr, frame.len()); + std::ptr::addr_of_mut!((*slot_ptr).len).write(frame.len() as u32); + true + } + + /// Push a frame into the ring buffer. + /// Returns true if successful, false if buffer is full (frame dropped). + pub fn push(&self, frame: &[u8]) -> bool { + let ring = self.ring(); + let head = ring.header().load_head(Ordering::Relaxed); + let tail = ring.header().load_tail(Ordering::Acquire); + + // Check if full (one slot always empty to distinguish full from empty) + let next_head = (head + 1) % RING_SIZE as u64; + if next_head == tail { + return false; + } + + // Write to slot via raw pointer (avoids creating &mut RingBuffer alias) + // SAFETY: Producer has exclusive write access to the slot at head + if !unsafe { self.write_slot(head as usize, frame) } { + return false; + } + + // Memory fence to ensure data is visible before advancing head + std::sync::atomic::fence(Ordering::Release); + + // Advance head + ring.header().store_head(next_head, Ordering::Relaxed); + + // Signal consumer if buffer was empty + if head == tail { + let val: u64 = 1; + unsafe { + libc::write( + self.eventfd.as_raw_fd(), + &val as *const u64 as *const libc::c_void, + std::mem::size_of::(), + ); + } + } + + true + } +} + +impl Drop for Producer { + fn drop(&mut self) { + // SAFETY: We own the mapping and it's the correct size. + unsafe { + libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE); + } + } +} + +/// Consumer side of an SPSC ring buffer. +/// Can either create its own shared memory (via `new()`) or map memory from a producer (via `from_fds()`). +pub struct Consumer { + /// Pointer to the mapped ring buffer. + ring: NonNull, + /// The memfd backing the ring buffer. + memfd: OwnedFd, + /// Eventfd for receiving signals from producer. + eventfd: OwnedFd, +} + +// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access. +unsafe impl Send for Consumer {} + +impl Consumer { + /// Create a new consumer with its own shared memory region. + /// Use this when YOU will be the consumer and share FDs with a remote producer. + pub fn new() -> Result { + // Create memfd for shared memory + let memfd = unsafe { + let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC); + if fd < 0 { + return Err(RingError::MemfdCreate(std::io::Error::last_os_error())); + } + OwnedFd::from_raw_fd(fd) + }; + + // Set size + let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) }; + if ret < 0 { + return Err(RingError::Ftruncate(std::io::Error::last_os_error())); + } + + // Map the memory + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + RING_BUFFER_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + memfd.as_raw_fd(), + 0, + ) + }; + if ptr == libc::MAP_FAILED { + return Err(RingError::Mmap(std::io::Error::last_os_error())); + } + + // Initialize the ring buffer + let ring = ptr as *mut RingBuffer; + unsafe { + std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new()); + for i in 0..RING_SIZE { + std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new()); + } + } + + // Create eventfd for signaling + let eventfd = unsafe { + let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK); + if fd < 0 { + libc::munmap(ptr, RING_BUFFER_SIZE); + return Err(RingError::EventfdCreate(std::io::Error::last_os_error())); + } + OwnedFd::from_raw_fd(fd) + }; + + Ok(Self { + ring: NonNull::new(ring).unwrap(), + memfd, + eventfd, + }) + } + + /// Get reference to the ring buffer. + fn ring(&self) -> &RingBuffer { + // SAFETY: Pointer is valid for lifetime of Consumer. + unsafe { self.ring.as_ref() } + } + + /// Create a consumer by mapping an existing producer's shared memory. + pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result { + // Map the memory + let ptr = unsafe { + libc::mmap( + std::ptr::null_mut(), + RING_BUFFER_SIZE, + libc::PROT_READ | libc::PROT_WRITE, + libc::MAP_SHARED, + memfd.as_raw_fd(), + 0, + ) + }; + if ptr == libc::MAP_FAILED { + return Err(RingError::Mmap(std::io::Error::last_os_error())); + } + + let ring = ptr as *mut RingBuffer; + + Ok(Self { + ring: NonNull::new(ring).unwrap(), + memfd, + eventfd, + }) + } + + /// Pop a frame from the ring buffer. + /// Returns None if buffer is empty. + pub fn pop(&self) -> Option> { + let ring = self.ring(); + let tail = ring.header().load_tail(Ordering::Relaxed); + let head = ring.header().load_head(Ordering::Acquire); + + // Check if empty + if head == tail { + return None; + } + + // Read from slot + let data = ring.slot(tail as usize).read()?.to_vec(); + + // Advance tail + let next_tail = (tail + 1) % RING_SIZE as u64; + ring.header().store_tail(next_tail, Ordering::Release); + + Some(data) + } + + /// Get the memfd for sharing with producer. + pub fn memfd(&self) -> &OwnedFd { + &self.memfd + } + + /// Get the eventfd for polling. + pub fn eventfd(&self) -> &OwnedFd { + &self.eventfd + } + + /// Drain the eventfd, returning the notification count. + /// Returns 0 if no notifications pending. + pub fn drain_eventfd(&self) -> u64 { + let mut val: u64 = 0; + let ret = unsafe { + libc::read( + self.eventfd.as_raw_fd(), + &mut val as *mut u64 as *mut libc::c_void, + std::mem::size_of::(), + ) + }; + if ret == std::mem::size_of::() as isize { + val + } else { + 0 + } + } +} + +impl Drop for Consumer { + fn drop(&mut self) { + // SAFETY: We own the mapping and it's the correct size. + unsafe { + libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn slot_write_and_read() { + let mut slot = Slot::new(); + let frame = [1u8, 2, 3, 4, 5]; + + assert!(slot.write(&frame)); + let read_data = slot.read().expect("slot should have data"); + assert_eq!(read_data, &frame); + } + + #[test] + fn slot_clear_resets_to_empty() { + let mut slot = Slot::new(); + slot.write(&[1, 2, 3]); + + slot.clear(); + assert!(slot.is_empty()); + assert!(slot.read().is_none()); + } + + #[test] + fn slot_rejects_oversized_frame() { + let mut slot = Slot::new(); + let oversized = vec![0u8; SLOT_DATA_SIZE + 1]; + + assert!(!slot.write(&oversized)); + assert!(slot.is_empty()); + } + + #[test] + fn header_stores_and_loads_head() { + let header = RingHeader::new(); + + header.store_head(42, Ordering::Relaxed); + assert_eq!(header.load_head(Ordering::Relaxed), 42); + } + + #[test] + fn header_stores_and_loads_tail() { + let header = RingHeader::new(); + + header.store_tail(99, Ordering::Relaxed); + assert_eq!(header.load_tail(Ordering::Relaxed), 99); + } + + #[test] + fn ring_buffer_len_reflects_items() { + // We'll test via header manipulation since we don't have push/pop yet + let mut buffer = std::mem::MaybeUninit::::uninit(); + // SAFETY: We're testing the layout, initializing header manually + let buffer = unsafe { + let ptr = buffer.as_mut_ptr(); + std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new()); + // Initialize a few slots + for i in 0..RING_SIZE { + std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new()); + } + buffer.assume_init_mut() + }; + + // Initially empty + assert_eq!(buffer.len(), 0); + assert!(buffer.is_empty()); + + // Simulate adding 3 items by advancing head + buffer.header().store_head(3, Ordering::Relaxed); + assert_eq!(buffer.len(), 3); + assert!(!buffer.is_empty()); + } + + #[test] + fn ring_buffer_is_full_when_one_slot_remains() { + let mut buffer = std::mem::MaybeUninit::::uninit(); + let buffer = unsafe { + let ptr = buffer.as_mut_ptr(); + std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new()); + for i in 0..RING_SIZE { + std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new()); + } + buffer.assume_init_mut() + }; + + // Head at RING_SIZE-1, tail at 0 means buffer is full + // (we keep one slot empty to distinguish full from empty) + buffer.header().store_head((RING_SIZE - 1) as u64, Ordering::Relaxed); + buffer.header().store_tail(0, Ordering::Relaxed); + + assert!(buffer.is_full()); + assert_eq!(buffer.len(), RING_SIZE - 1); + } + + #[test] + fn producer_new_creates_valid_buffer() { + let producer = Producer::new().expect("should create producer"); + + // memfd and eventfd should be valid file descriptors + assert!(producer.memfd().as_raw_fd() >= 0); + assert!(producer.eventfd().as_raw_fd() >= 0); + } + + #[test] + fn consumer_maps_producer_memory() { + let producer = Producer::new().expect("should create producer"); + + // Duplicate FDs (simulating what would happen with SCM_RIGHTS) + let memfd_dup = unsafe { + let fd = libc::dup(producer.memfd().as_raw_fd()); + assert!(fd >= 0); + OwnedFd::from_raw_fd(fd) + }; + let eventfd_dup = unsafe { + let fd = libc::dup(producer.eventfd().as_raw_fd()); + assert!(fd >= 0); + OwnedFd::from_raw_fd(fd) + }; + + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + assert!(consumer.eventfd().as_raw_fd() >= 0); + } + + #[test] + fn producer_push_adds_frame() { + let producer = Producer::new().expect("should create producer"); + let frame = [1u8, 2, 3, 4, 5]; + + let result = producer.push(&frame); + + assert!(result); + // Verify head advanced + assert_eq!(producer.ring().header().load_head(Ordering::Relaxed), 1); + } + + #[test] + fn consumer_pop_returns_pushed_frame() { + let producer = Producer::new().expect("should create producer"); + let frame = [1u8, 2, 3, 4, 5]; + producer.push(&frame); + + // Create consumer with duplicated FDs + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) + }; + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + + let popped = consumer.pop(); + + assert!(popped.is_some()); + assert_eq!(popped.unwrap(), frame); + } + + #[test] + fn ring_buffer_maintains_fifo_order() { + let producer = Producer::new().expect("should create producer"); + + producer.push(&[1]); + producer.push(&[2]); + producer.push(&[3]); + + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) + }; + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + + assert_eq!(consumer.pop().unwrap(), vec![1]); + assert_eq!(consumer.pop().unwrap(), vec![2]); + assert_eq!(consumer.pop().unwrap(), vec![3]); + assert!(consumer.pop().is_none()); + } + + #[test] + fn producer_drops_frame_when_full() { + let producer = Producer::new().expect("should create producer"); + + // Fill the buffer (RING_SIZE - 1 slots usable) + for i in 0..(RING_SIZE - 1) { + assert!(producer.push(&[i as u8]), "push {} should succeed", i); + } + + // This should fail - buffer full + assert!(!producer.push(&[255])); + + // Create consumer and verify we can pop all items + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) + }; + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + + for i in 0..(RING_SIZE - 1) { + let data = consumer.pop().expect("should have data"); + assert_eq!(data, vec![i as u8]); + } + assert!(consumer.pop().is_none()); + } + + #[test] + fn ring_buffer_wraps_around_correctly() { + let producer = Producer::new().expect("should create producer"); + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) + }; + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + + // Push and pop more items than RING_SIZE to test wraparound + for round in 0..3 { + for i in 0..(RING_SIZE - 1) { + let val = (round * RING_SIZE + i) as u8; + assert!(producer.push(&[val])); + } + for i in 0..(RING_SIZE - 1) { + let val = (round * RING_SIZE + i) as u8; + assert_eq!(consumer.pop().unwrap(), vec![val]); + } + } + } + + #[test] + fn consumer_drain_eventfd_clears_notifications() { + let producer = Producer::new().expect("should create producer"); + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) + }; + let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer"); + + // Push triggers eventfd write + producer.push(&[1]); + producer.push(&[2]); + + // Drain should return the accumulated count + let count = consumer.drain_eventfd(); + assert!(count > 0); + + // Second drain should return 0 (nothing new) + let count2 = consumer.drain_eventfd(); + assert_eq!(count2, 0); + } + + #[test] + fn consumer_new_creates_valid_buffer() { + let consumer = Consumer::new().expect("should create consumer"); + + // memfd and eventfd should be valid file descriptors + assert!(consumer.memfd().as_raw_fd() >= 0); + assert!(consumer.eventfd().as_raw_fd() >= 0); + } + + #[test] + fn producer_from_fds_maps_consumer_memory() { + // Consumer creates buffer, producer connects to it + let consumer = Consumer::new().expect("should create consumer"); + + // Duplicate FDs (simulating what would happen with SCM_RIGHTS) + let memfd_dup = unsafe { + let fd = libc::dup(consumer.memfd().as_raw_fd()); + assert!(fd >= 0); + OwnedFd::from_raw_fd(fd) + }; + let eventfd_dup = unsafe { + let fd = libc::dup(consumer.eventfd().as_raw_fd()); + assert!(fd >= 0); + OwnedFd::from_raw_fd(fd) + }; + + let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer"); + assert!(producer.eventfd().as_raw_fd() >= 0); + } + + #[test] + fn consumer_created_buffer_receives_from_producer() { + // Consumer creates buffer, shares with remote producer + let consumer = Consumer::new().expect("should create consumer"); + + // Create producer from consumer's FDs + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd())) + }; + let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer"); + + // Producer pushes, consumer receives + let frame = [1u8, 2, 3, 4, 5]; + assert!(producer.push(&frame)); + + let popped = consumer.pop(); + assert!(popped.is_some()); + assert_eq!(popped.unwrap(), frame); + } + + #[test] + fn consumer_created_buffer_maintains_fifo_order() { + let consumer = Consumer::new().expect("should create consumer"); + let memfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd())) + }; + let eventfd_dup = unsafe { + OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd())) + }; + let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer"); + + producer.push(&[1]); + producer.push(&[2]); + producer.push(&[3]); + + assert_eq!(consumer.pop().unwrap(), vec![1]); + assert_eq!(consumer.pop().unwrap(), vec![2]); + assert_eq!(consumer.pop().unwrap(), vec![3]); + assert!(consumer.pop().is_none()); + } +} diff --git a/vm-switch/src/sandbox.rs b/vm-switch/src/sandbox.rs new file mode 100644 index 0000000..3886810 --- /dev/null +++ b/vm-switch/src/sandbox.rs @@ -0,0 +1,564 @@ +//! Namespace sandboxing for process isolation. + +use nix::mount::{mount, MsFlags}; +use nix::sched::{unshare, CloneFlags}; +use nix::sys::signal::{self, SaFlags, SigAction, SigHandler, SigSet, Signal}; +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{chdir, fork, pivot_root, ForkResult, Gid, Pid, Uid}; +use std::fs; +use std::path::Path; +use std::sync::atomic::{AtomicI32, Ordering}; +use thiserror::Error; + +/// Errors that can occur during sandbox setup. +#[derive(Debug, Error)] +pub enum SandboxError { + #[error("failed to unshare namespace: {0}")] + Unshare(#[source] nix::Error), + + #[error("fork failed: {0}")] + Fork(#[source] nix::Error), + + #[error("failed to write {path}: {source}")] + WriteFile { + path: String, + #[source] + source: std::io::Error, + }, + + #[error("failed to {operation} {target}: {source}")] + Mount { + operation: String, + target: String, + #[source] + source: nix::Error, + }, + + #[error("failed to create directory {path}: {source}")] + Mkdir { + path: String, + #[source] + source: std::io::Error, + }, + + #[error("failed to pivot_root: {0}")] + PivotRoot(#[source] nix::Error), + + #[error("failed to change directory to {path}: {source}")] + Chdir { + path: String, + #[source] + source: nix::Error, + }, +} + +/// Result of applying sandbox. +#[derive(Debug)] +pub enum SandboxResult { + /// We are the wrapper parent. Contains child's exit code. + Parent(i32), + /// Sandbox applied successfully. Contains new config path. + Sandboxed(std::path::PathBuf), +} + +/// Generate the content for /proc/self/uid_map. +/// +/// Maps the given outside UID to UID 0 inside the namespace. +fn generate_uid_map(outside_uid: u32) -> String { + format!("0 {} 1\n", outside_uid) +} + +/// Generate the content for /proc/self/gid_map. +/// +/// Maps the given outside GID to GID 0 inside the namespace. +fn generate_gid_map(outside_gid: u32) -> String { + format!("0 {} 1\n", outside_gid) +} + +/// Write uid_map and gid_map files for the current process. +/// +/// Must be called after unshare(CLONE_NEWUSER). +/// Writes "deny" to setgroups before gid_map (required by kernel). +fn write_uid_gid_maps(outside_uid: Uid, outside_gid: Gid) -> Result<(), SandboxError> { + // Must deny setgroups before writing gid_map (kernel requirement) + fs::write("/proc/self/setgroups", "deny").map_err(|e| SandboxError::WriteFile { + path: "/proc/self/setgroups".to_string(), + source: e, + })?; + + // Write uid_map + let uid_content = generate_uid_map(outside_uid.as_raw()); + fs::write("/proc/self/uid_map", &uid_content).map_err(|e| SandboxError::WriteFile { + path: "/proc/self/uid_map".to_string(), + source: e, + })?; + + // Write gid_map + let gid_content = generate_gid_map(outside_gid.as_raw()); + fs::write("/proc/self/gid_map", &gid_content).map_err(|e| SandboxError::WriteFile { + path: "/proc/self/gid_map".to_string(), + source: e, + })?; + + Ok(()) +} + +/// Global storing the child PID for the signal forwarding handler. +static CHILD_PID: AtomicI32 = AtomicI32::new(0); + +/// Signal handler that forwards SIGTERM to the child process. +extern "C" fn forward_signal(_sig: libc::c_int) { + let pid = CHILD_PID.load(Ordering::Relaxed); + if pid > 0 { + unsafe { libc::kill(pid, libc::SIGTERM) }; + } +} + +/// Install a SIGTERM handler on the parent that forwards to the child. +fn install_forwarding_handler(child: Pid) { + CHILD_PID.store(child.as_raw(), Ordering::Relaxed); + let action = SigAction::new( + SigHandler::Handler(forward_signal), + SaFlags::empty(), + SigSet::empty(), + ); + // Safety: forward_signal is async-signal-safe (only calls kill) + unsafe { signal::sigaction(Signal::SIGTERM, &action) }.ok(); +} + +/// Fork into new PID and mount namespaces. +/// +/// This function: +/// 1. Calls `unshare(CLONE_NEWPID | CLONE_NEWNS)` to create new namespaces +/// 2. Makes all mounts private (prevents propagation) +/// 3. Forks - the child becomes PID 1 in the new PID namespace +/// 4. Parent waits for child and returns its exit status +/// 5. Child continues execution +/// +/// Returns: +/// - `Ok(Some(exit_code))` in the parent (should propagate and exit) +/// - `Ok(None)` in the child (continue with sandbox setup) +/// - `Err` on failure +/// +/// Must be called AFTER entering user namespace and BEFORE +/// starting any multi-threaded runtime. +pub fn fork_into_pid_namespace() -> Result, SandboxError> { + // Create new PID namespace and mount namespace together + // The mount namespace is needed so we can mount procfs for the new PID namespace + unshare(CloneFlags::CLONE_NEWPID | CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?; + + // Make all mounts private to prevent propagation + mount( + None::<&str>, + "/", + None::<&str>, + MsFlags::MS_REC | MsFlags::MS_PRIVATE, + None::<&str>, + ) + .map_err(|e| SandboxError::Mount { + operation: "make private".to_string(), + target: "/".to_string(), + source: e, + })?; + + // Fork - child becomes PID 1 in the new namespace + match unsafe { fork() }.map_err(SandboxError::Fork)? { + ForkResult::Parent { child } => { + // Install signal handler that forwards SIGTERM to the child. + // Without this, SIGTERM kills the parent (default action) without + // notifying the child, which as PID 1 in a namespace ignores + // unhandled signals. + install_forwarding_handler(child); + + // Parent: wait for child and collect exit status + loop { + match waitpid(child, None) { + Ok(WaitStatus::Exited(_, code)) => { + return Ok(Some(code)); + } + Ok(WaitStatus::Signaled(_, sig, _)) => { + // Child killed by signal, propagate as exit code + return Ok(Some(128 + sig as i32)); + } + Ok(_) => continue, // Stopped/continued, keep waiting + Err(nix::Error::EINTR) => continue, // Interrupted, retry + Err(e) => { + return Err(SandboxError::Fork(e)); + } + } + } + } + ForkResult::Child => { + // Child: we are now PID 1 in the new namespace + Ok(None) + } + } +} + +/// Enter a new user namespace. +/// +/// After this call: +/// - The process appears to run as UID 0 / GID 0 inside the namespace +/// - The process can create other namespaces (mount, IPC, network, PID) +/// - The process has no capabilities in the parent namespace +/// +/// Must be called before any other namespace operations. +pub fn enter_user_namespace() -> Result<(), SandboxError> { + // Capture current UID/GID before unshare + let outside_uid = Uid::current(); + let outside_gid = Gid::current(); + + // Create new user namespace + unshare(CloneFlags::CLONE_NEWUSER).map_err(SandboxError::Unshare)?; + + // Set up UID/GID mappings + write_uid_gid_maps(outside_uid, outside_gid)?; + + Ok(()) +} + +/// Enter a new mount namespace with private mounts. +/// +/// After this call: +/// - Mount changes are isolated from the parent namespace +/// - All mounts are marked as private (no propagation) +/// +/// Must be called after `enter_user_namespace()`. +pub fn enter_mount_namespace() -> Result<(), SandboxError> { + // Create new mount namespace + unshare(CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?; + + // Make all mounts private to prevent propagation + mount( + None::<&str>, + "/", + None::<&str>, + MsFlags::MS_REC | MsFlags::MS_PRIVATE, + None::<&str>, + ) + .map_err(|e| SandboxError::Mount { + operation: "make private".to_string(), + target: "/".to_string(), + source: e, + })?; + + Ok(()) +} + +/// Create minimal /dev with essential devices. +/// +/// Bind-mounts /dev/null, /dev/zero, and /dev/urandom from the host. +/// These are required for basic operation (logging, random numbers). +fn create_minimal_dev(new_root: &Path) -> Result<(), SandboxError> { + let dev_dir = new_root.join("dev"); + + // Device nodes to bind-mount from host + let devices = ["null", "zero", "urandom"]; + + for device in devices { + let target = dev_dir.join(device); + + // Create empty file as mount point + fs::write(&target, b"").map_err(|e| SandboxError::WriteFile { + path: target.display().to_string(), + source: e, + })?; + + // Bind-mount the device from host + let source = Path::new("/dev").join(device); + mount( + Some(&source), + &target, + None::<&str>, + MsFlags::MS_BIND, + None::<&str>, + ) + .map_err(|e| SandboxError::Mount { + operation: "bind".to_string(), + target: target.display().to_string(), + source: e, + })?; + } + + Ok(()) +} + +/// Create a minimal root filesystem on tmpfs. +/// +/// Creates: +/// - tmpfs mounted at a temporary location +/// - /proc (mount point) +/// - /dev with null, zero, urandom +/// - /tmp (empty tmpfs) +/// - /config (bind-mount of config_dir) +/// +/// Returns the path to the new root. +fn setup_minimal_root(config_dir: &Path) -> Result { + // Create tmpfs for new root + let new_root = Path::new("/tmp/sandbox-root"); + + fs::create_dir_all(new_root).map_err(|e| SandboxError::Mkdir { + path: new_root.display().to_string(), + source: e, + })?; + + // Mount tmpfs on new root + mount( + Some("tmpfs"), + new_root, + Some("tmpfs"), + MsFlags::MS_NOSUID | MsFlags::MS_NODEV, + Some("size=16M,mode=0755"), + ) + .map_err(|e| SandboxError::Mount { + operation: "mount tmpfs".to_string(), + target: new_root.display().to_string(), + source: e, + })?; + + // Create directory structure + let dirs = ["proc", "dev", "tmp", "config", "old_root"]; + for dir in dirs { + let path = new_root.join(dir); + fs::create_dir_all(&path).map_err(|e| SandboxError::Mkdir { + path: path.display().to_string(), + source: e, + })?; + } + + // Set up /dev + create_minimal_dev(new_root)?; + + // Bind-mount config directory + let config_target = new_root.join("config"); + mount( + Some(config_dir), + &config_target, + None::<&str>, + MsFlags::MS_BIND | MsFlags::MS_REC, + None::<&str>, + ) + .map_err(|e| SandboxError::Mount { + operation: "bind config".to_string(), + target: config_target.display().to_string(), + source: e, + })?; + + // Mount /proc BEFORE pivot_root (required for proc mount to work) + // This may fail if we don't own a PID namespace (EPERM), which is fine + // for standalone filesystem isolation without apply_sandbox() + let proc_target = new_root.join("proc"); + match mount( + Some("proc"), + &proc_target, + Some("proc"), + MsFlags::MS_NOSUID | MsFlags::MS_NODEV | MsFlags::MS_NOEXEC, + None::<&str>, + ) { + Ok(()) => {} + Err(nix::Error::EPERM) => { + // Can't mount procfs without owning a PID namespace - this is expected + // when setup_filesystem_isolation is called standalone + } + Err(e) => { + return Err(SandboxError::Mount { + operation: "mount".to_string(), + target: proc_target.display().to_string(), + source: e, + }); + } + } + + Ok(new_root.to_path_buf()) +} + +/// Pivot root to the new filesystem and unmount the old root. +/// +/// After this call, the process sees only the new root filesystem. +/// The old root is unmounted and inaccessible. +fn pivot_to_new_root(new_root: &Path) -> Result<(), SandboxError> { + // Change to new root before pivot + chdir(new_root).map_err(|e| SandboxError::Chdir { + path: new_root.display().to_string(), + source: e, + })?; + + // Pivot root: new_root becomes /, old root moves to old_root + pivot_root(".", "old_root").map_err(SandboxError::PivotRoot)?; + + // Change to new root + chdir("/").map_err(|e| SandboxError::Chdir { + path: "/".to_string(), + source: e, + })?; + + // After pivot_root, old root is now at /old_root in the new namespace + let old_root = Path::new("/old_root"); + + // Unmount old root + nix::mount::umount2(old_root, nix::mount::MntFlags::MNT_DETACH).map_err(|e| { + SandboxError::Mount { + operation: "umount".to_string(), + target: old_root.display().to_string(), + source: e, + } + })?; + + // Remove old_root directory + let _ = fs::remove_dir(old_root); + + Ok(()) +} + +/// Enter a new IPC namespace. +/// +/// After this call: +/// - System V IPC objects (semaphores, message queues, shared memory) are isolated +/// - The process cannot access host IPC objects +/// +/// Must be called after `enter_user_namespace()`. +pub fn enter_ipc_namespace() -> Result<(), SandboxError> { + unshare(CloneFlags::CLONE_NEWIPC).map_err(SandboxError::Unshare)?; + Ok(()) +} + +/// Enter a new network namespace. +/// +/// After this call: +/// - The process has no network interfaces (not even loopback) +/// - Network communication is only possible via inherited file descriptors +/// - Unix sockets created before entering the namespace remain usable +/// +/// Must be called after `enter_user_namespace()`. +pub fn enter_network_namespace() -> Result<(), SandboxError> { + unshare(CloneFlags::CLONE_NEWNET).map_err(SandboxError::Unshare)?; + Ok(()) +} + +/// Apply full sandbox isolation to the current process. +/// +/// This function: +/// 1. Enters user namespace (appears as root inside, enables other namespaces) +/// 2. Forks into a new PID namespace (becomes PID 1) +/// 3. Sets up filesystem isolation (minimal root with config at /config) +/// 4. Enters IPC namespace (isolates System V IPC) +/// 5. Enters network namespace (no network interfaces) +/// +/// IMPORTANT: This function forks. In the parent, it returns +/// `Ok(SandboxResult::Parent(exit_code))`. The parent should propagate +/// this exit code. In the child, it returns +/// `Ok(SandboxResult::Sandboxed(config_path))`. +/// +/// After sandboxing: +/// - The process is PID 1 in its own PID namespace +/// - The process appears to run as UID 0 +/// - Only /config, /dev (minimal), /proc, and /tmp are accessible +/// - System V IPC is isolated from host +/// - No network interfaces (communication via inherited FDs only) +/// +/// Must be called BEFORE starting any multi-threaded runtime (tokio). +pub fn apply_sandbox(config_dir: &Path) -> Result { + // 1. Enter user namespace first (provides CAP_SYS_ADMIN for other namespaces) + enter_user_namespace()?; + + // 2. Fork into PID namespace (requires CAP_SYS_ADMIN from user namespace) + if let Some(exit_code) = fork_into_pid_namespace()? { + return Ok(SandboxResult::Parent(exit_code)); + } + + // 3. Set up filesystem isolation (mount ns already created by fork_into_pid_namespace) + setup_filesystem_isolation(config_dir, false)?; + + // 4. Enter IPC namespace + enter_ipc_namespace()?; + + // 5. Enter network namespace + enter_network_namespace()?; + + // Config dir is now at /config + Ok(SandboxResult::Sandboxed(std::path::PathBuf::from("/config"))) +} + +/// Set up filesystem isolation with a minimal root. +/// +/// This function: +/// 1. Optionally creates a mount namespace (if `needs_mount_ns` is true) +/// 2. Builds a minimal root on tmpfs +/// 3. Mounts /proc (requires owning a PID namespace; skipped if not available) +/// 4. Pivots to the new root +/// +/// After this call, only /config (mapped to config_dir), /dev (minimal), +/// /proc (if available), and /tmp are accessible. +/// +/// # Arguments +/// +/// * `config_dir` - Host path to bind-mount as /config +/// * `needs_mount_ns` - If true, creates a new mount namespace before setup. +/// Set to false when already in a mount namespace (e.g., after +/// `fork_into_pid_namespace()` which creates one). Set to true when calling +/// standalone after `enter_user_namespace()`. +pub fn setup_filesystem_isolation(config_dir: &Path, needs_mount_ns: bool) -> Result<(), SandboxError> { + if needs_mount_ns { + enter_mount_namespace()?; + } + + // Create minimal root (also attempts to mount /proc) + let new_root = setup_minimal_root(config_dir)?; + + // Pivot to new root + pivot_to_new_root(&new_root)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_uid_map_maps_outside_to_zero() { + let content = generate_uid_map(1000); + // Format: inside_uid outside_uid count + assert_eq!(content, "0 1000 1\n"); + } + + #[test] + fn generate_uid_map_handles_root() { + let content = generate_uid_map(0); + assert_eq!(content, "0 0 1\n"); + } + + #[test] + fn generate_gid_map_maps_outside_to_zero() { + let content = generate_gid_map(1000); + assert_eq!(content, "0 1000 1\n"); + } + + #[test] + fn generate_gid_map_handles_root() { + let content = generate_gid_map(0); + assert_eq!(content, "0 0 1\n"); + } + + #[test] + fn sandbox_error_displays_mount_error() { + let err = SandboxError::Mount { + operation: "mount".to_string(), + target: "/proc".to_string(), + source: nix::Error::EPERM, + }; + let msg = err.to_string(); + assert!(msg.contains("mount")); + assert!(msg.contains("/proc")); + } + + #[test] + fn sandbox_error_displays_mkdir_error() { + let err = SandboxError::Mkdir { + path: "/newroot/proc".to_string(), + source: std::io::Error::from_raw_os_error(libc::EACCES), + }; + let msg = err.to_string(); + assert!(msg.contains("/newroot/proc")); + } +} diff --git a/vm-switch/src/seccomp.rs b/vm-switch/src/seccomp.rs new file mode 100644 index 0000000..e988622 --- /dev/null +++ b/vm-switch/src/seccomp.rs @@ -0,0 +1,443 @@ +//! Seccomp-bpf filtering for syscall restriction. + +use clap::ValueEnum; +use thiserror::Error; + +/// Errors that can occur during seccomp filter setup. +#[derive(Debug, Error)] +pub enum SeccompError { + #[error("failed to compile filter: {0}")] + Compile(String), + + #[error("failed to apply filter: {0}")] + Apply(#[source] std::io::Error), + + #[error("seccomp is disabled")] + Disabled, +} + +/// Syscalls allowed for child (worker) processes. +/// +/// This is the base allowlist. The main process filter is built as a +/// superset of this list (see `MAIN_EXTRA_SYSCALLS`), which is important +/// because children inherit the main filter via fork() and seccomp +/// filters stack with AND semantics — both must allow a syscall. +/// +/// This is a tight whitelist because the child's own filter is applied AFTER: +/// - Socket creation and binding (vhost-user socket ready) +/// - Thread spawning (vhost daemon thread running) +/// - Signal handler setup +/// +/// Children only need syscalls for: +/// - Accepting connections on existing socket +/// - Reading/writing on existing FDs +/// - Ring buffer operations (memfd, mmap) +/// - Polling and synchronization +pub static CHILD_SYSCALLS: &[i64] = &[ + // Basic I/O + libc::SYS_read, + libc::SYS_write, + libc::SYS_close, + libc::SYS_lseek, + libc::SYS_pread64, + libc::SYS_pwrite64, + libc::SYS_readv, + libc::SYS_writev, + + // Memory management + libc::SYS_mmap, + libc::SYS_mprotect, + libc::SYS_munmap, + libc::SYS_brk, + libc::SYS_mremap, + libc::SYS_madvise, + + // Ring buffer operations + libc::SYS_memfd_create, + libc::SYS_ftruncate, + + // File operations (limited - no creation) + libc::SYS_fstat, + libc::SYS_newfstatat, + libc::SYS_fcntl, + libc::SYS_dup, + libc::SYS_dup2, + libc::SYS_dup3, + libc::SYS_unlink, // glibc may use this instead of unlinkat + libc::SYS_unlinkat, + + // Process/thread control + libc::SYS_clone3, // glibc pthread_create uses clone3 (vhost-user spawns threads lazily) + libc::SYS_exit, + libc::SYS_exit_group, + libc::SYS_getpid, + libc::SYS_gettid, + libc::SYS_getuid, + libc::SYS_getgid, + libc::SYS_geteuid, + libc::SYS_getegid, + libc::SYS_sched_yield, + libc::SYS_sched_getaffinity, + libc::SYS_set_robust_list, + libc::SYS_rseq, + + // Signal handling (handlers already installed) + libc::SYS_rt_sigaction, + libc::SYS_rt_sigprocmask, + libc::SYS_rt_sigreturn, + libc::SYS_sigaltstack, + + // Polling + libc::SYS_epoll_create1, + libc::SYS_epoll_ctl, + libc::SYS_epoll_wait, + libc::SYS_epoll_pwait, + libc::SYS_poll, + libc::SYS_ppoll, + + // Socket operations (on existing sockets only) + // NO: socket, bind, listen, connect, setsockopt, getsockopt, socketpair + libc::SYS_accept, + libc::SYS_accept4, + libc::SYS_sendto, + libc::SYS_recvfrom, + libc::SYS_sendmsg, + libc::SYS_recvmsg, + libc::SYS_shutdown, + libc::SYS_getsockname, + libc::SYS_getpeername, + + // Time + libc::SYS_clock_gettime, + libc::SYS_clock_getres, + libc::SYS_nanosleep, + libc::SYS_gettimeofday, + + // Thread synchronization (for existing threads) + libc::SYS_futex, + + // Misc + libc::SYS_getrandom, + libc::SYS_prctl, + libc::SYS_arch_prctl, + libc::SYS_ioctl, + libc::SYS_pipe2, + libc::SYS_eventfd2, +]; + +/// Additional syscalls needed only by the main process. +/// +/// The main filter is built from CHILD_SYSCALLS + these extras. +/// This ensures children always inherit a superset of what they need. +pub static MAIN_EXTRA_SYSCALLS: &[i64] = &[ + // File operations (main does config watching, directory traversal) + libc::SYS_openat, + libc::SYS_getdents64, + libc::SYS_mkdirat, + libc::SYS_readlinkat, + libc::SYS_faccessat, + libc::SYS_faccessat2, + libc::SYS_statx, + + // Process control (main forks children, manages lifecycle) + libc::SYS_fork, + libc::SYS_clone, + libc::SYS_clone3, + libc::SYS_wait4, + libc::SYS_kill, + libc::SYS_getppid, + + // Polling (extra variants) + libc::SYS_select, + libc::SYS_pselect6, + + // Socket operations (main creates sockets, children only accept) + libc::SYS_socket, + libc::SYS_bind, + libc::SYS_listen, + libc::SYS_connect, + libc::SYS_getsockopt, + libc::SYS_setsockopt, + libc::SYS_socketpair, + + // inotify (for config file watching) + libc::SYS_inotify_init1, + libc::SYS_inotify_add_watch, + libc::SYS_inotify_rm_watch, + + // Time (extra variants) + libc::SYS_clock_nanosleep, + + // Misc (main-only) + libc::SYS_uname, + libc::SYS_timerfd_create, + libc::SYS_timerfd_settime, + libc::SYS_timerfd_gettime, + + // Seccomp (for applying child filters after fork) + libc::SYS_seccomp, +]; + +/// Seccomp filter mode. +#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, ValueEnum)] +pub enum SeccompMode { + /// Kill process on blocked syscall (production default). + #[default] + Kill, + /// Send SIGSYS on blocked syscall (for debugging). + Trap, + /// Log blocked syscalls but allow them. + Log, + /// Disable seccomp filtering entirely. + Disabled, +} + +use seccompiler::{BpfMap, SeccompAction, SeccompFilter, TargetArch}; +use std::collections::BTreeMap; + +/// Build a BPF filter for the given syscall whitelist. +/// +/// Returns a BpfMap ready to be applied via `seccompiler::apply_filter_all_threads`. +pub fn build_filter(syscalls: &[i64], mode: SeccompMode) -> Result { + if mode == SeccompMode::Disabled { + return Err(SeccompError::Disabled); + } + + let default_action = match mode { + SeccompMode::Kill => SeccompAction::KillProcess, + SeccompMode::Trap => SeccompAction::Trap, + SeccompMode::Log => SeccompAction::Log, + SeccompMode::Disabled => unreachable!(), + }; + + // Build allow rules for each syscall. + // An empty rule vector means "match unconditionally on syscall number". + let mut rules: BTreeMap> = BTreeMap::new(); + for &syscall in syscalls { + rules.insert(syscall, vec![]); + } + + // Create filter: allow whitelisted syscalls, block others. + // SeccompFilter::new(rules, mismatch_action, match_action, arch) + let filter = SeccompFilter::new( + rules, + default_action, // mismatch_action: block non-whitelisted syscalls + SeccompAction::Allow, // match_action: allow whitelisted syscalls + TargetArch::x86_64, + ) + .map_err(|e| SeccompError::Compile(e.to_string()))?; + + // Compile to BPF + let bpf_prog = filter + .try_into() + .map_err(|e: seccompiler::BackendError| SeccompError::Compile(e.to_string()))?; + + let mut map = BpfMap::new(); + map.insert("main".to_string(), bpf_prog); + + Ok(map) +} + +/// Apply a compiled BPF filter to all threads in the current process. +/// +/// Uses `prctl(PR_SET_NO_NEW_PRIVS)` and `seccomp(SECCOMP_SET_MODE_FILTER)`. +/// Once applied, the filter cannot be removed or made less restrictive. +pub fn apply_filter(bpf_map: &BpfMap) -> Result<(), SeccompError> { + let bpf_prog = bpf_map + .get("main") + .ok_or_else(|| SeccompError::Compile("no 'main' filter in map".to_string()))?; + + seccompiler::apply_filter_all_threads(bpf_prog) + .map_err(|e| SeccompError::Apply(std::io::Error::other(e.to_string())))?; + + Ok(()) +} + +/// Collect the full main syscall list (CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS). +fn main_syscalls() -> Vec { + let mut syscalls = Vec::with_capacity(CHILD_SYSCALLS.len() + MAIN_EXTRA_SYSCALLS.len()); + syscalls.extend_from_slice(CHILD_SYSCALLS); + syscalls.extend_from_slice(MAIN_EXTRA_SYSCALLS); + syscalls +} + +/// Apply seccomp filter for the main process. +/// +/// Call this after namespace setup, before starting the tokio runtime. +/// The filter is built from CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS, ensuring +/// children always inherit a permissive-enough base filter. +pub fn apply_main_seccomp(mode: SeccompMode) -> Result<(), SeccompError> { + if mode == SeccompMode::Disabled { + return Ok(()); + } + + let syscalls = main_syscalls(); + let filter = build_filter(&syscalls, mode)?; + apply_filter(&filter) +} + +/// Apply seccomp filter for a child (worker) process. +/// +/// Call this after socket creation and thread spawning, just before +/// entering the event loop. This allows the tightest possible filter. +pub fn apply_child_seccomp(mode: SeccompMode) -> Result<(), SeccompError> { + if mode == SeccompMode::Disabled { + return Ok(()); + } + + let filter = build_filter(CHILD_SYSCALLS, mode)?; + apply_filter(&filter) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn seccomp_error_displays_compile_error() { + let err = SeccompError::Compile("invalid rule".to_string()); + assert!(err.to_string().contains("compile")); + assert!(err.to_string().contains("invalid rule")); + } + + #[test] + fn seccomp_error_displays_apply_error() { + let err = SeccompError::Apply(std::io::Error::from_raw_os_error(libc::EPERM)); + let msg = err.to_string(); + assert!(msg.contains("apply")); + } + + #[test] + fn seccomp_error_displays_disabled() { + let err = SeccompError::Disabled; + assert!(err.to_string().contains("disabled")); + } + + #[test] + fn main_syscalls_is_not_empty() { + assert!(!main_syscalls().is_empty()); + } + + #[test] + fn main_syscalls_contains_essential_syscalls() { + let syscalls = main_syscalls(); + assert!(syscalls.contains(&libc::SYS_read)); + assert!(syscalls.contains(&libc::SYS_write)); + assert!(syscalls.contains(&libc::SYS_close)); + assert!(syscalls.contains(&libc::SYS_exit_group)); + } + + #[test] + fn main_syscalls_allows_fork() { + // Main process needs fork to spawn children + assert!(main_syscalls().contains(&libc::SYS_fork)); + } + + #[test] + fn main_syscalls_is_superset_of_child() { + let main = main_syscalls(); + for &syscall in CHILD_SYSCALLS { + assert!( + main.contains(&syscall), + "CHILD_SYSCALLS contains {} which is missing from main filter", + syscall + ); + } + } + + #[test] + fn child_syscalls_is_not_empty() { + assert!(!CHILD_SYSCALLS.is_empty()); + } + + #[test] + fn child_syscalls_does_not_allow_fork() { + // Children don't spawn processes + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_fork)); + } + + #[test] + fn child_syscalls_allows_clone3_for_vhost_threads() { + // vhost-user library spawns threads lazily on client connect + // clone3 is used by glibc pthread_create; clone is NOT allowed + // (fork uses clone, so blocking it prevents fork in children) + assert!(CHILD_SYSCALLS.contains(&libc::SYS_clone3)); + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_clone)); + } + + #[test] + fn child_syscalls_does_not_allow_socket_creation() { + // Socket created before seccomp + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_socket)); + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_bind)); + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_listen)); + assert!(!CHILD_SYSCALLS.contains(&libc::SYS_setsockopt)); + } + + #[test] + fn child_syscalls_allows_accept() { + // Need to accept vhost-user connections + assert!(CHILD_SYSCALLS.contains(&libc::SYS_accept4)); + } + + #[test] + fn child_syscalls_allows_ring_buffer_ops() { + // Children need memfd_create for ring buffers + assert!(CHILD_SYSCALLS.contains(&libc::SYS_memfd_create)); + assert!(CHILD_SYSCALLS.contains(&libc::SYS_ftruncate)); + } + + #[test] + fn build_filter_creates_valid_bpf() { + let syscalls = main_syscalls(); + let filter = build_filter(&syscalls, SeccompMode::Kill) + .expect("filter should compile"); + assert_eq!(filter.len(), 1); + assert!(filter.contains_key("main")); + } + + #[test] + fn build_filter_handles_all_modes() { + let syscalls = main_syscalls(); + for mode in [SeccompMode::Kill, SeccompMode::Trap, SeccompMode::Log] { + let result = build_filter(&syscalls, mode); + assert!(result.is_ok(), "mode {:?} should compile", mode); + } + } + + #[test] + fn build_filter_disabled_returns_error() { + let syscalls = main_syscalls(); + let result = build_filter(&syscalls, SeccompMode::Disabled); + assert!(matches!(result, Err(SeccompError::Disabled))); + } + + #[test] + fn apply_filter_signature_check() { + fn check_signature(_f: fn(&BpfMap) -> Result<(), SeccompError>) {} + check_signature(apply_filter); + } + + #[test] + fn apply_main_seccomp_signature_check() { + fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {} + check(apply_main_seccomp); + } + + #[test] + fn apply_child_seccomp_signature_check() { + fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {} + check(apply_child_seccomp); + } + + #[test] + fn apply_main_seccomp_disabled_succeeds() { + // Disabled mode should be a no-op, not an error + assert!(apply_main_seccomp(SeccompMode::Disabled).is_ok()); + } + + #[test] + fn apply_child_seccomp_disabled_succeeds() { + assert!(apply_child_seccomp(SeccompMode::Disabled).is_ok()); + } +} diff --git a/vm-switch/src/switch.rs b/vm-switch/src/switch.rs deleted file mode 100644 index 33b0dd4..0000000 --- a/vm-switch/src/switch.rs +++ /dev/null @@ -1,445 +0,0 @@ -//! L2 switch logic with MAC filtering. - -use std::collections::HashMap; - -use tracing::debug; - -use crate::config::VmRole; -use crate::mac::Mac; - -/// Unique identifier for a connected VM. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct ConnectionId(u64); - -impl ConnectionId { - /// Create a new connection ID. - pub fn new(id: u64) -> Self { - Self(id) - } -} - -/// Decision for how to forward a frame. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ForwardDecision { - /// Forward to a single destination. - Unicast(ConnectionId), - /// Forward to multiple destinations (broadcast/multicast). - Multicast(Vec), - /// Drop the frame. - Drop(DropReason), -} - -/// Reason why a frame was dropped. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum DropReason { - /// Source MAC doesn't match the sender's configured MAC. - SourceMacMismatch { expected: Mac, actual: Mac }, - /// Client tried to send to a MAC other than router/broadcast/multicast. - ClientViolation { destination: Mac }, - /// Router tried to send to an unknown MAC. - UnknownDestination { destination: Mac }, - /// No router is connected. - NoRouter, - /// Sender connection ID is not registered. - UnknownSender, -} - -/// Information about a connected VM. -#[derive(Debug, Clone)] -struct Connection { - /// VM name for logging. - name: String, - /// Role (router or client). - role: VmRole, - /// Configured MAC address. - mac: Mac, -} - -/// L2 switch with MAC filtering. -/// -/// Maintains a registry of connected VMs and applies filtering rules -/// to determine how frames should be forwarded. -pub struct Switch { - /// Connected VMs by connection ID. - connections: HashMap, - /// MAC address to connection ID mapping for fast lookup. - mac_to_conn: HashMap, - /// The router's connection ID (if connected). - router: Option, - /// Next connection ID to assign. - next_id: u64, -} - -impl Switch { - /// Create a new empty switch. - pub fn new() -> Self { - Self { - connections: HashMap::new(), - mac_to_conn: HashMap::new(), - router: None, - next_id: 0, - } - } - - /// Register a new VM connection. - /// - /// Returns the assigned connection ID, or None if a router is already - /// connected and this is another router. - pub fn register(&mut self, name: String, role: VmRole, mac: Mac) -> Option { - // Reject second router - if role == VmRole::Router && self.router.is_some() { - return None; - } - - // Reject duplicate MAC address - if self.mac_to_conn.contains_key(&mac) { - return None; - } - - let id = ConnectionId::new(self.next_id); - self.next_id += 1; - - self.connections.insert(id, Connection { name, role, mac }); - self.mac_to_conn.insert(mac, id); - - if role == VmRole::Router { - self.router = Some(id); - } - - Some(id) - } - - /// Unregister a VM connection. - pub fn unregister(&mut self, id: ConnectionId) { - if let Some(conn) = self.connections.remove(&id) { - debug!(name = %conn.name, role = ?conn.role, mac = %conn.mac, "unregistered connection"); - self.mac_to_conn.remove(&conn.mac); - - if conn.role == VmRole::Router { - self.router = None; - } - } - } - - /// Determine how to forward a frame. - /// - /// # Arguments - /// * `sender` - Connection ID of the sender - /// * `source_mac` - Source MAC from the frame - /// * `dest_mac` - Destination MAC from the frame - pub fn forward(&self, sender: ConnectionId, source_mac: Mac, dest_mac: Mac) -> ForwardDecision { - // Get sender info - let sender_conn = match self.connections.get(&sender) { - Some(c) => c, - None => return ForwardDecision::Drop(DropReason::UnknownSender), - }; - - // Validate source MAC - if source_mac != sender_conn.mac { - return ForwardDecision::Drop(DropReason::SourceMacMismatch { - expected: sender_conn.mac, - actual: source_mac, - }); - } - - match sender_conn.role { - VmRole::Client => self.forward_from_client(dest_mac), - VmRole::Router => self.forward_from_router(sender, dest_mac), - } - } - - fn forward_from_client(&self, dest_mac: Mac) -> ForwardDecision { - // Get router - let router_id = match self.router { - Some(id) => id, - None => return ForwardDecision::Drop(DropReason::NoRouter), - }; - - let router_conn = self.connections.get(&router_id).unwrap(); - - // Clients can only send to router, broadcast, or multicast - if dest_mac == router_conn.mac || dest_mac.is_broadcast() || dest_mac.is_multicast() { - ForwardDecision::Unicast(router_id) - } else { - ForwardDecision::Drop(DropReason::ClientViolation { destination: dest_mac }) - } - } - - fn forward_from_router(&self, sender: ConnectionId, dest_mac: Mac) -> ForwardDecision { - // Broadcast or multicast goes to all clients - if dest_mac.is_broadcast() || dest_mac.is_multicast() { - let client_ids: Vec = self.connections - .iter() - .filter(|(id, conn)| **id != sender && conn.role == VmRole::Client) - .map(|(id, _)| *id) - .collect(); - - return ForwardDecision::Multicast(client_ids); - } - - // Unicast to specific client - match self.mac_to_conn.get(&dest_mac) { - Some(id) if *id != sender => { - // Verify it's a client - if let Some(conn) = self.connections.get(id) { - if conn.role == VmRole::Client { - return ForwardDecision::Unicast(*id); - } - } - ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac }) - } - _ => ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac }), - } - } -} - -impl Default for Switch { - fn default() -> Self { - Self::new() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - // Helper to create test MACs - fn mac(s: &str) -> Mac { - Mac::parse(s).unwrap() - } - - #[test] - fn register_client() { - let mut switch = Switch::new(); - let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")); - - assert!(id.is_some()); - } - - #[test] - fn register_router() { - let mut switch = Switch::new(); - let id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")); - - assert!(id.is_some()); - } - - #[test] - fn register_second_router_fails() { - let mut switch = Switch::new(); - let id1 = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66")); - let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff")); - - assert!(id1.is_some()); - assert!(id2.is_none(), "Second router should be rejected"); - } - - #[test] - fn register_duplicate_mac_fails() { - let mut switch = Switch::new(); - let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")); - let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")); - - assert!(id1.is_some()); - assert!(id2.is_none(), "Duplicate MAC should be rejected"); - } - - #[test] - fn register_multiple_clients() { - let mut switch = Switch::new(); - let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")); - let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")); - - assert!(id1.is_some()); - assert!(id2.is_some()); - assert_ne!(id1, id2); - } - - #[test] - fn unregister_client() { - let mut switch = Switch::new(); - let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - switch.unregister(id); - - // Should be able to register another client with same MAC - let id2 = switch.register("banking2".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")); - assert!(id2.is_some()); - } - - #[test] - fn unregister_router_allows_new_router() { - let mut switch = Switch::new(); - let id = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - - switch.unregister(id); - - // Should now be able to register a new router - let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff")); - assert!(id2.is_some()); - } - - #[test] - fn forward_rejects_source_mac_mismatch() { - let mut switch = Switch::new(); - let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - // Client sends with wrong source MAC - let result = switch.forward(client_id, mac("00:00:00:00:00:01"), mac("11:22:33:44:55:66")); - - match result { - ForwardDecision::Drop(DropReason::SourceMacMismatch { expected, actual }) => { - assert_eq!(expected, mac("aa:bb:cc:dd:ee:ff")); - assert_eq!(actual, mac("00:00:00:00:00:01")); - } - _ => panic!("Expected SourceMacMismatch, got {:?}", result), - } - } - - #[test] - fn forward_client_to_router_unicast() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("11:22:33:44:55:66")); - - assert_eq!(result, ForwardDecision::Unicast(router_id)); - } - - #[test] - fn forward_client_to_router_broadcast() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff")); - - // Broadcast from client goes to router only - assert_eq!(result, ForwardDecision::Unicast(router_id)); - } - - #[test] - fn forward_client_to_router_multicast() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - // IPv4 multicast MAC - let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("01:00:5e:00:00:01")); - - // Multicast from client goes to router only - assert_eq!(result, ForwardDecision::Unicast(router_id)); - } - - #[test] - fn forward_client_violation_to_other_client() { - let mut switch = Switch::new(); - let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap(); - let _client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap(); - - // Client tries to send directly to another client - let result = switch.forward(client1_id, mac("aa:bb:cc:dd:ee:01"), mac("aa:bb:cc:dd:ee:02")); - - match result { - ForwardDecision::Drop(DropReason::ClientViolation { destination }) => { - assert_eq!(destination, mac("aa:bb:cc:dd:ee:02")); - } - _ => panic!("Expected ClientViolation, got {:?}", result), - } - } - - #[test] - fn forward_client_no_router() { - let mut switch = Switch::new(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff")); - - assert_eq!(result, ForwardDecision::Drop(DropReason::NoRouter)); - } - - #[test] - fn forward_router_to_client_unicast() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("aa:bb:cc:dd:ee:ff")); - - assert_eq!(result, ForwardDecision::Unicast(client_id)); - } - - #[test] - fn forward_router_broadcast_to_all_clients() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap(); - let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap(); - - let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff")); - - match result { - ForwardDecision::Multicast(ids) => { - assert_eq!(ids.len(), 2); - assert!(ids.contains(&client1_id)); - assert!(ids.contains(&client2_id)); - } - _ => panic!("Expected Multicast, got {:?}", result), - } - } - - #[test] - fn forward_router_multicast_to_all_clients() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap(); - let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap(); - - let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("01:00:5e:00:00:01")); - - match result { - ForwardDecision::Multicast(ids) => { - assert_eq!(ids.len(), 2); - assert!(ids.contains(&client1_id)); - assert!(ids.contains(&client2_id)); - } - _ => panic!("Expected Multicast, got {:?}", result), - } - } - - #[test] - fn forward_router_unknown_destination() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - let _client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap(); - - // Router sends to unknown MAC - let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("00:00:00:00:00:01")); - - match result { - ForwardDecision::Drop(DropReason::UnknownDestination { destination }) => { - assert_eq!(destination, mac("00:00:00:00:00:01")); - } - _ => panic!("Expected UnknownDestination, got {:?}", result), - } - } - - #[test] - fn forward_router_broadcast_no_clients() { - let mut switch = Switch::new(); - let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap(); - - let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff")); - - // Empty multicast is valid (no clients to send to) - match result { - ForwardDecision::Multicast(ids) => { - assert!(ids.is_empty()); - } - _ => panic!("Expected empty Multicast, got {:?}", result), - } - } -} diff --git a/vm-switch/tests/buffer_exchange.rs b/vm-switch/tests/buffer_exchange.rs new file mode 100644 index 0000000..ce649bb --- /dev/null +++ b/vm-switch/tests/buffer_exchange.rs @@ -0,0 +1,323 @@ +//! Integration tests for ring buffer exchange between processes. + +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{fork, ForkResult}; +use serial_test::serial; +use std::os::fd::AsRawFd; +use vm_switch::control::{ChildToMain, ControlChannel, MainToChild}; +use vm_switch::mac::Mac; +use vm_switch::ring::{Consumer, Producer}; + +#[test] +#[serial] +fn child_sends_ready_on_startup() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + + let temp_dir = tempfile::tempdir().expect("tempdir"); + let socket_path = temp_dir.path().join("test.sock"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + drop(child_end); + + // Wait for Ready + let (msg, fds): (ChildToMain, _) = main_end + .recv_with_fds_typed() + .expect("should receive"); + + match msg { + ChildToMain::Ready => {} + _ => panic!("expected Ready, got {:?}", msg), + } + + assert!(fds.is_empty(), "Ready should have no FDs"); + + // Close main end to trigger child exit + drop(main_end); + + // Wait for child + let status = waitpid(child, None).expect("waitpid failed"); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "child should exit cleanly"); + } + other => panic!("unexpected status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + drop(main_end); + let control_fd = child_end.into_fd(); + vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled); + } + Err(e) => panic!("fork failed: {}", e), + } +} + +#[test] +#[serial] +fn child_creates_ingress_buffer_on_request() { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); + + let temp_dir = tempfile::tempdir().expect("tempdir"); + let socket_path = temp_dir.path().join("test.sock"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + drop(child_end); + + // Wait for Ready + let (msg, _): (ChildToMain, _) = main_end + .recv_with_fds_typed() + .expect("should receive"); + assert!(matches!(msg, ChildToMain::Ready)); + + // Request buffer for a peer + let peer_name = "router".to_string(); + let peer_mac = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66]; + let msg = MainToChild::GetBuffer { + peer_name: peer_name.clone(), + peer_mac, + }; + main_end.send(&msg).expect("send GetBuffer"); + + // Wait for BufferReady + let (msg, fds): (ChildToMain, _) = main_end + .recv_with_fds_typed() + .expect("should receive"); + + match msg { + ChildToMain::BufferReady { peer_name: name } => { + assert_eq!(name, peer_name); + } + _ => panic!("expected BufferReady"), + } + + assert_eq!(fds.len(), 2, "should receive memfd and eventfd"); + + // Create producer from received FDs and verify we can write + let mut fds = fds.into_iter(); + let memfd = fds.next().unwrap(); + let eventfd = fds.next().unwrap(); + + let producer = Producer::from_fds(memfd, eventfd) + .expect("should create producer"); + + // Push a frame + assert!(producer.push(&[1, 2, 3, 4, 5])); + + // Close main end to trigger child exit + drop(main_end); + + // Wait for child + let status = waitpid(child, None).expect("waitpid failed"); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "child should exit cleanly"); + } + other => panic!("unexpected status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + drop(main_end); + let control_fd = child_end.into_fd(); + vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled); + } + Err(e) => panic!("fork failed: {}", e), + } +} + +#[test] +#[serial] +fn new_protocol_buffer_exchange() { + // Test the new protocol where: + // 1. Child sends Ready + // 2. Main sends GetBuffer + // 3. Child creates consumer and sends BufferReady with FDs + // 4. Main forwards FDs to peer as PutBuffer + // 5. Peer creates producer and can write + + // Simulate two "children" with control channels + let (main_a, child_a) = ControlChannel::pair().expect("pair A"); + let (main_b, child_b) = ControlChannel::pair().expect("pair B"); + + let name_a = "client_a".to_string(); + let name_b = "router".to_string(); + let mac_a = [0xaa, 0, 0, 0, 0, 1]; + let mac_b = [0xbb, 0, 0, 0, 0, 2]; + + // Step 1: Both children send Ready + child_a.send(&ChildToMain::Ready).expect("A ready"); + child_b.send(&ChildToMain::Ready).expect("B ready"); + + // Main receives Ready from both + let (msg_a, _): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready"); + let (msg_b, _): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready"); + assert!(matches!(msg_a, ChildToMain::Ready)); + assert!(matches!(msg_b, ChildToMain::Ready)); + + // Step 2: Main requests buffer from A for B (A will be consumer of data from B) + let get_buffer_msg = MainToChild::GetBuffer { + peer_name: name_b.clone(), + peer_mac: mac_b, + }; + main_a.send(&get_buffer_msg).expect("send GetBuffer to A"); + + // Step 3: A receives GetBuffer + let (msg, _): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer"); + match msg { + MainToChild::GetBuffer { peer_name, peer_mac } => { + assert_eq!(peer_name, name_b); + assert_eq!(peer_mac, mac_b); + } + _ => panic!("expected GetBuffer"), + } + + // A creates consumer (ingress buffer) and sends BufferReady + let consumer_a = Consumer::new().expect("consumer A"); + let buffer_ready = ChildToMain::BufferReady { peer_name: name_b.clone() }; + child_a.send_with_fds_typed(&buffer_ready, &[ + consumer_a.memfd().as_raw_fd(), + consumer_a.eventfd().as_raw_fd(), + ]).expect("A send BufferReady"); + + // Main receives BufferReady with FDs + let (msg, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv BufferReady"); + match msg { + ChildToMain::BufferReady { peer_name } => { + assert_eq!(peer_name, name_b); + } + _ => panic!("expected BufferReady"), + } + assert_eq!(fds_from_a.len(), 2); + + // Step 4: Main sends PutBuffer to B with A's ingress buffer + // B will use this as egress to A + // broadcast=true because A is a client and B is router + let put_buffer_msg = MainToChild::PutBuffer { + peer_name: name_a.clone(), + peer_mac: mac_a, + broadcast: false, // A is client, not router + }; + main_b.send_with_fds_typed(&put_buffer_msg, &[ + fds_from_a[0].as_raw_fd(), + fds_from_a[1].as_raw_fd(), + ]).expect("send PutBuffer to B"); + + // Step 5: B receives PutBuffer with FDs + let (msg, fds_for_b): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer"); + match msg { + MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => { + assert_eq!(peer_name, name_a); + assert_eq!(peer_mac, mac_a); + assert!(!broadcast); + } + _ => panic!("expected PutBuffer"), + } + assert_eq!(fds_for_b.len(), 2); + + // B creates producer (egress buffer) from received FDs + let mut fds_b = fds_for_b.into_iter(); + let producer_b = Producer::from_fds(fds_b.next().unwrap(), fds_b.next().unwrap()) + .expect("producer B"); + + // Step 6: B writes to producer, A can read from consumer + producer_b.push(&[10, 20, 30, 40, 50]); + + let data = consumer_a.pop().expect("should pop"); + assert_eq!(data, vec![10, 20, 30, 40, 50]); +} + +#[test] +#[serial] +fn bidirectional_new_protocol_exchange() { + // Two VMs exchange buffers and can communicate both ways using new protocol + + let (main_a, child_a) = ControlChannel::pair().expect("pair A"); + let (main_b, child_b) = ControlChannel::pair().expect("pair B"); + + let name_a = "client_a".to_string(); + let name_b = "router".to_string(); + let mac_a = [0xaa, 0, 0, 0, 0, 1]; + let mac_b = [0xbb, 0, 0, 0, 0, 2]; + + // Both send Ready + child_a.send(&ChildToMain::Ready).expect("A ready"); + child_b.send(&ChildToMain::Ready).expect("B ready"); + + let _: (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready"); + let _: (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready"); + + // Request ingress buffers from both sides + // A creates ingress for B + main_a.send(&MainToChild::GetBuffer { + peer_name: name_b.clone(), + peer_mac: mac_b, + }).expect("GetBuffer A->B"); + + // B creates ingress for A + main_b.send(&MainToChild::GetBuffer { + peer_name: name_a.clone(), + peer_mac: mac_a, + }).expect("GetBuffer B->A"); + + // A receives GetBuffer, creates consumer, sends BufferReady + let _: (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer"); + let consumer_a = Consumer::new().expect("consumer A"); + child_a.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_b.clone() }, &[ + consumer_a.memfd().as_raw_fd(), + consumer_a.eventfd().as_raw_fd(), + ]).expect("A BufferReady"); + + // B receives GetBuffer, creates consumer, sends BufferReady + let _: (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv GetBuffer"); + let consumer_b = Consumer::new().expect("consumer B"); + child_b.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_a.clone() }, &[ + consumer_b.memfd().as_raw_fd(), + consumer_b.eventfd().as_raw_fd(), + ]).expect("B BufferReady"); + + // Main receives BufferReady from both + let (_, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv A BufferReady"); + let (_, fds_from_b): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("main recv B BufferReady"); + + // Cross-send: A's ingress becomes B's egress to A, and vice versa + // Send A's buffer to B (B is router, so broadcast=true when sending TO router) + main_b.send_with_fds_typed(&MainToChild::PutBuffer { + peer_name: name_a.clone(), + peer_mac: mac_a, + broadcast: false, // A is not router + }, &[ + fds_from_a[0].as_raw_fd(), + fds_from_a[1].as_raw_fd(), + ]).expect("PutBuffer A to B"); + + // Send B's buffer to A (A is client, broadcast=true because B is router) + main_a.send_with_fds_typed(&MainToChild::PutBuffer { + peer_name: name_b.clone(), + peer_mac: mac_b, + broadcast: true, // B is router + }, &[ + fds_from_b[0].as_raw_fd(), + fds_from_b[1].as_raw_fd(), + ]).expect("PutBuffer B to A"); + + // Each side receives PutBuffer and creates producer + let (_, fds_b_egress): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer"); + let (_, fds_a_egress): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv PutBuffer"); + + let mut fds = fds_b_egress.into_iter(); + let producer_b = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer B"); + + let mut fds = fds_a_egress.into_iter(); + let producer_a = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer A"); + + // B sends to A + producer_b.push(&[1, 2, 3]); + assert_eq!(consumer_a.pop().unwrap(), vec![1, 2, 3]); + + // A sends to B + producer_a.push(&[4, 5, 6]); + assert_eq!(consumer_b.pop().unwrap(), vec![4, 5, 6]); +} diff --git a/vm-switch/tests/crash_handling.rs b/vm-switch/tests/crash_handling.rs new file mode 100644 index 0000000..86d3046 --- /dev/null +++ b/vm-switch/tests/crash_handling.rs @@ -0,0 +1,104 @@ +//! Integration tests for crash detection and cleanup. + +use nix::sys::signal::{kill, Signal}; +use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus}; +use nix::unistd::{fork, ForkResult}; +use std::time::Duration; +use vm_switch::control::{ControlChannel, MainToChild}; + +/// Test that main detects child crash via waitpid. +#[test] +fn main_detects_child_crash() { + let (main_end, child_end) = ControlChannel::pair().expect("pair"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + drop(child_end); + + std::thread::sleep(Duration::from_millis(50)); + + kill(child, Signal::SIGKILL).expect("kill"); + + let status = waitpid(child, None).expect("waitpid"); + assert!(matches!(status, WaitStatus::Signaled(_, Signal::SIGKILL, _))); + + drop(main_end); + } + Ok(ForkResult::Child) => { + drop(main_end); + let control = ControlChannel::from_fd(child_end.into_fd()); + loop { + match control.recv::() { + Err(_) => break, + _ => continue, + } + } + std::process::exit(0); + } + Err(e) => panic!("fork failed: {}", e), + } +} + +/// Test that RemovePeer message is delivered correctly. +#[test] +fn remove_peer_message_delivery() { + let (main_end, child_end) = ControlChannel::pair().expect("pair"); + + let crashed_vm = "crashed_client".to_string(); + let msg = MainToChild::RemovePeer { peer_name: crashed_vm.clone() }; + + main_end.send(&msg).expect("send"); + + let received: MainToChild = child_end.recv().expect("recv"); + match received { + MainToChild::RemovePeer { peer_name } => { + assert_eq!(peer_name, crashed_vm); + } + _ => panic!("expected RemovePeer"), + } +} + +/// Test that children exit cleanly when control channel closes. +#[test] +fn children_exit_on_control_close() { + let (main_end, child_end) = ControlChannel::pair().expect("pair"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + drop(child_end); + + std::thread::sleep(Duration::from_millis(50)); + + drop(main_end); + + let deadline = std::time::Instant::now() + Duration::from_secs(2); + loop { + match waitpid(child, Some(WaitPidFlag::WNOHANG)) { + Ok(WaitStatus::Exited(_, 0)) => return, + Ok(WaitStatus::Exited(_, code)) => { + panic!("Child exited with code {}", code); + } + Ok(_) => { + if std::time::Instant::now() > deadline { + kill(child, Signal::SIGKILL).ok(); + panic!("Child did not exit within timeout"); + } + std::thread::sleep(Duration::from_millis(50)); + } + Err(e) => panic!("waitpid error: {}", e), + } + } + } + Ok(ForkResult::Child) => { + drop(main_end); + let control = ControlChannel::from_fd(child_end.into_fd()); + loop { + match control.recv::() { + Err(_) => std::process::exit(0), + _ => continue, + } + } + } + Err(e) => panic!("fork failed: {}", e), + } +} diff --git a/vm-switch/tests/fork_lifecycle.rs b/vm-switch/tests/fork_lifecycle.rs new file mode 100644 index 0000000..0d019d2 --- /dev/null +++ b/vm-switch/tests/fork_lifecycle.rs @@ -0,0 +1,160 @@ +//! Integration tests for fork-based child process lifecycle. +//! +//! Note: These tests must be run with `--test-threads=1` because +//! fork tests cannot run in parallel within the same process. + +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{fork, ForkResult}; +use serial_test::serial; +use std::time::Duration; +use vm_switch::control::ControlChannel; +use vm_switch::mac::Mac; + +#[test] +#[serial] +fn child_exits_when_control_channel_closes() { + // Create control channel + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x01]); + + let temp_dir = tempfile::tempdir().expect("tempdir"); + let socket_path = temp_dir.path().join("test.sock"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + // Parent: drop child's end, keep main's end + drop(child_end); + + // Give child time to start and send Ready + std::thread::sleep(Duration::from_millis(100)); + + // Drain any Ready message + let _ = main_end.recv_with_fds_typed::(); + + // Close main's end - should cause child to exit + drop(main_end); + + // Wait for child to exit + let status = waitpid(child, None).expect("waitpid failed"); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "child should exit with code 0"); + } + other => panic!("unexpected wait status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + // Child: drop main's end (we don't need parent's socket) + drop(main_end); + + // Run child entry point - this should exit when control closes + let control_fd = child_end.into_fd(); + vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled); + } + Err(e) => panic!("fork failed: {}", e), + } +} + +#[test] +#[serial] +fn child_processes_messages_before_exit() { + use vm_switch::control::MainToChild; + + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x02]); + + let temp_dir = tempfile::tempdir().expect("tempdir"); + let socket_path = temp_dir.path().join("test.sock"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + // Parent: drop child's end + drop(child_end); + + // Wait for Ready from child + std::thread::sleep(Duration::from_millis(100)); + let _ = main_end.recv_with_fds_typed::(); + + // Send a RemovePeer message + let msg = MainToChild::RemovePeer { + peer_name: "some-peer".to_string(), + }; + main_end.send(&msg).expect("should send"); + + // Close channel (drop main_end) + drop(main_end); + + // Child should exit cleanly after processing message + let status = waitpid(child, None).expect("waitpid failed"); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "child should exit with code 0"); + } + other => panic!("unexpected wait status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + drop(main_end); + let control_fd = child_end.into_fd(); + vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled); + } + Err(e) => panic!("fork failed: {}", e), + } +} + +#[test] +#[serial] +fn multiple_children_shut_down_gracefully() { + // Fork 3 children + let mut children = Vec::new(); + + for i in 0..3u8 { + let (main_end, child_end) = ControlChannel::pair().expect("should create pair"); + let vm_name = format!("test-vm-{}", i); + let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, i]); + + let temp_dir = tempfile::tempdir().expect("tempdir"); + let socket_path = temp_dir.path().join("test.sock"); + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + drop(child_end); + children.push((child, main_end, vm_name, temp_dir)); + } + Ok(ForkResult::Child) => { + drop(main_end); + let control_fd = child_end.into_fd(); + vm_switch::child::run_child_process(&vm_name, mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled); + } + Err(e) => panic!("fork failed: {}", e), + } + } + + // Give children time to start and send Ready + std::thread::sleep(Duration::from_millis(100)); + + // Drain Ready messages + for (_, control, _, _) in &children { + let _ = control.recv_with_fds_typed::(); + } + + // Collect PIDs and drop control channels to signal shutdown + let pids: Vec<_> = children + .into_iter() + .map(|(pid, control, _, _temp_dir)| { + drop(control); + pid + }) + .collect(); + + // Wait for all children + for pid in pids { + let status = waitpid(pid, None).expect("waitpid failed"); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "child should exit with code 0"); + } + other => panic!("unexpected wait status: {:?}", other), + } + } +} diff --git a/vm-switch/tests/packet_flow.rs b/vm-switch/tests/packet_flow.rs new file mode 100644 index 0000000..e000365 --- /dev/null +++ b/vm-switch/tests/packet_flow.rs @@ -0,0 +1,63 @@ +//! Integration tests for packet forwarding. + +use std::os::fd::{AsRawFd, FromRawFd, OwnedFd}; +use vm_switch::child::PacketForwarder; +use vm_switch::mac::Mac; +use vm_switch::ring::{Consumer, Producer}; + +fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec { + let mut frame = vec![0u8; 64]; + frame[0..6].copy_from_slice(&dest); + frame[6..12].copy_from_slice(&src); + frame +} + +#[test] +fn forwarder_validates_source_mac() { + let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]); + let peer_mac = [2, 0, 0, 0, 0, 2]; + let mut forwarder = PacketForwarder::new(our_mac); + + // Set up an egress buffer for broadcast + let consumer = Consumer::new().expect("consumer"); + let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd())) }; + let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd())) }; + let producer = Producer::from_fds(memfd, eventfd).expect("producer"); + forwarder.add_egress("router".to_string(), peer_mac, producer, true); + + // Correct source MAC - broadcast frame should be sent to egress with broadcast=true + let good_frame = make_frame([0xff; 6], our_mac.bytes()); + assert!(forwarder.forward_tx(&good_frame)); + + // Wrong source MAC - should be dropped + let bad_frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]); + assert!(!forwarder.forward_tx(&bad_frame)); +} + +#[test] +fn forwarder_ingress_validates_peer_mac() { + let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]); + let peer_mac = [2, 0, 0, 0, 0, 2]; + let spoofed = [3, 0, 0, 0, 0, 3]; + + let mut forwarder = PacketForwarder::new(our_mac); + + // Set up ingress from peer + let producer = Producer::new().expect("producer"); + let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) }; + let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) }; + let consumer = Consumer::from_fds(memfd, eventfd).expect("consumer"); + forwarder.add_ingress("router".to_string(), peer_mac, consumer); + + // Good frame from peer + let good = make_frame(our_mac.bytes(), peer_mac); + producer.push(&good); + + // Spoofed frame (wrong source) + let bad = make_frame(our_mac.bytes(), spoofed); + producer.push(&bad); + + let received = forwarder.poll_ingress(); + assert_eq!(received.len(), 1); + assert_eq!(&received[0][6..12], &peer_mac); +} diff --git a/vm-switch/tests/sandbox_full.rs b/vm-switch/tests/sandbox_full.rs new file mode 100644 index 0000000..de0ab5d --- /dev/null +++ b/vm-switch/tests/sandbox_full.rs @@ -0,0 +1,256 @@ +//! Integration tests for full sandbox isolation. + +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{fork, ForkResult, Uid}; +use std::fs; +use std::path::Path; +use vm_switch::sandbox::{apply_sandbox, SandboxResult}; + +/// Helper to run test in forked child process. +fn run_in_fork(test_fn: F) { + if Uid::current().is_root() { + eprintln!("Skipping test: already running as root"); + return; + } + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + let status = waitpid(child, None).unwrap(); + match status { + WaitStatus::Exited(_, 0) => {} + other => panic!("Child failed: {:?}", other), + } + } + Ok(ForkResult::Child) => { + let result = std::panic::catch_unwind(test_fn); + match &result { + Err(e) => { + if let Some(s) = e.downcast_ref::<&str>() { + eprintln!("Child panic: {}", s); + } else if let Some(s) = e.downcast_ref::() { + eprintln!("Child panic: {}", s); + } else { + eprintln!("Child panic: unknown error"); + } + } + Ok(()) => {} + } + std::process::exit(if result.is_ok() { 0 } else { 1 }); + } + Err(e) => panic!("Fork failed: {}", e), + } +} + +/// Helper to handle SandboxResult in tests. +/// In the inner wrapper parent, propagates exit code. +/// In the sandboxed child, returns the config path. +fn apply_and_unwrap(config_path: &Path) -> std::path::PathBuf { + match apply_sandbox(config_path).expect("apply_sandbox failed") { + SandboxResult::Parent(code) => { + // Inner wrapper parent - propagate child's exit code + std::process::exit(code); + } + SandboxResult::Sandboxed(path) => path, + } +} + +#[test] +fn apply_sandbox_returns_config_path() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + let result = apply_and_unwrap(&config_path); + assert_eq!(result, Path::new("/config")); + }); +} + +#[test] +fn apply_sandbox_isolates_ipc_namespace() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + // Get parent IPC namespace before fork + let parent_ipc = fs::read_link("/proc/self/ns/ipc").unwrap(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // The fact that apply_sandbox succeeded means IPC namespace was entered + // We verify isolation indirectly through the other tests + }); + + // Parent's namespace should be unchanged + let parent_ipc_after = fs::read_link("/proc/self/ns/ipc").unwrap(); + assert_eq!(parent_ipc, parent_ipc_after); +} + +#[test] +fn apply_sandbox_isolates_network_namespace() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // In empty network namespace, /sys/class/net should be empty or not exist + // Since /sys is not mounted in our minimal root, network is effectively isolated + assert!(!Path::new("/sys").exists(), "/sys should not exist in sandbox"); + }); +} + +#[test] +fn apply_sandbox_creates_complete_isolation() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + // Create marker file + fs::write(config_path.join("marker.txt"), "isolated").unwrap(); + + run_in_fork(move || { + let new_config = apply_and_unwrap(&config_path); + + // Verify config path is correct + assert_eq!(new_config, Path::new("/config")); + + // Verify filesystem isolation + assert!(Path::new("/config/marker.txt").exists()); + assert!(!Path::new("/home").exists()); + assert!(!Path::new("/usr").exists()); + + // Verify /dev works + assert!(Path::new("/dev/null").exists()); + fs::write("/dev/null", "test").expect("write to /dev/null"); + + // Verify /tmp is writable + fs::write("/tmp/test.txt", "temp").expect("write to /tmp"); + }); +} + +#[test] +fn forked_children_inherit_sandbox() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + fs::write(config_path.join("shared.txt"), "parent").unwrap(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // Fork a child from within the sandbox + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + let status = waitpid(child, None).unwrap(); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "Nested child should exit with 0"); + } + other => panic!("Nested child unexpected status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + // Nested child inherits sandbox + let exists = Path::new("/config/shared.txt").exists(); + let no_home = !Path::new("/home").exists(); + + std::process::exit(if exists && no_home { 0 } else { 1 }); + } + Err(e) => panic!("Nested fork failed: {}", e), + } + }); +} + +#[test] +fn apply_sandbox_creates_pid_namespace() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // Read /proc/self/stat - should be PID 1 in the new namespace + let stat = fs::read_to_string("/proc/self/stat").expect("should read /proc/self/stat"); + assert!( + stat.starts_with("1 "), + "Should be PID 1 in namespace, got: {}", + stat + ); + }); +} + +#[test] +fn apply_sandbox_mounts_proc() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // /proc should be mounted and functional + assert!(Path::new("/proc").exists(), "/proc should exist"); + assert!( + Path::new("/proc/self").exists(), + "/proc/self should exist" + ); + + // Should be able to read process info + let cmdline = fs::read_to_string("/proc/self/cmdline"); + assert!(cmdline.is_ok(), "Should be able to read /proc/self/cmdline"); + }); +} + +#[test] +fn sandbox_children_get_sequential_pids() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // Fork a child and verify it gets PID 2 + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + let status = waitpid(child, None).unwrap(); + match status { + WaitStatus::Exited(_, code) => { + assert_eq!(code, 0, "Child should be PID 2"); + } + other => panic!("Child unexpected status: {:?}", other), + } + } + Ok(ForkResult::Child) => { + // Read our PID from /proc + let stat = fs::read_to_string("/proc/self/stat").unwrap(); + let pid_str = stat.split_whitespace().next().unwrap(); + let pid: i32 = pid_str.parse().unwrap(); + // Should be PID 2 (parent is 1) + std::process::exit(if pid == 2 { 0 } else { 1 }); + } + Err(e) => panic!("fork failed: {}", e), + } + }); +} + +#[test] +fn sandbox_proc_shows_only_sandboxed_processes() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + apply_and_unwrap(&config_path); + + // Count processes visible in /proc + let mut proc_pids = 0; + for entry in fs::read_dir("/proc").unwrap() { + let entry = entry.unwrap(); + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + // Count numeric directories (PIDs) + if name_str.chars().all(|c| c.is_ascii_digit()) { + proc_pids += 1; + } + } + + // Should only see PID 1 (ourselves) + assert_eq!(proc_pids, 1, "Should only see 1 process in /proc, saw {}", proc_pids); + }); +} diff --git a/vm-switch/tests/sandbox_mount_ns.rs b/vm-switch/tests/sandbox_mount_ns.rs new file mode 100644 index 0000000..3905d43 --- /dev/null +++ b/vm-switch/tests/sandbox_mount_ns.rs @@ -0,0 +1,140 @@ +//! Integration tests for mount namespace and filesystem isolation. +//! +//! These tests require user namespace support (for unprivileged mount ns). + +use nix::unistd::Uid; +use std::fs; +use std::path::Path; +use vm_switch::sandbox::{enter_user_namespace, setup_filesystem_isolation}; + +/// Helper to run test in forked child process. +fn run_in_fork(test_fn: F) { + if Uid::current().is_root() { + eprintln!("Skipping test: already running as root"); + return; + } + + match unsafe { nix::unistd::fork() } { + Ok(nix::unistd::ForkResult::Parent { child }) => { + let status = nix::sys::wait::waitpid(child, None).unwrap(); + match status { + nix::sys::wait::WaitStatus::Exited(_, 0) => {} + other => panic!("Child failed: {:?}", other), + } + } + Ok(nix::unistd::ForkResult::Child) => { + let result = std::panic::catch_unwind(test_fn); + match &result { + Err(e) => { + if let Some(s) = e.downcast_ref::<&str>() { + eprintln!("Child panic: {}", s); + } else if let Some(s) = e.downcast_ref::() { + eprintln!("Child panic: {}", s); + } else { + eprintln!("Child panic: unknown error"); + } + } + Ok(()) => {} + } + std::process::exit(if result.is_ok() { 0 } else { 1 }); + } + Err(e) => panic!("Fork failed: {}", e), + } +} + +#[test] +fn filesystem_isolation_creates_minimal_root() { + // Create a temp config dir that we'll bind-mount + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + // Create a marker file in config dir + fs::write(config_path.join("marker.txt"), "test").unwrap(); + + run_in_fork(move || { + enter_user_namespace().expect("enter_user_namespace failed"); + setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed"); + + // Verify /config exists and contains our marker + assert!(Path::new("/config/marker.txt").exists(), "/config/marker.txt should exist"); + let content = fs::read_to_string("/config/marker.txt").unwrap(); + assert_eq!(content, "test"); + + // Verify /proc mount point exists (may not be mounted without PID namespace) + assert!(Path::new("/proc").is_dir(), "/proc should be a directory"); + + // Verify /dev devices exist + assert!(Path::new("/dev/null").exists(), "/dev/null should exist"); + assert!(Path::new("/dev/zero").exists(), "/dev/zero should exist"); + assert!(Path::new("/dev/urandom").exists(), "/dev/urandom should exist"); + + // Verify /tmp exists + assert!(Path::new("/tmp").is_dir(), "/tmp should be a directory"); + }); +} + +#[test] +fn filesystem_isolation_hides_host_filesystem() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + enter_user_namespace().expect("enter_user_namespace failed"); + setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed"); + + // These paths should NOT exist (host filesystem hidden) + assert!(!Path::new("/home").exists(), "/home should not exist"); + assert!(!Path::new("/usr").exists(), "/usr should not exist"); + assert!(!Path::new("/etc").exists(), "/etc should not exist"); + assert!(!Path::new("/bin").exists(), "/bin should not exist"); + assert!(!Path::new("/nix").exists(), "/nix should not exist"); + }); +} + +#[test] +fn filesystem_isolation_config_dir_is_writable() { + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + enter_user_namespace().expect("enter_user_namespace failed"); + setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed"); + + // Should be able to create files in /config (for socket files) + let test_file = Path::new("/config/test-write.txt"); + fs::write(test_file, "writable").expect("should be able to write to /config"); + assert!(test_file.exists()); + + let content = fs::read_to_string(test_file).unwrap(); + assert_eq!(content, "writable"); + }); +} + +#[test] +fn dev_devices_are_functional() { + use std::io::Read; + + let config_dir = tempfile::tempdir().unwrap(); + let config_path = config_dir.path().to_path_buf(); + + run_in_fork(move || { + enter_user_namespace().expect("enter_user_namespace failed"); + setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed"); + + // /dev/null should accept writes and return nothing on read + fs::write("/dev/null", "discard this").expect("write to /dev/null"); + let content = fs::read("/dev/null").unwrap(); + assert!(content.is_empty(), "/dev/null should return empty on read"); + + // /dev/zero should return zeros (read limited amount) + let mut f = fs::File::open("/dev/zero").expect("open /dev/zero"); + let mut buf = [0xffu8; 16]; // Initialize with non-zero to verify it changes + f.read_exact(&mut buf).expect("read from /dev/zero"); + assert!(buf.iter().all(|&b| b == 0), "/dev/zero should return zeros"); + + // /dev/urandom should return random bytes (just check it's readable) + let mut f = fs::File::open("/dev/urandom").expect("open /dev/urandom"); + let mut buf = [0u8; 16]; + f.read_exact(&mut buf).expect("/dev/urandom should be readable"); + }); +} diff --git a/vm-switch/tests/sandbox_user_ns.rs b/vm-switch/tests/sandbox_user_ns.rs new file mode 100644 index 0000000..13b08f7 --- /dev/null +++ b/vm-switch/tests/sandbox_user_ns.rs @@ -0,0 +1,88 @@ +//! Integration tests for user namespace isolation. +//! +//! These tests require the ability to create user namespaces, +//! which is available to unprivileged users on most Linux systems. + +use nix::unistd::{getgid, getuid, Uid}; +use std::fs; +use vm_switch::sandbox::enter_user_namespace; + +#[test] +fn enter_user_namespace_maps_to_root() { + // Skip if we're already root (CI environments sometimes run as root) + if Uid::current().is_root() { + eprintln!("Skipping test: already running as root"); + return; + } + + // Fork to avoid affecting the test process + match unsafe { nix::unistd::fork() } { + Ok(nix::unistd::ForkResult::Parent { child }) => { + // Parent waits for child + let status = nix::sys::wait::waitpid(child, None).unwrap(); + match status { + nix::sys::wait::WaitStatus::Exited(_, 0) => {} + other => panic!("Child failed: {:?}", other), + } + } + Ok(nix::unistd::ForkResult::Child) => { + // Child enters user namespace and checks UID + let result = std::panic::catch_unwind(|| { + enter_user_namespace().expect("enter_user_namespace failed"); + + // After entering user namespace, we should appear as root + let uid = getuid(); + let gid = getgid(); + + assert!(uid.is_root(), "Expected UID 0, got {}", uid); + assert_eq!(gid.as_raw(), 0, "Expected GID 0, got {}", gid); + + // Verify we're in a different namespace + let ns_path = fs::read_link("/proc/self/ns/user").unwrap(); + eprintln!("User namespace: {:?}", ns_path); + }); + + std::process::exit(if result.is_ok() { 0 } else { 1 }); + } + Err(e) => panic!("Fork failed: {}", e), + } +} + +#[test] +fn enter_user_namespace_is_isolated() { + // Skip if we're already root + if Uid::current().is_root() { + eprintln!("Skipping test: already running as root"); + return; + } + + // Get parent namespace inode before fork + let parent_ns = fs::read_link("/proc/self/ns/user").unwrap(); + + match unsafe { nix::unistd::fork() } { + Ok(nix::unistd::ForkResult::Parent { child }) => { + let status = nix::sys::wait::waitpid(child, None).unwrap(); + match status { + nix::sys::wait::WaitStatus::Exited(_, 0) => {} + other => panic!("Child failed: {:?}", other), + } + } + Ok(nix::unistd::ForkResult::Child) => { + let result = std::panic::catch_unwind(|| { + enter_user_namespace().expect("enter_user_namespace failed"); + + let child_ns = fs::read_link("/proc/self/ns/user").unwrap(); + + // Namespace should be different from parent + assert_ne!( + parent_ns, child_ns, + "User namespace should differ: parent={:?}, child={:?}", + parent_ns, child_ns + ); + }); + + std::process::exit(if result.is_ok() { 0 } else { 1 }); + } + Err(e) => panic!("Fork failed: {}", e), + } +} diff --git a/vm-switch/tests/seccomp_filter.rs b/vm-switch/tests/seccomp_filter.rs new file mode 100644 index 0000000..b60bf95 --- /dev/null +++ b/vm-switch/tests/seccomp_filter.rs @@ -0,0 +1,178 @@ +//! Integration tests for seccomp filtering. + +use nix::sys::signal::Signal; +use nix::sys::wait::{waitpid, WaitStatus}; +use nix::unistd::{fork, ForkResult, Uid}; +use std::process; +use vm_switch::seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompMode}; + +/// Run test in forked child, return wait status. +fn run_in_fork i32 + std::panic::UnwindSafe>(test_fn: F) -> WaitStatus { + if Uid::current().is_root() { + eprintln!("Skipping: running as root"); + return WaitStatus::Exited(nix::unistd::Pid::from_raw(0), 0); + } + + match unsafe { fork() } { + Ok(ForkResult::Parent { child }) => { + waitpid(child, None).expect("waitpid failed") + } + Ok(ForkResult::Child) => { + let result = std::panic::catch_unwind(test_fn); + let code = match result { + Ok(code) => code, + Err(_) => 1, + }; + process::exit(code); + } + Err(e) => panic!("Fork failed: {}", e), + } +} + +#[test] +fn seccomp_kill_blocks_ptrace() { + let status = run_in_fork(|| { + apply_main_seccomp(SeccompMode::Kill).expect("apply failed"); + + // ptrace is not whitelisted + unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) }; + 0 // Should not reach here + }); + + match status { + WaitStatus::Signaled(_, sig, _) => { + assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL); + } + WaitStatus::Exited(_, code) => { + assert_ne!(code, 0, "should have been killed"); + } + _ => panic!("unexpected: {:?}", status), + } +} + +#[test] +fn seccomp_trap_sends_sigsys() { + let status = run_in_fork(|| { + apply_main_seccomp(SeccompMode::Trap).expect("apply failed"); + unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) }; + 0 + }); + + match status { + WaitStatus::Signaled(_, Signal::SIGSYS, _) => {} + WaitStatus::Stopped(_, Signal::SIGSYS) => {} + _ => panic!("expected SIGSYS, got {:?}", status), + } +} + +#[test] +fn seccomp_disabled_allows_all() { + let status = run_in_fork(|| { + apply_main_seccomp(SeccompMode::Disabled).expect("apply failed"); + // Would be blocked, but disabled mode allows it + let _ = unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) }; + 0 + }); + + match status { + WaitStatus::Exited(_, 0) => {} + _ => panic!("expected clean exit, got {:?}", status), + } +} + +#[test] +fn seccomp_allows_whitelisted() { + let status = run_in_fork(|| { + apply_main_seccomp(SeccompMode::Kill).expect("apply failed"); + + // All whitelisted + let _ = unsafe { libc::getpid() }; + let _ = unsafe { libc::getuid() }; + + let msg = b"ok\n"; + let ret = unsafe { libc::write(2, msg.as_ptr() as *const _, msg.len()) }; + if ret < 0 { return 1; } + + 0 + }); + + match status { + WaitStatus::Exited(_, 0) => {} + _ => panic!("expected clean exit, got {:?}", status), + } +} + +#[test] +fn child_filter_blocks_fork() { + let status = run_in_fork(|| { + apply_child_seccomp(SeccompMode::Kill).expect("apply failed"); + + // fork not in child whitelist + let ret = unsafe { libc::fork() }; + if ret == 0 { + process::exit(99); // grandchild + } + 0 // Should not reach + }); + + match status { + WaitStatus::Signaled(_, sig, _) => { + assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL); + } + WaitStatus::Exited(_, code) => { + assert!(code != 0 && code != 99, "fork should be blocked"); + } + _ => panic!("unexpected: {:?}", status), + } +} + +#[test] +fn child_filter_blocks_socket() { + let status = run_in_fork(|| { + apply_child_seccomp(SeccompMode::Kill).expect("apply failed"); + + // socket() not in child whitelist + let fd = unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0) }; + if fd >= 0 { + unsafe { libc::close(fd) }; + return 1; // Should have been blocked + } + 0 + }); + + match status { + WaitStatus::Signaled(_, sig, _) => { + assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL); + } + WaitStatus::Exited(_, code) => { + assert_ne!(code, 0, "socket should be blocked"); + } + _ => panic!("unexpected: {:?}", status), + } +} + +#[test] +fn child_filter_allows_memfd() { + let status = run_in_fork(|| { + apply_child_seccomp(SeccompMode::Kill).expect("apply failed"); + + // memfd_create is in child whitelist (for ring buffers) + let fd = unsafe { + libc::syscall( + libc::SYS_memfd_create, + b"test\0".as_ptr(), + 0u32, + ) + }; + if fd < 0 { + return 1; + } + unsafe { libc::close(fd as i32) }; + 0 + }); + + match status { + WaitStatus::Exited(_, 0) => {} + _ => panic!("expected clean exit, got {:?}", status), + } +}