diff --git a/option_parser/src/lib.rs b/option_parser/src/lib.rs index f83524a20..ff3d9ffff 100644 --- a/option_parser/src/lib.rs +++ b/option_parser/src/lib.rs @@ -9,6 +9,23 @@ use std::str::FromStr; use thiserror::Error; +mod private_trait { + // Voldemort trait that dispatches to `FromStr::from_str` on externally-defined types + // and to custom parsing code for types in this module. + pub trait Parseable + where + Self: Sized, + { + type Err; + // Actually does the parsing, but panics if the input doesn't have + // balanced quotes. This is fine because split_commas checks that the + // input has balanced quotes, and option names cannot contain anything + // that split_commas treats as special. + fn from_str(input: &str) -> Result::Err>; + } +} +use private_trait::Parseable; + #[derive(Default)] pub struct OptionParser { options: HashMap, @@ -34,38 +51,34 @@ type OptionParserResult = std::result::Result; fn split_commas(s: &str) -> OptionParserResult> { let mut list: Vec = Vec::new(); - let mut opened_brackets = 0; + let mut opened_brackets = 0u64; let mut in_quotes = false; let mut current = String::new(); for c in s.trim().chars() { match c { - '[' => { - opened_brackets += 1; - current.push('['); - } + // In quotes, only '"' is special + '"' => in_quotes = !in_quotes, + _ if in_quotes => {} + '[' => opened_brackets += 1, ']' => { - opened_brackets -= 1; - if opened_brackets < 0 { + if opened_brackets < 1 { return Err(OptionParserError::InvalidSyntax(s.to_owned())); } - current.push(']'); + opened_brackets -= 1; } - '"' => in_quotes = !in_quotes, - ',' => { - if opened_brackets > 0 || in_quotes { - current.push(',') - } else { - list.push(current); - current = String::new(); - } + ',' if opened_brackets == 0 => { + list.push(current); + current = String::new(); + continue; } - c => current.push(c), - } + _ => {} + }; + current.push(c); } list.push(current); - if opened_brackets != 0 || in_quotes { + if in_quotes || opened_brackets != 0 { return Err(OptionParserError::InvalidSyntax(s.to_owned())); } @@ -86,7 +99,6 @@ impl OptionParser { for option in split_commas(input)?.iter() { let parts: Vec<&str> = option.splitn(2, '=').collect(); - match self.options.get_mut(parts[0]) { None => return Err(OptionParserError::UnknownOption(parts[0].to_owned())), Some(value) => { @@ -106,6 +118,12 @@ impl OptionParser { } pub fn add(&mut self, option: &str) -> &mut Self { + // Check that option=value has balanced + // quotes and brackets iff value does. + assert!( + !option.contains(['"', '[', ']', '=', ',']), + "forbidden character in option name" + ); self.options.insert( option.to_owned(), OptionParserValue { @@ -133,7 +151,13 @@ impl OptionParser { self.options .get(option) .and_then(|v| v.value.clone()) - .and_then(|s| if s.is_empty() { None } else { Some(s) }) + .and_then(|s| { + if s.is_empty() { + None + } else { + Some(dequote(&s)) + } + }) } pub fn is_set(&self, option: &str) -> bool { @@ -143,12 +167,18 @@ impl OptionParser { .is_some() } - pub fn convert(&self, option: &str) -> OptionParserResult> { - match self.get(option) { + pub fn convert(&self, option: &str) -> OptionParserResult> { + match self.options.get(option).and_then(|v| v.value.as_ref()) { None => Ok(None), - Some(v) => Ok(Some(v.parse().map_err(|_| { - OptionParserError::Conversion(option.to_owned(), v.to_owned()) - })?)), + Some(v) => { + Ok(if v.is_empty() { + None + } else { + Some(Parseable::from_str(v).map_err(|_| { + OptionParserError::Conversion(option.to_owned(), v.to_owned()) + })?) + }) + } } } } @@ -161,7 +191,7 @@ pub enum ToggleParseError { InvalidValue(String), } -impl FromStr for Toggle { +impl Parseable for Toggle { type Err = ToggleParseError; fn from_str(s: &str) -> std::result::Result { @@ -216,7 +246,7 @@ pub enum IntegerListParseError { InvalidValue(String), } -impl FromStr for IntegerList { +impl Parseable for IntegerList { type Err = IntegerListParseError; fn from_str(s: &str) -> std::result::Result { @@ -300,6 +330,7 @@ impl TupleValue for Vec { } } +#[derive(PartialEq, Eq, Debug)] pub struct Tuple(pub Vec<(S, T)>); #[derive(Error, Debug)] @@ -314,31 +345,39 @@ pub enum TupleError { InvalidInteger(#[source] ParseIntError), } -impl FromStr for Tuple { +impl Parseable for Tuple { type Err = TupleError; fn from_str(s: &str) -> std::result::Result { let mut list: Vec<(S, T)> = Vec::new(); - let body = s .trim() .strip_prefix('[') .and_then(|s| s.strip_suffix(']')) .ok_or_else(|| TupleError::InvalidValue(s.to_string()))?; - let tuples_list = split_commas(body).map_err(TupleError::SplitOutsideBrackets)?; for tuple in tuples_list.iter() { - let items: Vec<&str> = tuple.split('@').collect(); - - if items.len() != 2 { - return Err(TupleError::InvalidValue((*tuple).to_string())); + let mut in_quotes = false; + let mut last_idx = 0; + let mut first_val = None; + for (idx, c) in tuple.as_bytes().iter().enumerate() { + match c { + b'"' => in_quotes = !in_quotes, + b'@' if !in_quotes => { + if last_idx != 0 { + return Err(TupleError::InvalidValue((*tuple).to_string())); + } + first_val = Some(&tuple[last_idx..idx]); + last_idx = idx + 1; + } + _ => {} + } } - - let item1 = items[0] - .parse::() - .map_err(|_| TupleError::InvalidValue(items[0].to_owned()))?; - let item2 = TupleValue::parse_value(items[1])?; - + let item1 = ::from_str( + first_val.ok_or(TupleError::InvalidValue((*tuple).to_string()))?, + ) + .map_err(|_| TupleError::InvalidValue(first_val.unwrap().to_owned()))?; + let item2 = TupleValue::parse_value(&tuple[last_idx..])?; list.push((item1, item2)); } @@ -355,16 +394,48 @@ pub enum StringListParseError { InvalidValue(String), } -impl FromStr for StringList { +fn dequote(s: &str) -> String { + let mut prev_byte = b'\0'; + let mut in_quotes = false; + let mut out: Vec = vec![]; + for i in s.bytes() { + if i == b'"' { + if prev_byte == b'"' && !in_quotes { + out.push(b'"'); + } + in_quotes = !in_quotes; + } else { + out.push(i); + } + prev_byte = i + } + assert!(!in_quotes, "split_commas didn't reject unbalanced quotes"); + // SAFETY: the non-ASCII bytes in the output are the same + // and in the same order as those in the input, so if the + // input is valid UTF-8 the output will be as well. + unsafe { String::from_utf8_unchecked(out) } +} + +impl Parseable for T +where + T: FromStr + Sized, +{ + type Err = ::Err; + fn from_str(s: &str) -> std::result::Result { + dequote(s).parse() + } +} + +impl Parseable for StringList { type Err = StringListParseError; fn from_str(s: &str) -> std::result::Result { - let string_list: Vec = s - .trim() - .trim_matches(|c| c == '[' || c == ']') - .split(',') - .map(|e| e.to_owned()) - .collect(); + let string_list: Vec = + split_commas(s.trim().trim_matches(|c| c == '[' || c == ']')) + .map_err(|_| StringListParseError::InvalidValue(s.to_owned()))? + .iter() + .map(|e| e.to_owned()) + .collect(); Ok(StringList(string_list)) } @@ -385,6 +456,7 @@ mod tests { .add("topology") .add("cmdline"); + assert_eq!(split_commas("\"\"").unwrap(), vec!["\"\""]); parser.parse("size=128M,hanging_param").unwrap_err(); parser .parse("size=128M,too_many_equals=foo=bar") @@ -395,6 +467,8 @@ mod tests { assert_eq!(parser.get("size"), Some("128M".to_owned())); assert!(!parser.is_set("mergeable")); assert!(parser.is_set("size")); + parser.parse("size=").unwrap(); + assert!(parser.get("size").is_none()); parser.parse("size=128M,mergeable=on").unwrap(); assert_eq!(parser.get("size"), Some("128M".to_owned())); @@ -416,6 +490,14 @@ mod tests { parser.parse("topology=[").unwrap_err(); parser.parse("topology=[[[]]]]").unwrap_err(); + parser.parse("topology=[\"@\"\"b\"@[1,2]]").unwrap(); + assert_eq!( + parser + .convert::>>("topology") + .unwrap() + .unwrap(), + Tuple(vec![("@\"b".to_owned(), vec![1, 2])]) + ); parser.parse("cmdline=\"console=ttyS0,9600n8\"").unwrap(); assert_eq!( @@ -425,4 +507,14 @@ mod tests { parser.parse("cmdline=\"").unwrap_err(); parser.parse("cmdline=\"\"\"").unwrap_err(); } + + #[test] + fn parse_bytes() { + assert_eq!(::from_str("a=\"b\"").unwrap(), "a=b"); + } + + #[test] + fn check_dequote() { + assert_eq!(dequote("a\u{3b2}\"a\"\"\""), "a\u{3b2}a\"") + } }