From b42c76579de3da55a6f1577e9da0277d4a390512 Mon Sep 17 00:00:00 2001 From: Manos Pitsidianakis Date: Mon, 4 Dec 2023 11:45:47 +0200 Subject: [PATCH] rng: use PathBuf for socket path clap can parse a PathBuf directly from the command line arguments, and paths are not always UTF-8. Use PathBuf instead of a String to allow for all valid filesystem paths. Signed-off-by: Manos Pitsidianakis --- vhost-device-rng/src/main.rs | 100 +++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 27 deletions(-) diff --git a/vhost-device-rng/src/main.rs b/vhost-device-rng/src/main.rs index a0e09a2..37edbd1 100644 --- a/vhost-device-rng/src/main.rs +++ b/vhost-device-rng/src/main.rs @@ -7,6 +7,7 @@ mod vhu_rng; use log::error; use std::fs::File; +use std::path::PathBuf; use std::process::exit; use std::sync::{Arc, Mutex, RwLock}; use std::thread::{self, JoinHandle}; @@ -57,12 +58,12 @@ struct RngArgs { socket_count: u32, // Location of vhost-user Unix domain socket. This is suffixed by 0,1,2..socket_count-1. - #[clap(short, long)] - socket_path: String, + #[clap(short, long, value_name = "SOCKET")] + socket_path: PathBuf, // Where to get the RNG data from. Defaults to /dev/urandom. #[clap(short = 'f', long, default_value = "/dev/urandom")] - rng_source: String, + rng_source: PathBuf, } #[derive(Clone, Debug, Eq, PartialEq)] @@ -70,45 +71,69 @@ pub(crate) struct VuRngConfig { pub period_ms: u128, pub max_bytes: usize, pub count: u32, - pub socket_path: String, - pub rng_source: String, + pub socket_path: PathBuf, + pub rng_source: PathBuf, } impl TryFrom for VuRngConfig { type Error = Error; fn try_from(args: RngArgs) -> Result { - if args.period == 0 || args.period > VHU_RNG_MAX_PERIOD_MS { - return Err(Error::InvalidPeriodInput(args.period)); + let RngArgs { + period, + max_bytes, + socket_count, + socket_path, + rng_source, + } = args; + + if period == 0 || period > VHU_RNG_MAX_PERIOD_MS { + return Err(Error::InvalidPeriodInput(period)); } - if args.socket_count == 0 { - return Err(Error::InvalidSocketCount(args.socket_count)); + if socket_count == 0 { + return Err(Error::InvalidSocketCount(socket_count)); } // Divide available bandwidth by the number of threads in order // to avoid overwhelming the HW. - let max_bytes = args.max_bytes / args.socket_count as usize; - let socket_path = args.socket_path.trim().to_string(); - let rng_source = args.rng_source.trim().to_string(); + let max_bytes = max_bytes / socket_count as usize; Ok(VuRngConfig { - period_ms: args.period, + period_ms: period, max_bytes, - count: args.socket_count, + count: socket_count, socket_path, rng_source, }) } } +impl VuRngConfig { + pub fn generate_socket_paths(&self) -> Vec { + let socket_file_name = self + .socket_path + .file_name() + .expect("socket_path has no filename."); + let socket_file_parent = self + .socket_path + .parent() + .expect("socket_path has no parent directory."); + + let make_socket_path = |i: u32| -> PathBuf { + let mut file_name = socket_file_name.to_os_string(); + file_name.push(std::ffi::OsStr::new(&i.to_string())); + socket_file_parent.join(&file_name) + }; + (0..self.count).map(make_socket_path).collect() + } +} + pub(crate) fn start_backend(config: VuRngConfig) -> Result<()> { let mut handles = Vec::new(); let file = File::open(&config.rng_source).map_err(|_| Error::AccessRngSourceFile)?; let random_file = Arc::new(Mutex::new(file)); - - for i in 0..config.count { - let socket = format!("{}{}", config.socket_path.to_owned(), i); + for socket in config.generate_socket_paths() { let period_ms = config.period_ms; let max_bytes = config.max_bytes; let random = Arc::clone(&random_file); @@ -155,32 +180,53 @@ fn main() { #[cfg(test)] mod tests { use assert_matches::assert_matches; + use std::path::Path; use tempfile::tempdir; use super::*; #[test] fn verify_cmd_line_arguments() { - // All parameters have default values, except for the socket path. White spaces are - // introduced on purpose to make sure Strings are trimmed properly. - let default_args: RngArgs = Parser::parse_from(["", "-s /some/socket_path "]); + // All parameters have default values, except for the socket path. + let default_args: RngArgs = Parser::parse_from(["", "-s", "/some/socket_path"]); // A valid configuration that should be equal to the above default configuration. let args = RngArgs { period: VHU_RNG_MAX_PERIOD_MS, max_bytes: usize::MAX, socket_count: 1, - socket_path: "/some/socket_path".to_string(), - rng_source: "/dev/urandom".to_string(), + socket_path: "/some/socket_path".into(), + rng_source: "/dev/urandom".into(), }; - // All configuration elements should be what we expect them to be. Using - // VuRngConfig::try_from() ensures that strings have been properly trimmed. + // All configuration elements should be what we expect them to be. assert_eq!( VuRngConfig::try_from(default_args).unwrap(), VuRngConfig::try_from(args.clone()).unwrap() ); + // Socket paths are what we expect them to be. + assert_eq!( + VuRngConfig::try_from(args.clone()) + .unwrap() + .generate_socket_paths(), + vec![Path::new("/some/socket_path0").to_path_buf()] + ); + + let mut many_socket_count_args = args.clone(); + many_socket_count_args.socket_count = 3; + + assert_eq!( + VuRngConfig::try_from(many_socket_count_args) + .unwrap() + .generate_socket_paths(), + vec![ + Path::new("/some/socket_path0").to_path_buf(), + Path::new("/some/socket_path1").to_path_buf(), + Path::new("/some/socket_path2").to_path_buf(), + ] + ); + // Setting a invalid period should trigger an InvalidPeriodInput error. let mut invalid_period_args = args.clone(); invalid_period_args.period = VHU_RNG_MAX_PERIOD_MS + 1; @@ -208,8 +254,8 @@ mod tests { period_ms: 1000, max_bytes: 512, count: 1, - socket_path: "/invalid/path".to_string(), - rng_source: "/invalid/path".to_string(), + socket_path: "/invalid/path".into(), + rng_source: "/invalid/path".into(), }; // An invalid RNG source file should trigger an AccessRngSourceFile error. @@ -222,7 +268,7 @@ mod tests { // of the socket file. Since the latter is invalid, serving will throw // an error, forcing the thread to exit and the call to handle.join() // to fail. - config.rng_source = random_path.to_str().unwrap().to_string(); + config.rng_source = random_path; assert_matches!(start_backend(config).unwrap_err(), Error::ServeFailed(_)); } }