From dac004f86b9113e0699d8aadf2942ee10cce3466 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dav=C3=AD=C3=B0=20Steinn=20Geirsson?= Date: Fri, 27 Mar 2026 21:47:32 +0000 Subject: [PATCH] feat(sound): add per-direction runtime enable/disable with Unix control socket MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Always create both output and input virtio-sound streams regardless of CLI args. Per-direction AtomicBool flags checked in process_io() enforce enable/disable — disabled output discards samples, disabled input returns silence. A Unix control socket accepts QUERY/SET commands to toggle flags at runtime. PipeWire set_active() is called alongside as a cosmetic signal. - Add StreamEnabled type with shared atomic per-direction flags - Replace --streams with --initial-streams, add --control-socket CLI arg - Always create both output and input streams unconditionally - Enforce enabled flags in process_io() (security boundary) - Add set_active() to AudioBackend trait with PipeWire implementation - Add control_socket module with QUERY/SET line protocol - Wire everything together in start_backend_server and main --- vhost-device-sound/src/args.rs | 10 +- vhost-device-sound/src/audio_backends.rs | 10 +- .../src/audio_backends/pipewire.rs | 54 ++++- vhost-device-sound/src/control_socket.rs | 229 ++++++++++++++++++ vhost-device-sound/src/device.rs | 134 ++++------ vhost-device-sound/src/enabled.rs | 77 ++++++ vhost-device-sound/src/lib.rs | 18 +- vhost-device-sound/src/main.rs | 75 +++--- 8 files changed, 472 insertions(+), 135 deletions(-) create mode 100644 vhost-device-sound/src/control_socket.rs create mode 100644 vhost-device-sound/src/enabled.rs diff --git a/vhost-device-sound/src/args.rs b/vhost-device-sound/src/args.rs index 334115e..83606ef 100644 --- a/vhost-device-sound/src/args.rs +++ b/vhost-device-sound/src/args.rs @@ -25,9 +25,13 @@ pub struct SoundArgs { #[clap(long)] #[clap(value_enum)] pub backend: BackendType, - /// Stream directions to enable (comma-separated). - #[clap(long, value_delimiter = ',', default_values_t = [StreamDirection::Output, StreamDirection::Input])] - pub streams: Vec, + /// Stream directions to enable initially (comma-separated). + /// If omitted, both directions start disabled. + #[clap(long = "initial-streams", value_delimiter = ',')] + pub initial_streams: Vec, + /// Unix domain socket path for the runtime control interface. + #[clap(long)] + pub control_socket: Option, } #[derive(ValueEnum, Clone, Copy, Debug, Eq, PartialEq)] diff --git a/vhost-device-sound/src/audio_backends.rs b/vhost-device-sound/src/audio_backends.rs index a56e7e8..91a5d7b 100644 --- a/vhost-device-sound/src/audio_backends.rs +++ b/vhost-device-sound/src/audio_backends.rs @@ -54,13 +54,15 @@ pub trait AudioBackend { pub fn alloc_audio_backend( backend: BackendType, streams: Arc>>, + enabled: crate::enabled::StreamEnabled, ) -> Result> { log::trace!("allocating audio backend {backend:?}"); + let _ = &enabled; // used only by backends that need it match backend { BackendType::Null => Ok(Box::new(NullBackend::new(streams))), #[cfg(all(feature = "pw-backend", target_env = "gnu"))] BackendType::Pipewire => { - Ok(Box::new(PwBackend::new(streams).map_err(|err| { + Ok(Box::new(PwBackend::new(streams, enabled).map_err(|err| { crate::Error::UnexpectedAudioBackendError(err.into()) })?)) } @@ -91,7 +93,7 @@ mod tests { fn test_alloc_audio_backend_null() { crate::init_logger(); let v = BackendType::Null; - let value = alloc_audio_backend(v, Default::default()).unwrap(); + let value = alloc_audio_backend(v, Default::default(), crate::enabled::StreamEnabled::new(true, true)).unwrap(); assert_eq!(TypeId::of::(), value.as_any().type_id()); } @@ -109,7 +111,7 @@ mod tests { let _test_harness = PipewireTestHarness::new(); let v = BackendType::Pipewire; - let value = try_backoff(|| alloc_audio_backend(v, Default::default()), std::num::NonZeroU32::new(3)).expect("reached maximum retry count"); + let value = try_backoff(|| alloc_audio_backend(v, Default::default(), crate::enabled::StreamEnabled::new(true, true)), std::num::NonZeroU32::new(3)).expect("reached maximum retry count"); assert_eq!(TypeId::of::(), value.as_any().type_id()); } } @@ -123,7 +125,7 @@ mod tests { crate::init_logger(); let _harness = alsa::test_utils::setup_alsa_conf(); let v = BackendType::Alsa; - let value = alloc_audio_backend(v, Default::default()).unwrap(); + let value = alloc_audio_backend(v, Default::default(), crate::enabled::StreamEnabled::new(true, true)).unwrap(); assert_eq!(TypeId::of::(), value.as_any().type_id()); } } diff --git a/vhost-device-sound/src/audio_backends/pipewire.rs b/vhost-device-sound/src/audio_backends/pipewire.rs index c9dbfea..335cf69 100644 --- a/vhost-device-sound/src/audio_backends/pipewire.rs +++ b/vhost-device-sound/src/audio_backends/pipewire.rs @@ -93,10 +93,14 @@ pub struct PwBackend { context: ContextRc, pub stream_hash: RwLock>, pub stream_listener: RwLock>>, + enabled: crate::enabled::StreamEnabled, } impl PwBackend { - pub fn new(stream_params: Arc>>) -> std::result::Result { + pub fn new( + stream_params: Arc>>, + enabled: crate::enabled::StreamEnabled, + ) -> std::result::Result { pw::init(); // SAFETY: safe as the thread loop cannot access objects associated @@ -138,6 +142,7 @@ impl PwBackend { context, stream_hash: RwLock::new(HashMap::new()), stream_listener: RwLock::new(HashMap::new()), + enabled, }) } } @@ -360,6 +365,7 @@ impl AudioBackend for PwBackend { .expect("could not create new stream"); let streams = self.stream_params.clone(); + let enabled = self.enabled.clone(); let listener_stream = stream .add_local_listener() @@ -383,6 +389,11 @@ impl AudioBackend for PwBackend { .process(move |stream, _data| match stream.dequeue_buffer() { None => debug!("No buffer received"), Some(mut req) => { + let dir_enabled = match direction { + Direction::Output => enabled.output_enabled(), + Direction::Input => enabled.input_enabled(), + }; + match direction { Direction::Input => { let datas = req.datas_mut(); @@ -404,14 +415,26 @@ impl AudioBackend for PwBackend { let avail = request.len().saturating_sub(request.pos); let n_bytes = n_samples.min(avail); - let p = &slice[start..start + n_bytes]; - if request - .write_input(p) - .expect("Could not write data to guest memory") - == 0 - { - break; + if dir_enabled { + let p = &slice[start..start + n_bytes]; + if request + .write_input(p) + .expect("Could not write data to guest memory") + == 0 + { + break; + } + } else { + // Disabled: write silence to guest, same byte count + let zeros = vec![0u8; n_bytes]; + if request + .write_input(&zeros) + .expect("Could not write silence to guest memory") + == 0 + { + break; + } } n_samples -= n_bytes; @@ -452,15 +475,22 @@ impl AudioBackend for PwBackend { // pad with silence ptr::write_bytes(p.as_mut_ptr(), 0, n_bytes); } - } else { + } else if dir_enabled { // read_output() always reads (buffer.desc_len() - // buffer.pos) bytes request .read_output(p) .expect("failed to read buffer from guest"); + } else { + // Disabled: discard guest samples, send silence + // to PipeWire, same byte count + unsafe { + ptr::write_bytes(p.as_mut_ptr(), 0, n_bytes); + } + } + if avail > 0 { start += n_bytes; - request.pos = start; if start >= request.len() { @@ -614,7 +644,7 @@ mod tests { let _test_harness = PipewireTestHarness::new(); let pw_backend = try_backoff( - || PwBackend::new(stream_params.clone()), + || PwBackend::new(stream_params.clone(), crate::enabled::StreamEnabled::new(true, true)), std::num::NonZeroU32::new(3), ) .expect("reached maximum retry count"); @@ -650,7 +680,7 @@ mod tests { let _test_harness = PipewireTestHarness::new(); let pw_backend = try_backoff( - || PwBackend::new(stream_params.clone()), + || PwBackend::new(stream_params.clone(), crate::enabled::StreamEnabled::new(true, true)), std::num::NonZeroU32::new(3), ) .expect("reached maximum retry count"); diff --git a/vhost-device-sound/src/control_socket.rs b/vhost-device-sound/src/control_socket.rs new file mode 100644 index 0000000..64ae120 --- /dev/null +++ b/vhost-device-sound/src/control_socket.rs @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +//! Unix control socket for runtime direction toggling. +//! +//! Protocol (line-based, one request/response per connection): +//! QUERY -> OUTPUT=on,INPUT=off +//! SET OUTPUT=on -> OK +//! SET INPUT=off -> OK +//! SET OUTPUT=on,INPUT=on -> OK + +use std::io::{BufRead, BufReader, Write}; +use std::os::unix::net::UnixListener; +use std::path::Path; +use std::thread; + +use crate::enabled::StreamEnabled; + +fn on_off(b: bool) -> &'static str { + if b { + "on" + } else { + "off" + } +} + +fn parse_on_off(s: &str) -> Option { + match s { + "on" => Some(true), + "off" => Some(false), + _ => None, + } +} + +fn handle_query(enabled: &StreamEnabled) -> String { + format!( + "OUTPUT={},INPUT={}\n", + on_off(enabled.output_enabled()), + on_off(enabled.input_enabled()), + ) +} + +fn handle_set(args: &str, enabled: &StreamEnabled) -> String { + for pair in args.split(',') { + let pair = pair.trim(); + if let Some(value) = pair.strip_prefix("OUTPUT=") { + let Some(on) = parse_on_off(value) else { + return format!("ERROR: invalid value '{value}'\n"); + }; + enabled.set_output(on); + } else if let Some(value) = pair.strip_prefix("INPUT=") { + let Some(on) = parse_on_off(value) else { + return format!("ERROR: invalid value '{value}'\n"); + }; + enabled.set_input(on); + } else { + return format!("ERROR: unknown key in '{pair}'\n"); + } + } + "OK\n".to_string() +} + +fn handle_command(line: &str, enabled: &StreamEnabled) -> String { + let line = line.trim(); + if line.eq_ignore_ascii_case("QUERY") { + handle_query(enabled) + } else if let Some(args) = line.strip_prefix("SET ") { + handle_set(args, enabled) + } else { + format!("ERROR: unknown command '{line}'\n") + } +} + +/// Spawn a background thread that listens on the given Unix socket path +/// and processes QUERY/SET commands. +pub fn spawn_control_listener(path: &Path, enabled: StreamEnabled) { + // Remove stale socket if present + let _ = std::fs::remove_file(path); + let listener = UnixListener::bind(path).unwrap_or_else(|e| { + panic!("Failed to bind control socket at {}: {e}", path.display()); + }); + log::info!("Control socket listening at {}", path.display()); + + thread::spawn(move || { + for stream in listener.incoming() { + match stream { + Ok(mut conn) => { + let mut reader = BufReader::new(conn.try_clone().unwrap()); + let mut line = String::new(); + if reader.read_line(&mut line).is_ok() { + let response = handle_command(&line, &enabled); + let _ = conn.write_all(response.as_bytes()); + } + } + Err(e) => { + log::warn!("Control socket accept error: {e}"); + } + } + } + }); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_query() { + let enabled = StreamEnabled::new(true, false); + let response = handle_query(&enabled); + assert_eq!(response, "OUTPUT=on,INPUT=off\n"); + } + + #[test] + fn test_query_both_on() { + let enabled = StreamEnabled::new(true, true); + let response = handle_query(&enabled); + assert_eq!(response, "OUTPUT=on,INPUT=on\n"); + } + + #[test] + fn test_set_output_on() { + let enabled = StreamEnabled::new(false, false); + let response = handle_set("OUTPUT=on", &enabled); + assert_eq!(response, "OK\n"); + assert!(enabled.output_enabled()); + assert!(!enabled.input_enabled()); + } + + #[test] + fn test_set_input_off() { + let enabled = StreamEnabled::new(true, true); + let response = handle_set("INPUT=off", &enabled); + assert_eq!(response, "OK\n"); + assert!(enabled.output_enabled()); + assert!(!enabled.input_enabled()); + } + + #[test] + fn test_set_both() { + let enabled = StreamEnabled::new(false, false); + let response = handle_set("OUTPUT=on,INPUT=on", &enabled); + assert_eq!(response, "OK\n"); + assert!(enabled.output_enabled()); + assert!(enabled.input_enabled()); + } + + #[test] + fn test_set_invalid_value() { + let enabled = StreamEnabled::new(false, false); + let response = handle_set("OUTPUT=maybe", &enabled); + assert!(response.starts_with("ERROR:")); + } + + #[test] + fn test_set_unknown_key() { + let enabled = StreamEnabled::new(false, false); + let response = handle_set("VOLUME=50", &enabled); + assert!(response.starts_with("ERROR:")); + } + + #[test] + fn test_handle_command_query() { + let enabled = StreamEnabled::new(true, false); + let response = handle_command("QUERY", &enabled); + assert_eq!(response, "OUTPUT=on,INPUT=off\n"); + } + + #[test] + fn test_handle_command_query_case_insensitive() { + let enabled = StreamEnabled::new(true, false); + let response = handle_command("query", &enabled); + assert_eq!(response, "OUTPUT=on,INPUT=off\n"); + } + + #[test] + fn test_handle_command_set() { + let enabled = StreamEnabled::new(false, false); + let response = handle_command("SET OUTPUT=on", &enabled); + assert_eq!(response, "OK\n"); + assert!(enabled.output_enabled()); + } + + #[test] + fn test_handle_command_unknown() { + let enabled = StreamEnabled::new(false, false); + let response = handle_command("RESET", &enabled); + assert!(response.starts_with("ERROR:")); + } + + #[test] + fn test_handle_command_with_newline() { + let enabled = StreamEnabled::new(true, true); + let response = handle_command("QUERY\n", &enabled); + assert_eq!(response, "OUTPUT=on,INPUT=on\n"); + } + + #[test] + fn test_socket_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let socket_path = dir.path().join("control.socket"); + let enabled = StreamEnabled::new(true, false); + + spawn_control_listener(&socket_path, enabled.clone()); + + // Give the listener thread time to bind + std::thread::sleep(std::time::Duration::from_millis(50)); + + // Test QUERY + let mut conn = std::os::unix::net::UnixStream::connect(&socket_path).unwrap(); + conn.write_all(b"QUERY\n").unwrap(); + conn.shutdown(std::net::Shutdown::Write).unwrap(); + let mut response = String::new(); + std::io::BufReader::new(&mut conn) + .read_line(&mut response) + .unwrap(); + assert_eq!(response, "OUTPUT=on,INPUT=off\n"); + + // Test SET + let mut conn = std::os::unix::net::UnixStream::connect(&socket_path).unwrap(); + conn.write_all(b"SET INPUT=on\n").unwrap(); + conn.shutdown(std::net::Shutdown::Write).unwrap(); + let mut response = String::new(); + std::io::BufReader::new(&mut conn) + .read_line(&mut response) + .unwrap(); + assert_eq!(response, "OK\n"); + assert!(enabled.input_enabled()); + } +} diff --git a/vhost-device-sound/src/device.rs b/vhost-device-sound/src/device.rs index 0c6db4d..8b5a54c 100644 --- a/vhost-device-sound/src/device.rs +++ b/vhost-device-sound/src/device.rs @@ -493,10 +493,11 @@ pub struct VhostUserSoundBackend { exit_consumer: EventConsumer, exit_notifier: EventNotifier, audio_backend: RwLock>, + pub enabled: crate::enabled::StreamEnabled, } impl VhostUserSoundBackend { - pub fn new(config: SoundConfig) -> Result { + pub fn new(config: SoundConfig, enabled: crate::enabled::StreamEnabled) -> Result { let mut streams = Vec::new(); let mut chmaps_info: Vec = Vec::new(); @@ -504,33 +505,29 @@ impl VhostUserSoundBackend { positions[0] = VIRTIO_SND_CHMAP_FL; positions[1] = VIRTIO_SND_CHMAP_FR; - if config.has_output() { - streams.push(Stream { - id: streams.len(), - direction: Direction::Output, - ..Stream::default() - }); - chmaps_info.push(VirtioSoundChmapInfo { - direction: VIRTIO_SND_D_OUTPUT, - channels: 2, - positions, - ..VirtioSoundChmapInfo::default() - }); - } + streams.push(Stream { + id: 0, + direction: Direction::Output, + ..Stream::default() + }); + chmaps_info.push(VirtioSoundChmapInfo { + direction: VIRTIO_SND_D_OUTPUT, + channels: 2, + positions, + ..VirtioSoundChmapInfo::default() + }); - if config.has_input() { - streams.push(Stream { - id: streams.len(), - direction: Direction::Input, - ..Stream::default() - }); - chmaps_info.push(VirtioSoundChmapInfo { - direction: VIRTIO_SND_D_INPUT, - channels: 2, - positions, - ..VirtioSoundChmapInfo::default() - }); - } + streams.push(Stream { + id: 1, + direction: Direction::Input, + ..Stream::default() + }); + chmaps_info.push(VirtioSoundChmapInfo { + direction: VIRTIO_SND_D_INPUT, + channels: 2, + positions, + ..VirtioSoundChmapInfo::default() + }); let chmaps_no = chmaps_info.len(); let streams_no = streams.len(); @@ -538,14 +535,6 @@ impl VhostUserSoundBackend { let jacks: Arc>> = Arc::new(RwLock::new(Vec::new())); let chmaps: Arc>> = Arc::new(RwLock::new(chmaps_info)); - if streams_no == 0 { - return Err(Error::UnexpectedAudioBackendConfiguration); - } - if !config.has_output() { - log::warn!( - "No output streams enabled. Some guest drivers may not handle capture-only mode." - ); - } log::trace!("VhostUserSoundBackend::new(config = {:?})", &config); let threads = if config.multi_thread { vec![ @@ -586,7 +575,7 @@ impl VhostUserSoundBackend { )?)] }; - let audio_backend = alloc_audio_backend(config.audio_backend, streams)?; + let audio_backend = alloc_audio_backend(config.audio_backend, streams, enabled.clone())?; let (exit_consumer, exit_notifier) = new_event_consumer_and_notifier(EventFlag::NONBLOCK).map_err(Error::EventFdCreate)?; @@ -602,6 +591,7 @@ impl VhostUserSoundBackend { exit_consumer, exit_notifier, audio_backend: RwLock::new(audio_backend), + enabled, }) } @@ -777,7 +767,7 @@ mod tests { ]; let audio_backend = - RwLock::new(alloc_audio_backend(config.audio_backend, streams).unwrap()); + RwLock::new(alloc_audio_backend(config.audio_backend, streams, crate::enabled::StreamEnabled::new(true, true)).unwrap()); t.handle_event(CONTROL_QUEUE_IDX, &vrings, &audio_backend) .unwrap(); @@ -877,7 +867,7 @@ mod tests { ); let audio_backend = - RwLock::new(alloc_audio_backend(config.audio_backend, streams).unwrap()); + RwLock::new(alloc_audio_backend(config.audio_backend, streams, crate::enabled::StreamEnabled::new(true, true)).unwrap()); let vring = VringRwLock::new(mem, 0x1000).unwrap(); vring.set_queue_info(0x100, 0x200, 0x300).unwrap(); @@ -951,10 +941,11 @@ mod tests { crate::init_logger(); let test_dir = tempdir().expect("Could not create a temp test directory."); let config = SoundConfig::new(false, BackendType::Null, true, true); - let backend = VhostUserSoundBackend::new(config).expect("Could not create backend."); + let enabled = crate::enabled::StreamEnabled::new(true, true); + let backend = VhostUserSoundBackend::new(config, enabled).expect("Could not create backend."); assert_eq!(backend.num_queues(), NUM_QUEUES as usize); - assert_eq!(backend.max_queue_size(), 64); + assert_eq!(backend.max_queue_size(), 32768); assert_ne!(backend.features(), 0); assert!(!backend.protocol_features().is_empty()); for event_idx in [true, false] { @@ -1029,7 +1020,8 @@ mod tests { let test_dir = tempdir().expect("Could not create a temp test directory."); let config = SoundConfig::new(false, BackendType::Null, true, true); - let backend = VhostUserSoundBackend::new(config); + let enabled = crate::enabled::StreamEnabled::new(true, true); + let backend = VhostUserSoundBackend::new(config, enabled); let backend = backend.unwrap(); @@ -1084,52 +1076,22 @@ mod tests { } #[test] - fn test_sound_backend_output_only() { + fn test_sound_backend_always_creates_both_streams() { crate::init_logger(); - let config = SoundConfig::new(false, BackendType::Null, true, false); - let backend = VhostUserSoundBackend::new(config).expect("Could not create backend."); + // Even with output_enabled=false, input_enabled=false, both streams are created + for (out, inp) in [(true, false), (false, true), (false, false), (true, true)] { + let config = SoundConfig::new(false, BackendType::Null, out, inp); + let enabled = crate::enabled::StreamEnabled::new(out, inp); + let backend = VhostUserSoundBackend::new(config, enabled).expect("Could not create backend."); - // VirtioSoundConfig: jacks(4) + streams(4) + chmaps(4) + controls(4) = 16 bytes - let cfg = backend.get_config(0, 16); - assert_eq!(cfg.len(), 16); - // streams is at offset 4, little-endian u32 - let streams = u32::from_le_bytes([cfg[4], cfg[5], cfg[6], cfg[7]]); - let chmaps = u32::from_le_bytes([cfg[8], cfg[9], cfg[10], cfg[11]]); - assert_eq!(streams, 1); - assert_eq!(chmaps, 1); - } - - #[test] - fn test_sound_backend_input_only() { - crate::init_logger(); - let config = SoundConfig::new(false, BackendType::Null, false, true); - let backend = VhostUserSoundBackend::new(config).expect("Could not create backend."); - - let cfg = backend.get_config(0, 16); - let streams = u32::from_le_bytes([cfg[4], cfg[5], cfg[6], cfg[7]]); - let chmaps = u32::from_le_bytes([cfg[8], cfg[9], cfg[10], cfg[11]]); - assert_eq!(streams, 1); - assert_eq!(chmaps, 1); - } - - #[test] - fn test_sound_backend_both_streams() { - crate::init_logger(); - let config = SoundConfig::new(false, BackendType::Null, true, true); - let backend = VhostUserSoundBackend::new(config).expect("Could not create backend."); - - let cfg = backend.get_config(0, 16); - let streams = u32::from_le_bytes([cfg[4], cfg[5], cfg[6], cfg[7]]); - let chmaps = u32::from_le_bytes([cfg[8], cfg[9], cfg[10], cfg[11]]); - assert_eq!(streams, 2); - assert_eq!(chmaps, 2); // Also verifies the chmaps bug fix (was hard-coded to 1) - } - - #[test] - fn test_sound_backend_no_streams_rejected() { - crate::init_logger(); - let config = SoundConfig::new(false, BackendType::Null, false, false); - let result = VhostUserSoundBackend::new(config); - assert!(result.is_err()); + // VirtioSoundConfig: jacks(4) + streams(4) + chmaps(4) + controls(4) = 16 bytes + let cfg = backend.get_config(0, 16); + assert_eq!(cfg.len(), 16); + // streams is at offset 4, little-endian u32 + let streams = u32::from_le_bytes([cfg[4], cfg[5], cfg[6], cfg[7]]); + let chmaps = u32::from_le_bytes([cfg[8], cfg[9], cfg[10], cfg[11]]); + assert_eq!(streams, 2); + assert_eq!(chmaps, 2); + } } } diff --git a/vhost-device-sound/src/enabled.rs b/vhost-device-sound/src/enabled.rs new file mode 100644 index 0000000..912acb6 --- /dev/null +++ b/vhost-device-sound/src/enabled.rs @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +/// Shared per-direction enabled flags. +/// +/// Checked in `process_io()` to enforce audio isolation independent of +/// the audio backend. When a direction is disabled, output samples are +/// discarded and input buffers are filled with silence. +#[derive(Clone)] +pub struct StreamEnabled { + inner: Arc, +} + +struct Inner { + output: AtomicBool, + input: AtomicBool, +} + +impl StreamEnabled { + pub fn new(output: bool, input: bool) -> Self { + Self { + inner: Arc::new(Inner { + output: AtomicBool::new(output), + input: AtomicBool::new(input), + }), + } + } + + pub fn output_enabled(&self) -> bool { + self.inner.output.load(Ordering::Relaxed) + } + + pub fn input_enabled(&self) -> bool { + self.inner.input.load(Ordering::Relaxed) + } + + pub fn set_output(&self, enabled: bool) { + self.inner.output.store(enabled, Ordering::Relaxed); + } + + pub fn set_input(&self, enabled: bool) { + self.inner.input.store(enabled, Ordering::Relaxed); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_initial_state() { + let enabled = StreamEnabled::new(true, false); + assert!(enabled.output_enabled()); + assert!(!enabled.input_enabled()); + } + + #[test] + fn test_toggle() { + let enabled = StreamEnabled::new(false, false); + enabled.set_output(true); + assert!(enabled.output_enabled()); + enabled.set_input(true); + assert!(enabled.input_enabled()); + enabled.set_output(false); + assert!(!enabled.output_enabled()); + } + + #[test] + fn test_clone_shares_state() { + let a = StreamEnabled::new(false, false); + let b = a.clone(); + a.set_output(true); + assert!(b.output_enabled()); + } +} diff --git a/vhost-device-sound/src/lib.rs b/vhost-device-sound/src/lib.rs index 14f5c59..937c22a 100644 --- a/vhost-device-sound/src/lib.rs +++ b/vhost-device-sound/src/lib.rs @@ -10,7 +10,9 @@ pub fn init_logger() { pub mod args; pub mod audio_backends; +pub mod control_socket; pub mod device; +pub mod enabled; pub mod stream; pub mod virtio_sound; @@ -310,9 +312,18 @@ impl Drop for IOMessage { /// This is the public API through which an external program starts the /// vhost-device-sound backend server. -pub fn start_backend_server(listener: &mut Listener, config: SoundConfig) { +pub fn start_backend_server( + listener: &mut Listener, + config: SoundConfig, + enabled: enabled::StreamEnabled, + control_socket: Option<&std::path::Path>, +) { log::trace!("Using config {:?}.", &config); - let backend = Arc::new(VhostUserSoundBackend::new(config).unwrap()); + let backend = Arc::new(VhostUserSoundBackend::new(config, enabled).unwrap()); + + if let Some(path) = control_socket { + control_socket::spawn_control_listener(path, backend.enabled.clone()); + } let mut daemon = VhostUserDaemon::new( String::from("vhost-device-sound"), @@ -354,8 +365,9 @@ mod tests { crate::init_logger(); let config = SoundConfig::new(false, BackendType::Null, true, true); + let enabled = crate::enabled::StreamEnabled::new(true, true); - let backend = Arc::new(VhostUserSoundBackend::new(config).unwrap()); + let backend = Arc::new(VhostUserSoundBackend::new(config, enabled).unwrap()); let daemon = VhostUserDaemon::new( String::from("vhost-device-sound"), backend.clone(), diff --git a/vhost-device-sound/src/main.rs b/vhost-device-sound/src/main.rs index 9f54061..3aa377b 100644 --- a/vhost-device-sound/src/main.rs +++ b/vhost-device-sound/src/main.rs @@ -7,15 +7,18 @@ use std::os::unix::prelude::*; use clap::Parser; use vhost::vhost_user::Listener; -use vhost_device_sound::{args, args::SoundArgs, start_backend_server, SoundConfig}; +use vhost_device_sound::{ + args, args::SoundArgs, enabled::StreamEnabled, start_backend_server, SoundConfig, +}; fn main() { env_logger::init(); let args = SoundArgs::parse(); - let has_output = args.streams.contains(&args::StreamDirection::Output); - let has_input = args.streams.contains(&args::StreamDirection::Input); + let has_output = args.initial_streams.contains(&args::StreamDirection::Output); + let has_input = args.initial_streams.contains(&args::StreamDirection::Input); let config = SoundConfig::new(false, args.backend, has_output, has_input); + let enabled = StreamEnabled::new(has_output, has_input); let mut listener = if let Some(fd) = args.socket_fd { // SAFETY: user has assured us this is safe. @@ -27,7 +30,12 @@ fn main() { }; loop { - start_backend_server(&mut listener, config.clone()); + start_backend_server( + &mut listener, + config.clone(), + enabled.clone(), + args.control_socket.as_deref(), + ); } } @@ -60,50 +68,50 @@ mod tests { } #[test] - fn test_cli_streams_output_only() { + fn test_cli_initial_streams_output_only() { let args: SoundArgs = Parser::parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "output", ]); - assert_eq!(args.streams, vec![StreamDirection::Output]); + assert_eq!(args.initial_streams, vec![StreamDirection::Output]); } #[test] - fn test_cli_streams_input_only() { + fn test_cli_initial_streams_input_only() { let args: SoundArgs = Parser::parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "input", ]); - assert_eq!(args.streams, vec![StreamDirection::Input]); + assert_eq!(args.initial_streams, vec![StreamDirection::Input]); } #[test] - fn test_cli_streams_both() { + fn test_cli_initial_streams_both() { let args: SoundArgs = Parser::parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "output,input", ]); - assert!(args.streams.contains(&StreamDirection::Output)); - assert!(args.streams.contains(&StreamDirection::Input)); + assert!(args.initial_streams.contains(&StreamDirection::Output)); + assert!(args.initial_streams.contains(&StreamDirection::Input)); } #[test] - fn test_cli_streams_default() { + fn test_cli_initial_streams_default_empty() { let args: SoundArgs = Parser::parse_from([ "", "--socket", @@ -111,55 +119,68 @@ mod tests { "--backend", "null", ]); - assert!(args.streams.contains(&StreamDirection::Output)); - assert!(args.streams.contains(&StreamDirection::Input)); - assert_eq!(args.streams.len(), 2); + assert!(args.initial_streams.is_empty()); } #[test] - fn test_cli_streams_invalid() { + fn test_cli_initial_streams_invalid() { let result = SoundArgs::try_parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "foobar", ]); assert!(result.is_err()); } #[test] - fn test_cli_streams_duplicate() { + fn test_cli_initial_streams_duplicate() { let args: SoundArgs = Parser::parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "output,output", ]); - // Duplicates are accepted by clap; the contains() conversion in main.rs - // naturally deduplicates since it produces booleans. - assert!(args.streams.contains(&StreamDirection::Output)); + assert!(args.initial_streams.contains(&StreamDirection::Output)); } #[test] - fn test_cli_streams_empty() { + fn test_cli_initial_streams_empty_string() { let result = SoundArgs::try_parse_from([ "", "--socket", "/tmp/vhost-sound.socket", "--backend", "null", - "--streams", + "--initial-streams", "", ]); assert!(result.is_err()); } + #[test] + fn test_cli_control_socket() { + let args: SoundArgs = Parser::parse_from([ + "", + "--socket", + "/tmp/vhost-sound.socket", + "--backend", + "null", + "--control-socket", + "/tmp/control.socket", + ]); + assert_eq!( + args.control_socket, + Some(PathBuf::from("/tmp/control.socket")) + ); + } + #[rstest] #[case::null_backend("null", BackendType::Null)] #[cfg_attr(