feat(vm-switch): add process isolation with namespace sandbox and seccomp
Replace the thread-based vhost-user backend architecture with a fork-based process model where each VM gets its own child process. This enables strong isolation between VMs handling untrusted network traffic, with multiple layers of defense in depth. Process model: - Main process watches config directory and orchestrates child lifecycle - One child process forked per VM, running as vhost-user net backend - Children communicate via SOCK_SEQPACKET control channel with SCM_RIGHTS - Automatic child restart on crash/disconnect, with peer notification - Ping/pong heartbeat monitoring for worker health (1s interval, 100ms timeout) - SIGCHLD handling integrated into tokio event loop Inter-process packet forwarding: - Lock-free SPSC ring buffers in shared memory (memfd + mmap) - 64-slot rings (~598KB each) with atomic head/tail, no locks in datapath - Eventfd signaling for empty-to-non-empty transitions - Main orchestrates buffer exchange: GetBuffer -> BufferReady -> PutBuffer - Zero-copy path: producers write directly into consumer's shared memory Namespace sandbox (applied before tokio, single-threaded): - User namespace: unprivileged outside, UID 0 inside - PID namespace: main is PID 1, children invisible to host - Mount namespace: minimal tmpfs root with /config, /dev, /proc, /tmp - IPC namespace: isolated System V IPC - Network namespace: empty, communication only via inherited FDs - Controllable via --no-sandbox flag Seccomp BPF filtering (two-tier whitelist): - Main filter: allows fork, socket creation, inotify, openat - Child filter: strict subset - no fork, no socket, no file open - Child filter applied after vhost setup, before event loop - Modes: kill (default), trap (SIGSYS debug), log, disabled Also adds vm-switch service dependencies to VM units in the NixOS module so VMs wait for their network switch before starting. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
6722c0fbb4
commit
6941d2fe4c
29 changed files with 6275 additions and 2041 deletions
69
CLAUDE.md
69
CLAUDE.md
|
|
@ -138,6 +138,14 @@ Provides L2 switching for VM-to-VM networks:
|
|||
- **Purpose:** Handles vhost-user protocol for VM network interfaces
|
||||
- **Systemd:** One service per vmNetwork (`vm-switch-<netname>.service`)
|
||||
|
||||
**CLI flags:**
|
||||
```
|
||||
-d, --config-dir <PATH> Config/MAC file directory (default: /run/vm-switch)
|
||||
--log-level <LEVEL> error, warn, info, debug, trace (default: warn)
|
||||
--no-sandbox Disable namespace sandboxing
|
||||
--seccomp-mode <MODE> kill (default), trap, log, disabled
|
||||
```
|
||||
|
||||
**Testing locally:**
|
||||
```bash
|
||||
# Build and run manually
|
||||
|
|
@ -149,6 +157,67 @@ mkdir -p /tmp/test-switch/router
|
|||
echo "52:00:00:00:00:01" > /tmp/test-switch/router/router.mac
|
||||
```
|
||||
|
||||
**Process model:** Main process forks one child per VM. Children are vhost-user net backends that handle virtio TX/RX for their VM. Main orchestrates lifecycle, config watching, and buffer exchange between children. Children exit when the vhost-user client (crosvm) disconnects; main automatically restarts them so crosvm can reconnect.
|
||||
|
||||
**Startup sequence:**
|
||||
1. Parse args, apply namespace sandbox (single-threaded, before tokio)
|
||||
2. Apply main seccomp filter
|
||||
3. Start tokio runtime, create ConfigWatcher + BackendManager
|
||||
4. Start tokio runtime, enter async event loop (SIGCHLD via tokio select branch)
|
||||
|
||||
**Key source files:**
|
||||
- `src/main.rs` - Entry point, sandbox/seccomp setup, async event loop
|
||||
- `src/manager.rs` - BackendManager: fork children, buffer exchange, crash cleanup
|
||||
- `src/child/process.rs` - Child entry point: control channel, vhost daemon, child seccomp
|
||||
- `src/child/forwarder.rs` - PacketForwarder: L2 routing via ring buffers
|
||||
- `src/child/vhost.rs` - ChildVhostBackend: virtio TX/RX callbacks
|
||||
- `src/child/poll.rs` - Event polling for control channel + ingress buffers
|
||||
- `src/control.rs` - Main-child IPC over Unix seqpacket sockets + SCM_RIGHTS
|
||||
- `src/ring.rs` - Lock-free SPSC ring buffer in shared memory (memfd)
|
||||
- `src/sandbox.rs` - Namespace isolation (user, PID, mount, IPC, network)
|
||||
- `src/seccomp.rs` - BPF syscall filters (main and child whitelists)
|
||||
- `src/frame.rs` - Ethernet frame parsing, MAC validation
|
||||
- `src/main.rs` - SIGCHLD handling via tokio select branch
|
||||
|
||||
**Control protocol** (main <-> child IPC via `SOCK_SEQPACKET` + `SCM_RIGHTS`):
|
||||
|
||||
| Direction | Message | FDs | Purpose |
|
||||
|-----------|---------|-----|---------|
|
||||
| Main -> Child | `GetBuffer { peer_name, peer_mac }` | - | Ask child to create ingress buffer for a peer |
|
||||
| Child -> Main | `BufferReady { peer_name }` | memfd, eventfd | Ingress buffer created, here are the FDs |
|
||||
| Main -> Child | `PutBuffer { peer_name, peer_mac, broadcast }` | memfd, eventfd | Give child a peer's buffer as egress target |
|
||||
| Main -> Child | `RemovePeer { peer_name }` | - | Clean up buffers for disconnected/crashed peer |
|
||||
| Main -> Child | `Ping` | - | Heartbeat request (sent every 1s) |
|
||||
| Child -> Main | `Ready` | - | Child initialized and ready |
|
||||
| Child -> Main | `Pong` | - | Heartbeat response (must arrive within 100ms) |
|
||||
|
||||
Messages serialized with `postcard`. FDs passed via ancillary data.
|
||||
|
||||
**Buffer exchange flow:**
|
||||
1. Main sends `GetBuffer` to Child1 ("create ingress buffer for Child2")
|
||||
2. Child1 creates SPSC ring buffer (memfd + eventfd), becomes Consumer, replies `BufferReady`
|
||||
3. Main forwards those FDs to Child2 via `PutBuffer` -- Child2 becomes Producer
|
||||
4. Packets now flow: Child2 writes to Producer -> shared memfd -> Child1 reads from Consumer
|
||||
|
||||
**SPSC ring buffer** (`ring.rs`): Lock-free single-producer/single-consumer queue backed by `memfd_create()` + `mmap(MAP_SHARED)`. 64 slots, ~598KB total. Head/tail use atomic operations (no locks in datapath). Eventfd signals empty-to-non-empty transitions.
|
||||
|
||||
**Sandbox** (applied before tokio, requires single-threaded):
|
||||
1. **User namespace** - Maps real UID to 0 inside, enables unprivileged namespace creation
|
||||
2. **PID namespace** - Fork into new PID ns; main becomes PID 1
|
||||
3. **Mount namespace** - Minimal tmpfs root with `/config` (bind-mount of config dir), `/dev` (null, zero, urandom), `/proc`, `/tmp`. Pivot root, unmount old.
|
||||
4. **IPC namespace** - Isolates System V IPC
|
||||
5. **Network namespace** - Empty (no interfaces). Communication only via inherited FDs.
|
||||
|
||||
**Seccomp filtering** (BPF syscall whitelist):
|
||||
- `--seccomp-mode=kill` (default): Terminate on blocked syscall
|
||||
- `--seccomp-mode=trap`: Send SIGSYS (debug with strace)
|
||||
- `--seccomp-mode=log`: Log violations but allow
|
||||
- `--seccomp-mode=disabled`: Skip filtering
|
||||
|
||||
Two filter tiers (child is a strict subset of main):
|
||||
- **Main**: Allows fork, socket creation, inotify, openat (config watching + child management)
|
||||
- **Child**: No fork, no socket creation, no file open. Applied after vhost setup completes. Allows clone3 for vhost-user threads.
|
||||
|
||||
### Dependencies
|
||||
|
||||
- Custom crosvm fork: `git.dsg.is/davidlowsec/crosvm.git`
|
||||
|
|
|
|||
50
README.md
50
README.md
|
|
@ -499,3 +499,53 @@ The host provides:
|
|||
- Polkit rules for the configured user to manage VM services without sudo
|
||||
- CLI tools: `vm-run`, `vm-start`, `vm-stop`, `vm-start-debug`, `vm-shell`
|
||||
- Desktop integration with .desktop files for guest applications
|
||||
|
||||
### vm-switch
|
||||
|
||||
The `vm-switch` daemon (`vm-switch/` Rust crate) provides L2 switching for VM-to-VM networks. One instance runs per `vmNetwork`, managed by systemd (`vm-switch-<netname>.service`).
|
||||
|
||||
**Process model:** The main process watches a config directory for MAC files and forks one child process per VM. Each child is a vhost-user net backend serving a single VM's network interface.
|
||||
|
||||
```
|
||||
Main Process
|
||||
(config watch, orchestration)
|
||||
/ | \
|
||||
fork / fork | fork \
|
||||
v v v
|
||||
Child: router Child: banking Child: shopping
|
||||
(vhost-user) (vhost-user) (vhost-user)
|
||||
| | |
|
||||
[unix socket] [unix socket] [unix socket]
|
||||
| | |
|
||||
crosvm crosvm crosvm
|
||||
(router VM) (banking VM) (shopping VM)
|
||||
```
|
||||
|
||||
**Packet forwarding** uses lock-free SPSC ring buffers in shared memory (`memfd_create` + `mmap`). When a VM transmits a frame, its child process validates the source MAC address and routes the frame to the correct destination:
|
||||
- Unicast: pushed into the destination child's ingress ring buffer
|
||||
- Broadcast/multicast: pushed into all peers' ingress buffers
|
||||
|
||||
Ring buffers use atomic head/tail pointers (no locks in the datapath) with eventfd signaling for empty-to-non-empty transitions.
|
||||
|
||||
**Buffer exchange protocol:** The main process orchestrates buffer setup between children via a control channel (`SOCK_SEQPACKET` + `SCM_RIGHTS` for passing memfd/eventfd file descriptors):
|
||||
|
||||
1. Main tells Child A: "create an ingress buffer for Child B" (`GetBuffer`)
|
||||
2. Child A creates the ring buffer and returns the FDs (`BufferReady`)
|
||||
3. Main forwards those FDs to Child B as an egress target (`PutBuffer`)
|
||||
4. Child B can now write frames directly into Child A's memory -- no copies through the main process
|
||||
|
||||
**Sandboxing:** The daemon runs in a multi-layer sandbox applied at startup (before any async runtime or threads):
|
||||
|
||||
| Layer | Mechanism | Effect |
|
||||
|-------|-----------|--------|
|
||||
| User namespace | `CLONE_NEWUSER` | Unprivileged outside, appears as UID 0 inside |
|
||||
| PID namespace | `CLONE_NEWPID` | Main is PID 1; children invisible to host |
|
||||
| Mount namespace | `CLONE_NEWNS` + pivot_root | Minimal tmpfs root: `/config`, `/dev` (null/zero/urandom), `/proc`, `/tmp` |
|
||||
| IPC namespace | `CLONE_NEWIPC` | Isolated System V IPC |
|
||||
| Network namespace | `CLONE_NEWNET` | No interfaces; communication only via inherited FDs |
|
||||
| Seccomp (main) | BPF whitelist | Allows fork, socket creation, inotify for config watching |
|
||||
| Seccomp (child) | Tighter BPF whitelist | No fork, no socket creation, no file open; applied after vhost setup |
|
||||
|
||||
Seccomp modes: `--seccomp-mode=kill` (default), `trap` (SIGSYS for debugging), `log`, `disabled`.
|
||||
|
||||
Disable sandboxing for debugging with `--no-sandbox` and `--seccomp-mode=disabled`.
|
||||
|
|
|
|||
|
|
@ -1097,7 +1097,10 @@ in
|
|||
vm:
|
||||
lib.nameValuePair "qubes-lite-${vm.name}-vm" {
|
||||
description = "qubes-lite VM: ${vm.name}";
|
||||
after = [ "network.target" ];
|
||||
after =
|
||||
[ "network.target" ]
|
||||
++ map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork);
|
||||
requires = map (netName: "vm-switch-${netName}.service") (lib.attrNames vm.vmNetwork);
|
||||
serviceConfig = {
|
||||
Type = "simple";
|
||||
ExecStart = "${mkVmScript vm}";
|
||||
|
|
|
|||
267
vm-switch/Cargo.lock
generated
267
vm-switch/Cargo.lock
generated
|
|
@ -61,6 +61,12 @@ dependencies = [
|
|||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.101"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea"
|
||||
|
||||
[[package]]
|
||||
name = "arc-swap"
|
||||
version = "1.8.1"
|
||||
|
|
@ -70,6 +76,15 @@ dependencies = [
|
|||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-polyfill"
|
||||
version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
|
||||
dependencies = [
|
||||
"critical-section",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
|
|
@ -88,6 +103,12 @@ version = "3.19.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.11.1"
|
||||
|
|
@ -100,6 +121,12 @@ version = "1.0.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||
|
||||
[[package]]
|
||||
name = "cfg_aliases"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724"
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.57"
|
||||
|
|
@ -140,12 +167,39 @@ version = "0.7.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3e64b0cc0439b12df2fa678eae89a1c56a529fd067a9115f7827f1fffd22b32"
|
||||
|
||||
[[package]]
|
||||
name = "cobs"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0fa961b519f0b462e3a3b4a34b64d119eeaca1d59af726fe450bbba07a9fc0a1"
|
||||
dependencies = [
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
|
||||
|
||||
[[package]]
|
||||
name = "critical-section"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b"
|
||||
|
||||
[[package]]
|
||||
name = "embedded-io"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef1a6892d9eef45c8fa6b9e0086428a2cca8491aca8f787c534a3d6d0bcb3ced"
|
||||
|
||||
[[package]]
|
||||
name = "embedded-io"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edd0f118536f44f5ccd48bcb8b111bdc3de888b58c74639dfb034a357d0f206d"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.14"
|
||||
|
|
@ -191,6 +245,42 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.3.4"
|
||||
|
|
@ -203,6 +293,29 @@ dependencies = [
|
|||
"wasip2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hash32"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0c35f58762feb77d74ebe43bdbc3210f09be9fe6742234d573bacc26ed92b67"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heapless"
|
||||
version = "0.7.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdc6457c0eb62c71aac4bc17216026d8410337c4126773b9c5daba343f17964f"
|
||||
dependencies = [
|
||||
"atomic-polyfill",
|
||||
"hash32",
|
||||
"rustc_version",
|
||||
"serde",
|
||||
"spin",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
|
|
@ -345,6 +458,18 @@ dependencies = [
|
|||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.29.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
|
||||
dependencies = [
|
||||
"bitflags 2.10.0",
|
||||
"cfg-if",
|
||||
"cfg_aliases",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "notify"
|
||||
version = "7.0.0"
|
||||
|
|
@ -436,6 +561,25 @@ version = "0.2.16"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "postcard"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6764c3b5dd454e283a30e6dfe78e9b31096d9e32036b5d1eaac7a6119ccb9a24"
|
||||
dependencies = [
|
||||
"cobs",
|
||||
"embedded-io 0.4.0",
|
||||
"embedded-io 0.6.1",
|
||||
"heapless",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.21"
|
||||
|
|
@ -533,6 +677,15 @@ version = "0.8.9"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
|
||||
dependencies = [
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "1.1.3"
|
||||
|
|
@ -561,12 +714,98 @@ dependencies = [
|
|||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scc"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc"
|
||||
dependencies = [
|
||||
"sdd",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "sdd"
|
||||
version = "3.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca"
|
||||
|
||||
[[package]]
|
||||
name = "seccompiler"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "345a3e4dddf721a478089d4697b83c6c0a8f5bf16086f6c13397e4534eb6e2e5"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_core"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial_test"
|
||||
version = "3.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555"
|
||||
dependencies = [
|
||||
"futures-executor",
|
||||
"futures-util",
|
||||
"log",
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"scc",
|
||||
"serial_test_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serial_test_derive"
|
||||
version = "3.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sharded-slab"
|
||||
version = "0.1.7"
|
||||
|
|
@ -586,6 +825,12 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.1"
|
||||
|
|
@ -602,6 +847,21 @@ dependencies = [
|
|||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
||||
dependencies = [
|
||||
"lock_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
|
@ -841,9 +1101,16 @@ dependencies = [
|
|||
name = "vm-switch"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"libc",
|
||||
"nix",
|
||||
"notify",
|
||||
"notify-debouncer-full",
|
||||
"postcard",
|
||||
"seccompiler",
|
||||
"serde",
|
||||
"serial_test",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
|
|
|
|||
|
|
@ -27,9 +27,20 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
|||
|
||||
# Error handling
|
||||
thiserror = "2"
|
||||
anyhow = "1"
|
||||
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Sandboxing
|
||||
nix = { version = "0.29", features = ["sched", "mount", "user", "signal", "process", "poll", "fs"] }
|
||||
seccompiler = "0.4"
|
||||
libc = "0.2"
|
||||
|
||||
# Serialization (for control channel)
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
postcard = { version = "1", features = ["alloc"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3"
|
||||
serial_test = "3"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,46 @@
|
|||
//! Command-line argument parsing.
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::fmt;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use crate::seccomp::SeccompMode;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use tracing_subscriber::fmt::format::Writer;
|
||||
use tracing_subscriber::fmt::{FmtContext, FormatEvent, FormatFields};
|
||||
use tracing_subscriber::registry::LookupSpan;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
/// Process name for log prefixes ("main" or "worker-$vmname").
|
||||
static PROCESS_NAME: Mutex<String> = Mutex::new(String::new());
|
||||
|
||||
/// Set the process name for log prefixes.
|
||||
pub fn set_process_name(name: impl Into<String>) {
|
||||
*PROCESS_NAME.lock().unwrap() = name.into();
|
||||
}
|
||||
|
||||
/// Custom log formatter that outputs: LEVEL process-name: message fields
|
||||
struct PrefixedFormatter;
|
||||
|
||||
impl<S, N> FormatEvent<S, N> for PrefixedFormatter
|
||||
where
|
||||
S: tracing::Subscriber + for<'a> LookupSpan<'a>,
|
||||
N: for<'a> FormatFields<'a> + 'static,
|
||||
{
|
||||
fn format_event(
|
||||
&self,
|
||||
ctx: &FmtContext<'_, S, N>,
|
||||
mut writer: Writer<'_>,
|
||||
event: &tracing::Event<'_>,
|
||||
) -> fmt::Result {
|
||||
let level = *event.metadata().level();
|
||||
let name = PROCESS_NAME.lock().unwrap();
|
||||
write!(writer, "{level} {name}: ")?;
|
||||
ctx.field_format().format_fields(writer.by_ref(), event)?;
|
||||
writeln!(writer)
|
||||
}
|
||||
}
|
||||
|
||||
/// Log level for the application.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, ValueEnum)]
|
||||
pub enum LogLevel {
|
||||
|
|
@ -39,16 +76,26 @@ pub struct Args {
|
|||
/// Log level (error, warn, info, debug, trace).
|
||||
#[arg(long, value_enum, default_value_t = LogLevel::Warn)]
|
||||
pub log_level: LogLevel,
|
||||
|
||||
/// Disable namespace sandboxing (for debugging).
|
||||
#[arg(long, default_value_t = false)]
|
||||
pub no_sandbox: bool,
|
||||
|
||||
/// Seccomp filter mode (kill, trap, log, disabled).
|
||||
#[arg(long, value_enum, default_value_t = SeccompMode::Kill)]
|
||||
pub seccomp_mode: SeccompMode,
|
||||
}
|
||||
|
||||
/// Initialize logging based on log level.
|
||||
pub fn init_logging(level: LogLevel) {
|
||||
set_process_name("main");
|
||||
|
||||
let filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new(format!("vm_switch={}", level.as_str())));
|
||||
|
||||
let _ = tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.with_target(false)
|
||||
.event_format(PrefixedFormatter)
|
||||
.try_init();
|
||||
}
|
||||
|
||||
|
|
@ -106,4 +153,40 @@ mod tests {
|
|||
init_logging(LogLevel::Debug);
|
||||
init_logging(LogLevel::Trace);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_no_sandbox_flag() {
|
||||
let args = Args::try_parse_from(["vm-switch", "--no-sandbox"]).unwrap();
|
||||
assert!(args.no_sandbox);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_no_sandbox_default_false() {
|
||||
let args = Args::try_parse_from(["vm-switch"]).unwrap();
|
||||
assert!(!args.no_sandbox);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_seccomp_mode_default() {
|
||||
let args = Args::try_parse_from(["vm-switch"]).unwrap();
|
||||
assert_eq!(args.seccomp_mode, SeccompMode::Kill);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_seccomp_mode_trap() {
|
||||
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "trap"]).unwrap();
|
||||
assert_eq!(args.seccomp_mode, SeccompMode::Trap);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_seccomp_mode_log() {
|
||||
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "log"]).unwrap();
|
||||
assert_eq!(args.seccomp_mode, SeccompMode::Log);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_seccomp_mode_disabled() {
|
||||
let args = Args::try_parse_from(["vm-switch", "--seccomp-mode", "disabled"]).unwrap();
|
||||
assert_eq!(args.seccomp_mode, SeccompMode::Disabled);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,921 +0,0 @@
|
|||
//! Vhost-user network backend implementation.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
|
||||
use tracing::{debug, trace, warn};
|
||||
use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
|
||||
use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT};
|
||||
use virtio_bindings::virtio_net::{
|
||||
VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4,
|
||||
VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4,
|
||||
VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS,
|
||||
};
|
||||
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
|
||||
use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
|
||||
use virtio_queue::QueueT;
|
||||
use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
|
||||
use vmm_sys_util::epoll::EventSet;
|
||||
|
||||
use crate::config::VmRole;
|
||||
use crate::frame::EthernetFrame;
|
||||
use crate::mac::Mac;
|
||||
use crate::switch::{ConnectionId, ForwardDecision, Switch};
|
||||
|
||||
/// Registry mapping connection IDs to their RX vrings and associated memory.
|
||||
/// Shared between all backends for frame routing.
|
||||
pub type VringRegistry = Arc<RwLock<HashMap<ConnectionId, (VringRwLock, GuestMemoryAtomic<GuestMemoryMmap>)>>>;
|
||||
|
||||
/// RX queue index.
|
||||
pub const RX_QUEUE: u16 = 0;
|
||||
/// TX queue index.
|
||||
pub const TX_QUEUE: u16 = 1;
|
||||
/// Number of queues (RX + TX).
|
||||
pub const NUM_QUEUES: usize = 2;
|
||||
/// Maximum queue size (must be power of 2, 32768 is typical max).
|
||||
pub const MAX_QUEUE_SIZE: usize = 32768;
|
||||
/// Size of virtio-net header.
|
||||
pub const VIRTIO_NET_HDR_SIZE: usize = 12;
|
||||
|
||||
/// Result of processing a frame from the TX queue.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProcessedFrame {
|
||||
/// Raw frame data (Ethernet frame, no virtio header).
|
||||
pub data: Vec<u8>,
|
||||
/// Forwarding decision from the switch.
|
||||
pub decision: ForwardDecision,
|
||||
}
|
||||
|
||||
/// Virtio net features we support.
|
||||
const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM)
|
||||
| (1 << VIRTIO_NET_F_GUEST_CSUM)
|
||||
| (1 << VIRTIO_NET_F_GUEST_TSO4)
|
||||
| (1 << VIRTIO_NET_F_GUEST_TSO6)
|
||||
| (1 << VIRTIO_NET_F_GUEST_UFO)
|
||||
| (1 << VIRTIO_NET_F_HOST_TSO4)
|
||||
| (1 << VIRTIO_NET_F_HOST_TSO6)
|
||||
| (1 << VIRTIO_NET_F_HOST_UFO)
|
||||
| (1 << VIRTIO_NET_F_MAC)
|
||||
| (1 << VIRTIO_NET_F_STATUS);
|
||||
|
||||
/// Network backend for a single VM.
|
||||
pub struct NetBackend {
|
||||
/// VM name for logging.
|
||||
name: String,
|
||||
/// VM's role (router or client).
|
||||
role: VmRole,
|
||||
/// VM's MAC address.
|
||||
mac: Mac,
|
||||
/// Connection ID in the switch (set after registration).
|
||||
connection_id: Mutex<Option<ConnectionId>>,
|
||||
/// Shared switch for forwarding.
|
||||
switch: Arc<RwLock<Switch>>,
|
||||
/// Shared registry of all backends' RX vrings for frame routing.
|
||||
vring_registry: VringRegistry,
|
||||
/// Guest memory.
|
||||
mem: Mutex<Option<GuestMemoryAtomic<GuestMemoryMmap>>>,
|
||||
/// Whether EVENT_IDX is enabled.
|
||||
event_idx: Mutex<bool>,
|
||||
/// Acked features.
|
||||
acked_features: Mutex<u64>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for NetBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("NetBackend")
|
||||
.field("name", &self.name)
|
||||
.field("role", &self.role)
|
||||
.field("mac", &self.mac)
|
||||
.field("connection_id", &self.connection_id)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl NetBackend {
|
||||
/// Create a new network backend.
|
||||
pub fn new(
|
||||
name: String,
|
||||
role: VmRole,
|
||||
mac: Mac,
|
||||
switch: Arc<RwLock<Switch>>,
|
||||
vring_registry: VringRegistry,
|
||||
) -> Self {
|
||||
Self {
|
||||
name,
|
||||
role,
|
||||
mac,
|
||||
connection_id: Mutex::new(None),
|
||||
switch,
|
||||
vring_registry,
|
||||
mem: Mutex::new(None),
|
||||
event_idx: Mutex::new(false),
|
||||
acked_features: Mutex::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the VM name.
|
||||
pub fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
/// Get the VM role.
|
||||
pub fn role(&self) -> VmRole {
|
||||
self.role
|
||||
}
|
||||
|
||||
/// Get the VM MAC.
|
||||
pub fn mac(&self) -> Mac {
|
||||
self.mac
|
||||
}
|
||||
|
||||
/// Get the connection ID (if registered).
|
||||
pub fn connection_id(&self) -> Option<ConnectionId> {
|
||||
*self.connection_id.lock().unwrap()
|
||||
}
|
||||
|
||||
/// Register this backend with the switch.
|
||||
pub fn register(&self) -> Option<ConnectionId> {
|
||||
let mut switch = self.switch.write().unwrap();
|
||||
let id = switch.register(self.name.clone(), self.role, self.mac)?;
|
||||
*self.connection_id.lock().unwrap() = Some(id);
|
||||
Some(id)
|
||||
}
|
||||
|
||||
/// Unregister this backend from the switch.
|
||||
pub fn unregister(&self) {
|
||||
let mut id_guard = self.connection_id.lock().unwrap();
|
||||
if let Some(id) = id_guard.take() {
|
||||
let mut switch = self.switch.write().unwrap();
|
||||
switch.unregister(id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear state for connection reset (called between vhost-user sessions).
|
||||
pub fn clear_state(&self) {
|
||||
// Clear guest memory
|
||||
*self.mem.lock().unwrap() = None;
|
||||
|
||||
// Reset event_idx
|
||||
*self.event_idx.lock().unwrap() = false;
|
||||
|
||||
// Reset acked features
|
||||
*self.acked_features.lock().unwrap() = 0;
|
||||
}
|
||||
|
||||
/// Process a frame and determine forwarding.
|
||||
///
|
||||
/// Takes raw Ethernet frame bytes (no virtio header).
|
||||
/// Returns None if not registered or frame is invalid.
|
||||
pub fn process_frame(&self, frame_data: &[u8]) -> Option<ProcessedFrame> {
|
||||
let conn_id = self.connection_id()?;
|
||||
|
||||
let frame = EthernetFrame::parse(frame_data)?;
|
||||
|
||||
let switch = self.switch.read().unwrap();
|
||||
let decision = switch.forward(conn_id, frame.source_mac(), frame.dest_mac());
|
||||
|
||||
Some(ProcessedFrame {
|
||||
data: frame_data.to_vec(),
|
||||
decision,
|
||||
})
|
||||
}
|
||||
|
||||
/// Strip the virtio-net header from frame data.
|
||||
/// Returns the Ethernet frame without the header.
|
||||
pub fn strip_virtio_header(data: &[u8]) -> Option<&[u8]> {
|
||||
if data.len() < VIRTIO_NET_HDR_SIZE {
|
||||
return None;
|
||||
}
|
||||
Some(&data[VIRTIO_NET_HDR_SIZE..])
|
||||
}
|
||||
|
||||
/// Prepend a virtio-net header to frame data.
|
||||
/// Returns the complete buffer for RX injection.
|
||||
pub fn prepend_virtio_header(frame: &[u8]) -> Vec<u8> {
|
||||
let mut result = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()];
|
||||
// Header is all zeros (basic operation, no offloading)
|
||||
result[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame);
|
||||
result
|
||||
}
|
||||
|
||||
/// Process frames from the TX queue.
|
||||
///
|
||||
/// Reads frames, strips virtio header, gets forwarding decisions.
|
||||
/// Returns frames with their forwarding decisions.
|
||||
pub fn process_tx_queue(&self, vring: &VringRwLock) -> Vec<ProcessedFrame> {
|
||||
use vm_memory::GuestMemoryLoadGuard;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Need connection ID and memory to process
|
||||
if self.connection_id().is_none() {
|
||||
trace!(vm = %self.name, "process_tx_queue: no connection_id");
|
||||
return results;
|
||||
}
|
||||
|
||||
let mem_guard = self.mem.lock().unwrap();
|
||||
let mem: GuestMemoryLoadGuard<GuestMemoryMmap> = match mem_guard.as_ref() {
|
||||
Some(m) => m.memory(),
|
||||
None => {
|
||||
warn!(vm = %self.name, "process_tx_queue: no guest memory set!");
|
||||
return results;
|
||||
}
|
||||
};
|
||||
|
||||
// Collect all frames and head indices
|
||||
let mut frames: Vec<(u16, Vec<u8>)> = Vec::new();
|
||||
{
|
||||
let mut vring_state = vring.get_mut();
|
||||
let queue = vring_state.get_queue_mut();
|
||||
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
queue_ready = queue.ready(),
|
||||
queue_size = queue.size(),
|
||||
next_avail = queue.next_avail(),
|
||||
next_used = queue.next_used(),
|
||||
"process_tx_queue: checking queue"
|
||||
);
|
||||
|
||||
while let Some(desc_chain) = queue.pop_descriptor_chain(mem.clone()) {
|
||||
let head_index = desc_chain.head_index();
|
||||
let mut raw_data = Vec::new();
|
||||
|
||||
// Read all descriptors in the chain
|
||||
for desc in desc_chain {
|
||||
let addr = desc.addr();
|
||||
let len = desc.len() as usize;
|
||||
|
||||
let mut buf = vec![0u8; len];
|
||||
if let Err(e) = mem.read_slice(&mut buf, addr) {
|
||||
warn!(
|
||||
vm = %self.name,
|
||||
head_index,
|
||||
addr = ?addr,
|
||||
len,
|
||||
error = %e,
|
||||
"process_tx_queue: failed to read descriptor"
|
||||
);
|
||||
break;
|
||||
}
|
||||
raw_data.extend_from_slice(&buf);
|
||||
}
|
||||
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
head_index,
|
||||
raw_len = raw_data.len(),
|
||||
"process_tx_queue: popped descriptor"
|
||||
);
|
||||
frames.push((head_index, raw_data));
|
||||
}
|
||||
}
|
||||
|
||||
if frames.is_empty() {
|
||||
trace!(vm = %self.name, "process_tx_queue: no frames in queue");
|
||||
}
|
||||
|
||||
// Process frames and mark descriptors as used
|
||||
for (head_index, raw_data) in &frames {
|
||||
if let Err(e) = vring.add_used(*head_index, 0) {
|
||||
warn!(
|
||||
vm = %self.name,
|
||||
head_index,
|
||||
error = ?e,
|
||||
"process_tx_queue: add_used failed"
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
head_index,
|
||||
"process_tx_queue: add_used ok"
|
||||
);
|
||||
}
|
||||
|
||||
// Strip header and process frame
|
||||
if let Some(frame_data) = Self::strip_virtio_header(raw_data) {
|
||||
if let Some(processed) = self.process_frame(frame_data) {
|
||||
results.push(processed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enable notifications so guest will kick us when it adds more buffers.
|
||||
// This is critical for EVENT_IDX: after draining the queue, we must tell
|
||||
// the guest to notify us of new buffers, otherwise it will suppress kicks.
|
||||
match vring.enable_notification() {
|
||||
Ok(has_more) => {
|
||||
if has_more {
|
||||
trace!(vm = %self.name, "process_tx_queue: enable_notification returned has_more=true");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(vm = %self.name, error = ?e, "process_tx_queue: enable_notification failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Signal guest that we've processed the queue
|
||||
if !frames.is_empty() {
|
||||
// Check if the call eventfd is set
|
||||
{
|
||||
let vring_state = vring.get_ref();
|
||||
let has_call = vring_state.get_call().is_some();
|
||||
if !has_call {
|
||||
warn!(
|
||||
vm = %self.name,
|
||||
num_frames = frames.len(),
|
||||
"process_tx_queue: no call eventfd set, cannot notify guest!"
|
||||
);
|
||||
} else {
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
num_frames = frames.len(),
|
||||
"process_tx_queue: call eventfd is set"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match vring.signal_used_queue() {
|
||||
Ok(()) => trace!(
|
||||
vm = %self.name,
|
||||
num_frames = frames.len(),
|
||||
"process_tx_queue: signal_used_queue ok"
|
||||
),
|
||||
Err(e) => warn!(
|
||||
vm = %self.name,
|
||||
num_frames = frames.len(),
|
||||
error = %e,
|
||||
"process_tx_queue: signal_used_queue FAILED"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// Log final queue state
|
||||
{
|
||||
let vring_state = vring.get_ref();
|
||||
let queue = vring_state.get_queue();
|
||||
debug!(
|
||||
vm = %self.name,
|
||||
frames_processed = results.len(),
|
||||
next_avail = queue.next_avail(),
|
||||
next_used = queue.next_used(),
|
||||
"process_tx_queue complete"
|
||||
);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Inject a frame into the RX queue.
|
||||
///
|
||||
/// Prepends virtio header and writes using scatter-gather across descriptors.
|
||||
/// Returns true if successful, false if queue is full or insufficient space.
|
||||
///
|
||||
/// Note: This is a static method that takes the destination VM's memory mapping.
|
||||
/// This is important when injecting frames into a different VM's RX queue -
|
||||
/// we must use that VM's memory mapping, not our own.
|
||||
pub fn inject_rx_frame(vring: &VringRwLock, mem: &GuestMemoryAtomic<GuestMemoryMmap>, frame: &[u8]) -> bool {
|
||||
use vm_memory::GuestMemoryLoadGuard;
|
||||
|
||||
let mem: GuestMemoryLoadGuard<GuestMemoryMmap> = mem.memory();
|
||||
|
||||
let data_to_write = Self::prepend_virtio_header(frame);
|
||||
let total_len = data_to_write.len();
|
||||
|
||||
let head_index;
|
||||
let written;
|
||||
{
|
||||
let mut vring_state = vring.get_mut();
|
||||
let queue = vring_state.get_queue_mut();
|
||||
|
||||
trace!(
|
||||
queue_ready = queue.ready(),
|
||||
queue_size = queue.size(),
|
||||
frame_len = frame.len(),
|
||||
total_len,
|
||||
"inject_rx_frame: checking queue"
|
||||
);
|
||||
|
||||
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
|
||||
Some(chain) => chain,
|
||||
None => {
|
||||
trace!("inject_rx_frame: no descriptor available (queue full)");
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
head_index = desc_chain.head_index();
|
||||
|
||||
// First pass: collect writable descriptors and calculate available space
|
||||
let mut writable_descs = Vec::new();
|
||||
for desc in desc_chain {
|
||||
if desc.is_write_only() {
|
||||
writable_descs.push((desc.addr(), desc.len() as usize));
|
||||
}
|
||||
}
|
||||
|
||||
let available_space: usize = writable_descs.iter().map(|(_, len)| *len).sum();
|
||||
|
||||
trace!(
|
||||
head_index,
|
||||
num_writable_descs = writable_descs.len(),
|
||||
available_space,
|
||||
total_len,
|
||||
"inject_rx_frame: descriptor chain"
|
||||
);
|
||||
|
||||
if available_space < total_len {
|
||||
// Insufficient space - don't write partial data
|
||||
warn!(
|
||||
head_index,
|
||||
available_space,
|
||||
total_len,
|
||||
"inject_rx_frame: insufficient space in descriptors"
|
||||
);
|
||||
written = 0;
|
||||
} else {
|
||||
// Second pass: scatter-gather write across all descriptors
|
||||
let mut bytes_written = 0;
|
||||
for (addr, len) in writable_descs {
|
||||
let remaining = total_len - bytes_written;
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
let to_write = std::cmp::min(remaining, len);
|
||||
|
||||
if let Err(e) = mem.write_slice(&data_to_write[bytes_written..bytes_written + to_write], addr) {
|
||||
warn!(
|
||||
head_index,
|
||||
addr = ?addr,
|
||||
to_write,
|
||||
error = %e,
|
||||
"inject_rx_frame: write_slice failed"
|
||||
);
|
||||
break;
|
||||
}
|
||||
bytes_written += to_write;
|
||||
}
|
||||
written = bytes_written;
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = vring.add_used(head_index, written as u32) {
|
||||
warn!(
|
||||
head_index,
|
||||
written,
|
||||
error = ?e,
|
||||
"inject_rx_frame: add_used failed"
|
||||
);
|
||||
}
|
||||
|
||||
// Re-enable notifications so guest knows to provide more RX buffers
|
||||
if let Err(e) = vring.enable_notification() {
|
||||
warn!(
|
||||
head_index,
|
||||
error = ?e,
|
||||
"inject_rx_frame: enable_notification failed"
|
||||
);
|
||||
}
|
||||
|
||||
if let Err(e) = vring.signal_used_queue() {
|
||||
warn!(
|
||||
head_index,
|
||||
error = %e,
|
||||
"inject_rx_frame: signal_used_queue failed"
|
||||
);
|
||||
}
|
||||
|
||||
let success = written >= total_len;
|
||||
trace!(
|
||||
head_index,
|
||||
written,
|
||||
total_len,
|
||||
success,
|
||||
"inject_rx_frame complete"
|
||||
);
|
||||
success
|
||||
}
|
||||
}
|
||||
|
||||
impl VhostUserBackend for NetBackend {
|
||||
type Bitmap = ();
|
||||
type Vring = VringRwLock;
|
||||
|
||||
fn num_queues(&self) -> usize {
|
||||
NUM_QUEUES
|
||||
}
|
||||
|
||||
fn max_queue_size(&self) -> usize {
|
||||
MAX_QUEUE_SIZE
|
||||
}
|
||||
|
||||
fn features(&self) -> u64 {
|
||||
let features = VIRTIO_NET_FEATURES
|
||||
| (1 << VIRTIO_F_VERSION_1)
|
||||
| (1 << VIRTIO_RING_F_EVENT_IDX)
|
||||
| VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
|
||||
trace!(vm = %self.name, features = format!("{:#x}", features), "features requested");
|
||||
features
|
||||
}
|
||||
|
||||
fn protocol_features(&self) -> VhostUserProtocolFeatures {
|
||||
let proto = VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ;
|
||||
trace!(vm = %self.name, protocol_features = ?proto, "protocol_features requested");
|
||||
proto
|
||||
}
|
||||
|
||||
fn set_event_idx(&self, enabled: bool) {
|
||||
debug!(vm = %self.name, enabled, "set_event_idx");
|
||||
*self.event_idx.lock().unwrap() = enabled;
|
||||
}
|
||||
|
||||
fn update_memory(
|
||||
&self,
|
||||
mem: GuestMemoryAtomic<GuestMemoryMmap>,
|
||||
) -> std::io::Result<()> {
|
||||
debug!(vm = %self.name, "update_memory called");
|
||||
*self.mem.lock().unwrap() = Some(mem);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_event(
|
||||
&self,
|
||||
device_event: u16,
|
||||
_evset: EventSet,
|
||||
vrings: &[VringRwLock],
|
||||
_thread_id: usize,
|
||||
) -> std::io::Result<()> {
|
||||
// Note: read_kick() is already called by VringEpollHandler before invoking this method.
|
||||
// We do not call it again to avoid blocking on a drained eventfd.
|
||||
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
device_event,
|
||||
queue_name = if device_event == RX_QUEUE { "RX" } else if device_event == TX_QUEUE { "TX" } else { "?" },
|
||||
"handle_event called"
|
||||
);
|
||||
|
||||
// Validate event index
|
||||
if (device_event as usize) >= vrings.len() {
|
||||
debug!(device_event, vrings_len = vrings.len(), "ignoring out-of-range device event");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Get our connection ID
|
||||
let conn_id = match self.connection_id() {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
debug!(vm = %self.name, "handle_event: no connection_id, ignoring");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// Register our RX vring in the shared registry (if not already done)
|
||||
// This allows other backends to inject frames into our RX queue
|
||||
if vrings.len() > RX_QUEUE as usize {
|
||||
let mut registry = self.vring_registry.write().unwrap();
|
||||
if !registry.contains_key(&conn_id) {
|
||||
// Clone the memory for registry storage
|
||||
let mem_guard = self.mem.lock().unwrap();
|
||||
if let Some(mem) = mem_guard.clone() {
|
||||
debug!(
|
||||
vm = %self.name,
|
||||
conn_id = ?conn_id,
|
||||
"registering RX vring in registry"
|
||||
);
|
||||
registry.insert(conn_id, (vrings[RX_QUEUE as usize].clone(), mem));
|
||||
} else {
|
||||
warn!(
|
||||
vm = %self.name,
|
||||
conn_id = ?conn_id,
|
||||
"cannot register RX vring: no guest memory set"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only process TX queue kicks
|
||||
if device_event != TX_QUEUE {
|
||||
trace!(vm = %self.name, device_event, "ignoring non-TX queue event");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Process frames from our TX queue
|
||||
let tx_vring = &vrings[TX_QUEUE as usize];
|
||||
let processed_frames = self.process_tx_queue(tx_vring);
|
||||
|
||||
if processed_frames.is_empty() {
|
||||
trace!(vm = %self.name, "handle_event: no frames processed from TX queue");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Route each frame to its destination(s)
|
||||
let registry = self.vring_registry.read().unwrap();
|
||||
let mut routed = 0;
|
||||
let mut dropped = 0;
|
||||
|
||||
trace!(
|
||||
vm = %self.name,
|
||||
num_frames = processed_frames.len(),
|
||||
registry_size = registry.len(),
|
||||
"routing frames"
|
||||
);
|
||||
|
||||
for processed in processed_frames {
|
||||
match &processed.decision {
|
||||
ForwardDecision::Unicast(dest_id) => {
|
||||
if let Some((rx_vring, dest_mem)) = registry.get(dest_id) {
|
||||
if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) {
|
||||
routed += 1;
|
||||
} else {
|
||||
dropped += 1;
|
||||
debug!(
|
||||
src = %self.name,
|
||||
dest_id = ?dest_id,
|
||||
"RX queue full, dropping frame"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
dropped += 1;
|
||||
debug!(
|
||||
src = %self.name,
|
||||
dest_id = ?dest_id,
|
||||
"destination not in registry, dropping frame"
|
||||
);
|
||||
}
|
||||
}
|
||||
ForwardDecision::Multicast(dest_ids) => {
|
||||
for dest_id in dest_ids {
|
||||
if let Some((rx_vring, dest_mem)) = registry.get(dest_id) {
|
||||
if Self::inject_rx_frame(rx_vring, dest_mem, &processed.data) {
|
||||
routed += 1;
|
||||
} else {
|
||||
dropped += 1;
|
||||
debug!(
|
||||
src = %self.name,
|
||||
dest_id = ?dest_id,
|
||||
"RX queue full, dropping frame"
|
||||
);
|
||||
}
|
||||
} else {
|
||||
dropped += 1;
|
||||
debug!(
|
||||
src = %self.name,
|
||||
dest_id = ?dest_id,
|
||||
"destination not in registry (multicast), dropping frame"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
ForwardDecision::Drop(reason) => {
|
||||
dropped += 1;
|
||||
debug!(src = %self.name, ?reason, "Dropping frame per switch decision");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trace!(vm = %self.name, routed, dropped, "handle_event complete");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn acked_features(&self, features: u64) {
|
||||
debug!(vm = %self.name, features = format!("{:#x}", features), "acked_features");
|
||||
*self.acked_features.lock().unwrap() = features;
|
||||
}
|
||||
|
||||
fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
|
||||
// Virtio net config: MAC (6 bytes) + status (2) + max_virtqueue_pairs (2)
|
||||
let mut config = [0u8; 10];
|
||||
config[0..6].copy_from_slice(&self.mac.bytes());
|
||||
config[6] = 1; // VIRTIO_NET_S_LINK_UP
|
||||
config[8] = 1; // max_virtqueue_pairs = 1
|
||||
let config = config;
|
||||
|
||||
let offset = offset as usize;
|
||||
let size = size as usize;
|
||||
if offset < config.len() {
|
||||
let end = std::cmp::min(offset + size, config.len());
|
||||
let mut result = config[offset..end].to_vec();
|
||||
// Pad with zeros if requested more than available
|
||||
result.resize(size, 0);
|
||||
result
|
||||
} else {
|
||||
vec![0u8; size]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::switch::Switch;
|
||||
|
||||
fn make_switch() -> Arc<RwLock<Switch>> {
|
||||
Arc::new(RwLock::new(Switch::new()))
|
||||
}
|
||||
|
||||
fn make_vring_registry() -> VringRegistry {
|
||||
Arc::new(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
fn make_backend() -> NetBackend {
|
||||
NetBackend::new(
|
||||
"test".to_string(),
|
||||
VmRole::Client,
|
||||
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
|
||||
make_switch(),
|
||||
make_vring_registry(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn num_queues_returns_two() {
|
||||
let backend = make_backend();
|
||||
assert_eq!(backend.num_queues(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn max_queue_size_returns_max() {
|
||||
let backend = make_backend();
|
||||
assert_eq!(backend.max_queue_size(), MAX_QUEUE_SIZE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_assigns_connection_id() {
|
||||
let backend = make_backend();
|
||||
|
||||
assert!(backend.connection_id().is_none());
|
||||
let id = backend.register();
|
||||
assert!(id.is_some());
|
||||
assert_eq!(backend.connection_id(), id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_clears_connection_id() {
|
||||
let backend = make_backend();
|
||||
|
||||
backend.register();
|
||||
assert!(backend.connection_id().is_some());
|
||||
|
||||
backend.unregister();
|
||||
assert!(backend.connection_id().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn duplicate_router_returns_none() {
|
||||
let switch = make_switch();
|
||||
let registry = make_vring_registry();
|
||||
|
||||
let router1 = NetBackend::new(
|
||||
"router1".to_string(),
|
||||
VmRole::Router,
|
||||
Mac::from_bytes([1, 0, 0, 0, 0, 1]),
|
||||
Arc::clone(&switch),
|
||||
Arc::clone(®istry),
|
||||
);
|
||||
let router2 = NetBackend::new(
|
||||
"router2".to_string(),
|
||||
VmRole::Router,
|
||||
Mac::from_bytes([2, 0, 0, 0, 0, 2]),
|
||||
Arc::clone(&switch),
|
||||
Arc::clone(®istry),
|
||||
);
|
||||
|
||||
assert!(router1.register().is_some());
|
||||
assert!(router2.register().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_config_returns_mac_at_offset_zero() {
|
||||
let switch = make_switch();
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry());
|
||||
|
||||
let config = backend.get_config(0, 6);
|
||||
assert_eq!(config, vec![0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn get_config_partial_read() {
|
||||
let switch = make_switch();
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
let backend = NetBackend::new("test".to_string(), VmRole::Client, mac, switch, make_vring_registry());
|
||||
|
||||
// Read bytes 2-4 of MAC
|
||||
let config = backend.get_config(2, 3);
|
||||
assert_eq!(config, vec![0xcc, 0xdd, 0xee]);
|
||||
}
|
||||
|
||||
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
|
||||
let mut frame = vec![0u8; 14];
|
||||
frame[0..6].copy_from_slice(&dest);
|
||||
frame[6..12].copy_from_slice(&src);
|
||||
frame
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_frame_returns_none_when_unregistered() {
|
||||
let backend = make_backend();
|
||||
// Don't register
|
||||
let frame = make_frame([0xff; 6], [1, 2, 3, 4, 5, 6]);
|
||||
assert!(backend.process_frame(&frame).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_frame_returns_decision_when_registered() {
|
||||
let switch = make_switch();
|
||||
let registry = make_vring_registry();
|
||||
|
||||
// Register router first (clients need a router to forward to)
|
||||
let router = NetBackend::new(
|
||||
"router".to_string(),
|
||||
VmRole::Router,
|
||||
Mac::from_bytes([0xaa, 0, 0, 0, 0, 1]),
|
||||
Arc::clone(&switch),
|
||||
Arc::clone(®istry),
|
||||
);
|
||||
router.register();
|
||||
|
||||
// Register client
|
||||
let client = NetBackend::new(
|
||||
"client".to_string(),
|
||||
VmRole::Client,
|
||||
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
|
||||
Arc::clone(&switch),
|
||||
Arc::clone(®istry),
|
||||
);
|
||||
client.register();
|
||||
|
||||
// Client sends to router
|
||||
let frame = make_frame([0xaa, 0, 0, 0, 0, 1], [1, 2, 3, 4, 5, 6]);
|
||||
let result = client.process_frame(&frame);
|
||||
|
||||
assert!(result.is_some());
|
||||
let processed = result.unwrap();
|
||||
assert_eq!(processed.data, frame);
|
||||
assert!(matches!(processed.decision, ForwardDecision::Unicast(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn process_frame_returns_none_for_invalid_frame() {
|
||||
let switch = make_switch();
|
||||
let backend = NetBackend::new(
|
||||
"test".to_string(),
|
||||
VmRole::Client,
|
||||
Mac::from_bytes([1, 2, 3, 4, 5, 6]),
|
||||
switch,
|
||||
make_vring_registry(),
|
||||
);
|
||||
backend.register();
|
||||
|
||||
// Frame too small (< 14 bytes)
|
||||
let frame = vec![0u8; 10];
|
||||
assert!(backend.process_frame(&frame).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_virtio_header_returns_frame() {
|
||||
// 12-byte header + 14-byte Ethernet frame
|
||||
let mut data = vec![0u8; 26];
|
||||
data[12..18].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]); // dest MAC
|
||||
|
||||
let frame = NetBackend::strip_virtio_header(&data).unwrap();
|
||||
assert_eq!(frame.len(), 14);
|
||||
assert_eq!(&frame[0..6], &[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_virtio_header_returns_none_if_too_small() {
|
||||
// Less than 12 bytes
|
||||
let data = vec![0u8; 10];
|
||||
assert!(NetBackend::strip_virtio_header(&data).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn strip_virtio_header_returns_empty_slice_if_exact() {
|
||||
// Exactly 12 bytes (header only, no frame)
|
||||
let data = vec![0u8; 12];
|
||||
let frame = NetBackend::strip_virtio_header(&data).unwrap();
|
||||
assert!(frame.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepend_virtio_header_adds_12_bytes() {
|
||||
let frame = vec![0xaa, 0xbb, 0xcc, 0xdd];
|
||||
let result = NetBackend::prepend_virtio_header(&frame);
|
||||
|
||||
assert_eq!(result.len(), 12 + 4);
|
||||
assert_eq!(&result[0..12], &[0u8; 12]); // Header is zeros
|
||||
assert_eq!(&result[12..], &[0xaa, 0xbb, 0xcc, 0xdd]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prepend_virtio_header_empty_frame() {
|
||||
let frame: Vec<u8> = vec![];
|
||||
let result = NetBackend::prepend_virtio_header(&frame);
|
||||
|
||||
assert_eq!(result.len(), 12);
|
||||
assert_eq!(&result[..], &[0u8; 12]);
|
||||
}
|
||||
}
|
||||
370
vm-switch/src/child/forwarder.rs
Normal file
370
vm-switch/src/child/forwarder.rs
Normal file
|
|
@ -0,0 +1,370 @@
|
|||
//! Packet forwarding logic for child processes.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::os::fd::{AsRawFd, RawFd};
|
||||
|
||||
use tracing::{debug, info, trace};
|
||||
|
||||
use crate::frame::validate_source_mac;
|
||||
use crate::mac::Mac;
|
||||
use crate::ring::{Consumer, Producer};
|
||||
|
||||
/// Ingress buffer from a peer (they produce, we consume).
|
||||
struct IngressBuffer {
|
||||
peer_mac: [u8; 6],
|
||||
consumer: Consumer,
|
||||
}
|
||||
|
||||
/// Egress buffer to a peer (we produce, they consume).
|
||||
struct EgressBuffer {
|
||||
peer_mac: [u8; 6],
|
||||
producer: Producer,
|
||||
/// If true, this buffer accepts broadcast/multicast traffic.
|
||||
broadcast: bool,
|
||||
}
|
||||
|
||||
/// Manages packet forwarding for a child process.
|
||||
pub struct PacketForwarder {
|
||||
/// This child's MAC address.
|
||||
our_mac: Mac,
|
||||
/// Ingress buffers FROM peers (they produce, we consume). Keyed by peer name.
|
||||
ingress: HashMap<String, IngressBuffer>,
|
||||
/// Egress buffers TO peers (we produce, they consume). Keyed by peer name.
|
||||
egress: HashMap<String, EgressBuffer>,
|
||||
}
|
||||
|
||||
impl PacketForwarder {
|
||||
/// Create a new packet forwarder.
|
||||
pub fn new(our_mac: Mac) -> Self {
|
||||
Self {
|
||||
our_mac,
|
||||
ingress: HashMap::new(),
|
||||
egress: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an ingress buffer from a peer (we consume packets they produce).
|
||||
pub fn add_ingress(&mut self, peer_name: String, peer_mac: [u8; 6], consumer: Consumer) {
|
||||
info!(peer = %peer_name, "added ingress buffer from peer");
|
||||
self.ingress.insert(
|
||||
peer_name,
|
||||
IngressBuffer { peer_mac, consumer },
|
||||
);
|
||||
}
|
||||
|
||||
/// Add an egress buffer to a peer (we produce packets they consume).
|
||||
pub fn add_egress(
|
||||
&mut self,
|
||||
peer_name: String,
|
||||
peer_mac: [u8; 6],
|
||||
producer: Producer,
|
||||
broadcast: bool,
|
||||
) {
|
||||
info!(peer = %peer_name, broadcast, "added egress buffer to peer");
|
||||
self.egress.insert(
|
||||
peer_name,
|
||||
EgressBuffer {
|
||||
peer_mac,
|
||||
producer,
|
||||
broadcast,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove all buffers for a peer.
|
||||
pub fn remove_peer(&mut self, peer_name: &str) {
|
||||
if self.ingress.remove(peer_name).is_some() {
|
||||
info!(peer = %peer_name, "removed ingress buffer");
|
||||
}
|
||||
if self.egress.remove(peer_name).is_some() {
|
||||
info!(peer = %peer_name, "removed egress buffer");
|
||||
}
|
||||
}
|
||||
|
||||
/// Get eventfds for all ingress consumers (for polling).
|
||||
pub fn ingress_eventfds(&self) -> Vec<(RawFd, [u8; 6])> {
|
||||
self.ingress
|
||||
.values()
|
||||
.map(|buf| (buf.consumer.eventfd().as_raw_fd(), buf.peer_mac))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Forward a TX frame to peers based on destination MAC.
|
||||
///
|
||||
/// - Broadcast/multicast: sent to all egress buffers with broadcast=true
|
||||
/// - Unicast: sent to the egress buffer matching the destination MAC
|
||||
pub fn forward_tx(&self, frame: &[u8]) -> bool {
|
||||
// Validate source MAC
|
||||
if !validate_source_mac(frame, self.our_mac) {
|
||||
debug!(
|
||||
reason = "source MAC rejected",
|
||||
our_mac = %self.our_mac,
|
||||
frame_src = %Mac::from_bytes(frame[6..12].try_into().unwrap()),
|
||||
size = frame.len(),
|
||||
"TX: dropped"
|
||||
);
|
||||
return false;
|
||||
}
|
||||
|
||||
let dest_mac: [u8; 6] = frame[0..6].try_into().unwrap();
|
||||
let dest = Mac::from_bytes(dest_mac);
|
||||
|
||||
// Broadcast/multicast: send to all egress buffers with broadcast=true
|
||||
if dest.is_broadcast() || dest.is_multicast() {
|
||||
let mut sent = false;
|
||||
for (peer_name, egress) in &self.egress {
|
||||
if egress.broadcast {
|
||||
if egress.producer.push(frame) {
|
||||
trace!(
|
||||
to = %peer_name,
|
||||
mac = %Mac::from_bytes(egress.peer_mac),
|
||||
size = frame.len(),
|
||||
"TX: broadcast"
|
||||
);
|
||||
sent = true;
|
||||
} else {
|
||||
debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: broadcast dropped");
|
||||
}
|
||||
}
|
||||
}
|
||||
if !sent {
|
||||
debug!(reason = "no broadcast peers", size = frame.len(), "TX: dropped");
|
||||
}
|
||||
return sent;
|
||||
}
|
||||
|
||||
// Unicast: find egress buffer by destination MAC
|
||||
for (peer_name, egress) in &self.egress {
|
||||
if egress.peer_mac == dest_mac {
|
||||
if egress.producer.push(frame) {
|
||||
trace!(
|
||||
to = %peer_name,
|
||||
mac = %Mac::from_bytes(egress.peer_mac),
|
||||
size = frame.len(),
|
||||
"TX: pushed to egress buffer"
|
||||
);
|
||||
return true;
|
||||
} else {
|
||||
debug!(reason = "buffer full", to = %peer_name, size = frame.len(), "TX: dropped");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unknown destination
|
||||
debug!(
|
||||
reason = "unknown destination",
|
||||
dest_mac = %dest,
|
||||
size = frame.len(),
|
||||
"TX: dropped"
|
||||
);
|
||||
false
|
||||
}
|
||||
|
||||
/// Poll all ingress buffers and return received frames.
|
||||
///
|
||||
/// Validates source MAC matches the expected peer for each buffer.
|
||||
pub fn poll_ingress(&self) -> Vec<Vec<u8>> {
|
||||
let mut frames = Vec::new();
|
||||
|
||||
for (peer_name, ingress) in &self.ingress {
|
||||
// Drain the eventfd
|
||||
ingress.consumer.drain_eventfd();
|
||||
|
||||
// Pop all available frames
|
||||
while let Some(frame) = ingress.consumer.pop() {
|
||||
// Validate source MAC matches expected peer
|
||||
if !validate_source_mac(&frame, Mac::from_bytes(ingress.peer_mac)) {
|
||||
debug!(
|
||||
reason = "source MAC mismatch",
|
||||
from = %peer_name,
|
||||
expected = %Mac::from_bytes(ingress.peer_mac),
|
||||
actual = %Mac::from_bytes(frame[6..12].try_into().unwrap_or([0; 6])),
|
||||
size = frame.len(),
|
||||
"RX: dropped"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
trace!(
|
||||
from = %peer_name,
|
||||
mac = %Mac::from_bytes(ingress.peer_mac),
|
||||
size = frame.len(),
|
||||
"RX: read from ingress buffer"
|
||||
);
|
||||
frames.push(frame);
|
||||
}
|
||||
}
|
||||
|
||||
frames
|
||||
}
|
||||
|
||||
/// Get number of configured ingress peers.
|
||||
pub fn ingress_count(&self) -> usize {
|
||||
self.ingress.len()
|
||||
}
|
||||
|
||||
/// Get number of configured egress peers.
|
||||
pub fn egress_count(&self) -> usize {
|
||||
self.egress.len()
|
||||
}
|
||||
|
||||
/// Get our MAC address.
|
||||
pub fn our_mac(&self) -> Mac {
|
||||
self.our_mac
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::os::fd::{FromRawFd, OwnedFd};
|
||||
|
||||
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
|
||||
let mut frame = vec![0u8; 14];
|
||||
frame[0..6].copy_from_slice(&dest);
|
||||
frame[6..12].copy_from_slice(&src);
|
||||
frame
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_tx_validates_source_mac() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Frame with wrong source MAC - should be dropped
|
||||
let frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]);
|
||||
assert!(!forwarder.forward_tx(&frame));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_tx_drops_when_no_egress() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Frame with correct source MAC but no egress peers
|
||||
let frame = make_frame([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff], our_mac.bytes());
|
||||
assert!(!forwarder.forward_tx(&frame));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_tx_unicast_to_matching_peer() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Add egress to peer with specific MAC
|
||||
let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
|
||||
let consumer = Consumer::new().expect("consumer");
|
||||
let producer = Producer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("producer");
|
||||
forwarder.add_egress("router".to_string(), peer_mac, producer, false);
|
||||
|
||||
// Unicast frame to that MAC should succeed
|
||||
let frame = make_frame(peer_mac, our_mac.bytes());
|
||||
assert!(forwarder.forward_tx(&frame));
|
||||
|
||||
// Consumer should receive the frame
|
||||
consumer.drain_eventfd();
|
||||
let received = consumer.pop().expect("should have frame");
|
||||
assert_eq!(received, frame);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_tx_broadcast_to_broadcast_peers() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Add broadcast-enabled egress
|
||||
let consumer1 = Consumer::new().expect("consumer1");
|
||||
let producer1 = Producer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer1.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("producer1");
|
||||
forwarder.add_egress("router".to_string(), [0x11; 6], producer1, true);
|
||||
|
||||
// Add non-broadcast egress
|
||||
let consumer2 = Consumer::new().expect("consumer2");
|
||||
let producer2 = Producer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("producer2");
|
||||
forwarder.add_egress("client_a".to_string(), [0x22; 6], producer2, false);
|
||||
|
||||
// Broadcast frame
|
||||
let frame = make_frame([0xff; 6], our_mac.bytes());
|
||||
assert!(forwarder.forward_tx(&frame));
|
||||
|
||||
// Only broadcast-enabled peer should receive
|
||||
consumer1.drain_eventfd();
|
||||
assert!(consumer1.pop().is_some());
|
||||
|
||||
consumer2.drain_eventfd();
|
||||
assert!(consumer2.pop().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn poll_ingress_validates_source_mac() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Create producer/consumer pair - producer simulates peer sending to us
|
||||
let producer = Producer::new().expect("producer");
|
||||
let consumer = Consumer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("consumer");
|
||||
|
||||
let peer_mac = [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff];
|
||||
forwarder.add_ingress("router".to_string(), peer_mac, consumer);
|
||||
|
||||
// Push frame with correct source MAC
|
||||
let good_frame = make_frame(our_mac.bytes(), peer_mac);
|
||||
producer.push(&good_frame);
|
||||
|
||||
// Push frame with wrong source MAC
|
||||
let bad_frame = make_frame(our_mac.bytes(), [0x99; 6]);
|
||||
producer.push(&bad_frame);
|
||||
|
||||
// Only good frame should be returned
|
||||
let frames = forwarder.poll_ingress();
|
||||
assert_eq!(frames.len(), 1);
|
||||
assert_eq!(frames[0], good_frame);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remove_peer_cleans_up_buffers() {
|
||||
let our_mac = Mac::from_bytes([1, 2, 3, 4, 5, 6]);
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Add ingress
|
||||
let producer = Producer::new().expect("producer");
|
||||
let consumer = Consumer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(producer.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("consumer");
|
||||
forwarder.add_ingress("router".to_string(), [0x11; 6], consumer);
|
||||
|
||||
// Add egress
|
||||
let consumer2 = Consumer::new().expect("consumer2");
|
||||
let producer2 = Producer::from_fds(
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.memfd().as_raw_fd())) },
|
||||
unsafe { OwnedFd::from_raw_fd(nix::libc::dup(consumer2.eventfd().as_raw_fd())) },
|
||||
)
|
||||
.expect("producer2");
|
||||
forwarder.add_egress("router".to_string(), [0x11; 6], producer2, true);
|
||||
|
||||
assert_eq!(forwarder.ingress_count(), 1);
|
||||
assert_eq!(forwarder.egress_count(), 1);
|
||||
|
||||
forwarder.remove_peer("router");
|
||||
|
||||
assert_eq!(forwarder.ingress_count(), 0);
|
||||
assert_eq!(forwarder.egress_count(), 0);
|
||||
}
|
||||
}
|
||||
14
vm-switch/src/child/mod.rs
Normal file
14
vm-switch/src/child/mod.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
//! Child process entry point for VM backends.
|
||||
//!
|
||||
//! Each VM runs in its own forked child process, communicating with
|
||||
//! the main process via a control channel.
|
||||
|
||||
pub mod forwarder;
|
||||
pub mod poll;
|
||||
pub mod process;
|
||||
pub mod vhost;
|
||||
|
||||
pub use forwarder::PacketForwarder;
|
||||
pub use poll::{poll_events, PollResult};
|
||||
pub use process::run_child_process;
|
||||
pub use vhost::ChildVhostBackend;
|
||||
115
vm-switch/src/child/poll.rs
Normal file
115
vm-switch/src/child/poll.rs
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
//! Event polling for child processes.
|
||||
|
||||
use std::os::fd::{BorrowedFd, RawFd};
|
||||
|
||||
use nix::poll::{poll, PollFd, PollFlags, PollTimeout};
|
||||
use nix::Error as NixError;
|
||||
|
||||
/// Result of polling for events.
|
||||
#[derive(Debug)]
|
||||
pub enum PollResult {
|
||||
/// Control channel has data.
|
||||
Control,
|
||||
/// One or more ingress buffers have data.
|
||||
Ingress(Vec<[u8; 6]>),
|
||||
/// Second FD slot has an event (POLLIN/POLLHUP/POLLERR).
|
||||
/// Used for daemon exit pipe detection.
|
||||
TxKick,
|
||||
/// Timeout expired with no events.
|
||||
Timeout,
|
||||
/// Error occurred.
|
||||
Error(NixError),
|
||||
}
|
||||
|
||||
/// Poll for events on control channel, TX kick, and ingress eventfds.
|
||||
///
|
||||
/// # Safety
|
||||
/// All raw file descriptors passed in must be valid and open for the duration of this call.
|
||||
pub fn poll_events(
|
||||
control_fd: RawFd,
|
||||
tx_kick_fd: Option<RawFd>,
|
||||
ingress_fds: &[(RawFd, [u8; 6])],
|
||||
timeout_ms: i32,
|
||||
) -> PollResult {
|
||||
// SAFETY: caller guarantees fds are valid for the duration of this call
|
||||
let control_borrowed = unsafe { BorrowedFd::borrow_raw(control_fd) };
|
||||
|
||||
let mut pollfds: Vec<PollFd> = Vec::with_capacity(2 + ingress_fds.len());
|
||||
|
||||
// Control channel is first
|
||||
pollfds.push(PollFd::new(control_borrowed, PollFlags::POLLIN));
|
||||
|
||||
// TX kick eventfd is second (if present)
|
||||
let tx_index = if let Some(fd) = tx_kick_fd {
|
||||
let tx_borrowed = unsafe { BorrowedFd::borrow_raw(fd) };
|
||||
pollfds.push(PollFd::new(tx_borrowed, PollFlags::POLLIN));
|
||||
Some(1)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Ingress eventfds follow
|
||||
let ingress_start = pollfds.len();
|
||||
for (fd, _mac) in ingress_fds {
|
||||
let ingress_borrowed = unsafe { BorrowedFd::borrow_raw(*fd) };
|
||||
pollfds.push(PollFd::new(ingress_borrowed, PollFlags::POLLIN));
|
||||
}
|
||||
|
||||
let timeout = if timeout_ms < 0 {
|
||||
PollTimeout::NONE
|
||||
} else if timeout_ms > u16::MAX as i32 {
|
||||
PollTimeout::MAX
|
||||
} else {
|
||||
PollTimeout::from(timeout_ms as u16)
|
||||
};
|
||||
|
||||
match poll(&mut pollfds, timeout) {
|
||||
Ok(0) => PollResult::Timeout,
|
||||
Ok(_) => {
|
||||
// Check control channel first (priority)
|
||||
if let Some(revents) = pollfds[0].revents() {
|
||||
if revents.contains(PollFlags::POLLIN)
|
||||
|| revents.contains(PollFlags::POLLHUP)
|
||||
|| revents.contains(PollFlags::POLLERR)
|
||||
{
|
||||
return PollResult::Control;
|
||||
}
|
||||
}
|
||||
|
||||
// Check TX kick / daemon exit pipe
|
||||
if let Some(idx) = tx_index {
|
||||
if let Some(revents) = pollfds[idx].revents() {
|
||||
if revents.contains(PollFlags::POLLIN)
|
||||
|| revents.contains(PollFlags::POLLHUP)
|
||||
|| revents.contains(PollFlags::POLLERR)
|
||||
{
|
||||
return PollResult::TxKick;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check ingress eventfds
|
||||
let mut ready_macs = Vec::new();
|
||||
for (i, (_, mac)) in ingress_fds.iter().enumerate() {
|
||||
if let Some(revents) = pollfds[ingress_start + i].revents() {
|
||||
if revents.contains(PollFlags::POLLIN) {
|
||||
ready_macs.push(*mac);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ready_macs.is_empty() {
|
||||
PollResult::Timeout
|
||||
} else {
|
||||
PollResult::Ingress(ready_macs)
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if e == NixError::EINTR {
|
||||
PollResult::Timeout
|
||||
} else {
|
||||
PollResult::Error(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
239
vm-switch/src/child/process.rs
Normal file
239
vm-switch/src/child/process.rs
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
//! Child process main loop.
|
||||
|
||||
use std::os::fd::{AsRawFd, OwnedFd, RawFd};
|
||||
use std::os::unix::net::UnixListener;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
|
||||
use nix::unistd::pipe;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use vhost_user_backend::VhostUserDaemon;
|
||||
use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
|
||||
|
||||
use crate::control::{ChildToMain, ControlChannel, ControlError, MainToChild};
|
||||
use crate::mac::Mac;
|
||||
use crate::ring::{Consumer, Producer};
|
||||
use crate::seccomp::{apply_child_seccomp, SeccompMode};
|
||||
|
||||
use super::forwarder::PacketForwarder;
|
||||
use super::poll::{poll_events, PollResult};
|
||||
use super::vhost::ChildVhostBackend;
|
||||
|
||||
/// Run the child process.
|
||||
///
|
||||
/// This is the entry point after fork(). Does not return.
|
||||
pub fn run_child_process(
|
||||
vm_name: &str,
|
||||
mac: Mac,
|
||||
control_fd: OwnedFd,
|
||||
socket_path: &Path,
|
||||
seccomp_mode: SeccompMode,
|
||||
) -> ! {
|
||||
// Set process name for log prefix before any logging
|
||||
crate::args::set_process_name(format!("worker-{}", vm_name));
|
||||
|
||||
info!(vm = %vm_name, mac = %mac, socket = ?socket_path, "child starting");
|
||||
|
||||
// Reconstruct control channel from owned fd
|
||||
let control = ControlChannel::from_fd(control_fd);
|
||||
|
||||
// Send Ready to main
|
||||
let msg = ChildToMain::Ready;
|
||||
if let Err(e) = control.send(&msg) {
|
||||
error!(vm = %vm_name, error = %e, "failed to send Ready");
|
||||
std::process::exit(1)
|
||||
}
|
||||
debug!("control: worker-{} -> main Ready", vm_name);
|
||||
|
||||
// Create packet forwarder
|
||||
let forwarder = Arc::new(Mutex::new(PacketForwarder::new(mac)));
|
||||
|
||||
// Create vhost backend
|
||||
let backend = ChildVhostBackend::new(vm_name.to_string(), mac);
|
||||
|
||||
// Set TX callback
|
||||
let fwd = Arc::clone(&forwarder);
|
||||
backend.set_tx_callback(Box::new(move |frame| {
|
||||
fwd.lock().unwrap().forward_tx(frame);
|
||||
}));
|
||||
|
||||
// Create vhost socket
|
||||
if socket_path.exists() {
|
||||
let _ = std::fs::remove_file(socket_path);
|
||||
}
|
||||
let listener = match UnixListener::bind(socket_path) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
error!(vm = %vm_name, error = %e, "failed to bind socket");
|
||||
std::process::exit(1)
|
||||
}
|
||||
};
|
||||
let _ = listener.set_nonblocking(true);
|
||||
|
||||
// Start vhost daemon thread
|
||||
let mem = GuestMemoryAtomic::new(GuestMemoryMmap::<()>::new());
|
||||
let mut daemon = match VhostUserDaemon::new(vm_name.to_string(), backend.clone(), mem) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
error!(vm = %vm_name, error = %e, "failed to create daemon");
|
||||
std::process::exit(1)
|
||||
}
|
||||
};
|
||||
|
||||
// Create pipe to detect daemon thread exit. The write end is moved into
|
||||
// the daemon thread; when the thread exits for any reason, the write end
|
||||
// is dropped, causing POLLHUP on the read end.
|
||||
let (pipe_rd, pipe_wr) = match pipe() {
|
||||
Ok((rd, wr)) => (rd, wr),
|
||||
Err(e) => {
|
||||
error!(vm = %vm_name, error = %e, "failed to create pipe");
|
||||
std::process::exit(1)
|
||||
}
|
||||
};
|
||||
|
||||
let vhost_listener = vhost::vhost_user::Listener::from(listener);
|
||||
let name = vm_name.to_string();
|
||||
thread::spawn(move || {
|
||||
let _pipe_wr = pipe_wr; // dropped on thread exit → POLLHUP on read end
|
||||
let mut l = vhost_listener;
|
||||
if let Err(e) = daemon.start(&mut l) {
|
||||
warn!(vm = %name, error = %e, "daemon start failed");
|
||||
return;
|
||||
}
|
||||
if let Err(e) = daemon.wait() {
|
||||
debug!(vm = %name, error = %e, "daemon wait returned error");
|
||||
}
|
||||
});
|
||||
|
||||
// Apply seccomp filter now that setup is complete
|
||||
// (socket created, thread spawned, signals configured)
|
||||
if let Err(e) = apply_child_seccomp(seccomp_mode) {
|
||||
error!(vm = %vm_name, error = %e, "failed to apply seccomp");
|
||||
std::process::exit(1);
|
||||
}
|
||||
if seccomp_mode != SeccompMode::Disabled {
|
||||
debug!(vm = %vm_name, mode = ?seccomp_mode, "seccomp filter applied");
|
||||
}
|
||||
|
||||
// Main event loop
|
||||
let daemon_exit_fd = pipe_rd.as_raw_fd();
|
||||
match event_loop(vm_name, control, forwarder, backend, daemon_exit_fd) {
|
||||
Ok(()) => {
|
||||
info!(vm = %vm_name, "exiting normally");
|
||||
std::process::exit(0)
|
||||
}
|
||||
Err(e) => {
|
||||
error!(vm = %vm_name, error = %e, "exiting with error");
|
||||
std::process::exit(1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn event_loop(
|
||||
vm_name: &str,
|
||||
control: ControlChannel,
|
||||
forwarder: Arc<Mutex<PacketForwarder>>,
|
||||
backend: Arc<ChildVhostBackend>,
|
||||
daemon_exit_fd: RawFd,
|
||||
) -> Result<(), ControlError> {
|
||||
let control_fd = control.as_raw_fd();
|
||||
|
||||
loop {
|
||||
let ingress_fds = forwarder.lock().unwrap().ingress_eventfds();
|
||||
|
||||
match poll_events(control_fd, Some(daemon_exit_fd), &ingress_fds, 100) {
|
||||
PollResult::TxKick => {
|
||||
// Daemon thread exited (pipe write end closed → POLLHUP)
|
||||
info!(vm = %vm_name, "vhost daemon exited, shutting down");
|
||||
return Ok(());
|
||||
}
|
||||
PollResult::Control => {
|
||||
let (msg, fds) = match control.recv_with_fds_typed() {
|
||||
Ok(r) => r,
|
||||
Err(ControlError::Closed) => {
|
||||
debug!(vm = %vm_name, "control closed");
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
match msg {
|
||||
MainToChild::GetBuffer { peer_name, peer_mac } => {
|
||||
debug!(
|
||||
"control: main -> worker-{} GetBuffer({}, {})",
|
||||
vm_name, peer_name, Mac::from_bytes(peer_mac)
|
||||
);
|
||||
|
||||
// Create ingress buffer (we are Consumer)
|
||||
match Consumer::new() {
|
||||
Ok(consumer) => {
|
||||
let response_fds = [
|
||||
consumer.memfd().as_raw_fd(),
|
||||
consumer.eventfd().as_raw_fd(),
|
||||
];
|
||||
let response = ChildToMain::BufferReady {
|
||||
peer_name: peer_name.clone(),
|
||||
};
|
||||
if let Err(e) = control.send_with_fds_typed(&response, &response_fds) {
|
||||
warn!(vm = %vm_name, error = %e, "failed to send BufferReady");
|
||||
} else {
|
||||
debug!(
|
||||
"control: worker-{} -> main BufferReady({})",
|
||||
vm_name, peer_name
|
||||
);
|
||||
forwarder.lock().unwrap().add_ingress(peer_name, peer_mac, consumer);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!(vm = %vm_name, error = %e, "failed to create ingress buffer"),
|
||||
}
|
||||
}
|
||||
|
||||
MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => {
|
||||
debug!(
|
||||
"control: main -> worker-{} PutBuffer({}, {}, broadcast={})",
|
||||
vm_name, peer_name, Mac::from_bytes(peer_mac), broadcast
|
||||
);
|
||||
|
||||
if fds.len() == 2 {
|
||||
let mut fds = fds.into_iter();
|
||||
match Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()) {
|
||||
Ok(producer) => {
|
||||
forwarder.lock().unwrap().add_egress(
|
||||
peer_name,
|
||||
peer_mac,
|
||||
producer,
|
||||
broadcast,
|
||||
);
|
||||
}
|
||||
Err(e) => warn!(vm = %vm_name, error = %e, "failed to map egress buffer"),
|
||||
}
|
||||
} else {
|
||||
warn!(vm = %vm_name, "PutBuffer with wrong number of FDs: {}", fds.len());
|
||||
}
|
||||
}
|
||||
|
||||
MainToChild::RemovePeer { peer_name } => {
|
||||
debug!("control: main -> worker-{} RemovePeer({})", vm_name, peer_name);
|
||||
forwarder.lock().unwrap().remove_peer(&peer_name);
|
||||
}
|
||||
|
||||
MainToChild::Ping => {
|
||||
control.send(&ChildToMain::Pong)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
PollResult::Ingress(_) | PollResult::Timeout => {
|
||||
let frames = forwarder.lock().unwrap().poll_ingress();
|
||||
for frame in frames {
|
||||
if !backend.inject_rx_frame(&frame) {
|
||||
debug!(vm = %vm_name, "RX inject failed (queue full)");
|
||||
}
|
||||
}
|
||||
}
|
||||
PollResult::Error(e) => {
|
||||
warn!(vm = %vm_name, error = ?e, "poll error");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
283
vm-switch/src/child/vhost.rs
Normal file
283
vm-switch/src/child/vhost.rs
Normal file
|
|
@ -0,0 +1,283 @@
|
|||
//! Vhost-user backend for child processes.
|
||||
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use tracing::{debug, warn};
|
||||
use vhost::vhost_user::message::{VhostUserProtocolFeatures, VhostUserVirtioFeatures};
|
||||
use vhost_user_backend::{VhostUserBackend, VringRwLock, VringT};
|
||||
use virtio_bindings::virtio_config::VIRTIO_F_VERSION_1;
|
||||
use virtio_bindings::virtio_net::{
|
||||
VIRTIO_NET_F_CSUM, VIRTIO_NET_F_GUEST_CSUM, VIRTIO_NET_F_GUEST_TSO4,
|
||||
VIRTIO_NET_F_GUEST_TSO6, VIRTIO_NET_F_GUEST_UFO, VIRTIO_NET_F_HOST_TSO4,
|
||||
VIRTIO_NET_F_HOST_TSO6, VIRTIO_NET_F_HOST_UFO, VIRTIO_NET_F_MAC, VIRTIO_NET_F_STATUS,
|
||||
};
|
||||
use virtio_bindings::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
|
||||
use virtio_queue::QueueT;
|
||||
use vm_memory::{Bytes, GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
|
||||
use vmm_sys_util::epoll::EventSet;
|
||||
|
||||
use crate::mac::Mac;
|
||||
|
||||
/// RX queue index.
|
||||
pub const RX_QUEUE: u16 = 0;
|
||||
/// TX queue index.
|
||||
pub const TX_QUEUE: u16 = 1;
|
||||
/// Number of queues.
|
||||
pub const NUM_QUEUES: usize = 2;
|
||||
/// Maximum queue size.
|
||||
pub const MAX_QUEUE_SIZE: usize = 32768;
|
||||
/// Virtio-net header size.
|
||||
pub const VIRTIO_NET_HDR_SIZE: usize = 12;
|
||||
|
||||
/// Virtio net features.
|
||||
const VIRTIO_NET_FEATURES: u64 = (1 << VIRTIO_NET_F_CSUM)
|
||||
| (1 << VIRTIO_NET_F_GUEST_CSUM)
|
||||
| (1 << VIRTIO_NET_F_GUEST_TSO4)
|
||||
| (1 << VIRTIO_NET_F_GUEST_TSO6)
|
||||
| (1 << VIRTIO_NET_F_GUEST_UFO)
|
||||
| (1 << VIRTIO_NET_F_HOST_TSO4)
|
||||
| (1 << VIRTIO_NET_F_HOST_TSO6)
|
||||
| (1 << VIRTIO_NET_F_HOST_UFO)
|
||||
| (1 << VIRTIO_NET_F_MAC)
|
||||
| (1 << VIRTIO_NET_F_STATUS);
|
||||
|
||||
/// Callback type for TX frames.
|
||||
pub type TxCallback = Box<dyn Fn(&[u8]) + Send>;
|
||||
|
||||
/// Child's vhost-user backend.
|
||||
pub struct ChildVhostBackend {
|
||||
name: String,
|
||||
mac: Mac,
|
||||
mem: Mutex<Option<GuestMemoryAtomic<GuestMemoryMmap>>>,
|
||||
tx_callback: Mutex<Option<TxCallback>>,
|
||||
rx_vring: Mutex<Option<VringRwLock>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for ChildVhostBackend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ChildVhostBackend")
|
||||
.field("name", &self.name)
|
||||
.field("mac", &self.mac)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
impl ChildVhostBackend {
|
||||
/// Create a new backend.
|
||||
pub fn new(name: String, mac: Mac) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
name,
|
||||
mac,
|
||||
mem: Mutex::new(None),
|
||||
tx_callback: Mutex::new(None),
|
||||
rx_vring: Mutex::new(None),
|
||||
})
|
||||
}
|
||||
|
||||
/// Set the TX callback.
|
||||
pub fn set_tx_callback(&self, callback: TxCallback) {
|
||||
*self.tx_callback.lock().unwrap() = Some(callback);
|
||||
}
|
||||
|
||||
/// Inject a frame into the RX queue for the guest.
|
||||
///
|
||||
/// Returns `true` if the frame was written to the virtio RX queue, `false` if
|
||||
/// dropped. Frames are silently dropped before the guest driver has initialized
|
||||
/// virtio queues (RX vring not yet available). This is expected during early
|
||||
/// startup -- the guest isn't ready to receive traffic until the driver
|
||||
/// negotiates features and sets up vrings.
|
||||
pub fn inject_rx_frame(&self, frame: &[u8]) -> bool {
|
||||
let vring = match self.rx_vring.lock().unwrap().as_ref() {
|
||||
Some(v) => v.clone(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let mem_guard = self.mem.lock().unwrap();
|
||||
let mem = match mem_guard.as_ref() {
|
||||
Some(m) => m.memory(),
|
||||
None => return false,
|
||||
};
|
||||
|
||||
// Prepend virtio header
|
||||
let mut data = vec![0u8; VIRTIO_NET_HDR_SIZE + frame.len()];
|
||||
data[VIRTIO_NET_HDR_SIZE..].copy_from_slice(frame);
|
||||
|
||||
let head_index;
|
||||
let written;
|
||||
{
|
||||
let mut vring_state = vring.get_mut();
|
||||
let queue = vring_state.get_queue_mut();
|
||||
|
||||
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
|
||||
Some(c) => c,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
head_index = desc_chain.head_index();
|
||||
|
||||
let mut writable_descs = Vec::new();
|
||||
for desc in desc_chain {
|
||||
if desc.is_write_only() {
|
||||
writable_descs.push((desc.addr(), desc.len() as usize));
|
||||
}
|
||||
}
|
||||
|
||||
let available: usize = writable_descs.iter().map(|(_, l)| *l).sum();
|
||||
if available < data.len() {
|
||||
written = 0;
|
||||
} else {
|
||||
let mut bytes_written = 0;
|
||||
for (addr, len) in writable_descs {
|
||||
let remaining = data.len() - bytes_written;
|
||||
if remaining == 0 {
|
||||
break;
|
||||
}
|
||||
let to_write = std::cmp::min(remaining, len);
|
||||
if mem.write_slice(&data[bytes_written..bytes_written + to_write], addr).is_err() {
|
||||
break;
|
||||
}
|
||||
bytes_written += to_write;
|
||||
}
|
||||
written = bytes_written;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = vring.add_used(head_index, written as u32);
|
||||
let _ = vring.enable_notification();
|
||||
let _ = vring.signal_used_queue();
|
||||
|
||||
written >= data.len()
|
||||
}
|
||||
|
||||
/// Process TX queue and call callback for each frame.
|
||||
fn process_tx(&self, vring: &VringRwLock) {
|
||||
let mem_guard = self.mem.lock().unwrap();
|
||||
let mem = match mem_guard.as_ref() {
|
||||
Some(m) => m.memory(),
|
||||
None => return,
|
||||
};
|
||||
|
||||
let callback = self.tx_callback.lock().unwrap();
|
||||
let callback = match callback.as_ref() {
|
||||
Some(c) => c,
|
||||
None => return,
|
||||
};
|
||||
|
||||
loop {
|
||||
let head_index;
|
||||
let raw_data;
|
||||
{
|
||||
let mut vring_state = vring.get_mut();
|
||||
let queue = vring_state.get_queue_mut();
|
||||
|
||||
let desc_chain = match queue.pop_descriptor_chain(mem.clone()) {
|
||||
Some(c) => c,
|
||||
None => break,
|
||||
};
|
||||
|
||||
head_index = desc_chain.head_index();
|
||||
let mut data = Vec::new();
|
||||
|
||||
for desc in desc_chain {
|
||||
let addr = desc.addr();
|
||||
let len = desc.len() as usize;
|
||||
let mut buf = vec![0u8; len];
|
||||
if let Err(e) = mem.read_slice(&mut buf, addr) {
|
||||
warn!(vm = %self.name, error = %e, "failed to read descriptor");
|
||||
break;
|
||||
}
|
||||
data.extend_from_slice(&buf);
|
||||
}
|
||||
raw_data = data;
|
||||
}
|
||||
|
||||
if let Err(e) = vring.add_used(head_index, 0) {
|
||||
warn!(vm = %self.name, error = ?e, "add_used failed");
|
||||
}
|
||||
|
||||
// Strip virtio header and call callback
|
||||
if raw_data.len() > VIRTIO_NET_HDR_SIZE {
|
||||
callback(&raw_data[VIRTIO_NET_HDR_SIZE..]);
|
||||
}
|
||||
}
|
||||
|
||||
let _ = vring.enable_notification();
|
||||
let _ = vring.signal_used_queue();
|
||||
}
|
||||
}
|
||||
|
||||
impl VhostUserBackend for ChildVhostBackend {
|
||||
type Bitmap = ();
|
||||
type Vring = VringRwLock;
|
||||
|
||||
fn num_queues(&self) -> usize {
|
||||
NUM_QUEUES
|
||||
}
|
||||
|
||||
fn max_queue_size(&self) -> usize {
|
||||
MAX_QUEUE_SIZE
|
||||
}
|
||||
|
||||
fn features(&self) -> u64 {
|
||||
VIRTIO_NET_FEATURES
|
||||
| (1 << VIRTIO_F_VERSION_1)
|
||||
| (1 << VIRTIO_RING_F_EVENT_IDX)
|
||||
| VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits()
|
||||
}
|
||||
|
||||
fn protocol_features(&self) -> VhostUserProtocolFeatures {
|
||||
VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::MQ
|
||||
}
|
||||
|
||||
fn set_event_idx(&self, _enabled: bool) {}
|
||||
|
||||
fn update_memory(&self, mem: GuestMemoryAtomic<GuestMemoryMmap>) -> std::io::Result<()> {
|
||||
debug!(vm = %self.name, "update_memory");
|
||||
*self.mem.lock().unwrap() = Some(mem);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_event(
|
||||
&self,
|
||||
device_event: u16,
|
||||
_evset: EventSet,
|
||||
vrings: &[VringRwLock],
|
||||
_thread_id: usize,
|
||||
) -> std::io::Result<()> {
|
||||
// Store RX vring for injection
|
||||
if vrings.len() > RX_QUEUE as usize {
|
||||
let mut rx = self.rx_vring.lock().unwrap();
|
||||
if rx.is_none() {
|
||||
*rx = Some(vrings[RX_QUEUE as usize].clone());
|
||||
debug!(vm = %self.name, "stored RX vring");
|
||||
}
|
||||
}
|
||||
|
||||
// Process TX queue
|
||||
if device_event == TX_QUEUE && vrings.len() > TX_QUEUE as usize {
|
||||
self.process_tx(&vrings[TX_QUEUE as usize]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn acked_features(&self, _features: u64) {}
|
||||
|
||||
fn get_config(&self, offset: u32, size: u32) -> Vec<u8> {
|
||||
let mut config = [0u8; 10];
|
||||
config[0..6].copy_from_slice(&self.mac.bytes());
|
||||
config[6] = 1; // LINK_UP
|
||||
config[8] = 1; // max_virtqueue_pairs
|
||||
|
||||
let offset = offset as usize;
|
||||
let size = size as usize;
|
||||
if offset < config.len() {
|
||||
let end = std::cmp::min(offset + size, config.len());
|
||||
let mut result = config[offset..end].to_vec();
|
||||
result.resize(size, 0);
|
||||
result
|
||||
} else {
|
||||
vec![0u8; size]
|
||||
}
|
||||
}
|
||||
}
|
||||
721
vm-switch/src/control.rs
Normal file
721
vm-switch/src/control.rs
Normal file
|
|
@ -0,0 +1,721 @@
|
|||
//! Control channel for main↔child process communication.
|
||||
//!
|
||||
//! Messages are serialized with postcard and sent over Unix sockets.
|
||||
//! File descriptors are passed via SCM_RIGHTS ancillary data.
|
||||
|
||||
use crate::mac::Mac;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd, RawFd};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Messages sent from main process to child.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum MainToChild {
|
||||
/// Request child create an ingress buffer for a peer.
|
||||
/// Child will be Consumer, peer will be Producer.
|
||||
GetBuffer {
|
||||
/// Name of the peer VM.
|
||||
peer_name: String,
|
||||
/// MAC address of the peer VM.
|
||||
peer_mac: [u8; 6],
|
||||
},
|
||||
/// Provide peer's ingress buffer for child to use as egress.
|
||||
/// Child becomes Producer for this buffer.
|
||||
/// FDs: [memfd, eventfd]
|
||||
PutBuffer {
|
||||
/// Name of the peer VM.
|
||||
peer_name: String,
|
||||
/// MAC address of the peer VM.
|
||||
peer_mac: [u8; 6],
|
||||
/// If true, buffer accepts broadcast/multicast traffic.
|
||||
broadcast: bool,
|
||||
},
|
||||
/// Peer disconnected, clean up all buffers for this peer.
|
||||
RemovePeer {
|
||||
/// Name of the peer VM.
|
||||
peer_name: String,
|
||||
},
|
||||
/// Heartbeat request.
|
||||
Ping,
|
||||
}
|
||||
|
||||
/// Messages sent from child process to main.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ChildToMain {
|
||||
/// Child is ready to receive commands.
|
||||
Ready,
|
||||
/// Response to GetBuffer - here's my ingress buffer for the requested peer.
|
||||
/// FDs: [memfd, eventfd]
|
||||
BufferReady {
|
||||
/// Name of the peer this buffer is for.
|
||||
peer_name: String,
|
||||
},
|
||||
/// Heartbeat response.
|
||||
Pong,
|
||||
}
|
||||
|
||||
impl MainToChild {
|
||||
/// Get the peer name, if this message has one.
|
||||
pub fn peer_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
MainToChild::GetBuffer { peer_name, .. } => Some(peer_name),
|
||||
MainToChild::PutBuffer { peer_name, .. } => Some(peer_name),
|
||||
MainToChild::RemovePeer { peer_name } => Some(peer_name),
|
||||
MainToChild::Ping => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the peer MAC address as a Mac type, if this message has one.
|
||||
pub fn peer_mac(&self) -> Option<Mac> {
|
||||
match self {
|
||||
MainToChild::GetBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)),
|
||||
MainToChild::PutBuffer { peer_mac, .. } => Some(Mac::from_bytes(*peer_mac)),
|
||||
MainToChild::RemovePeer { .. } | MainToChild::Ping => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ChildToMain {
|
||||
/// Get the peer name, if this message has one.
|
||||
pub fn peer_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
ChildToMain::Ready | ChildToMain::Pong => None,
|
||||
ChildToMain::BufferReady { peer_name } => Some(peer_name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Maximum message size in bytes.
|
||||
pub const MAX_MESSAGE_SIZE: usize = 256;
|
||||
|
||||
/// Maximum number of file descriptors per message.
|
||||
pub const MAX_FDS: usize = 4;
|
||||
|
||||
/// Errors that can occur with control channel operations.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ControlError {
|
||||
#[error("failed to create socketpair: {0}")]
|
||||
Socketpair(std::io::Error),
|
||||
#[error("failed to serialize message: {0}")]
|
||||
Serialize(postcard::Error),
|
||||
#[error("failed to deserialize message: {0}")]
|
||||
Deserialize(postcard::Error),
|
||||
#[error("failed to send message: {0}")]
|
||||
Send(std::io::Error),
|
||||
#[error("failed to receive message: {0}")]
|
||||
Recv(std::io::Error),
|
||||
#[error("connection closed")]
|
||||
Closed,
|
||||
#[error("message too large: {0} bytes")]
|
||||
MessageTooLarge(usize),
|
||||
}
|
||||
|
||||
/// A control channel endpoint for main↔child communication.
|
||||
pub struct ControlChannel {
|
||||
socket: OwnedFd,
|
||||
}
|
||||
|
||||
impl AsRawFd for ControlChannel {
|
||||
fn as_raw_fd(&self) -> RawFd {
|
||||
self.socket.as_raw_fd()
|
||||
}
|
||||
}
|
||||
|
||||
impl ControlChannel {
|
||||
/// Create a pair of connected control channels.
|
||||
/// Returns (main_end, child_end).
|
||||
pub fn pair() -> Result<(Self, Self), ControlError> {
|
||||
let mut fds = [0i32; 2];
|
||||
let ret = unsafe {
|
||||
libc::socketpair(
|
||||
libc::AF_UNIX,
|
||||
libc::SOCK_SEQPACKET | libc::SOCK_CLOEXEC,
|
||||
0,
|
||||
fds.as_mut_ptr(),
|
||||
)
|
||||
};
|
||||
if ret < 0 {
|
||||
return Err(ControlError::Socketpair(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let main_end = Self {
|
||||
socket: unsafe { OwnedFd::from_raw_fd(fds[0]) },
|
||||
};
|
||||
let child_end = Self {
|
||||
socket: unsafe { OwnedFd::from_raw_fd(fds[1]) },
|
||||
};
|
||||
|
||||
Ok((main_end, child_end))
|
||||
}
|
||||
|
||||
/// Create a ControlChannel from a raw file descriptor.
|
||||
///
|
||||
/// # Safety
|
||||
/// The fd must be a valid, open socket file descriptor.
|
||||
pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
|
||||
Self {
|
||||
socket: OwnedFd::from_raw_fd(fd),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a ControlChannel from an owned file descriptor.
|
||||
pub fn from_fd(fd: OwnedFd) -> Self {
|
||||
Self { socket: fd }
|
||||
}
|
||||
|
||||
/// Consume the ControlChannel and return the underlying file descriptor.
|
||||
pub fn into_fd(self) -> OwnedFd {
|
||||
self.socket
|
||||
}
|
||||
|
||||
/// Send a message without file descriptors.
|
||||
pub fn send<M: Serialize>(&self, msg: &M) -> Result<(), ControlError> {
|
||||
let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?;
|
||||
if bytes.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(ControlError::MessageTooLarge(bytes.len()));
|
||||
}
|
||||
send_with_fds(self.socket.as_raw_fd(), &bytes, &[])
|
||||
}
|
||||
|
||||
/// Send a message with file descriptors.
|
||||
pub fn send_with_fds_typed<M: Serialize>(
|
||||
&self,
|
||||
msg: &M,
|
||||
fds: &[RawFd],
|
||||
) -> Result<(), ControlError> {
|
||||
let bytes = postcard::to_allocvec(msg).map_err(ControlError::Serialize)?;
|
||||
if bytes.len() > MAX_MESSAGE_SIZE {
|
||||
return Err(ControlError::MessageTooLarge(bytes.len()));
|
||||
}
|
||||
send_with_fds(self.socket.as_raw_fd(), &bytes, fds)
|
||||
}
|
||||
|
||||
/// Receive a message without expecting file descriptors.
|
||||
pub fn recv<M: for<'de> Deserialize<'de>>(&self) -> Result<M, ControlError> {
|
||||
let mut buf = [0u8; MAX_MESSAGE_SIZE];
|
||||
let (n, _fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?;
|
||||
postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize)
|
||||
}
|
||||
|
||||
/// Receive a message with file descriptors.
|
||||
/// Returns (message, file_descriptors).
|
||||
pub fn recv_with_fds_typed<M: for<'de> Deserialize<'de>>(
|
||||
&self,
|
||||
) -> Result<(M, Vec<OwnedFd>), ControlError> {
|
||||
let mut buf = [0u8; MAX_MESSAGE_SIZE];
|
||||
let (n, fds) = recv_with_fds(self.socket.as_raw_fd(), &mut buf)?;
|
||||
let msg = postcard::from_bytes(&buf[..n]).map_err(ControlError::Deserialize)?;
|
||||
Ok((msg, fds))
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message with optional file descriptors via SCM_RIGHTS.
|
||||
fn send_with_fds(socket_fd: RawFd, data: &[u8], fds: &[RawFd]) -> Result<(), ControlError> {
|
||||
let mut iov = libc::iovec {
|
||||
iov_base: data.as_ptr() as *mut libc::c_void,
|
||||
iov_len: data.len(),
|
||||
};
|
||||
|
||||
// Calculate control message buffer size
|
||||
let cmsg_space = if fds.is_empty() {
|
||||
0
|
||||
} else {
|
||||
unsafe { libc::CMSG_SPACE(std::mem::size_of_val(fds) as u32) as usize }
|
||||
};
|
||||
|
||||
let mut cmsg_buf = vec![0u8; cmsg_space];
|
||||
|
||||
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
msg.msg_iov = &mut iov;
|
||||
msg.msg_iovlen = 1;
|
||||
|
||||
if !fds.is_empty() {
|
||||
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
|
||||
msg.msg_controllen = cmsg_space;
|
||||
|
||||
// Fill in the control message header
|
||||
let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
|
||||
unsafe {
|
||||
(*cmsg).cmsg_level = libc::SOL_SOCKET;
|
||||
(*cmsg).cmsg_type = libc::SCM_RIGHTS;
|
||||
(*cmsg).cmsg_len = libc::CMSG_LEN(std::mem::size_of_val(fds) as u32) as usize;
|
||||
|
||||
// Copy file descriptors into cmsg data
|
||||
let cmsg_data = libc::CMSG_DATA(cmsg) as *mut RawFd;
|
||||
for (i, &fd) in fds.iter().enumerate() {
|
||||
*cmsg_data.add(i) = fd;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let ret = unsafe { libc::sendmsg(socket_fd, &msg, 0) };
|
||||
if ret < 0 {
|
||||
return Err(ControlError::Send(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive a message with optional file descriptors via SCM_RIGHTS.
|
||||
/// Returns (bytes_read, file_descriptors).
|
||||
fn recv_with_fds(
|
||||
socket_fd: RawFd,
|
||||
buf: &mut [u8],
|
||||
) -> Result<(usize, Vec<OwnedFd>), ControlError> {
|
||||
let mut iov = libc::iovec {
|
||||
iov_base: buf.as_mut_ptr() as *mut libc::c_void,
|
||||
iov_len: buf.len(),
|
||||
};
|
||||
|
||||
// Buffer for control messages (enough for MAX_FDS file descriptors)
|
||||
let cmsg_space =
|
||||
unsafe { libc::CMSG_SPACE((MAX_FDS * std::mem::size_of::<RawFd>()) as u32) as usize };
|
||||
let mut cmsg_buf = vec![0u8; cmsg_space];
|
||||
|
||||
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
msg.msg_iov = &mut iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
|
||||
msg.msg_controllen = cmsg_space;
|
||||
|
||||
let n = unsafe { libc::recvmsg(socket_fd, &mut msg, 0) };
|
||||
if n < 0 {
|
||||
return Err(ControlError::Recv(std::io::Error::last_os_error()));
|
||||
}
|
||||
if n == 0 {
|
||||
return Err(ControlError::Closed);
|
||||
}
|
||||
|
||||
// Extract file descriptors from control message
|
||||
let mut fds = Vec::new();
|
||||
let mut cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
|
||||
while !cmsg.is_null() {
|
||||
unsafe {
|
||||
if (*cmsg).cmsg_level == libc::SOL_SOCKET && (*cmsg).cmsg_type == libc::SCM_RIGHTS {
|
||||
let data_len = (*cmsg).cmsg_len - libc::CMSG_LEN(0) as usize;
|
||||
let num_fds = data_len / std::mem::size_of::<RawFd>();
|
||||
let fd_ptr = libc::CMSG_DATA(cmsg) as *const RawFd;
|
||||
for i in 0..num_fds {
|
||||
let fd = *fd_ptr.add(i);
|
||||
fds.push(OwnedFd::from_raw_fd(fd));
|
||||
}
|
||||
}
|
||||
cmsg = libc::CMSG_NXTHDR(&msg, cmsg);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((n as usize, fds))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn main_to_child_get_buffer_serializes() {
|
||||
let msg = MainToChild::GetBuffer {
|
||||
peer_name: "router".to_string(),
|
||||
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
|
||||
};
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_to_child_put_buffer_serializes() {
|
||||
let msg = MainToChild::PutBuffer {
|
||||
peer_name: "client_a".to_string(),
|
||||
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
|
||||
broadcast: true,
|
||||
};
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_to_child_remove_peer_serializes() {
|
||||
let msg = MainToChild::RemovePeer {
|
||||
peer_name: "client_b".to_string(),
|
||||
};
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_to_child_ping_serializes() {
|
||||
let msg = MainToChild::Ping;
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: MainToChild = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_to_main_pong_serializes() {
|
||||
let msg = ChildToMain::Pong;
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_to_main_ready_serializes() {
|
||||
let msg = ChildToMain::Ready;
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_to_main_buffer_ready_serializes() {
|
||||
let msg = ChildToMain::BufferReady {
|
||||
peer_name: "router".to_string(),
|
||||
};
|
||||
|
||||
let bytes = postcard::to_allocvec(&msg).expect("should serialize");
|
||||
let decoded: ChildToMain = postcard::from_bytes(&bytes).expect("should deserialize");
|
||||
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_pair_creates_connected_sockets() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
assert!(main_end.as_raw_fd() >= 0);
|
||||
assert!(child_end.as_raw_fd() >= 0);
|
||||
assert_ne!(main_end.as_raw_fd(), child_end.as_raw_fd());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_with_fds_sends_data() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let data = b"hello";
|
||||
|
||||
send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send");
|
||||
|
||||
// Read on child end to verify
|
||||
let mut buf = [0u8; 64];
|
||||
let n = unsafe {
|
||||
libc::recv(
|
||||
child_end.as_raw_fd(),
|
||||
buf.as_mut_ptr() as *mut libc::c_void,
|
||||
buf.len(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
assert!(n > 0);
|
||||
assert_eq!(&buf[..n as usize], data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn send_with_fds_passes_file_descriptors() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
// Create a test eventfd to send
|
||||
let test_fd = unsafe { libc::eventfd(42, libc::EFD_CLOEXEC) };
|
||||
assert!(test_fd >= 0);
|
||||
|
||||
send_with_fds(main_end.as_raw_fd(), b"fd", &[test_fd]).expect("should send");
|
||||
|
||||
// Close our copy - child should have its own
|
||||
unsafe { libc::close(test_fd) };
|
||||
|
||||
// Receive on child end
|
||||
let mut buf = [0u8; 64];
|
||||
let mut cmsg_buf = [0u8; 64];
|
||||
let mut iov = libc::iovec {
|
||||
iov_base: buf.as_mut_ptr() as *mut libc::c_void,
|
||||
iov_len: buf.len(),
|
||||
};
|
||||
let mut msg: libc::msghdr = unsafe { std::mem::zeroed() };
|
||||
msg.msg_iov = &mut iov;
|
||||
msg.msg_iovlen = 1;
|
||||
msg.msg_control = cmsg_buf.as_mut_ptr() as *mut libc::c_void;
|
||||
msg.msg_controllen = cmsg_buf.len();
|
||||
|
||||
let n = unsafe { libc::recvmsg(child_end.as_raw_fd(), &mut msg, 0) };
|
||||
assert!(n > 0);
|
||||
|
||||
// Extract the file descriptor
|
||||
let cmsg = unsafe { libc::CMSG_FIRSTHDR(&msg) };
|
||||
assert!(!cmsg.is_null());
|
||||
let received_fd = unsafe { *(libc::CMSG_DATA(cmsg) as *const RawFd) };
|
||||
assert!(received_fd >= 0);
|
||||
|
||||
// Verify we can read the eventfd value (42)
|
||||
let mut val: u64 = 0;
|
||||
let ret = unsafe {
|
||||
libc::read(
|
||||
received_fd,
|
||||
&mut val as *mut u64 as *mut libc::c_void,
|
||||
std::mem::size_of::<u64>(),
|
||||
)
|
||||
};
|
||||
assert_eq!(ret, std::mem::size_of::<u64>() as isize);
|
||||
assert_eq!(val, 42);
|
||||
|
||||
unsafe { libc::close(received_fd) };
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recv_with_fds_receives_data() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let data = b"world";
|
||||
|
||||
// Send from main
|
||||
send_with_fds(main_end.as_raw_fd(), data, &[]).expect("should send");
|
||||
|
||||
// Receive on child
|
||||
let mut buf = [0u8; 64];
|
||||
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
|
||||
|
||||
assert_eq!(n, data.len());
|
||||
assert_eq!(&buf[..n], data);
|
||||
assert!(fds.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn recv_with_fds_receives_file_descriptors() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
// Create test eventfds
|
||||
let fd1 = unsafe { libc::eventfd(10, libc::EFD_CLOEXEC) };
|
||||
let fd2 = unsafe { libc::eventfd(20, libc::EFD_CLOEXEC) };
|
||||
assert!(fd1 >= 0);
|
||||
assert!(fd2 >= 0);
|
||||
|
||||
send_with_fds(main_end.as_raw_fd(), b"fds", &[fd1, fd2]).expect("should send");
|
||||
|
||||
// Close our copies
|
||||
unsafe {
|
||||
libc::close(fd1);
|
||||
libc::close(fd2);
|
||||
}
|
||||
|
||||
// Receive
|
||||
let mut buf = [0u8; 64];
|
||||
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
|
||||
|
||||
assert_eq!(n, 3);
|
||||
assert_eq!(&buf[..n], b"fds");
|
||||
assert_eq!(fds.len(), 2);
|
||||
|
||||
// Verify eventfd values
|
||||
let mut val: u64 = 0;
|
||||
unsafe {
|
||||
libc::read(
|
||||
fds[0].as_raw_fd(),
|
||||
&mut val as *mut u64 as *mut libc::c_void,
|
||||
8,
|
||||
);
|
||||
}
|
||||
assert_eq!(val, 10);
|
||||
|
||||
val = 0;
|
||||
unsafe {
|
||||
libc::read(
|
||||
fds[1].as_raw_fd(),
|
||||
&mut val as *mut u64 as *mut libc::c_void,
|
||||
8,
|
||||
);
|
||||
}
|
||||
assert_eq!(val, 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_send_delivers_message() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
let msg = MainToChild::RemovePeer {
|
||||
peer_name: "client_a".to_string(),
|
||||
};
|
||||
main_end.send(&msg).expect("should send");
|
||||
|
||||
// Verify by receiving raw bytes
|
||||
let mut buf = [0u8; MAX_MESSAGE_SIZE];
|
||||
let (n, _) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
|
||||
|
||||
let decoded: MainToChild = postcard::from_bytes(&buf[..n]).expect("should decode");
|
||||
assert_eq!(decoded, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_send_with_fds_delivers_message_and_fds() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
// Create test eventfd
|
||||
let test_fd = unsafe { libc::eventfd(99, libc::EFD_CLOEXEC) };
|
||||
assert!(test_fd >= 0);
|
||||
|
||||
let msg = ChildToMain::BufferReady {
|
||||
peer_name: "router".to_string(),
|
||||
};
|
||||
main_end
|
||||
.send_with_fds_typed(&msg, &[test_fd])
|
||||
.expect("should send");
|
||||
|
||||
unsafe { libc::close(test_fd) };
|
||||
|
||||
// Receive
|
||||
let mut buf = [0u8; MAX_MESSAGE_SIZE];
|
||||
let (n, fds) = recv_with_fds(child_end.as_raw_fd(), &mut buf).expect("should receive");
|
||||
|
||||
let decoded: ChildToMain = postcard::from_bytes(&buf[..n]).expect("should decode");
|
||||
assert_eq!(decoded, msg);
|
||||
assert_eq!(fds.len(), 1);
|
||||
|
||||
// Verify eventfd value
|
||||
let mut val: u64 = 0;
|
||||
unsafe {
|
||||
libc::read(
|
||||
fds[0].as_raw_fd(),
|
||||
&mut val as *mut u64 as *mut libc::c_void,
|
||||
8,
|
||||
);
|
||||
}
|
||||
assert_eq!(val, 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_recv_returns_message() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
let msg = MainToChild::RemovePeer {
|
||||
peer_name: "client_b".to_string(),
|
||||
};
|
||||
|
||||
// Send from main using typed method
|
||||
main_end.send(&msg).expect("should send");
|
||||
|
||||
// Receive on child using typed method
|
||||
let received: MainToChild = child_end.recv().expect("should receive");
|
||||
assert_eq!(received, msg);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_recv_with_fds_returns_message_and_fds() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
// Create test eventfds
|
||||
let fd1 = unsafe { libc::eventfd(111, libc::EFD_CLOEXEC) };
|
||||
let fd2 = unsafe { libc::eventfd(222, libc::EFD_CLOEXEC) };
|
||||
|
||||
let msg = MainToChild::PutBuffer {
|
||||
peer_name: "router".to_string(),
|
||||
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
|
||||
broadcast: true,
|
||||
};
|
||||
|
||||
main_end.send_with_fds_typed(&msg, &[fd1, fd2]).expect("should send");
|
||||
|
||||
unsafe {
|
||||
libc::close(fd1);
|
||||
libc::close(fd2);
|
||||
}
|
||||
|
||||
// Receive with typed method
|
||||
let (received, fds): (MainToChild, _) = child_end.recv_with_fds_typed().expect("should receive");
|
||||
assert_eq!(received, msg);
|
||||
assert_eq!(fds.len(), 2);
|
||||
|
||||
// Verify eventfd values
|
||||
let mut val: u64 = 0;
|
||||
unsafe {
|
||||
libc::read(fds[0].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8);
|
||||
}
|
||||
assert_eq!(val, 111);
|
||||
|
||||
val = 0;
|
||||
unsafe {
|
||||
libc::read(fds[1].as_raw_fd(), &mut val as *mut u64 as *mut libc::c_void, 8);
|
||||
}
|
||||
assert_eq!(val, 222);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn control_channel_recv_detects_closed_connection() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
|
||||
// Close the sender end
|
||||
drop(main_end);
|
||||
|
||||
// Receive should return Closed error
|
||||
let result: Result<MainToChild, _> = child_end.recv();
|
||||
match result {
|
||||
Err(ControlError::Closed) => (),
|
||||
other => panic!("expected Closed error, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_to_child_peer_name_helper_returns_correct_name() {
|
||||
let msg = MainToChild::PutBuffer {
|
||||
peer_name: "router".to_string(),
|
||||
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
|
||||
broadcast: true,
|
||||
};
|
||||
assert_eq!(msg.peer_name(), Some("router"));
|
||||
|
||||
let msg2 = MainToChild::GetBuffer {
|
||||
peer_name: "client_a".to_string(),
|
||||
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
|
||||
};
|
||||
assert_eq!(msg2.peer_name(), Some("client_a"));
|
||||
|
||||
let msg3 = MainToChild::RemovePeer {
|
||||
peer_name: "client_b".to_string(),
|
||||
};
|
||||
assert_eq!(msg3.peer_name(), Some("client_b"));
|
||||
|
||||
assert_eq!(MainToChild::Ping.peer_name(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_to_child_peer_mac_helper_returns_correct_mac() {
|
||||
let msg = MainToChild::PutBuffer {
|
||||
peer_name: "router".to_string(),
|
||||
peer_mac: [0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff],
|
||||
broadcast: true,
|
||||
};
|
||||
assert_eq!(format!("{}", msg.peer_mac().unwrap()), "aa:bb:cc:dd:ee:ff");
|
||||
|
||||
let msg2 = MainToChild::GetBuffer {
|
||||
peer_name: "client_a".to_string(),
|
||||
peer_mac: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
|
||||
};
|
||||
assert_eq!(format!("{}", msg2.peer_mac().unwrap()), "11:22:33:44:55:66");
|
||||
|
||||
let msg3 = MainToChild::RemovePeer {
|
||||
peer_name: "client_b".to_string(),
|
||||
};
|
||||
assert!(msg3.peer_mac().is_none());
|
||||
|
||||
assert!(MainToChild::Ping.peer_mac().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_to_main_peer_name_helper_returns_correct_name() {
|
||||
let msg = ChildToMain::BufferReady {
|
||||
peer_name: "router".to_string(),
|
||||
};
|
||||
assert_eq!(msg.peer_name(), Some("router"));
|
||||
|
||||
assert_eq!(ChildToMain::Ready.peer_name(), None);
|
||||
assert_eq!(ChildToMain::Pong.peer_name(), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -45,6 +45,28 @@ impl<'a> EthernetFrame<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Validate that a frame's source MAC matches the expected MAC.
|
||||
/// Returns false if the frame is too short or MAC doesn't match.
|
||||
pub fn validate_source_mac(frame: &[u8], expected: Mac) -> bool {
|
||||
if frame.len() < MIN_FRAME_SIZE {
|
||||
return false;
|
||||
}
|
||||
let mut src = [0u8; 6];
|
||||
src.copy_from_slice(&frame[6..12]);
|
||||
Mac::from_bytes(src) == expected
|
||||
}
|
||||
|
||||
/// Extract destination MAC from a frame.
|
||||
/// Returns None if frame is too short.
|
||||
pub fn extract_dest_mac(frame: &[u8]) -> Option<Mac> {
|
||||
if frame.len() < 6 {
|
||||
return None;
|
||||
}
|
||||
let mut dest = [0u8; 6];
|
||||
dest.copy_from_slice(&frame[0..6]);
|
||||
Some(Mac::from_bytes(dest))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -99,4 +121,44 @@ mod tests {
|
|||
Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_source_mac_accepts_matching() {
|
||||
let mut frame = [0u8; 14];
|
||||
frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
|
||||
let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
assert!(validate_source_mac(&frame, expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_source_mac_rejects_mismatch() {
|
||||
let mut frame = [0u8; 14];
|
||||
frame[6..12].copy_from_slice(&[0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
|
||||
let expected = Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]);
|
||||
assert!(!validate_source_mac(&frame, expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_source_mac_rejects_short_frame() {
|
||||
let frame = [0u8; 10]; // Too short
|
||||
let expected = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
assert!(!validate_source_mac(&frame, expected));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_dest_mac_returns_mac() {
|
||||
let mut frame = [0u8; 14];
|
||||
frame[0..6].copy_from_slice(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66]);
|
||||
|
||||
let result = extract_dest_mac(&frame);
|
||||
assert_eq!(result, Some(Mac::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66])));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_dest_mac_returns_none_for_short_frame() {
|
||||
let frame = [0u8; 5];
|
||||
assert!(extract_dest_mac(&frame).is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,14 +1,15 @@
|
|||
//! vm-switch: Virtual L2 switch for VM-to-VM networking.
|
||||
//!
|
||||
//! This library provides a vhost-user network backend that enables
|
||||
//! communication between VMs through a virtual L2 switch with MAC filtering.
|
||||
//! communication between VMs through a virtual L2 switch with ring buffer
|
||||
//! forwarding.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! - **ConfigWatcher**: Monitors `/run/vm-switch/` for `.mac` files
|
||||
//! - **BackendManager**: Coordinates multiple NetBackend instances
|
||||
//! - **Switch**: L2 switching logic with MAC filtering
|
||||
//! - **NetBackend**: vhost-user backend for a single VM
|
||||
//! - **BackendManager**: Coordinates forked child processes for each VM
|
||||
//! - **Child**: vhost-user backend running in forked process with sandboxing
|
||||
//! - **RingBuffer**: Inter-process packet transfer via shared memory
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
|
|
@ -18,23 +19,30 @@
|
|||
//!
|
||||
//! let args = Args::parse();
|
||||
//! let watcher = ConfigWatcher::new(&args.config_dir, 64)?;
|
||||
//! let manager = BackendManager::new(&args.config_dir);
|
||||
//! let manager = BackendManager::new(&args.config_dir, SeccompMode::Kill);
|
||||
//! ```
|
||||
|
||||
pub mod args;
|
||||
pub mod backend;
|
||||
pub mod child;
|
||||
pub mod config;
|
||||
pub mod control;
|
||||
pub mod frame;
|
||||
pub mod mac;
|
||||
pub mod manager;
|
||||
pub mod switch;
|
||||
pub mod ring;
|
||||
pub mod sandbox;
|
||||
pub mod seccomp;
|
||||
pub mod watcher;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use args::{init_logging, Args, LogLevel};
|
||||
pub use backend::NetBackend;
|
||||
pub use config::{ConfigEvent, VmConfig, VmRole};
|
||||
pub use control::{ChildToMain, ControlChannel, ControlError, MainToChild, MAX_FDS, MAX_MESSAGE_SIZE};
|
||||
pub use mac::Mac;
|
||||
pub use manager::BackendManager;
|
||||
pub use switch::{ConnectionId, ForwardDecision, Switch};
|
||||
pub use manager::{BackendManager, ChildMessage};
|
||||
pub use ring::{Consumer, Producer, RingError, RING_BUFFER_SIZE, RING_SIZE, SLOT_DATA_SIZE};
|
||||
pub use sandbox::{
|
||||
apply_sandbox, enter_user_namespace, setup_filesystem_isolation, SandboxError, SandboxResult,
|
||||
};
|
||||
pub use seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompError, SeccompMode};
|
||||
pub use watcher::ConfigWatcher;
|
||||
|
|
|
|||
|
|
@ -1,30 +1,99 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use clap::Parser;
|
||||
use tokio::signal;
|
||||
use tokio::signal::unix::{signal as unix_signal, SignalKind};
|
||||
use tokio::sync::broadcast;
|
||||
use vm_switch::{Args, BackendManager, ConfigWatcher, init_logging};
|
||||
use vm_switch::{apply_sandbox, apply_main_seccomp, Args, BackendManager, ConfigWatcher, SandboxResult, SeccompMode, init_logging};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
init_logging(args.log_level);
|
||||
|
||||
tracing::info!(config_dir = ?args.config_dir, "Starting vm-switch");
|
||||
|
||||
// Ensure config directory exists
|
||||
/// Apply sandbox before tokio runtime starts (must be single-threaded).
|
||||
///
|
||||
/// This function may fork. If we are the parent wrapper, it exits with
|
||||
/// the child's exit code. The actual vm-switch logic runs in the child.
|
||||
fn setup_sandbox(args: &Args) -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
// Ensure config directory exists before sandboxing
|
||||
if !args.config_dir.exists() {
|
||||
std::fs::create_dir_all(&args.config_dir)?;
|
||||
tracing::info!(path = ?args.config_dir, "Created config directory");
|
||||
eprintln!("Created config directory: {:?}", args.config_dir);
|
||||
}
|
||||
|
||||
// Create watcher and manager
|
||||
let mut watcher = ConfigWatcher::new(&args.config_dir, 64)?;
|
||||
let mut manager = BackendManager::new(&args.config_dir);
|
||||
// Apply sandbox unless disabled
|
||||
let config_path = if args.no_sandbox {
|
||||
eprintln!("Sandboxing disabled via --no-sandbox");
|
||||
args.config_dir.clone()
|
||||
} else {
|
||||
match apply_sandbox(&args.config_dir) {
|
||||
Ok(SandboxResult::Parent(exit_code)) => {
|
||||
// We are the wrapper parent - propagate child's exit code
|
||||
std::process::exit(exit_code);
|
||||
}
|
||||
Ok(SandboxResult::Sandboxed(path)) => {
|
||||
eprintln!("Sandbox applied, config at {:?}", path);
|
||||
path
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Failed to apply sandbox: {}", e);
|
||||
return Err(e.into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Apply seccomp filter
|
||||
if args.seccomp_mode != SeccompMode::Disabled {
|
||||
apply_main_seccomp(args.seccomp_mode)?;
|
||||
eprintln!("Seccomp applied (mode: {:?})", args.seccomp_mode);
|
||||
} else {
|
||||
eprintln!("Seccomp disabled via --seccomp-mode=disabled");
|
||||
}
|
||||
|
||||
Ok(config_path)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Apply sandbox BEFORE tokio runtime starts (unshare requires single-threaded)
|
||||
let config_dir = setup_sandbox(&args)?;
|
||||
|
||||
// Now start tokio runtime and run async main
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async_main(args, config_dir))
|
||||
}
|
||||
|
||||
async fn async_main(args: Args, config_dir: PathBuf) -> Result<(), Box<dyn std::error::Error>> {
|
||||
init_logging(args.log_level);
|
||||
|
||||
tracing::info!(config_dir = ?config_dir, "Starting vm-switch");
|
||||
|
||||
// Create watcher and manager using sandboxed path
|
||||
let mut watcher = ConfigWatcher::new(&config_dir, 64)?;
|
||||
let (mut manager, mut child_rx) = BackendManager::new(&config_dir, args.seccomp_mode);
|
||||
|
||||
// Create SIGCHLD signal stream for child process monitoring
|
||||
let mut sigchld = unix_signal(SignalKind::child())
|
||||
.map_err(|e| format!("failed to create SIGCHLD signal stream: {}", e))?;
|
||||
|
||||
// Create SIGTERM signal stream for graceful shutdown.
|
||||
// This is critical: as PID 1 in a PID namespace, the kernel only delivers
|
||||
// signals for which a handler is registered. Without this, SIGTERM is
|
||||
// silently dropped and systemd has to SIGKILL after timeout.
|
||||
let mut sigterm = unix_signal(SignalKind::terminate())
|
||||
.map_err(|e| format!("failed to create SIGTERM signal stream: {}", e))?;
|
||||
|
||||
// Get receiver for events (includes initial scan)
|
||||
let mut rx = watcher.take_receiver();
|
||||
|
||||
tracing::info!("Processing configuration events (Ctrl+C to stop)...");
|
||||
|
||||
// Heartbeat: ping workers every second, check for responses after 100ms
|
||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(1));
|
||||
let ping_timeout = tokio::time::sleep(Duration::from_secs(86400));
|
||||
tokio::pin!(ping_timeout);
|
||||
|
||||
// Process events until shutdown signal
|
||||
loop {
|
||||
tokio::select! {
|
||||
|
|
@ -43,8 +112,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||
}
|
||||
}
|
||||
}
|
||||
Some(msg) = child_rx.recv() => {
|
||||
manager.handle_child_message(msg);
|
||||
}
|
||||
_ = sigchld.recv() => {
|
||||
manager.reap_children();
|
||||
}
|
||||
_ = ping_interval.tick() => {
|
||||
manager.send_pings();
|
||||
ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_millis(100));
|
||||
}
|
||||
_ = &mut ping_timeout => {
|
||||
manager.check_ping_timeouts();
|
||||
ping_timeout.as_mut().reset(tokio::time::Instant::now() + Duration::from_secs(86400));
|
||||
}
|
||||
_ = sigterm.recv() => {
|
||||
tracing::info!("Received SIGTERM, shutting down");
|
||||
break;
|
||||
}
|
||||
_ = signal::ctrl_c() => {
|
||||
tracing::info!("Received shutdown signal");
|
||||
tracing::info!("Received SIGINT, shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
865
vm-switch/src/ring.rs
Normal file
865
vm-switch/src/ring.rs
Normal file
|
|
@ -0,0 +1,865 @@
|
|||
//! SPSC ring buffer for cross-process frame passing.
|
||||
|
||||
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
|
||||
use std::ptr::NonNull;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use nix::libc;
|
||||
|
||||
use crate::frame::MAX_FRAME_SIZE;
|
||||
|
||||
/// Errors that can occur with ring buffer operations.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RingError {
|
||||
#[error("failed to create memfd: {0}")]
|
||||
MemfdCreate(std::io::Error),
|
||||
#[error("failed to set memfd size: {0}")]
|
||||
Ftruncate(std::io::Error),
|
||||
#[error("failed to mmap: {0}")]
|
||||
Mmap(std::io::Error),
|
||||
#[error("failed to create eventfd: {0}")]
|
||||
EventfdCreate(std::io::Error),
|
||||
}
|
||||
|
||||
/// Slot data size - accommodates jumbo frames with headroom.
|
||||
pub const SLOT_DATA_SIZE: usize = MAX_FRAME_SIZE + 256;
|
||||
|
||||
/// Number of slots in the ring buffer.
|
||||
pub const RING_SIZE: usize = 64;
|
||||
|
||||
/// Total size of the ring buffer in bytes.
|
||||
pub const RING_BUFFER_SIZE: usize = std::mem::size_of::<RingHeader>()
|
||||
+ RING_SIZE * std::mem::size_of::<Slot>();
|
||||
|
||||
/// Ring buffer header containing head and tail indices.
|
||||
/// Padded to cache line size to prevent false sharing.
|
||||
#[repr(C, align(64))]
|
||||
pub struct RingHeader {
|
||||
/// Next write position (only producer modifies).
|
||||
head: AtomicU64,
|
||||
/// Next read position (only consumer modifies).
|
||||
tail: AtomicU64,
|
||||
}
|
||||
|
||||
impl RingHeader {
|
||||
/// Create a new header with head and tail at 0.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
head: AtomicU64::new(0),
|
||||
tail: AtomicU64::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the current head value.
|
||||
pub fn load_head(&self, order: Ordering) -> u64 {
|
||||
self.head.load(order)
|
||||
}
|
||||
|
||||
/// Load the current tail value.
|
||||
pub fn load_tail(&self, order: Ordering) -> u64 {
|
||||
self.tail.load(order)
|
||||
}
|
||||
|
||||
/// Store a new head value.
|
||||
pub fn store_head(&self, val: u64, order: Ordering) {
|
||||
self.head.store(val, order)
|
||||
}
|
||||
|
||||
/// Store a new tail value.
|
||||
pub fn store_tail(&self, val: u64, order: Ordering) {
|
||||
self.tail.store(val, order)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RingHeader {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// A single slot in the ring buffer.
|
||||
/// Contains frame length and data.
|
||||
#[repr(C)]
|
||||
pub struct Slot {
|
||||
/// Frame length in bytes. 0 means slot is empty/unused.
|
||||
len: u32,
|
||||
/// Padding for alignment.
|
||||
_padding: u32,
|
||||
/// Frame data buffer.
|
||||
data: [u8; SLOT_DATA_SIZE],
|
||||
}
|
||||
|
||||
impl Slot {
|
||||
/// Create a new empty slot.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
len: 0,
|
||||
_padding: 0,
|
||||
data: [0; SLOT_DATA_SIZE],
|
||||
}
|
||||
}
|
||||
|
||||
/// Write frame data to this slot.
|
||||
/// Returns false if frame is too large.
|
||||
pub fn write(&mut self, frame: &[u8]) -> bool {
|
||||
if frame.len() > SLOT_DATA_SIZE {
|
||||
return false;
|
||||
}
|
||||
self.data[..frame.len()].copy_from_slice(frame);
|
||||
self.len = frame.len() as u32;
|
||||
true
|
||||
}
|
||||
|
||||
/// Read frame data from this slot.
|
||||
/// Returns None if slot is empty.
|
||||
pub fn read(&self) -> Option<&[u8]> {
|
||||
if self.len == 0 {
|
||||
return None;
|
||||
}
|
||||
Some(&self.data[..self.len as usize])
|
||||
}
|
||||
|
||||
/// Clear the slot.
|
||||
pub fn clear(&mut self) {
|
||||
self.len = 0;
|
||||
}
|
||||
|
||||
/// Returns true if slot is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Slot {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// The in-memory layout of a ring buffer.
|
||||
/// This struct is mapped directly over shared memory.
|
||||
#[repr(C)]
|
||||
pub struct RingBuffer {
|
||||
/// Header with head/tail atomics.
|
||||
header: RingHeader,
|
||||
/// Fixed-size array of slots.
|
||||
slots: [Slot; RING_SIZE],
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
/// Returns the number of items currently in the buffer.
|
||||
pub fn len(&self) -> usize {
|
||||
let head = self.header.load_head(Ordering::Relaxed);
|
||||
let tail = self.header.load_tail(Ordering::Relaxed);
|
||||
((head + RING_SIZE as u64 - tail) % RING_SIZE as u64) as usize
|
||||
}
|
||||
|
||||
/// Returns true if the buffer is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.header.load_head(Ordering::Relaxed) == self.header.load_tail(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Returns true if the buffer is full.
|
||||
pub fn is_full(&self) -> bool {
|
||||
let head = self.header.load_head(Ordering::Relaxed);
|
||||
let tail = self.header.load_tail(Ordering::Relaxed);
|
||||
(head + 1) % RING_SIZE as u64 == tail
|
||||
}
|
||||
|
||||
/// Get mutable reference to slot at index.
|
||||
pub fn slot_mut(&mut self, index: usize) -> &mut Slot {
|
||||
&mut self.slots[index]
|
||||
}
|
||||
|
||||
/// Get reference to slot at index.
|
||||
pub fn slot(&self, index: usize) -> &Slot {
|
||||
&self.slots[index]
|
||||
}
|
||||
|
||||
/// Get reference to the header.
|
||||
pub fn header(&self) -> &RingHeader {
|
||||
&self.header
|
||||
}
|
||||
}
|
||||
|
||||
/// Producer side of an SPSC ring buffer.
|
||||
/// Creates and owns the underlying shared memory.
|
||||
pub struct Producer {
|
||||
/// Pointer to the mapped ring buffer.
|
||||
ring: NonNull<RingBuffer>,
|
||||
/// The memfd backing the ring buffer.
|
||||
memfd: OwnedFd,
|
||||
/// Eventfd for signaling consumer.
|
||||
eventfd: OwnedFd,
|
||||
}
|
||||
|
||||
// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access.
|
||||
// The producer has exclusive write access in SPSC pattern.
|
||||
unsafe impl Send for Producer {}
|
||||
unsafe impl Sync for Producer {}
|
||||
|
||||
impl Producer {
|
||||
/// Create a new producer with its own shared memory region.
|
||||
pub fn new() -> Result<Self, RingError> {
|
||||
// Create memfd for shared memory
|
||||
let memfd = unsafe {
|
||||
let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC);
|
||||
if fd < 0 {
|
||||
return Err(RingError::MemfdCreate(std::io::Error::last_os_error()));
|
||||
}
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
// Set size
|
||||
let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) };
|
||||
if ret < 0 {
|
||||
return Err(RingError::Ftruncate(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
// Map the memory
|
||||
let ptr = unsafe {
|
||||
libc::mmap(
|
||||
std::ptr::null_mut(),
|
||||
RING_BUFFER_SIZE,
|
||||
libc::PROT_READ | libc::PROT_WRITE,
|
||||
libc::MAP_SHARED,
|
||||
memfd.as_raw_fd(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
if ptr == libc::MAP_FAILED {
|
||||
return Err(RingError::Mmap(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
// Initialize the ring buffer
|
||||
let ring = ptr as *mut RingBuffer;
|
||||
unsafe {
|
||||
std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new());
|
||||
for i in 0..RING_SIZE {
|
||||
std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new());
|
||||
}
|
||||
}
|
||||
|
||||
// Create eventfd for signaling
|
||||
let eventfd = unsafe {
|
||||
let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK);
|
||||
if fd < 0 {
|
||||
libc::munmap(ptr, RING_BUFFER_SIZE);
|
||||
return Err(RingError::EventfdCreate(std::io::Error::last_os_error()));
|
||||
}
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
ring: NonNull::new(ring).unwrap(),
|
||||
memfd,
|
||||
eventfd,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the memfd for sharing with consumer.
|
||||
pub fn memfd(&self) -> &OwnedFd {
|
||||
&self.memfd
|
||||
}
|
||||
|
||||
/// Get the eventfd for sharing with consumer.
|
||||
pub fn eventfd(&self) -> &OwnedFd {
|
||||
&self.eventfd
|
||||
}
|
||||
|
||||
/// Create a producer by mapping an existing consumer's shared memory.
|
||||
/// Use this when you receive FDs from a remote consumer and want to produce into their buffer.
|
||||
pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result<Self, RingError> {
|
||||
// Map the memory
|
||||
let ptr = unsafe {
|
||||
libc::mmap(
|
||||
std::ptr::null_mut(),
|
||||
RING_BUFFER_SIZE,
|
||||
libc::PROT_READ | libc::PROT_WRITE,
|
||||
libc::MAP_SHARED,
|
||||
memfd.as_raw_fd(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
if ptr == libc::MAP_FAILED {
|
||||
return Err(RingError::Mmap(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let ring = ptr as *mut RingBuffer;
|
||||
|
||||
Ok(Self {
|
||||
ring: NonNull::new(ring).unwrap(),
|
||||
memfd,
|
||||
eventfd,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get reference to the ring buffer.
|
||||
fn ring(&self) -> &RingBuffer {
|
||||
// SAFETY: Pointer is valid for lifetime of Producer.
|
||||
unsafe { self.ring.as_ref() }
|
||||
}
|
||||
|
||||
/// Write a frame to the slot at the given index using the raw pointer.
|
||||
///
|
||||
/// SAFETY: Producer has exclusive write access to slots in the SPSC pattern.
|
||||
/// Only the producer calls this, and only for the slot at `head`.
|
||||
/// Uses raw pointer writes to avoid creating `&mut RingBuffer` which would
|
||||
/// violate Rust's aliasing rules (since `&RingBuffer` references coexist).
|
||||
unsafe fn write_slot(&self, index: usize, frame: &[u8]) -> bool {
|
||||
if frame.len() > SLOT_DATA_SIZE {
|
||||
return false;
|
||||
}
|
||||
let ring_ptr = self.ring.as_ptr();
|
||||
let slot_ptr = std::ptr::addr_of_mut!((*ring_ptr).slots[index]);
|
||||
let data_ptr = std::ptr::addr_of_mut!((*slot_ptr).data) as *mut u8;
|
||||
std::ptr::copy_nonoverlapping(frame.as_ptr(), data_ptr, frame.len());
|
||||
std::ptr::addr_of_mut!((*slot_ptr).len).write(frame.len() as u32);
|
||||
true
|
||||
}
|
||||
|
||||
/// Push a frame into the ring buffer.
|
||||
/// Returns true if successful, false if buffer is full (frame dropped).
|
||||
pub fn push(&self, frame: &[u8]) -> bool {
|
||||
let ring = self.ring();
|
||||
let head = ring.header().load_head(Ordering::Relaxed);
|
||||
let tail = ring.header().load_tail(Ordering::Acquire);
|
||||
|
||||
// Check if full (one slot always empty to distinguish full from empty)
|
||||
let next_head = (head + 1) % RING_SIZE as u64;
|
||||
if next_head == tail {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Write to slot via raw pointer (avoids creating &mut RingBuffer alias)
|
||||
// SAFETY: Producer has exclusive write access to the slot at head
|
||||
if !unsafe { self.write_slot(head as usize, frame) } {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Memory fence to ensure data is visible before advancing head
|
||||
std::sync::atomic::fence(Ordering::Release);
|
||||
|
||||
// Advance head
|
||||
ring.header().store_head(next_head, Ordering::Relaxed);
|
||||
|
||||
// Signal consumer if buffer was empty
|
||||
if head == tail {
|
||||
let val: u64 = 1;
|
||||
unsafe {
|
||||
libc::write(
|
||||
self.eventfd.as_raw_fd(),
|
||||
&val as *const u64 as *const libc::c_void,
|
||||
std::mem::size_of::<u64>(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Producer {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: We own the mapping and it's the correct size.
|
||||
unsafe {
|
||||
libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Consumer side of an SPSC ring buffer.
|
||||
/// Can either create its own shared memory (via `new()`) or map memory from a producer (via `from_fds()`).
|
||||
pub struct Consumer {
|
||||
/// Pointer to the mapped ring buffer.
|
||||
ring: NonNull<RingBuffer>,
|
||||
/// The memfd backing the ring buffer.
|
||||
memfd: OwnedFd,
|
||||
/// Eventfd for receiving signals from producer.
|
||||
eventfd: OwnedFd,
|
||||
}
|
||||
|
||||
// SAFETY: The ring buffer uses proper atomic operations for cross-thread/process access.
|
||||
unsafe impl Send for Consumer {}
|
||||
|
||||
impl Consumer {
|
||||
/// Create a new consumer with its own shared memory region.
|
||||
/// Use this when YOU will be the consumer and share FDs with a remote producer.
|
||||
pub fn new() -> Result<Self, RingError> {
|
||||
// Create memfd for shared memory
|
||||
let memfd = unsafe {
|
||||
let fd = libc::memfd_create(c"ring_buffer".as_ptr(), libc::MFD_CLOEXEC);
|
||||
if fd < 0 {
|
||||
return Err(RingError::MemfdCreate(std::io::Error::last_os_error()));
|
||||
}
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
// Set size
|
||||
let ret = unsafe { libc::ftruncate(memfd.as_raw_fd(), RING_BUFFER_SIZE as libc::off_t) };
|
||||
if ret < 0 {
|
||||
return Err(RingError::Ftruncate(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
// Map the memory
|
||||
let ptr = unsafe {
|
||||
libc::mmap(
|
||||
std::ptr::null_mut(),
|
||||
RING_BUFFER_SIZE,
|
||||
libc::PROT_READ | libc::PROT_WRITE,
|
||||
libc::MAP_SHARED,
|
||||
memfd.as_raw_fd(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
if ptr == libc::MAP_FAILED {
|
||||
return Err(RingError::Mmap(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
// Initialize the ring buffer
|
||||
let ring = ptr as *mut RingBuffer;
|
||||
unsafe {
|
||||
std::ptr::addr_of_mut!((*ring).header).write(RingHeader::new());
|
||||
for i in 0..RING_SIZE {
|
||||
std::ptr::addr_of_mut!((*ring).slots[i]).write(Slot::new());
|
||||
}
|
||||
}
|
||||
|
||||
// Create eventfd for signaling
|
||||
let eventfd = unsafe {
|
||||
let fd = libc::eventfd(0, libc::EFD_CLOEXEC | libc::EFD_NONBLOCK);
|
||||
if fd < 0 {
|
||||
libc::munmap(ptr, RING_BUFFER_SIZE);
|
||||
return Err(RingError::EventfdCreate(std::io::Error::last_os_error()));
|
||||
}
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
ring: NonNull::new(ring).unwrap(),
|
||||
memfd,
|
||||
eventfd,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get reference to the ring buffer.
|
||||
fn ring(&self) -> &RingBuffer {
|
||||
// SAFETY: Pointer is valid for lifetime of Consumer.
|
||||
unsafe { self.ring.as_ref() }
|
||||
}
|
||||
|
||||
/// Create a consumer by mapping an existing producer's shared memory.
|
||||
pub fn from_fds(memfd: OwnedFd, eventfd: OwnedFd) -> Result<Self, RingError> {
|
||||
// Map the memory
|
||||
let ptr = unsafe {
|
||||
libc::mmap(
|
||||
std::ptr::null_mut(),
|
||||
RING_BUFFER_SIZE,
|
||||
libc::PROT_READ | libc::PROT_WRITE,
|
||||
libc::MAP_SHARED,
|
||||
memfd.as_raw_fd(),
|
||||
0,
|
||||
)
|
||||
};
|
||||
if ptr == libc::MAP_FAILED {
|
||||
return Err(RingError::Mmap(std::io::Error::last_os_error()));
|
||||
}
|
||||
|
||||
let ring = ptr as *mut RingBuffer;
|
||||
|
||||
Ok(Self {
|
||||
ring: NonNull::new(ring).unwrap(),
|
||||
memfd,
|
||||
eventfd,
|
||||
})
|
||||
}
|
||||
|
||||
/// Pop a frame from the ring buffer.
|
||||
/// Returns None if buffer is empty.
|
||||
pub fn pop(&self) -> Option<Vec<u8>> {
|
||||
let ring = self.ring();
|
||||
let tail = ring.header().load_tail(Ordering::Relaxed);
|
||||
let head = ring.header().load_head(Ordering::Acquire);
|
||||
|
||||
// Check if empty
|
||||
if head == tail {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read from slot
|
||||
let data = ring.slot(tail as usize).read()?.to_vec();
|
||||
|
||||
// Advance tail
|
||||
let next_tail = (tail + 1) % RING_SIZE as u64;
|
||||
ring.header().store_tail(next_tail, Ordering::Release);
|
||||
|
||||
Some(data)
|
||||
}
|
||||
|
||||
/// Get the memfd for sharing with producer.
|
||||
pub fn memfd(&self) -> &OwnedFd {
|
||||
&self.memfd
|
||||
}
|
||||
|
||||
/// Get the eventfd for polling.
|
||||
pub fn eventfd(&self) -> &OwnedFd {
|
||||
&self.eventfd
|
||||
}
|
||||
|
||||
/// Drain the eventfd, returning the notification count.
|
||||
/// Returns 0 if no notifications pending.
|
||||
pub fn drain_eventfd(&self) -> u64 {
|
||||
let mut val: u64 = 0;
|
||||
let ret = unsafe {
|
||||
libc::read(
|
||||
self.eventfd.as_raw_fd(),
|
||||
&mut val as *mut u64 as *mut libc::c_void,
|
||||
std::mem::size_of::<u64>(),
|
||||
)
|
||||
};
|
||||
if ret == std::mem::size_of::<u64>() as isize {
|
||||
val
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Consumer {
|
||||
fn drop(&mut self) {
|
||||
// SAFETY: We own the mapping and it's the correct size.
|
||||
unsafe {
|
||||
libc::munmap(self.ring.as_ptr() as *mut libc::c_void, RING_BUFFER_SIZE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn slot_write_and_read() {
|
||||
let mut slot = Slot::new();
|
||||
let frame = [1u8, 2, 3, 4, 5];
|
||||
|
||||
assert!(slot.write(&frame));
|
||||
let read_data = slot.read().expect("slot should have data");
|
||||
assert_eq!(read_data, &frame);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slot_clear_resets_to_empty() {
|
||||
let mut slot = Slot::new();
|
||||
slot.write(&[1, 2, 3]);
|
||||
|
||||
slot.clear();
|
||||
assert!(slot.is_empty());
|
||||
assert!(slot.read().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slot_rejects_oversized_frame() {
|
||||
let mut slot = Slot::new();
|
||||
let oversized = vec![0u8; SLOT_DATA_SIZE + 1];
|
||||
|
||||
assert!(!slot.write(&oversized));
|
||||
assert!(slot.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_stores_and_loads_head() {
|
||||
let header = RingHeader::new();
|
||||
|
||||
header.store_head(42, Ordering::Relaxed);
|
||||
assert_eq!(header.load_head(Ordering::Relaxed), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn header_stores_and_loads_tail() {
|
||||
let header = RingHeader::new();
|
||||
|
||||
header.store_tail(99, Ordering::Relaxed);
|
||||
assert_eq!(header.load_tail(Ordering::Relaxed), 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_len_reflects_items() {
|
||||
// We'll test via header manipulation since we don't have push/pop yet
|
||||
let mut buffer = std::mem::MaybeUninit::<RingBuffer>::uninit();
|
||||
// SAFETY: We're testing the layout, initializing header manually
|
||||
let buffer = unsafe {
|
||||
let ptr = buffer.as_mut_ptr();
|
||||
std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new());
|
||||
// Initialize a few slots
|
||||
for i in 0..RING_SIZE {
|
||||
std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new());
|
||||
}
|
||||
buffer.assume_init_mut()
|
||||
};
|
||||
|
||||
// Initially empty
|
||||
assert_eq!(buffer.len(), 0);
|
||||
assert!(buffer.is_empty());
|
||||
|
||||
// Simulate adding 3 items by advancing head
|
||||
buffer.header().store_head(3, Ordering::Relaxed);
|
||||
assert_eq!(buffer.len(), 3);
|
||||
assert!(!buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_is_full_when_one_slot_remains() {
|
||||
let mut buffer = std::mem::MaybeUninit::<RingBuffer>::uninit();
|
||||
let buffer = unsafe {
|
||||
let ptr = buffer.as_mut_ptr();
|
||||
std::ptr::addr_of_mut!((*ptr).header).write(RingHeader::new());
|
||||
for i in 0..RING_SIZE {
|
||||
std::ptr::addr_of_mut!((*ptr).slots[i]).write(Slot::new());
|
||||
}
|
||||
buffer.assume_init_mut()
|
||||
};
|
||||
|
||||
// Head at RING_SIZE-1, tail at 0 means buffer is full
|
||||
// (we keep one slot empty to distinguish full from empty)
|
||||
buffer.header().store_head((RING_SIZE - 1) as u64, Ordering::Relaxed);
|
||||
buffer.header().store_tail(0, Ordering::Relaxed);
|
||||
|
||||
assert!(buffer.is_full());
|
||||
assert_eq!(buffer.len(), RING_SIZE - 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn producer_new_creates_valid_buffer() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
|
||||
// memfd and eventfd should be valid file descriptors
|
||||
assert!(producer.memfd().as_raw_fd() >= 0);
|
||||
assert!(producer.eventfd().as_raw_fd() >= 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_maps_producer_memory() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
|
||||
// Duplicate FDs (simulating what would happen with SCM_RIGHTS)
|
||||
let memfd_dup = unsafe {
|
||||
let fd = libc::dup(producer.memfd().as_raw_fd());
|
||||
assert!(fd >= 0);
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
let fd = libc::dup(producer.eventfd().as_raw_fd());
|
||||
assert!(fd >= 0);
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
assert!(consumer.eventfd().as_raw_fd() >= 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn producer_push_adds_frame() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
let frame = [1u8, 2, 3, 4, 5];
|
||||
|
||||
let result = producer.push(&frame);
|
||||
|
||||
assert!(result);
|
||||
// Verify head advanced
|
||||
assert_eq!(producer.ring().header().load_head(Ordering::Relaxed), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_pop_returns_pushed_frame() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
let frame = [1u8, 2, 3, 4, 5];
|
||||
producer.push(&frame);
|
||||
|
||||
// Create consumer with duplicated FDs
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
|
||||
let popped = consumer.pop();
|
||||
|
||||
assert!(popped.is_some());
|
||||
assert_eq!(popped.unwrap(), frame);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_maintains_fifo_order() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
|
||||
producer.push(&[1]);
|
||||
producer.push(&[2]);
|
||||
producer.push(&[3]);
|
||||
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
|
||||
assert_eq!(consumer.pop().unwrap(), vec![1]);
|
||||
assert_eq!(consumer.pop().unwrap(), vec![2]);
|
||||
assert_eq!(consumer.pop().unwrap(), vec![3]);
|
||||
assert!(consumer.pop().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn producer_drops_frame_when_full() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
|
||||
// Fill the buffer (RING_SIZE - 1 slots usable)
|
||||
for i in 0..(RING_SIZE - 1) {
|
||||
assert!(producer.push(&[i as u8]), "push {} should succeed", i);
|
||||
}
|
||||
|
||||
// This should fail - buffer full
|
||||
assert!(!producer.push(&[255]));
|
||||
|
||||
// Create consumer and verify we can pop all items
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
|
||||
for i in 0..(RING_SIZE - 1) {
|
||||
let data = consumer.pop().expect("should have data");
|
||||
assert_eq!(data, vec![i as u8]);
|
||||
}
|
||||
assert!(consumer.pop().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_wraps_around_correctly() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
|
||||
// Push and pop more items than RING_SIZE to test wraparound
|
||||
for round in 0..3 {
|
||||
for i in 0..(RING_SIZE - 1) {
|
||||
let val = (round * RING_SIZE + i) as u8;
|
||||
assert!(producer.push(&[val]));
|
||||
}
|
||||
for i in 0..(RING_SIZE - 1) {
|
||||
let val = (round * RING_SIZE + i) as u8;
|
||||
assert_eq!(consumer.pop().unwrap(), vec![val]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_drain_eventfd_clears_notifications() {
|
||||
let producer = Producer::new().expect("should create producer");
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let consumer = Consumer::from_fds(memfd_dup, eventfd_dup).expect("should create consumer");
|
||||
|
||||
// Push triggers eventfd write
|
||||
producer.push(&[1]);
|
||||
producer.push(&[2]);
|
||||
|
||||
// Drain should return the accumulated count
|
||||
let count = consumer.drain_eventfd();
|
||||
assert!(count > 0);
|
||||
|
||||
// Second drain should return 0 (nothing new)
|
||||
let count2 = consumer.drain_eventfd();
|
||||
assert_eq!(count2, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_new_creates_valid_buffer() {
|
||||
let consumer = Consumer::new().expect("should create consumer");
|
||||
|
||||
// memfd and eventfd should be valid file descriptors
|
||||
assert!(consumer.memfd().as_raw_fd() >= 0);
|
||||
assert!(consumer.eventfd().as_raw_fd() >= 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn producer_from_fds_maps_consumer_memory() {
|
||||
// Consumer creates buffer, producer connects to it
|
||||
let consumer = Consumer::new().expect("should create consumer");
|
||||
|
||||
// Duplicate FDs (simulating what would happen with SCM_RIGHTS)
|
||||
let memfd_dup = unsafe {
|
||||
let fd = libc::dup(consumer.memfd().as_raw_fd());
|
||||
assert!(fd >= 0);
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
let fd = libc::dup(consumer.eventfd().as_raw_fd());
|
||||
assert!(fd >= 0);
|
||||
OwnedFd::from_raw_fd(fd)
|
||||
};
|
||||
|
||||
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
|
||||
assert!(producer.eventfd().as_raw_fd() >= 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_created_buffer_receives_from_producer() {
|
||||
// Consumer creates buffer, shares with remote producer
|
||||
let consumer = Consumer::new().expect("should create consumer");
|
||||
|
||||
// Create producer from consumer's FDs
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
|
||||
|
||||
// Producer pushes, consumer receives
|
||||
let frame = [1u8, 2, 3, 4, 5];
|
||||
assert!(producer.push(&frame));
|
||||
|
||||
let popped = consumer.pop();
|
||||
assert!(popped.is_some());
|
||||
assert_eq!(popped.unwrap(), frame);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn consumer_created_buffer_maintains_fifo_order() {
|
||||
let consumer = Consumer::new().expect("should create consumer");
|
||||
let memfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd()))
|
||||
};
|
||||
let eventfd_dup = unsafe {
|
||||
OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd()))
|
||||
};
|
||||
let producer = Producer::from_fds(memfd_dup, eventfd_dup).expect("should create producer");
|
||||
|
||||
producer.push(&[1]);
|
||||
producer.push(&[2]);
|
||||
producer.push(&[3]);
|
||||
|
||||
assert_eq!(consumer.pop().unwrap(), vec![1]);
|
||||
assert_eq!(consumer.pop().unwrap(), vec![2]);
|
||||
assert_eq!(consumer.pop().unwrap(), vec![3]);
|
||||
assert!(consumer.pop().is_none());
|
||||
}
|
||||
}
|
||||
564
vm-switch/src/sandbox.rs
Normal file
564
vm-switch/src/sandbox.rs
Normal file
|
|
@ -0,0 +1,564 @@
|
|||
//! Namespace sandboxing for process isolation.
|
||||
|
||||
use nix::mount::{mount, MsFlags};
|
||||
use nix::sched::{unshare, CloneFlags};
|
||||
use nix::sys::signal::{self, SaFlags, SigAction, SigHandler, SigSet, Signal};
|
||||
use nix::sys::wait::{waitpid, WaitStatus};
|
||||
use nix::unistd::{chdir, fork, pivot_root, ForkResult, Gid, Pid, Uid};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::atomic::{AtomicI32, Ordering};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during sandbox setup.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SandboxError {
|
||||
#[error("failed to unshare namespace: {0}")]
|
||||
Unshare(#[source] nix::Error),
|
||||
|
||||
#[error("fork failed: {0}")]
|
||||
Fork(#[source] nix::Error),
|
||||
|
||||
#[error("failed to write {path}: {source}")]
|
||||
WriteFile {
|
||||
path: String,
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
#[error("failed to {operation} {target}: {source}")]
|
||||
Mount {
|
||||
operation: String,
|
||||
target: String,
|
||||
#[source]
|
||||
source: nix::Error,
|
||||
},
|
||||
|
||||
#[error("failed to create directory {path}: {source}")]
|
||||
Mkdir {
|
||||
path: String,
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
#[error("failed to pivot_root: {0}")]
|
||||
PivotRoot(#[source] nix::Error),
|
||||
|
||||
#[error("failed to change directory to {path}: {source}")]
|
||||
Chdir {
|
||||
path: String,
|
||||
#[source]
|
||||
source: nix::Error,
|
||||
},
|
||||
}
|
||||
|
||||
/// Result of applying sandbox.
|
||||
#[derive(Debug)]
|
||||
pub enum SandboxResult {
|
||||
/// We are the wrapper parent. Contains child's exit code.
|
||||
Parent(i32),
|
||||
/// Sandbox applied successfully. Contains new config path.
|
||||
Sandboxed(std::path::PathBuf),
|
||||
}
|
||||
|
||||
/// Generate the content for /proc/self/uid_map.
|
||||
///
|
||||
/// Maps the given outside UID to UID 0 inside the namespace.
|
||||
fn generate_uid_map(outside_uid: u32) -> String {
|
||||
format!("0 {} 1\n", outside_uid)
|
||||
}
|
||||
|
||||
/// Generate the content for /proc/self/gid_map.
|
||||
///
|
||||
/// Maps the given outside GID to GID 0 inside the namespace.
|
||||
fn generate_gid_map(outside_gid: u32) -> String {
|
||||
format!("0 {} 1\n", outside_gid)
|
||||
}
|
||||
|
||||
/// Write uid_map and gid_map files for the current process.
|
||||
///
|
||||
/// Must be called after unshare(CLONE_NEWUSER).
|
||||
/// Writes "deny" to setgroups before gid_map (required by kernel).
|
||||
fn write_uid_gid_maps(outside_uid: Uid, outside_gid: Gid) -> Result<(), SandboxError> {
|
||||
// Must deny setgroups before writing gid_map (kernel requirement)
|
||||
fs::write("/proc/self/setgroups", "deny").map_err(|e| SandboxError::WriteFile {
|
||||
path: "/proc/self/setgroups".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Write uid_map
|
||||
let uid_content = generate_uid_map(outside_uid.as_raw());
|
||||
fs::write("/proc/self/uid_map", &uid_content).map_err(|e| SandboxError::WriteFile {
|
||||
path: "/proc/self/uid_map".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Write gid_map
|
||||
let gid_content = generate_gid_map(outside_gid.as_raw());
|
||||
fs::write("/proc/self/gid_map", &gid_content).map_err(|e| SandboxError::WriteFile {
|
||||
path: "/proc/self/gid_map".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Global storing the child PID for the signal forwarding handler.
|
||||
static CHILD_PID: AtomicI32 = AtomicI32::new(0);
|
||||
|
||||
/// Signal handler that forwards SIGTERM to the child process.
|
||||
extern "C" fn forward_signal(_sig: libc::c_int) {
|
||||
let pid = CHILD_PID.load(Ordering::Relaxed);
|
||||
if pid > 0 {
|
||||
unsafe { libc::kill(pid, libc::SIGTERM) };
|
||||
}
|
||||
}
|
||||
|
||||
/// Install a SIGTERM handler on the parent that forwards to the child.
|
||||
fn install_forwarding_handler(child: Pid) {
|
||||
CHILD_PID.store(child.as_raw(), Ordering::Relaxed);
|
||||
let action = SigAction::new(
|
||||
SigHandler::Handler(forward_signal),
|
||||
SaFlags::empty(),
|
||||
SigSet::empty(),
|
||||
);
|
||||
// Safety: forward_signal is async-signal-safe (only calls kill)
|
||||
unsafe { signal::sigaction(Signal::SIGTERM, &action) }.ok();
|
||||
}
|
||||
|
||||
/// Fork into new PID and mount namespaces.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Calls `unshare(CLONE_NEWPID | CLONE_NEWNS)` to create new namespaces
|
||||
/// 2. Makes all mounts private (prevents propagation)
|
||||
/// 3. Forks - the child becomes PID 1 in the new PID namespace
|
||||
/// 4. Parent waits for child and returns its exit status
|
||||
/// 5. Child continues execution
|
||||
///
|
||||
/// Returns:
|
||||
/// - `Ok(Some(exit_code))` in the parent (should propagate and exit)
|
||||
/// - `Ok(None)` in the child (continue with sandbox setup)
|
||||
/// - `Err` on failure
|
||||
///
|
||||
/// Must be called AFTER entering user namespace and BEFORE
|
||||
/// starting any multi-threaded runtime.
|
||||
pub fn fork_into_pid_namespace() -> Result<Option<i32>, SandboxError> {
|
||||
// Create new PID namespace and mount namespace together
|
||||
// The mount namespace is needed so we can mount procfs for the new PID namespace
|
||||
unshare(CloneFlags::CLONE_NEWPID | CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?;
|
||||
|
||||
// Make all mounts private to prevent propagation
|
||||
mount(
|
||||
None::<&str>,
|
||||
"/",
|
||||
None::<&str>,
|
||||
MsFlags::MS_REC | MsFlags::MS_PRIVATE,
|
||||
None::<&str>,
|
||||
)
|
||||
.map_err(|e| SandboxError::Mount {
|
||||
operation: "make private".to_string(),
|
||||
target: "/".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Fork - child becomes PID 1 in the new namespace
|
||||
match unsafe { fork() }.map_err(SandboxError::Fork)? {
|
||||
ForkResult::Parent { child } => {
|
||||
// Install signal handler that forwards SIGTERM to the child.
|
||||
// Without this, SIGTERM kills the parent (default action) without
|
||||
// notifying the child, which as PID 1 in a namespace ignores
|
||||
// unhandled signals.
|
||||
install_forwarding_handler(child);
|
||||
|
||||
// Parent: wait for child and collect exit status
|
||||
loop {
|
||||
match waitpid(child, None) {
|
||||
Ok(WaitStatus::Exited(_, code)) => {
|
||||
return Ok(Some(code));
|
||||
}
|
||||
Ok(WaitStatus::Signaled(_, sig, _)) => {
|
||||
// Child killed by signal, propagate as exit code
|
||||
return Ok(Some(128 + sig as i32));
|
||||
}
|
||||
Ok(_) => continue, // Stopped/continued, keep waiting
|
||||
Err(nix::Error::EINTR) => continue, // Interrupted, retry
|
||||
Err(e) => {
|
||||
return Err(SandboxError::Fork(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ForkResult::Child => {
|
||||
// Child: we are now PID 1 in the new namespace
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Enter a new user namespace.
|
||||
///
|
||||
/// After this call:
|
||||
/// - The process appears to run as UID 0 / GID 0 inside the namespace
|
||||
/// - The process can create other namespaces (mount, IPC, network, PID)
|
||||
/// - The process has no capabilities in the parent namespace
|
||||
///
|
||||
/// Must be called before any other namespace operations.
|
||||
pub fn enter_user_namespace() -> Result<(), SandboxError> {
|
||||
// Capture current UID/GID before unshare
|
||||
let outside_uid = Uid::current();
|
||||
let outside_gid = Gid::current();
|
||||
|
||||
// Create new user namespace
|
||||
unshare(CloneFlags::CLONE_NEWUSER).map_err(SandboxError::Unshare)?;
|
||||
|
||||
// Set up UID/GID mappings
|
||||
write_uid_gid_maps(outside_uid, outside_gid)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enter a new mount namespace with private mounts.
|
||||
///
|
||||
/// After this call:
|
||||
/// - Mount changes are isolated from the parent namespace
|
||||
/// - All mounts are marked as private (no propagation)
|
||||
///
|
||||
/// Must be called after `enter_user_namespace()`.
|
||||
pub fn enter_mount_namespace() -> Result<(), SandboxError> {
|
||||
// Create new mount namespace
|
||||
unshare(CloneFlags::CLONE_NEWNS).map_err(SandboxError::Unshare)?;
|
||||
|
||||
// Make all mounts private to prevent propagation
|
||||
mount(
|
||||
None::<&str>,
|
||||
"/",
|
||||
None::<&str>,
|
||||
MsFlags::MS_REC | MsFlags::MS_PRIVATE,
|
||||
None::<&str>,
|
||||
)
|
||||
.map_err(|e| SandboxError::Mount {
|
||||
operation: "make private".to_string(),
|
||||
target: "/".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create minimal /dev with essential devices.
|
||||
///
|
||||
/// Bind-mounts /dev/null, /dev/zero, and /dev/urandom from the host.
|
||||
/// These are required for basic operation (logging, random numbers).
|
||||
fn create_minimal_dev(new_root: &Path) -> Result<(), SandboxError> {
|
||||
let dev_dir = new_root.join("dev");
|
||||
|
||||
// Device nodes to bind-mount from host
|
||||
let devices = ["null", "zero", "urandom"];
|
||||
|
||||
for device in devices {
|
||||
let target = dev_dir.join(device);
|
||||
|
||||
// Create empty file as mount point
|
||||
fs::write(&target, b"").map_err(|e| SandboxError::WriteFile {
|
||||
path: target.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Bind-mount the device from host
|
||||
let source = Path::new("/dev").join(device);
|
||||
mount(
|
||||
Some(&source),
|
||||
&target,
|
||||
None::<&str>,
|
||||
MsFlags::MS_BIND,
|
||||
None::<&str>,
|
||||
)
|
||||
.map_err(|e| SandboxError::Mount {
|
||||
operation: "bind".to_string(),
|
||||
target: target.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a minimal root filesystem on tmpfs.
|
||||
///
|
||||
/// Creates:
|
||||
/// - tmpfs mounted at a temporary location
|
||||
/// - /proc (mount point)
|
||||
/// - /dev with null, zero, urandom
|
||||
/// - /tmp (empty tmpfs)
|
||||
/// - /config (bind-mount of config_dir)
|
||||
///
|
||||
/// Returns the path to the new root.
|
||||
fn setup_minimal_root(config_dir: &Path) -> Result<std::path::PathBuf, SandboxError> {
|
||||
// Create tmpfs for new root
|
||||
let new_root = Path::new("/tmp/sandbox-root");
|
||||
|
||||
fs::create_dir_all(new_root).map_err(|e| SandboxError::Mkdir {
|
||||
path: new_root.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Mount tmpfs on new root
|
||||
mount(
|
||||
Some("tmpfs"),
|
||||
new_root,
|
||||
Some("tmpfs"),
|
||||
MsFlags::MS_NOSUID | MsFlags::MS_NODEV,
|
||||
Some("size=16M,mode=0755"),
|
||||
)
|
||||
.map_err(|e| SandboxError::Mount {
|
||||
operation: "mount tmpfs".to_string(),
|
||||
target: new_root.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Create directory structure
|
||||
let dirs = ["proc", "dev", "tmp", "config", "old_root"];
|
||||
for dir in dirs {
|
||||
let path = new_root.join(dir);
|
||||
fs::create_dir_all(&path).map_err(|e| SandboxError::Mkdir {
|
||||
path: path.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
}
|
||||
|
||||
// Set up /dev
|
||||
create_minimal_dev(new_root)?;
|
||||
|
||||
// Bind-mount config directory
|
||||
let config_target = new_root.join("config");
|
||||
mount(
|
||||
Some(config_dir),
|
||||
&config_target,
|
||||
None::<&str>,
|
||||
MsFlags::MS_BIND | MsFlags::MS_REC,
|
||||
None::<&str>,
|
||||
)
|
||||
.map_err(|e| SandboxError::Mount {
|
||||
operation: "bind config".to_string(),
|
||||
target: config_target.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Mount /proc BEFORE pivot_root (required for proc mount to work)
|
||||
// This may fail if we don't own a PID namespace (EPERM), which is fine
|
||||
// for standalone filesystem isolation without apply_sandbox()
|
||||
let proc_target = new_root.join("proc");
|
||||
match mount(
|
||||
Some("proc"),
|
||||
&proc_target,
|
||||
Some("proc"),
|
||||
MsFlags::MS_NOSUID | MsFlags::MS_NODEV | MsFlags::MS_NOEXEC,
|
||||
None::<&str>,
|
||||
) {
|
||||
Ok(()) => {}
|
||||
Err(nix::Error::EPERM) => {
|
||||
// Can't mount procfs without owning a PID namespace - this is expected
|
||||
// when setup_filesystem_isolation is called standalone
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(SandboxError::Mount {
|
||||
operation: "mount".to_string(),
|
||||
target: proc_target.display().to_string(),
|
||||
source: e,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(new_root.to_path_buf())
|
||||
}
|
||||
|
||||
/// Pivot root to the new filesystem and unmount the old root.
|
||||
///
|
||||
/// After this call, the process sees only the new root filesystem.
|
||||
/// The old root is unmounted and inaccessible.
|
||||
fn pivot_to_new_root(new_root: &Path) -> Result<(), SandboxError> {
|
||||
// Change to new root before pivot
|
||||
chdir(new_root).map_err(|e| SandboxError::Chdir {
|
||||
path: new_root.display().to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// Pivot root: new_root becomes /, old root moves to old_root
|
||||
pivot_root(".", "old_root").map_err(SandboxError::PivotRoot)?;
|
||||
|
||||
// Change to new root
|
||||
chdir("/").map_err(|e| SandboxError::Chdir {
|
||||
path: "/".to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// After pivot_root, old root is now at /old_root in the new namespace
|
||||
let old_root = Path::new("/old_root");
|
||||
|
||||
// Unmount old root
|
||||
nix::mount::umount2(old_root, nix::mount::MntFlags::MNT_DETACH).map_err(|e| {
|
||||
SandboxError::Mount {
|
||||
operation: "umount".to_string(),
|
||||
target: old_root.display().to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
|
||||
// Remove old_root directory
|
||||
let _ = fs::remove_dir(old_root);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enter a new IPC namespace.
|
||||
///
|
||||
/// After this call:
|
||||
/// - System V IPC objects (semaphores, message queues, shared memory) are isolated
|
||||
/// - The process cannot access host IPC objects
|
||||
///
|
||||
/// Must be called after `enter_user_namespace()`.
|
||||
pub fn enter_ipc_namespace() -> Result<(), SandboxError> {
|
||||
unshare(CloneFlags::CLONE_NEWIPC).map_err(SandboxError::Unshare)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enter a new network namespace.
|
||||
///
|
||||
/// After this call:
|
||||
/// - The process has no network interfaces (not even loopback)
|
||||
/// - Network communication is only possible via inherited file descriptors
|
||||
/// - Unix sockets created before entering the namespace remain usable
|
||||
///
|
||||
/// Must be called after `enter_user_namespace()`.
|
||||
pub fn enter_network_namespace() -> Result<(), SandboxError> {
|
||||
unshare(CloneFlags::CLONE_NEWNET).map_err(SandboxError::Unshare)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply full sandbox isolation to the current process.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Enters user namespace (appears as root inside, enables other namespaces)
|
||||
/// 2. Forks into a new PID namespace (becomes PID 1)
|
||||
/// 3. Sets up filesystem isolation (minimal root with config at /config)
|
||||
/// 4. Enters IPC namespace (isolates System V IPC)
|
||||
/// 5. Enters network namespace (no network interfaces)
|
||||
///
|
||||
/// IMPORTANT: This function forks. In the parent, it returns
|
||||
/// `Ok(SandboxResult::Parent(exit_code))`. The parent should propagate
|
||||
/// this exit code. In the child, it returns
|
||||
/// `Ok(SandboxResult::Sandboxed(config_path))`.
|
||||
///
|
||||
/// After sandboxing:
|
||||
/// - The process is PID 1 in its own PID namespace
|
||||
/// - The process appears to run as UID 0
|
||||
/// - Only /config, /dev (minimal), /proc, and /tmp are accessible
|
||||
/// - System V IPC is isolated from host
|
||||
/// - No network interfaces (communication via inherited FDs only)
|
||||
///
|
||||
/// Must be called BEFORE starting any multi-threaded runtime (tokio).
|
||||
pub fn apply_sandbox(config_dir: &Path) -> Result<SandboxResult, SandboxError> {
|
||||
// 1. Enter user namespace first (provides CAP_SYS_ADMIN for other namespaces)
|
||||
enter_user_namespace()?;
|
||||
|
||||
// 2. Fork into PID namespace (requires CAP_SYS_ADMIN from user namespace)
|
||||
if let Some(exit_code) = fork_into_pid_namespace()? {
|
||||
return Ok(SandboxResult::Parent(exit_code));
|
||||
}
|
||||
|
||||
// 3. Set up filesystem isolation (mount ns already created by fork_into_pid_namespace)
|
||||
setup_filesystem_isolation(config_dir, false)?;
|
||||
|
||||
// 4. Enter IPC namespace
|
||||
enter_ipc_namespace()?;
|
||||
|
||||
// 5. Enter network namespace
|
||||
enter_network_namespace()?;
|
||||
|
||||
// Config dir is now at /config
|
||||
Ok(SandboxResult::Sandboxed(std::path::PathBuf::from("/config")))
|
||||
}
|
||||
|
||||
/// Set up filesystem isolation with a minimal root.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Optionally creates a mount namespace (if `needs_mount_ns` is true)
|
||||
/// 2. Builds a minimal root on tmpfs
|
||||
/// 3. Mounts /proc (requires owning a PID namespace; skipped if not available)
|
||||
/// 4. Pivots to the new root
|
||||
///
|
||||
/// After this call, only /config (mapped to config_dir), /dev (minimal),
|
||||
/// /proc (if available), and /tmp are accessible.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config_dir` - Host path to bind-mount as /config
|
||||
/// * `needs_mount_ns` - If true, creates a new mount namespace before setup.
|
||||
/// Set to false when already in a mount namespace (e.g., after
|
||||
/// `fork_into_pid_namespace()` which creates one). Set to true when calling
|
||||
/// standalone after `enter_user_namespace()`.
|
||||
pub fn setup_filesystem_isolation(config_dir: &Path, needs_mount_ns: bool) -> Result<(), SandboxError> {
|
||||
if needs_mount_ns {
|
||||
enter_mount_namespace()?;
|
||||
}
|
||||
|
||||
// Create minimal root (also attempts to mount /proc)
|
||||
let new_root = setup_minimal_root(config_dir)?;
|
||||
|
||||
// Pivot to new root
|
||||
pivot_to_new_root(&new_root)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn generate_uid_map_maps_outside_to_zero() {
|
||||
let content = generate_uid_map(1000);
|
||||
// Format: inside_uid outside_uid count
|
||||
assert_eq!(content, "0 1000 1\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_uid_map_handles_root() {
|
||||
let content = generate_uid_map(0);
|
||||
assert_eq!(content, "0 0 1\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_gid_map_maps_outside_to_zero() {
|
||||
let content = generate_gid_map(1000);
|
||||
assert_eq!(content, "0 1000 1\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_gid_map_handles_root() {
|
||||
let content = generate_gid_map(0);
|
||||
assert_eq!(content, "0 0 1\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_error_displays_mount_error() {
|
||||
let err = SandboxError::Mount {
|
||||
operation: "mount".to_string(),
|
||||
target: "/proc".to_string(),
|
||||
source: nix::Error::EPERM,
|
||||
};
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("mount"));
|
||||
assert!(msg.contains("/proc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_error_displays_mkdir_error() {
|
||||
let err = SandboxError::Mkdir {
|
||||
path: "/newroot/proc".to_string(),
|
||||
source: std::io::Error::from_raw_os_error(libc::EACCES),
|
||||
};
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("/newroot/proc"));
|
||||
}
|
||||
}
|
||||
443
vm-switch/src/seccomp.rs
Normal file
443
vm-switch/src/seccomp.rs
Normal file
|
|
@ -0,0 +1,443 @@
|
|||
//! Seccomp-bpf filtering for syscall restriction.
|
||||
|
||||
use clap::ValueEnum;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Errors that can occur during seccomp filter setup.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum SeccompError {
|
||||
#[error("failed to compile filter: {0}")]
|
||||
Compile(String),
|
||||
|
||||
#[error("failed to apply filter: {0}")]
|
||||
Apply(#[source] std::io::Error),
|
||||
|
||||
#[error("seccomp is disabled")]
|
||||
Disabled,
|
||||
}
|
||||
|
||||
/// Syscalls allowed for child (worker) processes.
|
||||
///
|
||||
/// This is the base allowlist. The main process filter is built as a
|
||||
/// superset of this list (see `MAIN_EXTRA_SYSCALLS`), which is important
|
||||
/// because children inherit the main filter via fork() and seccomp
|
||||
/// filters stack with AND semantics — both must allow a syscall.
|
||||
///
|
||||
/// This is a tight whitelist because the child's own filter is applied AFTER:
|
||||
/// - Socket creation and binding (vhost-user socket ready)
|
||||
/// - Thread spawning (vhost daemon thread running)
|
||||
/// - Signal handler setup
|
||||
///
|
||||
/// Children only need syscalls for:
|
||||
/// - Accepting connections on existing socket
|
||||
/// - Reading/writing on existing FDs
|
||||
/// - Ring buffer operations (memfd, mmap)
|
||||
/// - Polling and synchronization
|
||||
pub static CHILD_SYSCALLS: &[i64] = &[
|
||||
// Basic I/O
|
||||
libc::SYS_read,
|
||||
libc::SYS_write,
|
||||
libc::SYS_close,
|
||||
libc::SYS_lseek,
|
||||
libc::SYS_pread64,
|
||||
libc::SYS_pwrite64,
|
||||
libc::SYS_readv,
|
||||
libc::SYS_writev,
|
||||
|
||||
// Memory management
|
||||
libc::SYS_mmap,
|
||||
libc::SYS_mprotect,
|
||||
libc::SYS_munmap,
|
||||
libc::SYS_brk,
|
||||
libc::SYS_mremap,
|
||||
libc::SYS_madvise,
|
||||
|
||||
// Ring buffer operations
|
||||
libc::SYS_memfd_create,
|
||||
libc::SYS_ftruncate,
|
||||
|
||||
// File operations (limited - no creation)
|
||||
libc::SYS_fstat,
|
||||
libc::SYS_newfstatat,
|
||||
libc::SYS_fcntl,
|
||||
libc::SYS_dup,
|
||||
libc::SYS_dup2,
|
||||
libc::SYS_dup3,
|
||||
libc::SYS_unlink, // glibc may use this instead of unlinkat
|
||||
libc::SYS_unlinkat,
|
||||
|
||||
// Process/thread control
|
||||
libc::SYS_clone3, // glibc pthread_create uses clone3 (vhost-user spawns threads lazily)
|
||||
libc::SYS_exit,
|
||||
libc::SYS_exit_group,
|
||||
libc::SYS_getpid,
|
||||
libc::SYS_gettid,
|
||||
libc::SYS_getuid,
|
||||
libc::SYS_getgid,
|
||||
libc::SYS_geteuid,
|
||||
libc::SYS_getegid,
|
||||
libc::SYS_sched_yield,
|
||||
libc::SYS_sched_getaffinity,
|
||||
libc::SYS_set_robust_list,
|
||||
libc::SYS_rseq,
|
||||
|
||||
// Signal handling (handlers already installed)
|
||||
libc::SYS_rt_sigaction,
|
||||
libc::SYS_rt_sigprocmask,
|
||||
libc::SYS_rt_sigreturn,
|
||||
libc::SYS_sigaltstack,
|
||||
|
||||
// Polling
|
||||
libc::SYS_epoll_create1,
|
||||
libc::SYS_epoll_ctl,
|
||||
libc::SYS_epoll_wait,
|
||||
libc::SYS_epoll_pwait,
|
||||
libc::SYS_poll,
|
||||
libc::SYS_ppoll,
|
||||
|
||||
// Socket operations (on existing sockets only)
|
||||
// NO: socket, bind, listen, connect, setsockopt, getsockopt, socketpair
|
||||
libc::SYS_accept,
|
||||
libc::SYS_accept4,
|
||||
libc::SYS_sendto,
|
||||
libc::SYS_recvfrom,
|
||||
libc::SYS_sendmsg,
|
||||
libc::SYS_recvmsg,
|
||||
libc::SYS_shutdown,
|
||||
libc::SYS_getsockname,
|
||||
libc::SYS_getpeername,
|
||||
|
||||
// Time
|
||||
libc::SYS_clock_gettime,
|
||||
libc::SYS_clock_getres,
|
||||
libc::SYS_nanosleep,
|
||||
libc::SYS_gettimeofday,
|
||||
|
||||
// Thread synchronization (for existing threads)
|
||||
libc::SYS_futex,
|
||||
|
||||
// Misc
|
||||
libc::SYS_getrandom,
|
||||
libc::SYS_prctl,
|
||||
libc::SYS_arch_prctl,
|
||||
libc::SYS_ioctl,
|
||||
libc::SYS_pipe2,
|
||||
libc::SYS_eventfd2,
|
||||
];
|
||||
|
||||
/// Additional syscalls needed only by the main process.
|
||||
///
|
||||
/// The main filter is built from CHILD_SYSCALLS + these extras.
|
||||
/// This ensures children always inherit a superset of what they need.
|
||||
pub static MAIN_EXTRA_SYSCALLS: &[i64] = &[
|
||||
// File operations (main does config watching, directory traversal)
|
||||
libc::SYS_openat,
|
||||
libc::SYS_getdents64,
|
||||
libc::SYS_mkdirat,
|
||||
libc::SYS_readlinkat,
|
||||
libc::SYS_faccessat,
|
||||
libc::SYS_faccessat2,
|
||||
libc::SYS_statx,
|
||||
|
||||
// Process control (main forks children, manages lifecycle)
|
||||
libc::SYS_fork,
|
||||
libc::SYS_clone,
|
||||
libc::SYS_clone3,
|
||||
libc::SYS_wait4,
|
||||
libc::SYS_kill,
|
||||
libc::SYS_getppid,
|
||||
|
||||
// Polling (extra variants)
|
||||
libc::SYS_select,
|
||||
libc::SYS_pselect6,
|
||||
|
||||
// Socket operations (main creates sockets, children only accept)
|
||||
libc::SYS_socket,
|
||||
libc::SYS_bind,
|
||||
libc::SYS_listen,
|
||||
libc::SYS_connect,
|
||||
libc::SYS_getsockopt,
|
||||
libc::SYS_setsockopt,
|
||||
libc::SYS_socketpair,
|
||||
|
||||
// inotify (for config file watching)
|
||||
libc::SYS_inotify_init1,
|
||||
libc::SYS_inotify_add_watch,
|
||||
libc::SYS_inotify_rm_watch,
|
||||
|
||||
// Time (extra variants)
|
||||
libc::SYS_clock_nanosleep,
|
||||
|
||||
// Misc (main-only)
|
||||
libc::SYS_uname,
|
||||
libc::SYS_timerfd_create,
|
||||
libc::SYS_timerfd_settime,
|
||||
libc::SYS_timerfd_gettime,
|
||||
|
||||
// Seccomp (for applying child filters after fork)
|
||||
libc::SYS_seccomp,
|
||||
];
|
||||
|
||||
/// Seccomp filter mode.
|
||||
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, ValueEnum)]
|
||||
pub enum SeccompMode {
|
||||
/// Kill process on blocked syscall (production default).
|
||||
#[default]
|
||||
Kill,
|
||||
/// Send SIGSYS on blocked syscall (for debugging).
|
||||
Trap,
|
||||
/// Log blocked syscalls but allow them.
|
||||
Log,
|
||||
/// Disable seccomp filtering entirely.
|
||||
Disabled,
|
||||
}
|
||||
|
||||
use seccompiler::{BpfMap, SeccompAction, SeccompFilter, TargetArch};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
/// Build a BPF filter for the given syscall whitelist.
|
||||
///
|
||||
/// Returns a BpfMap ready to be applied via `seccompiler::apply_filter_all_threads`.
|
||||
pub fn build_filter(syscalls: &[i64], mode: SeccompMode) -> Result<BpfMap, SeccompError> {
|
||||
if mode == SeccompMode::Disabled {
|
||||
return Err(SeccompError::Disabled);
|
||||
}
|
||||
|
||||
let default_action = match mode {
|
||||
SeccompMode::Kill => SeccompAction::KillProcess,
|
||||
SeccompMode::Trap => SeccompAction::Trap,
|
||||
SeccompMode::Log => SeccompAction::Log,
|
||||
SeccompMode::Disabled => unreachable!(),
|
||||
};
|
||||
|
||||
// Build allow rules for each syscall.
|
||||
// An empty rule vector means "match unconditionally on syscall number".
|
||||
let mut rules: BTreeMap<i64, Vec<_>> = BTreeMap::new();
|
||||
for &syscall in syscalls {
|
||||
rules.insert(syscall, vec![]);
|
||||
}
|
||||
|
||||
// Create filter: allow whitelisted syscalls, block others.
|
||||
// SeccompFilter::new(rules, mismatch_action, match_action, arch)
|
||||
let filter = SeccompFilter::new(
|
||||
rules,
|
||||
default_action, // mismatch_action: block non-whitelisted syscalls
|
||||
SeccompAction::Allow, // match_action: allow whitelisted syscalls
|
||||
TargetArch::x86_64,
|
||||
)
|
||||
.map_err(|e| SeccompError::Compile(e.to_string()))?;
|
||||
|
||||
// Compile to BPF
|
||||
let bpf_prog = filter
|
||||
.try_into()
|
||||
.map_err(|e: seccompiler::BackendError| SeccompError::Compile(e.to_string()))?;
|
||||
|
||||
let mut map = BpfMap::new();
|
||||
map.insert("main".to_string(), bpf_prog);
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
/// Apply a compiled BPF filter to all threads in the current process.
|
||||
///
|
||||
/// Uses `prctl(PR_SET_NO_NEW_PRIVS)` and `seccomp(SECCOMP_SET_MODE_FILTER)`.
|
||||
/// Once applied, the filter cannot be removed or made less restrictive.
|
||||
pub fn apply_filter(bpf_map: &BpfMap) -> Result<(), SeccompError> {
|
||||
let bpf_prog = bpf_map
|
||||
.get("main")
|
||||
.ok_or_else(|| SeccompError::Compile("no 'main' filter in map".to_string()))?;
|
||||
|
||||
seccompiler::apply_filter_all_threads(bpf_prog)
|
||||
.map_err(|e| SeccompError::Apply(std::io::Error::other(e.to_string())))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Collect the full main syscall list (CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS).
|
||||
fn main_syscalls() -> Vec<i64> {
|
||||
let mut syscalls = Vec::with_capacity(CHILD_SYSCALLS.len() + MAIN_EXTRA_SYSCALLS.len());
|
||||
syscalls.extend_from_slice(CHILD_SYSCALLS);
|
||||
syscalls.extend_from_slice(MAIN_EXTRA_SYSCALLS);
|
||||
syscalls
|
||||
}
|
||||
|
||||
/// Apply seccomp filter for the main process.
|
||||
///
|
||||
/// Call this after namespace setup, before starting the tokio runtime.
|
||||
/// The filter is built from CHILD_SYSCALLS + MAIN_EXTRA_SYSCALLS, ensuring
|
||||
/// children always inherit a permissive-enough base filter.
|
||||
pub fn apply_main_seccomp(mode: SeccompMode) -> Result<(), SeccompError> {
|
||||
if mode == SeccompMode::Disabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let syscalls = main_syscalls();
|
||||
let filter = build_filter(&syscalls, mode)?;
|
||||
apply_filter(&filter)
|
||||
}
|
||||
|
||||
/// Apply seccomp filter for a child (worker) process.
|
||||
///
|
||||
/// Call this after socket creation and thread spawning, just before
|
||||
/// entering the event loop. This allows the tightest possible filter.
|
||||
pub fn apply_child_seccomp(mode: SeccompMode) -> Result<(), SeccompError> {
|
||||
if mode == SeccompMode::Disabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter = build_filter(CHILD_SYSCALLS, mode)?;
|
||||
apply_filter(&filter)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn seccomp_error_displays_compile_error() {
|
||||
let err = SeccompError::Compile("invalid rule".to_string());
|
||||
assert!(err.to_string().contains("compile"));
|
||||
assert!(err.to_string().contains("invalid rule"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_error_displays_apply_error() {
|
||||
let err = SeccompError::Apply(std::io::Error::from_raw_os_error(libc::EPERM));
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("apply"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_error_displays_disabled() {
|
||||
let err = SeccompError::Disabled;
|
||||
assert!(err.to_string().contains("disabled"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_syscalls_is_not_empty() {
|
||||
assert!(!main_syscalls().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_syscalls_contains_essential_syscalls() {
|
||||
let syscalls = main_syscalls();
|
||||
assert!(syscalls.contains(&libc::SYS_read));
|
||||
assert!(syscalls.contains(&libc::SYS_write));
|
||||
assert!(syscalls.contains(&libc::SYS_close));
|
||||
assert!(syscalls.contains(&libc::SYS_exit_group));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_syscalls_allows_fork() {
|
||||
// Main process needs fork to spawn children
|
||||
assert!(main_syscalls().contains(&libc::SYS_fork));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn main_syscalls_is_superset_of_child() {
|
||||
let main = main_syscalls();
|
||||
for &syscall in CHILD_SYSCALLS {
|
||||
assert!(
|
||||
main.contains(&syscall),
|
||||
"CHILD_SYSCALLS contains {} which is missing from main filter",
|
||||
syscall
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_is_not_empty() {
|
||||
assert!(!CHILD_SYSCALLS.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_does_not_allow_fork() {
|
||||
// Children don't spawn processes
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_fork));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_allows_clone3_for_vhost_threads() {
|
||||
// vhost-user library spawns threads lazily on client connect
|
||||
// clone3 is used by glibc pthread_create; clone is NOT allowed
|
||||
// (fork uses clone, so blocking it prevents fork in children)
|
||||
assert!(CHILD_SYSCALLS.contains(&libc::SYS_clone3));
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_clone));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_does_not_allow_socket_creation() {
|
||||
// Socket created before seccomp
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_socket));
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_bind));
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_listen));
|
||||
assert!(!CHILD_SYSCALLS.contains(&libc::SYS_setsockopt));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_allows_accept() {
|
||||
// Need to accept vhost-user connections
|
||||
assert!(CHILD_SYSCALLS.contains(&libc::SYS_accept4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_syscalls_allows_ring_buffer_ops() {
|
||||
// Children need memfd_create for ring buffers
|
||||
assert!(CHILD_SYSCALLS.contains(&libc::SYS_memfd_create));
|
||||
assert!(CHILD_SYSCALLS.contains(&libc::SYS_ftruncate));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_filter_creates_valid_bpf() {
|
||||
let syscalls = main_syscalls();
|
||||
let filter = build_filter(&syscalls, SeccompMode::Kill)
|
||||
.expect("filter should compile");
|
||||
assert_eq!(filter.len(), 1);
|
||||
assert!(filter.contains_key("main"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_filter_handles_all_modes() {
|
||||
let syscalls = main_syscalls();
|
||||
for mode in [SeccompMode::Kill, SeccompMode::Trap, SeccompMode::Log] {
|
||||
let result = build_filter(&syscalls, mode);
|
||||
assert!(result.is_ok(), "mode {:?} should compile", mode);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_filter_disabled_returns_error() {
|
||||
let syscalls = main_syscalls();
|
||||
let result = build_filter(&syscalls, SeccompMode::Disabled);
|
||||
assert!(matches!(result, Err(SeccompError::Disabled)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_filter_signature_check() {
|
||||
fn check_signature(_f: fn(&BpfMap) -> Result<(), SeccompError>) {}
|
||||
check_signature(apply_filter);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_main_seccomp_signature_check() {
|
||||
fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {}
|
||||
check(apply_main_seccomp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_child_seccomp_signature_check() {
|
||||
fn check(_f: fn(SeccompMode) -> Result<(), SeccompError>) {}
|
||||
check(apply_child_seccomp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_main_seccomp_disabled_succeeds() {
|
||||
// Disabled mode should be a no-op, not an error
|
||||
assert!(apply_main_seccomp(SeccompMode::Disabled).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_child_seccomp_disabled_succeeds() {
|
||||
assert!(apply_child_seccomp(SeccompMode::Disabled).is_ok());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,445 +0,0 @@
|
|||
//! L2 switch logic with MAC filtering.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::VmRole;
|
||||
use crate::mac::Mac;
|
||||
|
||||
/// Unique identifier for a connected VM.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub struct ConnectionId(u64);
|
||||
|
||||
impl ConnectionId {
|
||||
/// Create a new connection ID.
|
||||
pub fn new(id: u64) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Decision for how to forward a frame.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ForwardDecision {
|
||||
/// Forward to a single destination.
|
||||
Unicast(ConnectionId),
|
||||
/// Forward to multiple destinations (broadcast/multicast).
|
||||
Multicast(Vec<ConnectionId>),
|
||||
/// Drop the frame.
|
||||
Drop(DropReason),
|
||||
}
|
||||
|
||||
/// Reason why a frame was dropped.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DropReason {
|
||||
/// Source MAC doesn't match the sender's configured MAC.
|
||||
SourceMacMismatch { expected: Mac, actual: Mac },
|
||||
/// Client tried to send to a MAC other than router/broadcast/multicast.
|
||||
ClientViolation { destination: Mac },
|
||||
/// Router tried to send to an unknown MAC.
|
||||
UnknownDestination { destination: Mac },
|
||||
/// No router is connected.
|
||||
NoRouter,
|
||||
/// Sender connection ID is not registered.
|
||||
UnknownSender,
|
||||
}
|
||||
|
||||
/// Information about a connected VM.
|
||||
#[derive(Debug, Clone)]
|
||||
struct Connection {
|
||||
/// VM name for logging.
|
||||
name: String,
|
||||
/// Role (router or client).
|
||||
role: VmRole,
|
||||
/// Configured MAC address.
|
||||
mac: Mac,
|
||||
}
|
||||
|
||||
/// L2 switch with MAC filtering.
|
||||
///
|
||||
/// Maintains a registry of connected VMs and applies filtering rules
|
||||
/// to determine how frames should be forwarded.
|
||||
pub struct Switch {
|
||||
/// Connected VMs by connection ID.
|
||||
connections: HashMap<ConnectionId, Connection>,
|
||||
/// MAC address to connection ID mapping for fast lookup.
|
||||
mac_to_conn: HashMap<Mac, ConnectionId>,
|
||||
/// The router's connection ID (if connected).
|
||||
router: Option<ConnectionId>,
|
||||
/// Next connection ID to assign.
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl Switch {
|
||||
/// Create a new empty switch.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
connections: HashMap::new(),
|
||||
mac_to_conn: HashMap::new(),
|
||||
router: None,
|
||||
next_id: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new VM connection.
|
||||
///
|
||||
/// Returns the assigned connection ID, or None if a router is already
|
||||
/// connected and this is another router.
|
||||
pub fn register(&mut self, name: String, role: VmRole, mac: Mac) -> Option<ConnectionId> {
|
||||
// Reject second router
|
||||
if role == VmRole::Router && self.router.is_some() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Reject duplicate MAC address
|
||||
if self.mac_to_conn.contains_key(&mac) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let id = ConnectionId::new(self.next_id);
|
||||
self.next_id += 1;
|
||||
|
||||
self.connections.insert(id, Connection { name, role, mac });
|
||||
self.mac_to_conn.insert(mac, id);
|
||||
|
||||
if role == VmRole::Router {
|
||||
self.router = Some(id);
|
||||
}
|
||||
|
||||
Some(id)
|
||||
}
|
||||
|
||||
/// Unregister a VM connection.
|
||||
pub fn unregister(&mut self, id: ConnectionId) {
|
||||
if let Some(conn) = self.connections.remove(&id) {
|
||||
debug!(name = %conn.name, role = ?conn.role, mac = %conn.mac, "unregistered connection");
|
||||
self.mac_to_conn.remove(&conn.mac);
|
||||
|
||||
if conn.role == VmRole::Router {
|
||||
self.router = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine how to forward a frame.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `sender` - Connection ID of the sender
|
||||
/// * `source_mac` - Source MAC from the frame
|
||||
/// * `dest_mac` - Destination MAC from the frame
|
||||
pub fn forward(&self, sender: ConnectionId, source_mac: Mac, dest_mac: Mac) -> ForwardDecision {
|
||||
// Get sender info
|
||||
let sender_conn = match self.connections.get(&sender) {
|
||||
Some(c) => c,
|
||||
None => return ForwardDecision::Drop(DropReason::UnknownSender),
|
||||
};
|
||||
|
||||
// Validate source MAC
|
||||
if source_mac != sender_conn.mac {
|
||||
return ForwardDecision::Drop(DropReason::SourceMacMismatch {
|
||||
expected: sender_conn.mac,
|
||||
actual: source_mac,
|
||||
});
|
||||
}
|
||||
|
||||
match sender_conn.role {
|
||||
VmRole::Client => self.forward_from_client(dest_mac),
|
||||
VmRole::Router => self.forward_from_router(sender, dest_mac),
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_from_client(&self, dest_mac: Mac) -> ForwardDecision {
|
||||
// Get router
|
||||
let router_id = match self.router {
|
||||
Some(id) => id,
|
||||
None => return ForwardDecision::Drop(DropReason::NoRouter),
|
||||
};
|
||||
|
||||
let router_conn = self.connections.get(&router_id).unwrap();
|
||||
|
||||
// Clients can only send to router, broadcast, or multicast
|
||||
if dest_mac == router_conn.mac || dest_mac.is_broadcast() || dest_mac.is_multicast() {
|
||||
ForwardDecision::Unicast(router_id)
|
||||
} else {
|
||||
ForwardDecision::Drop(DropReason::ClientViolation { destination: dest_mac })
|
||||
}
|
||||
}
|
||||
|
||||
fn forward_from_router(&self, sender: ConnectionId, dest_mac: Mac) -> ForwardDecision {
|
||||
// Broadcast or multicast goes to all clients
|
||||
if dest_mac.is_broadcast() || dest_mac.is_multicast() {
|
||||
let client_ids: Vec<ConnectionId> = self.connections
|
||||
.iter()
|
||||
.filter(|(id, conn)| **id != sender && conn.role == VmRole::Client)
|
||||
.map(|(id, _)| *id)
|
||||
.collect();
|
||||
|
||||
return ForwardDecision::Multicast(client_ids);
|
||||
}
|
||||
|
||||
// Unicast to specific client
|
||||
match self.mac_to_conn.get(&dest_mac) {
|
||||
Some(id) if *id != sender => {
|
||||
// Verify it's a client
|
||||
if let Some(conn) = self.connections.get(id) {
|
||||
if conn.role == VmRole::Client {
|
||||
return ForwardDecision::Unicast(*id);
|
||||
}
|
||||
}
|
||||
ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac })
|
||||
}
|
||||
_ => ForwardDecision::Drop(DropReason::UnknownDestination { destination: dest_mac }),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Switch {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Helper to create test MACs
|
||||
fn mac(s: &str) -> Mac {
|
||||
Mac::parse(s).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_client() {
|
||||
let mut switch = Switch::new();
|
||||
let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
|
||||
|
||||
assert!(id.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_router() {
|
||||
let mut switch = Switch::new();
|
||||
let id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66"));
|
||||
|
||||
assert!(id.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_second_router_fails() {
|
||||
let mut switch = Switch::new();
|
||||
let id1 = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66"));
|
||||
let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff"));
|
||||
|
||||
assert!(id1.is_some());
|
||||
assert!(id2.is_none(), "Second router should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_duplicate_mac_fails() {
|
||||
let mut switch = Switch::new();
|
||||
let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
|
||||
let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
|
||||
|
||||
assert!(id1.is_some());
|
||||
assert!(id2.is_none(), "Duplicate MAC should be rejected");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_multiple_clients() {
|
||||
let mut switch = Switch::new();
|
||||
let id1 = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01"));
|
||||
let id2 = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02"));
|
||||
|
||||
assert!(id1.is_some());
|
||||
assert!(id2.is_some());
|
||||
assert_ne!(id1, id2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_client() {
|
||||
let mut switch = Switch::new();
|
||||
let id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
switch.unregister(id);
|
||||
|
||||
// Should be able to register another client with same MAC
|
||||
let id2 = switch.register("banking2".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff"));
|
||||
assert!(id2.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_router_allows_new_router() {
|
||||
let mut switch = Switch::new();
|
||||
let id = switch.register("gateway1".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
|
||||
switch.unregister(id);
|
||||
|
||||
// Should now be able to register a new router
|
||||
let id2 = switch.register("gateway2".into(), VmRole::Router, mac("aa:bb:cc:dd:ee:ff"));
|
||||
assert!(id2.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_rejects_source_mac_mismatch() {
|
||||
let mut switch = Switch::new();
|
||||
let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
// Client sends with wrong source MAC
|
||||
let result = switch.forward(client_id, mac("00:00:00:00:00:01"), mac("11:22:33:44:55:66"));
|
||||
|
||||
match result {
|
||||
ForwardDecision::Drop(DropReason::SourceMacMismatch { expected, actual }) => {
|
||||
assert_eq!(expected, mac("aa:bb:cc:dd:ee:ff"));
|
||||
assert_eq!(actual, mac("00:00:00:00:00:01"));
|
||||
}
|
||||
_ => panic!("Expected SourceMacMismatch, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_client_to_router_unicast() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("11:22:33:44:55:66"));
|
||||
|
||||
assert_eq!(result, ForwardDecision::Unicast(router_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_client_to_router_broadcast() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff"));
|
||||
|
||||
// Broadcast from client goes to router only
|
||||
assert_eq!(result, ForwardDecision::Unicast(router_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_client_to_router_multicast() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
// IPv4 multicast MAC
|
||||
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("01:00:5e:00:00:01"));
|
||||
|
||||
// Multicast from client goes to router only
|
||||
assert_eq!(result, ForwardDecision::Unicast(router_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_client_violation_to_other_client() {
|
||||
let mut switch = Switch::new();
|
||||
let _router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
|
||||
let _client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
|
||||
|
||||
// Client tries to send directly to another client
|
||||
let result = switch.forward(client1_id, mac("aa:bb:cc:dd:ee:01"), mac("aa:bb:cc:dd:ee:02"));
|
||||
|
||||
match result {
|
||||
ForwardDecision::Drop(DropReason::ClientViolation { destination }) => {
|
||||
assert_eq!(destination, mac("aa:bb:cc:dd:ee:02"));
|
||||
}
|
||||
_ => panic!("Expected ClientViolation, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_client_no_router() {
|
||||
let mut switch = Switch::new();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
let result = switch.forward(client_id, mac("aa:bb:cc:dd:ee:ff"), mac("ff:ff:ff:ff:ff:ff"));
|
||||
|
||||
assert_eq!(result, ForwardDecision::Drop(DropReason::NoRouter));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_router_to_client_unicast() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("aa:bb:cc:dd:ee:ff"));
|
||||
|
||||
assert_eq!(result, ForwardDecision::Unicast(client_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_router_broadcast_to_all_clients() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
|
||||
let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
|
||||
|
||||
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff"));
|
||||
|
||||
match result {
|
||||
ForwardDecision::Multicast(ids) => {
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert!(ids.contains(&client1_id));
|
||||
assert!(ids.contains(&client2_id));
|
||||
}
|
||||
_ => panic!("Expected Multicast, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_router_multicast_to_all_clients() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let client1_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:01")).unwrap();
|
||||
let client2_id = switch.register("shopping".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:02")).unwrap();
|
||||
|
||||
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("01:00:5e:00:00:01"));
|
||||
|
||||
match result {
|
||||
ForwardDecision::Multicast(ids) => {
|
||||
assert_eq!(ids.len(), 2);
|
||||
assert!(ids.contains(&client1_id));
|
||||
assert!(ids.contains(&client2_id));
|
||||
}
|
||||
_ => panic!("Expected Multicast, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_router_unknown_destination() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
let _client_id = switch.register("banking".into(), VmRole::Client, mac("aa:bb:cc:dd:ee:ff")).unwrap();
|
||||
|
||||
// Router sends to unknown MAC
|
||||
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("00:00:00:00:00:01"));
|
||||
|
||||
match result {
|
||||
ForwardDecision::Drop(DropReason::UnknownDestination { destination }) => {
|
||||
assert_eq!(destination, mac("00:00:00:00:00:01"));
|
||||
}
|
||||
_ => panic!("Expected UnknownDestination, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forward_router_broadcast_no_clients() {
|
||||
let mut switch = Switch::new();
|
||||
let router_id = switch.register("gateway".into(), VmRole::Router, mac("11:22:33:44:55:66")).unwrap();
|
||||
|
||||
let result = switch.forward(router_id, mac("11:22:33:44:55:66"), mac("ff:ff:ff:ff:ff:ff"));
|
||||
|
||||
// Empty multicast is valid (no clients to send to)
|
||||
match result {
|
||||
ForwardDecision::Multicast(ids) => {
|
||||
assert!(ids.is_empty());
|
||||
}
|
||||
_ => panic!("Expected empty Multicast, got {:?}", result),
|
||||
}
|
||||
}
|
||||
}
|
||||
323
vm-switch/tests/buffer_exchange.rs
Normal file
323
vm-switch/tests/buffer_exchange.rs
Normal file
|
|
@ -0,0 +1,323 @@
|
|||
//! Integration tests for ring buffer exchange between processes.
|
||||
|
||||
use nix::sys::wait::{waitpid, WaitStatus};
|
||||
use nix::unistd::{fork, ForkResult};
|
||||
use serial_test::serial;
|
||||
use std::os::fd::AsRawFd;
|
||||
use vm_switch::control::{ChildToMain, ControlChannel, MainToChild};
|
||||
use vm_switch::mac::Mac;
|
||||
use vm_switch::ring::{Consumer, Producer};
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn child_sends_ready_on_startup() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
|
||||
let temp_dir = tempfile::tempdir().expect("tempdir");
|
||||
let socket_path = temp_dir.path().join("test.sock");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
drop(child_end);
|
||||
|
||||
// Wait for Ready
|
||||
let (msg, fds): (ChildToMain, _) = main_end
|
||||
.recv_with_fds_typed()
|
||||
.expect("should receive");
|
||||
|
||||
match msg {
|
||||
ChildToMain::Ready => {}
|
||||
_ => panic!("expected Ready, got {:?}", msg),
|
||||
}
|
||||
|
||||
assert!(fds.is_empty(), "Ready should have no FDs");
|
||||
|
||||
// Close main end to trigger child exit
|
||||
drop(main_end);
|
||||
|
||||
// Wait for child
|
||||
let status = waitpid(child, None).expect("waitpid failed");
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "child should exit cleanly");
|
||||
}
|
||||
other => panic!("unexpected status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control_fd = child_end.into_fd();
|
||||
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn child_creates_ingress_buffer_on_request() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff]);
|
||||
|
||||
let temp_dir = tempfile::tempdir().expect("tempdir");
|
||||
let socket_path = temp_dir.path().join("test.sock");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
drop(child_end);
|
||||
|
||||
// Wait for Ready
|
||||
let (msg, _): (ChildToMain, _) = main_end
|
||||
.recv_with_fds_typed()
|
||||
.expect("should receive");
|
||||
assert!(matches!(msg, ChildToMain::Ready));
|
||||
|
||||
// Request buffer for a peer
|
||||
let peer_name = "router".to_string();
|
||||
let peer_mac = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66];
|
||||
let msg = MainToChild::GetBuffer {
|
||||
peer_name: peer_name.clone(),
|
||||
peer_mac,
|
||||
};
|
||||
main_end.send(&msg).expect("send GetBuffer");
|
||||
|
||||
// Wait for BufferReady
|
||||
let (msg, fds): (ChildToMain, _) = main_end
|
||||
.recv_with_fds_typed()
|
||||
.expect("should receive");
|
||||
|
||||
match msg {
|
||||
ChildToMain::BufferReady { peer_name: name } => {
|
||||
assert_eq!(name, peer_name);
|
||||
}
|
||||
_ => panic!("expected BufferReady"),
|
||||
}
|
||||
|
||||
assert_eq!(fds.len(), 2, "should receive memfd and eventfd");
|
||||
|
||||
// Create producer from received FDs and verify we can write
|
||||
let mut fds = fds.into_iter();
|
||||
let memfd = fds.next().unwrap();
|
||||
let eventfd = fds.next().unwrap();
|
||||
|
||||
let producer = Producer::from_fds(memfd, eventfd)
|
||||
.expect("should create producer");
|
||||
|
||||
// Push a frame
|
||||
assert!(producer.push(&[1, 2, 3, 4, 5]));
|
||||
|
||||
// Close main end to trigger child exit
|
||||
drop(main_end);
|
||||
|
||||
// Wait for child
|
||||
let status = waitpid(child, None).expect("waitpid failed");
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "child should exit cleanly");
|
||||
}
|
||||
other => panic!("unexpected status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control_fd = child_end.into_fd();
|
||||
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn new_protocol_buffer_exchange() {
|
||||
// Test the new protocol where:
|
||||
// 1. Child sends Ready
|
||||
// 2. Main sends GetBuffer
|
||||
// 3. Child creates consumer and sends BufferReady with FDs
|
||||
// 4. Main forwards FDs to peer as PutBuffer
|
||||
// 5. Peer creates producer and can write
|
||||
|
||||
// Simulate two "children" with control channels
|
||||
let (main_a, child_a) = ControlChannel::pair().expect("pair A");
|
||||
let (main_b, child_b) = ControlChannel::pair().expect("pair B");
|
||||
|
||||
let name_a = "client_a".to_string();
|
||||
let name_b = "router".to_string();
|
||||
let mac_a = [0xaa, 0, 0, 0, 0, 1];
|
||||
let mac_b = [0xbb, 0, 0, 0, 0, 2];
|
||||
|
||||
// Step 1: Both children send Ready
|
||||
child_a.send(&ChildToMain::Ready).expect("A ready");
|
||||
child_b.send(&ChildToMain::Ready).expect("B ready");
|
||||
|
||||
// Main receives Ready from both
|
||||
let (msg_a, _): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready");
|
||||
let (msg_b, _): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready");
|
||||
assert!(matches!(msg_a, ChildToMain::Ready));
|
||||
assert!(matches!(msg_b, ChildToMain::Ready));
|
||||
|
||||
// Step 2: Main requests buffer from A for B (A will be consumer of data from B)
|
||||
let get_buffer_msg = MainToChild::GetBuffer {
|
||||
peer_name: name_b.clone(),
|
||||
peer_mac: mac_b,
|
||||
};
|
||||
main_a.send(&get_buffer_msg).expect("send GetBuffer to A");
|
||||
|
||||
// Step 3: A receives GetBuffer
|
||||
let (msg, _): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer");
|
||||
match msg {
|
||||
MainToChild::GetBuffer { peer_name, peer_mac } => {
|
||||
assert_eq!(peer_name, name_b);
|
||||
assert_eq!(peer_mac, mac_b);
|
||||
}
|
||||
_ => panic!("expected GetBuffer"),
|
||||
}
|
||||
|
||||
// A creates consumer (ingress buffer) and sends BufferReady
|
||||
let consumer_a = Consumer::new().expect("consumer A");
|
||||
let buffer_ready = ChildToMain::BufferReady { peer_name: name_b.clone() };
|
||||
child_a.send_with_fds_typed(&buffer_ready, &[
|
||||
consumer_a.memfd().as_raw_fd(),
|
||||
consumer_a.eventfd().as_raw_fd(),
|
||||
]).expect("A send BufferReady");
|
||||
|
||||
// Main receives BufferReady with FDs
|
||||
let (msg, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv BufferReady");
|
||||
match msg {
|
||||
ChildToMain::BufferReady { peer_name } => {
|
||||
assert_eq!(peer_name, name_b);
|
||||
}
|
||||
_ => panic!("expected BufferReady"),
|
||||
}
|
||||
assert_eq!(fds_from_a.len(), 2);
|
||||
|
||||
// Step 4: Main sends PutBuffer to B with A's ingress buffer
|
||||
// B will use this as egress to A
|
||||
// broadcast=true because A is a client and B is router
|
||||
let put_buffer_msg = MainToChild::PutBuffer {
|
||||
peer_name: name_a.clone(),
|
||||
peer_mac: mac_a,
|
||||
broadcast: false, // A is client, not router
|
||||
};
|
||||
main_b.send_with_fds_typed(&put_buffer_msg, &[
|
||||
fds_from_a[0].as_raw_fd(),
|
||||
fds_from_a[1].as_raw_fd(),
|
||||
]).expect("send PutBuffer to B");
|
||||
|
||||
// Step 5: B receives PutBuffer with FDs
|
||||
let (msg, fds_for_b): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer");
|
||||
match msg {
|
||||
MainToChild::PutBuffer { peer_name, peer_mac, broadcast } => {
|
||||
assert_eq!(peer_name, name_a);
|
||||
assert_eq!(peer_mac, mac_a);
|
||||
assert!(!broadcast);
|
||||
}
|
||||
_ => panic!("expected PutBuffer"),
|
||||
}
|
||||
assert_eq!(fds_for_b.len(), 2);
|
||||
|
||||
// B creates producer (egress buffer) from received FDs
|
||||
let mut fds_b = fds_for_b.into_iter();
|
||||
let producer_b = Producer::from_fds(fds_b.next().unwrap(), fds_b.next().unwrap())
|
||||
.expect("producer B");
|
||||
|
||||
// Step 6: B writes to producer, A can read from consumer
|
||||
producer_b.push(&[10, 20, 30, 40, 50]);
|
||||
|
||||
let data = consumer_a.pop().expect("should pop");
|
||||
assert_eq!(data, vec![10, 20, 30, 40, 50]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn bidirectional_new_protocol_exchange() {
|
||||
// Two VMs exchange buffers and can communicate both ways using new protocol
|
||||
|
||||
let (main_a, child_a) = ControlChannel::pair().expect("pair A");
|
||||
let (main_b, child_b) = ControlChannel::pair().expect("pair B");
|
||||
|
||||
let name_a = "client_a".to_string();
|
||||
let name_b = "router".to_string();
|
||||
let mac_a = [0xaa, 0, 0, 0, 0, 1];
|
||||
let mac_b = [0xbb, 0, 0, 0, 0, 2];
|
||||
|
||||
// Both send Ready
|
||||
child_a.send(&ChildToMain::Ready).expect("A ready");
|
||||
child_b.send(&ChildToMain::Ready).expect("B ready");
|
||||
|
||||
let _: (ChildToMain, _) = main_a.recv_with_fds_typed().expect("recv A ready");
|
||||
let _: (ChildToMain, _) = main_b.recv_with_fds_typed().expect("recv B ready");
|
||||
|
||||
// Request ingress buffers from both sides
|
||||
// A creates ingress for B
|
||||
main_a.send(&MainToChild::GetBuffer {
|
||||
peer_name: name_b.clone(),
|
||||
peer_mac: mac_b,
|
||||
}).expect("GetBuffer A->B");
|
||||
|
||||
// B creates ingress for A
|
||||
main_b.send(&MainToChild::GetBuffer {
|
||||
peer_name: name_a.clone(),
|
||||
peer_mac: mac_a,
|
||||
}).expect("GetBuffer B->A");
|
||||
|
||||
// A receives GetBuffer, creates consumer, sends BufferReady
|
||||
let _: (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv GetBuffer");
|
||||
let consumer_a = Consumer::new().expect("consumer A");
|
||||
child_a.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_b.clone() }, &[
|
||||
consumer_a.memfd().as_raw_fd(),
|
||||
consumer_a.eventfd().as_raw_fd(),
|
||||
]).expect("A BufferReady");
|
||||
|
||||
// B receives GetBuffer, creates consumer, sends BufferReady
|
||||
let _: (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv GetBuffer");
|
||||
let consumer_b = Consumer::new().expect("consumer B");
|
||||
child_b.send_with_fds_typed(&ChildToMain::BufferReady { peer_name: name_a.clone() }, &[
|
||||
consumer_b.memfd().as_raw_fd(),
|
||||
consumer_b.eventfd().as_raw_fd(),
|
||||
]).expect("B BufferReady");
|
||||
|
||||
// Main receives BufferReady from both
|
||||
let (_, fds_from_a): (ChildToMain, _) = main_a.recv_with_fds_typed().expect("main recv A BufferReady");
|
||||
let (_, fds_from_b): (ChildToMain, _) = main_b.recv_with_fds_typed().expect("main recv B BufferReady");
|
||||
|
||||
// Cross-send: A's ingress becomes B's egress to A, and vice versa
|
||||
// Send A's buffer to B (B is router, so broadcast=true when sending TO router)
|
||||
main_b.send_with_fds_typed(&MainToChild::PutBuffer {
|
||||
peer_name: name_a.clone(),
|
||||
peer_mac: mac_a,
|
||||
broadcast: false, // A is not router
|
||||
}, &[
|
||||
fds_from_a[0].as_raw_fd(),
|
||||
fds_from_a[1].as_raw_fd(),
|
||||
]).expect("PutBuffer A to B");
|
||||
|
||||
// Send B's buffer to A (A is client, broadcast=true because B is router)
|
||||
main_a.send_with_fds_typed(&MainToChild::PutBuffer {
|
||||
peer_name: name_b.clone(),
|
||||
peer_mac: mac_b,
|
||||
broadcast: true, // B is router
|
||||
}, &[
|
||||
fds_from_b[0].as_raw_fd(),
|
||||
fds_from_b[1].as_raw_fd(),
|
||||
]).expect("PutBuffer B to A");
|
||||
|
||||
// Each side receives PutBuffer and creates producer
|
||||
let (_, fds_b_egress): (MainToChild, _) = child_b.recv_with_fds_typed().expect("B recv PutBuffer");
|
||||
let (_, fds_a_egress): (MainToChild, _) = child_a.recv_with_fds_typed().expect("A recv PutBuffer");
|
||||
|
||||
let mut fds = fds_b_egress.into_iter();
|
||||
let producer_b = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer B");
|
||||
|
||||
let mut fds = fds_a_egress.into_iter();
|
||||
let producer_a = Producer::from_fds(fds.next().unwrap(), fds.next().unwrap()).expect("producer A");
|
||||
|
||||
// B sends to A
|
||||
producer_b.push(&[1, 2, 3]);
|
||||
assert_eq!(consumer_a.pop().unwrap(), vec![1, 2, 3]);
|
||||
|
||||
// A sends to B
|
||||
producer_a.push(&[4, 5, 6]);
|
||||
assert_eq!(consumer_b.pop().unwrap(), vec![4, 5, 6]);
|
||||
}
|
||||
104
vm-switch/tests/crash_handling.rs
Normal file
104
vm-switch/tests/crash_handling.rs
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
//! Integration tests for crash detection and cleanup.
|
||||
|
||||
use nix::sys::signal::{kill, Signal};
|
||||
use nix::sys::wait::{waitpid, WaitPidFlag, WaitStatus};
|
||||
use nix::unistd::{fork, ForkResult};
|
||||
use std::time::Duration;
|
||||
use vm_switch::control::{ControlChannel, MainToChild};
|
||||
|
||||
/// Test that main detects child crash via waitpid.
|
||||
#[test]
|
||||
fn main_detects_child_crash() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("pair");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
drop(child_end);
|
||||
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
|
||||
kill(child, Signal::SIGKILL).expect("kill");
|
||||
|
||||
let status = waitpid(child, None).expect("waitpid");
|
||||
assert!(matches!(status, WaitStatus::Signaled(_, Signal::SIGKILL, _)));
|
||||
|
||||
drop(main_end);
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control = ControlChannel::from_fd(child_end.into_fd());
|
||||
loop {
|
||||
match control.recv::<MainToChild>() {
|
||||
Err(_) => break,
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
std::process::exit(0);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that RemovePeer message is delivered correctly.
|
||||
#[test]
|
||||
fn remove_peer_message_delivery() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("pair");
|
||||
|
||||
let crashed_vm = "crashed_client".to_string();
|
||||
let msg = MainToChild::RemovePeer { peer_name: crashed_vm.clone() };
|
||||
|
||||
main_end.send(&msg).expect("send");
|
||||
|
||||
let received: MainToChild = child_end.recv().expect("recv");
|
||||
match received {
|
||||
MainToChild::RemovePeer { peer_name } => {
|
||||
assert_eq!(peer_name, crashed_vm);
|
||||
}
|
||||
_ => panic!("expected RemovePeer"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that children exit cleanly when control channel closes.
|
||||
#[test]
|
||||
fn children_exit_on_control_close() {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("pair");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
drop(child_end);
|
||||
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
|
||||
drop(main_end);
|
||||
|
||||
let deadline = std::time::Instant::now() + Duration::from_secs(2);
|
||||
loop {
|
||||
match waitpid(child, Some(WaitPidFlag::WNOHANG)) {
|
||||
Ok(WaitStatus::Exited(_, 0)) => return,
|
||||
Ok(WaitStatus::Exited(_, code)) => {
|
||||
panic!("Child exited with code {}", code);
|
||||
}
|
||||
Ok(_) => {
|
||||
if std::time::Instant::now() > deadline {
|
||||
kill(child, Signal::SIGKILL).ok();
|
||||
panic!("Child did not exit within timeout");
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
Err(e) => panic!("waitpid error: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control = ControlChannel::from_fd(child_end.into_fd());
|
||||
loop {
|
||||
match control.recv::<MainToChild>() {
|
||||
Err(_) => std::process::exit(0),
|
||||
_ => continue,
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
160
vm-switch/tests/fork_lifecycle.rs
Normal file
160
vm-switch/tests/fork_lifecycle.rs
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
//! Integration tests for fork-based child process lifecycle.
|
||||
//!
|
||||
//! Note: These tests must be run with `--test-threads=1` because
|
||||
//! fork tests cannot run in parallel within the same process.
|
||||
|
||||
use nix::sys::wait::{waitpid, WaitStatus};
|
||||
use nix::unistd::{fork, ForkResult};
|
||||
use serial_test::serial;
|
||||
use std::time::Duration;
|
||||
use vm_switch::control::ControlChannel;
|
||||
use vm_switch::mac::Mac;
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn child_exits_when_control_channel_closes() {
|
||||
// Create control channel
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x01]);
|
||||
|
||||
let temp_dir = tempfile::tempdir().expect("tempdir");
|
||||
let socket_path = temp_dir.path().join("test.sock");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
// Parent: drop child's end, keep main's end
|
||||
drop(child_end);
|
||||
|
||||
// Give child time to start and send Ready
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
|
||||
// Drain any Ready message
|
||||
let _ = main_end.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
|
||||
|
||||
// Close main's end - should cause child to exit
|
||||
drop(main_end);
|
||||
|
||||
// Wait for child to exit
|
||||
let status = waitpid(child, None).expect("waitpid failed");
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "child should exit with code 0");
|
||||
}
|
||||
other => panic!("unexpected wait status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
// Child: drop main's end (we don't need parent's socket)
|
||||
drop(main_end);
|
||||
|
||||
// Run child entry point - this should exit when control closes
|
||||
let control_fd = child_end.into_fd();
|
||||
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn child_processes_messages_before_exit() {
|
||||
use vm_switch::control::MainToChild;
|
||||
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0x02]);
|
||||
|
||||
let temp_dir = tempfile::tempdir().expect("tempdir");
|
||||
let socket_path = temp_dir.path().join("test.sock");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
// Parent: drop child's end
|
||||
drop(child_end);
|
||||
|
||||
// Wait for Ready from child
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
let _ = main_end.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
|
||||
|
||||
// Send a RemovePeer message
|
||||
let msg = MainToChild::RemovePeer {
|
||||
peer_name: "some-peer".to_string(),
|
||||
};
|
||||
main_end.send(&msg).expect("should send");
|
||||
|
||||
// Close channel (drop main_end)
|
||||
drop(main_end);
|
||||
|
||||
// Child should exit cleanly after processing message
|
||||
let status = waitpid(child, None).expect("waitpid failed");
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "child should exit with code 0");
|
||||
}
|
||||
other => panic!("unexpected wait status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control_fd = child_end.into_fd();
|
||||
vm_switch::child::run_child_process("test-vm", mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn multiple_children_shut_down_gracefully() {
|
||||
// Fork 3 children
|
||||
let mut children = Vec::new();
|
||||
|
||||
for i in 0..3u8 {
|
||||
let (main_end, child_end) = ControlChannel::pair().expect("should create pair");
|
||||
let vm_name = format!("test-vm-{}", i);
|
||||
let mac = Mac::from_bytes([0xaa, 0xbb, 0xcc, 0xdd, 0xee, i]);
|
||||
|
||||
let temp_dir = tempfile::tempdir().expect("tempdir");
|
||||
let socket_path = temp_dir.path().join("test.sock");
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
drop(child_end);
|
||||
children.push((child, main_end, vm_name, temp_dir));
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
drop(main_end);
|
||||
let control_fd = child_end.into_fd();
|
||||
vm_switch::child::run_child_process(&vm_name, mac, control_fd, &socket_path, vm_switch::SeccompMode::Disabled);
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
// Give children time to start and send Ready
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
|
||||
// Drain Ready messages
|
||||
for (_, control, _, _) in &children {
|
||||
let _ = control.recv_with_fds_typed::<vm_switch::control::ChildToMain>();
|
||||
}
|
||||
|
||||
// Collect PIDs and drop control channels to signal shutdown
|
||||
let pids: Vec<_> = children
|
||||
.into_iter()
|
||||
.map(|(pid, control, _, _temp_dir)| {
|
||||
drop(control);
|
||||
pid
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Wait for all children
|
||||
for pid in pids {
|
||||
let status = waitpid(pid, None).expect("waitpid failed");
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "child should exit with code 0");
|
||||
}
|
||||
other => panic!("unexpected wait status: {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
63
vm-switch/tests/packet_flow.rs
Normal file
63
vm-switch/tests/packet_flow.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
//! Integration tests for packet forwarding.
|
||||
|
||||
use std::os::fd::{AsRawFd, FromRawFd, OwnedFd};
|
||||
use vm_switch::child::PacketForwarder;
|
||||
use vm_switch::mac::Mac;
|
||||
use vm_switch::ring::{Consumer, Producer};
|
||||
|
||||
fn make_frame(dest: [u8; 6], src: [u8; 6]) -> Vec<u8> {
|
||||
let mut frame = vec![0u8; 64];
|
||||
frame[0..6].copy_from_slice(&dest);
|
||||
frame[6..12].copy_from_slice(&src);
|
||||
frame
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarder_validates_source_mac() {
|
||||
let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]);
|
||||
let peer_mac = [2, 0, 0, 0, 0, 2];
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Set up an egress buffer for broadcast
|
||||
let consumer = Consumer::new().expect("consumer");
|
||||
let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.memfd().as_raw_fd())) };
|
||||
let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(consumer.eventfd().as_raw_fd())) };
|
||||
let producer = Producer::from_fds(memfd, eventfd).expect("producer");
|
||||
forwarder.add_egress("router".to_string(), peer_mac, producer, true);
|
||||
|
||||
// Correct source MAC - broadcast frame should be sent to egress with broadcast=true
|
||||
let good_frame = make_frame([0xff; 6], our_mac.bytes());
|
||||
assert!(forwarder.forward_tx(&good_frame));
|
||||
|
||||
// Wrong source MAC - should be dropped
|
||||
let bad_frame = make_frame([0xff; 6], [9, 9, 9, 9, 9, 9]);
|
||||
assert!(!forwarder.forward_tx(&bad_frame));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forwarder_ingress_validates_peer_mac() {
|
||||
let our_mac = Mac::from_bytes([1, 0, 0, 0, 0, 1]);
|
||||
let peer_mac = [2, 0, 0, 0, 0, 2];
|
||||
let spoofed = [3, 0, 0, 0, 0, 3];
|
||||
|
||||
let mut forwarder = PacketForwarder::new(our_mac);
|
||||
|
||||
// Set up ingress from peer
|
||||
let producer = Producer::new().expect("producer");
|
||||
let memfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.memfd().as_raw_fd())) };
|
||||
let eventfd = unsafe { OwnedFd::from_raw_fd(libc::dup(producer.eventfd().as_raw_fd())) };
|
||||
let consumer = Consumer::from_fds(memfd, eventfd).expect("consumer");
|
||||
forwarder.add_ingress("router".to_string(), peer_mac, consumer);
|
||||
|
||||
// Good frame from peer
|
||||
let good = make_frame(our_mac.bytes(), peer_mac);
|
||||
producer.push(&good);
|
||||
|
||||
// Spoofed frame (wrong source)
|
||||
let bad = make_frame(our_mac.bytes(), spoofed);
|
||||
producer.push(&bad);
|
||||
|
||||
let received = forwarder.poll_ingress();
|
||||
assert_eq!(received.len(), 1);
|
||||
assert_eq!(&received[0][6..12], &peer_mac);
|
||||
}
|
||||
256
vm-switch/tests/sandbox_full.rs
Normal file
256
vm-switch/tests/sandbox_full.rs
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
//! Integration tests for full sandbox isolation.
|
||||
|
||||
use nix::sys::wait::{waitpid, WaitStatus};
|
||||
use nix::unistd::{fork, ForkResult, Uid};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use vm_switch::sandbox::{apply_sandbox, SandboxResult};
|
||||
|
||||
/// Helper to run test in forked child process.
|
||||
fn run_in_fork<F: FnOnce() + std::panic::UnwindSafe>(test_fn: F) {
|
||||
if Uid::current().is_root() {
|
||||
eprintln!("Skipping test: already running as root");
|
||||
return;
|
||||
}
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
let status = waitpid(child, None).unwrap();
|
||||
match status {
|
||||
WaitStatus::Exited(_, 0) => {}
|
||||
other => panic!("Child failed: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
let result = std::panic::catch_unwind(test_fn);
|
||||
match &result {
|
||||
Err(e) => {
|
||||
if let Some(s) = e.downcast_ref::<&str>() {
|
||||
eprintln!("Child panic: {}", s);
|
||||
} else if let Some(s) = e.downcast_ref::<String>() {
|
||||
eprintln!("Child panic: {}", s);
|
||||
} else {
|
||||
eprintln!("Child panic: unknown error");
|
||||
}
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
std::process::exit(if result.is_ok() { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("Fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to handle SandboxResult in tests.
|
||||
/// In the inner wrapper parent, propagates exit code.
|
||||
/// In the sandboxed child, returns the config path.
|
||||
fn apply_and_unwrap(config_path: &Path) -> std::path::PathBuf {
|
||||
match apply_sandbox(config_path).expect("apply_sandbox failed") {
|
||||
SandboxResult::Parent(code) => {
|
||||
// Inner wrapper parent - propagate child's exit code
|
||||
std::process::exit(code);
|
||||
}
|
||||
SandboxResult::Sandboxed(path) => path,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_returns_config_path() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
let result = apply_and_unwrap(&config_path);
|
||||
assert_eq!(result, Path::new("/config"));
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_isolates_ipc_namespace() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
// Get parent IPC namespace before fork
|
||||
let parent_ipc = fs::read_link("/proc/self/ns/ipc").unwrap();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// The fact that apply_sandbox succeeded means IPC namespace was entered
|
||||
// We verify isolation indirectly through the other tests
|
||||
});
|
||||
|
||||
// Parent's namespace should be unchanged
|
||||
let parent_ipc_after = fs::read_link("/proc/self/ns/ipc").unwrap();
|
||||
assert_eq!(parent_ipc, parent_ipc_after);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_isolates_network_namespace() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// In empty network namespace, /sys/class/net should be empty or not exist
|
||||
// Since /sys is not mounted in our minimal root, network is effectively isolated
|
||||
assert!(!Path::new("/sys").exists(), "/sys should not exist in sandbox");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_creates_complete_isolation() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
// Create marker file
|
||||
fs::write(config_path.join("marker.txt"), "isolated").unwrap();
|
||||
|
||||
run_in_fork(move || {
|
||||
let new_config = apply_and_unwrap(&config_path);
|
||||
|
||||
// Verify config path is correct
|
||||
assert_eq!(new_config, Path::new("/config"));
|
||||
|
||||
// Verify filesystem isolation
|
||||
assert!(Path::new("/config/marker.txt").exists());
|
||||
assert!(!Path::new("/home").exists());
|
||||
assert!(!Path::new("/usr").exists());
|
||||
|
||||
// Verify /dev works
|
||||
assert!(Path::new("/dev/null").exists());
|
||||
fs::write("/dev/null", "test").expect("write to /dev/null");
|
||||
|
||||
// Verify /tmp is writable
|
||||
fs::write("/tmp/test.txt", "temp").expect("write to /tmp");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn forked_children_inherit_sandbox() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
fs::write(config_path.join("shared.txt"), "parent").unwrap();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// Fork a child from within the sandbox
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
let status = waitpid(child, None).unwrap();
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "Nested child should exit with 0");
|
||||
}
|
||||
other => panic!("Nested child unexpected status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
// Nested child inherits sandbox
|
||||
let exists = Path::new("/config/shared.txt").exists();
|
||||
let no_home = !Path::new("/home").exists();
|
||||
|
||||
std::process::exit(if exists && no_home { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("Nested fork failed: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_creates_pid_namespace() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// Read /proc/self/stat - should be PID 1 in the new namespace
|
||||
let stat = fs::read_to_string("/proc/self/stat").expect("should read /proc/self/stat");
|
||||
assert!(
|
||||
stat.starts_with("1 "),
|
||||
"Should be PID 1 in namespace, got: {}",
|
||||
stat
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_sandbox_mounts_proc() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// /proc should be mounted and functional
|
||||
assert!(Path::new("/proc").exists(), "/proc should exist");
|
||||
assert!(
|
||||
Path::new("/proc/self").exists(),
|
||||
"/proc/self should exist"
|
||||
);
|
||||
|
||||
// Should be able to read process info
|
||||
let cmdline = fs::read_to_string("/proc/self/cmdline");
|
||||
assert!(cmdline.is_ok(), "Should be able to read /proc/self/cmdline");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_children_get_sequential_pids() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// Fork a child and verify it gets PID 2
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
let status = waitpid(child, None).unwrap();
|
||||
match status {
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_eq!(code, 0, "Child should be PID 2");
|
||||
}
|
||||
other => panic!("Child unexpected status: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
// Read our PID from /proc
|
||||
let stat = fs::read_to_string("/proc/self/stat").unwrap();
|
||||
let pid_str = stat.split_whitespace().next().unwrap();
|
||||
let pid: i32 = pid_str.parse().unwrap();
|
||||
// Should be PID 2 (parent is 1)
|
||||
std::process::exit(if pid == 2 { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("fork failed: {}", e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sandbox_proc_shows_only_sandboxed_processes() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
apply_and_unwrap(&config_path);
|
||||
|
||||
// Count processes visible in /proc
|
||||
let mut proc_pids = 0;
|
||||
for entry in fs::read_dir("/proc").unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
// Count numeric directories (PIDs)
|
||||
if name_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
proc_pids += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Should only see PID 1 (ourselves)
|
||||
assert_eq!(proc_pids, 1, "Should only see 1 process in /proc, saw {}", proc_pids);
|
||||
});
|
||||
}
|
||||
140
vm-switch/tests/sandbox_mount_ns.rs
Normal file
140
vm-switch/tests/sandbox_mount_ns.rs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
//! Integration tests for mount namespace and filesystem isolation.
|
||||
//!
|
||||
//! These tests require user namespace support (for unprivileged mount ns).
|
||||
|
||||
use nix::unistd::Uid;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use vm_switch::sandbox::{enter_user_namespace, setup_filesystem_isolation};
|
||||
|
||||
/// Helper to run test in forked child process.
|
||||
fn run_in_fork<F: FnOnce() + std::panic::UnwindSafe>(test_fn: F) {
|
||||
if Uid::current().is_root() {
|
||||
eprintln!("Skipping test: already running as root");
|
||||
return;
|
||||
}
|
||||
|
||||
match unsafe { nix::unistd::fork() } {
|
||||
Ok(nix::unistd::ForkResult::Parent { child }) => {
|
||||
let status = nix::sys::wait::waitpid(child, None).unwrap();
|
||||
match status {
|
||||
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
|
||||
other => panic!("Child failed: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(nix::unistd::ForkResult::Child) => {
|
||||
let result = std::panic::catch_unwind(test_fn);
|
||||
match &result {
|
||||
Err(e) => {
|
||||
if let Some(s) = e.downcast_ref::<&str>() {
|
||||
eprintln!("Child panic: {}", s);
|
||||
} else if let Some(s) = e.downcast_ref::<String>() {
|
||||
eprintln!("Child panic: {}", s);
|
||||
} else {
|
||||
eprintln!("Child panic: unknown error");
|
||||
}
|
||||
}
|
||||
Ok(()) => {}
|
||||
}
|
||||
std::process::exit(if result.is_ok() { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("Fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filesystem_isolation_creates_minimal_root() {
|
||||
// Create a temp config dir that we'll bind-mount
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
// Create a marker file in config dir
|
||||
fs::write(config_path.join("marker.txt"), "test").unwrap();
|
||||
|
||||
run_in_fork(move || {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
|
||||
|
||||
// Verify /config exists and contains our marker
|
||||
assert!(Path::new("/config/marker.txt").exists(), "/config/marker.txt should exist");
|
||||
let content = fs::read_to_string("/config/marker.txt").unwrap();
|
||||
assert_eq!(content, "test");
|
||||
|
||||
// Verify /proc mount point exists (may not be mounted without PID namespace)
|
||||
assert!(Path::new("/proc").is_dir(), "/proc should be a directory");
|
||||
|
||||
// Verify /dev devices exist
|
||||
assert!(Path::new("/dev/null").exists(), "/dev/null should exist");
|
||||
assert!(Path::new("/dev/zero").exists(), "/dev/zero should exist");
|
||||
assert!(Path::new("/dev/urandom").exists(), "/dev/urandom should exist");
|
||||
|
||||
// Verify /tmp exists
|
||||
assert!(Path::new("/tmp").is_dir(), "/tmp should be a directory");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filesystem_isolation_hides_host_filesystem() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
|
||||
|
||||
// These paths should NOT exist (host filesystem hidden)
|
||||
assert!(!Path::new("/home").exists(), "/home should not exist");
|
||||
assert!(!Path::new("/usr").exists(), "/usr should not exist");
|
||||
assert!(!Path::new("/etc").exists(), "/etc should not exist");
|
||||
assert!(!Path::new("/bin").exists(), "/bin should not exist");
|
||||
assert!(!Path::new("/nix").exists(), "/nix should not exist");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filesystem_isolation_config_dir_is_writable() {
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
|
||||
|
||||
// Should be able to create files in /config (for socket files)
|
||||
let test_file = Path::new("/config/test-write.txt");
|
||||
fs::write(test_file, "writable").expect("should be able to write to /config");
|
||||
assert!(test_file.exists());
|
||||
|
||||
let content = fs::read_to_string(test_file).unwrap();
|
||||
assert_eq!(content, "writable");
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dev_devices_are_functional() {
|
||||
use std::io::Read;
|
||||
|
||||
let config_dir = tempfile::tempdir().unwrap();
|
||||
let config_path = config_dir.path().to_path_buf();
|
||||
|
||||
run_in_fork(move || {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
setup_filesystem_isolation(&config_path, true).expect("setup_filesystem_isolation failed");
|
||||
|
||||
// /dev/null should accept writes and return nothing on read
|
||||
fs::write("/dev/null", "discard this").expect("write to /dev/null");
|
||||
let content = fs::read("/dev/null").unwrap();
|
||||
assert!(content.is_empty(), "/dev/null should return empty on read");
|
||||
|
||||
// /dev/zero should return zeros (read limited amount)
|
||||
let mut f = fs::File::open("/dev/zero").expect("open /dev/zero");
|
||||
let mut buf = [0xffu8; 16]; // Initialize with non-zero to verify it changes
|
||||
f.read_exact(&mut buf).expect("read from /dev/zero");
|
||||
assert!(buf.iter().all(|&b| b == 0), "/dev/zero should return zeros");
|
||||
|
||||
// /dev/urandom should return random bytes (just check it's readable)
|
||||
let mut f = fs::File::open("/dev/urandom").expect("open /dev/urandom");
|
||||
let mut buf = [0u8; 16];
|
||||
f.read_exact(&mut buf).expect("/dev/urandom should be readable");
|
||||
});
|
||||
}
|
||||
88
vm-switch/tests/sandbox_user_ns.rs
Normal file
88
vm-switch/tests/sandbox_user_ns.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
//! Integration tests for user namespace isolation.
|
||||
//!
|
||||
//! These tests require the ability to create user namespaces,
|
||||
//! which is available to unprivileged users on most Linux systems.
|
||||
|
||||
use nix::unistd::{getgid, getuid, Uid};
|
||||
use std::fs;
|
||||
use vm_switch::sandbox::enter_user_namespace;
|
||||
|
||||
#[test]
|
||||
fn enter_user_namespace_maps_to_root() {
|
||||
// Skip if we're already root (CI environments sometimes run as root)
|
||||
if Uid::current().is_root() {
|
||||
eprintln!("Skipping test: already running as root");
|
||||
return;
|
||||
}
|
||||
|
||||
// Fork to avoid affecting the test process
|
||||
match unsafe { nix::unistd::fork() } {
|
||||
Ok(nix::unistd::ForkResult::Parent { child }) => {
|
||||
// Parent waits for child
|
||||
let status = nix::sys::wait::waitpid(child, None).unwrap();
|
||||
match status {
|
||||
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
|
||||
other => panic!("Child failed: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(nix::unistd::ForkResult::Child) => {
|
||||
// Child enters user namespace and checks UID
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
|
||||
// After entering user namespace, we should appear as root
|
||||
let uid = getuid();
|
||||
let gid = getgid();
|
||||
|
||||
assert!(uid.is_root(), "Expected UID 0, got {}", uid);
|
||||
assert_eq!(gid.as_raw(), 0, "Expected GID 0, got {}", gid);
|
||||
|
||||
// Verify we're in a different namespace
|
||||
let ns_path = fs::read_link("/proc/self/ns/user").unwrap();
|
||||
eprintln!("User namespace: {:?}", ns_path);
|
||||
});
|
||||
|
||||
std::process::exit(if result.is_ok() { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("Fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enter_user_namespace_is_isolated() {
|
||||
// Skip if we're already root
|
||||
if Uid::current().is_root() {
|
||||
eprintln!("Skipping test: already running as root");
|
||||
return;
|
||||
}
|
||||
|
||||
// Get parent namespace inode before fork
|
||||
let parent_ns = fs::read_link("/proc/self/ns/user").unwrap();
|
||||
|
||||
match unsafe { nix::unistd::fork() } {
|
||||
Ok(nix::unistd::ForkResult::Parent { child }) => {
|
||||
let status = nix::sys::wait::waitpid(child, None).unwrap();
|
||||
match status {
|
||||
nix::sys::wait::WaitStatus::Exited(_, 0) => {}
|
||||
other => panic!("Child failed: {:?}", other),
|
||||
}
|
||||
}
|
||||
Ok(nix::unistd::ForkResult::Child) => {
|
||||
let result = std::panic::catch_unwind(|| {
|
||||
enter_user_namespace().expect("enter_user_namespace failed");
|
||||
|
||||
let child_ns = fs::read_link("/proc/self/ns/user").unwrap();
|
||||
|
||||
// Namespace should be different from parent
|
||||
assert_ne!(
|
||||
parent_ns, child_ns,
|
||||
"User namespace should differ: parent={:?}, child={:?}",
|
||||
parent_ns, child_ns
|
||||
);
|
||||
});
|
||||
|
||||
std::process::exit(if result.is_ok() { 0 } else { 1 });
|
||||
}
|
||||
Err(e) => panic!("Fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
178
vm-switch/tests/seccomp_filter.rs
Normal file
178
vm-switch/tests/seccomp_filter.rs
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
//! Integration tests for seccomp filtering.
|
||||
|
||||
use nix::sys::signal::Signal;
|
||||
use nix::sys::wait::{waitpid, WaitStatus};
|
||||
use nix::unistd::{fork, ForkResult, Uid};
|
||||
use std::process;
|
||||
use vm_switch::seccomp::{apply_child_seccomp, apply_main_seccomp, SeccompMode};
|
||||
|
||||
/// Run test in forked child, return wait status.
|
||||
fn run_in_fork<F: FnOnce() -> i32 + std::panic::UnwindSafe>(test_fn: F) -> WaitStatus {
|
||||
if Uid::current().is_root() {
|
||||
eprintln!("Skipping: running as root");
|
||||
return WaitStatus::Exited(nix::unistd::Pid::from_raw(0), 0);
|
||||
}
|
||||
|
||||
match unsafe { fork() } {
|
||||
Ok(ForkResult::Parent { child }) => {
|
||||
waitpid(child, None).expect("waitpid failed")
|
||||
}
|
||||
Ok(ForkResult::Child) => {
|
||||
let result = std::panic::catch_unwind(test_fn);
|
||||
let code = match result {
|
||||
Ok(code) => code,
|
||||
Err(_) => 1,
|
||||
};
|
||||
process::exit(code);
|
||||
}
|
||||
Err(e) => panic!("Fork failed: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_kill_blocks_ptrace() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_main_seccomp(SeccompMode::Kill).expect("apply failed");
|
||||
|
||||
// ptrace is not whitelisted
|
||||
unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
|
||||
0 // Should not reach here
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Signaled(_, sig, _) => {
|
||||
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
|
||||
}
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_ne!(code, 0, "should have been killed");
|
||||
}
|
||||
_ => panic!("unexpected: {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_trap_sends_sigsys() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_main_seccomp(SeccompMode::Trap).expect("apply failed");
|
||||
unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
|
||||
0
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Signaled(_, Signal::SIGSYS, _) => {}
|
||||
WaitStatus::Stopped(_, Signal::SIGSYS) => {}
|
||||
_ => panic!("expected SIGSYS, got {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_disabled_allows_all() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_main_seccomp(SeccompMode::Disabled).expect("apply failed");
|
||||
// Would be blocked, but disabled mode allows it
|
||||
let _ = unsafe { libc::ptrace(libc::PTRACE_TRACEME, 0, 0, 0) };
|
||||
0
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Exited(_, 0) => {}
|
||||
_ => panic!("expected clean exit, got {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn seccomp_allows_whitelisted() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_main_seccomp(SeccompMode::Kill).expect("apply failed");
|
||||
|
||||
// All whitelisted
|
||||
let _ = unsafe { libc::getpid() };
|
||||
let _ = unsafe { libc::getuid() };
|
||||
|
||||
let msg = b"ok\n";
|
||||
let ret = unsafe { libc::write(2, msg.as_ptr() as *const _, msg.len()) };
|
||||
if ret < 0 { return 1; }
|
||||
|
||||
0
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Exited(_, 0) => {}
|
||||
_ => panic!("expected clean exit, got {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_filter_blocks_fork() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
|
||||
|
||||
// fork not in child whitelist
|
||||
let ret = unsafe { libc::fork() };
|
||||
if ret == 0 {
|
||||
process::exit(99); // grandchild
|
||||
}
|
||||
0 // Should not reach
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Signaled(_, sig, _) => {
|
||||
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
|
||||
}
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert!(code != 0 && code != 99, "fork should be blocked");
|
||||
}
|
||||
_ => panic!("unexpected: {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_filter_blocks_socket() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
|
||||
|
||||
// socket() not in child whitelist
|
||||
let fd = unsafe { libc::socket(libc::AF_UNIX, libc::SOCK_STREAM, 0) };
|
||||
if fd >= 0 {
|
||||
unsafe { libc::close(fd) };
|
||||
return 1; // Should have been blocked
|
||||
}
|
||||
0
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Signaled(_, sig, _) => {
|
||||
assert!(sig == Signal::SIGSYS || sig == Signal::SIGKILL);
|
||||
}
|
||||
WaitStatus::Exited(_, code) => {
|
||||
assert_ne!(code, 0, "socket should be blocked");
|
||||
}
|
||||
_ => panic!("unexpected: {:?}", status),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn child_filter_allows_memfd() {
|
||||
let status = run_in_fork(|| {
|
||||
apply_child_seccomp(SeccompMode::Kill).expect("apply failed");
|
||||
|
||||
// memfd_create is in child whitelist (for ring buffers)
|
||||
let fd = unsafe {
|
||||
libc::syscall(
|
||||
libc::SYS_memfd_create,
|
||||
b"test\0".as_ptr(),
|
||||
0u32,
|
||||
)
|
||||
};
|
||||
if fd < 0 {
|
||||
return 1;
|
||||
}
|
||||
unsafe { libc::close(fd as i32) };
|
||||
0
|
||||
});
|
||||
|
||||
match status {
|
||||
WaitStatus::Exited(_, 0) => {}
|
||||
_ => panic!("expected clean exit, got {:?}", status),
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue