diff --git a/virtio-devices/src/vsock/unix/muxer.rs b/virtio-devices/src/vsock/unix/muxer.rs index 842e02677..5ea8efb9b 100644 --- a/virtio-devices/src/vsock/unix/muxer.rs +++ b/virtio-devices/src/vsock/unix/muxer.rs @@ -38,6 +38,7 @@ //! To route all these events to their handlers, the muxer uses another `HashMap` object, //! mapping `RawFd`s to `EpollListener`s. +use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{self, ErrorKind, Read}; @@ -90,11 +91,13 @@ enum EpollListener { LocalStream(UnixStream), } +const PARTIALLY_READ_COMMAND_BUF_SIZE: usize = 32; + /// A partially read "CONNECT" command. #[derive(Default)] struct PartiallyReadCommand { /// The bytes of the command that have been read so far. - buf: [u8; 32], + buf: [u8; PARTIALLY_READ_COMMAND_BUF_SIZE], /// How much of `buf` has been used. len: usize, } @@ -435,7 +438,11 @@ impl VsockMuxer { // "connect" command that we're expecting. Some(EpollListener::LocalStream(_)) => { if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) { - let port = Self::read_local_stream_port(&mut self.partial_command_map, stream); + let command = self + .partial_command_map + .entry(stream.as_raw_fd()) + .or_default(); + let port = Self::read_local_stream_port(command, stream); if let Err(Error::UnixRead(ref e)) = port && e.kind() == ErrorKind::WouldBlock @@ -443,6 +450,11 @@ impl VsockMuxer { return; } + // either we have `Ok(port)` or a fatal Error such as + // Error::InvalidPortRequest, either way we must remove + // the command from the map + self.partial_command_map.remove(&stream.as_raw_fd()); + let stream = match self.remove_listener(fd) { Some(EpollListener::LocalStream(s)) => s, _ => unreachable!(), @@ -480,55 +492,73 @@ impl VsockMuxer { } } + fn parse_port_from_read_command(command: &PartiallyReadCommand) -> Result { + // normally followed by the port and a `\n` + let connect_prefix: &str = "connect "; + + let opt_new_line_position = command.buf[..command.len].iter().position(|x| *x == b'\n'); + + // we need to read more to get a `connect ` statement + if command.len < connect_prefix.len() { + return match opt_new_line_position { + Some(_) => Err(Error::InvalidPortRequest), + None => Err(Error::UnixRead(std::io::ErrorKind::WouldBlock.into())), + }; + } + + // check for both upper and lower case connect statements + if !command.buf[..connect_prefix.len()].eq_ignore_ascii_case(connect_prefix.as_bytes()) { + return Err(Error::InvalidPortRequest); + } + + // we filled our buffer + if command.buf.len() == command.len && opt_new_line_position.is_none() { + return Err(Error::InvalidPortRequest); + } + + // we parsed correctly `connect ` but need to wait for `\n` + let new_line_position = + opt_new_line_position.ok_or(Error::UnixRead(std::io::ErrorKind::WouldBlock.into()))?; + + // we now have the newline, we will treat everything in between as the port + let port_string_as_bytes = &command.buf[connect_prefix.len()..new_line_position]; + + std::str::from_utf8(port_string_as_bytes) + .map_err(|_| Error::InvalidPortRequest)? + .trim() + .parse::() + .map_err(|_| Error::InvalidPortRequest) + } + /// Parse a host "connect" command, and extract the destination vsock port. /// fn read_local_stream_port( - partial_command_map: &mut HashMap, + command: &mut PartiallyReadCommand, stream: &mut UnixStream, ) -> Result { - let command = partial_command_map.entry(stream.as_raw_fd()).or_default(); + // the minimum connect statement that is still valid + let connect_min_statement: &str = "connect 0\n"; - // This is the minimum number of bytes that we should be able to read, when parsing a - // valid connection request. I.e. `b"connect 0\n".len()`. - const MIN_COMMAND_LEN: usize = 10; + // read the amount of bytes that are required for a valid connect + // with the minimum length (`connect_min_statement`). + // Then, continue with reading a single byte at a time, this is + // really inefficient but prevents us to read past the `\n` character + // which might swallow actual application data + // alternative: the bytes that might have been read beyond `\n` would need + // to be sent somehow via `MuxerConnection` prior to reading from `stream` again + // Another, currently unstable alternative: use UnixStream::peak to read the + // data without removing it from the queue. + // Issue: https://github.com/rust-lang/rust/issues/76923 + let read_bytes = stream + .read(&mut command.buf[command.len..max(connect_min_statement.len(), command.len + 1)]) + .map_err(Error::UnixRead)?; - // Bring in the minimum number of bytes that we should be able to read. - if command.len < MIN_COMMAND_LEN { - command.len += stream - .read(&mut command.buf[command.len..MIN_COMMAND_LEN]) - .map_err(Error::UnixRead)?; + if read_bytes == 0 { + return Err(Error::InvalidPortRequest); } - // Now, finish reading the destination port number, by bringing in one byte at a time, - // until we reach an EOL terminator (or our buffer space runs out). Yeah, not - // particularly proud of this approach, but it will have to do for now. - while command.len.checked_sub(1).map(|n| command.buf[n]) != Some(b'\n') - && command.len < command.buf.len() - { - command.len += stream - .read(&mut command.buf[command.len..=command.len]) - .map_err(Error::UnixRead)?; - } - - let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap(); - - let mut word_iter = std::str::from_utf8(&command.buf[..command.len]) - .map_err(Error::ConvertFromUtf8)? - .split_whitespace(); - - word_iter - .next() - .ok_or(Error::InvalidPortRequest) - .and_then(|word| { - if word.to_lowercase() == "connect" { - Ok(()) - } else { - Err(Error::InvalidPortRequest) - } - }) - .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest)) - .and_then(|word| word.parse::().map_err(Error::ParseInteger)) - .map_err(|e| Error::ReadStreamPort(Box::new(e))) + command.len += read_bytes; + Self::parse_port_from_read_command(command) } /// Add a new connection to the active connection pool. @@ -850,6 +880,7 @@ impl VsockMuxer { #[cfg(test)] mod tests { + use std::cmp::min; use std::io::Write; use std::path::{Path, PathBuf}; @@ -859,6 +890,18 @@ mod tests { use super::super::super::tests::TestContext as VsockTestContext; use super::*; + impl PartiallyReadCommand { + /// used to construct `PartiallyReadCommand` for tests + fn from_str(s: &str) -> Self { + let input_bytes = s.as_bytes(); + let mut command = PartiallyReadCommand::default(); + let len_to_copy = min(input_bytes.len(), PARTIALLY_READ_COMMAND_BUF_SIZE); + command.buf[..len_to_copy].copy_from_slice(&input_bytes[..len_to_copy]); + command.len = len_to_copy; + command + } + } + const PEER_CID: u32 = 3; const PEER_BUF_ALLOC: u32 = 64 * 1024; @@ -971,7 +1014,10 @@ mod tests { stream.write_all(buf.as_bytes()).unwrap(); // The muxer would now get notified that data is available for reading from the locally // initiated connection. - self.notify_muxer(); + // this needs to happen multiple times because the command may not be read at once + for _ in 0..buf.len() { + self.notify_muxer(); + } // Successfully reading and parsing the connection request should have removed the // LocalStream epoll listener and added a Connection epoll listener. @@ -1454,4 +1500,123 @@ mod tests { // not be any pending RX in the muxer. assert!(!ctx.muxer.has_pending_rx()); } + + #[test] + fn test_parse_command() { + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("\n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONN\n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOO ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOOFOOX ")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONNECT ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect \n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 1337" + )), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect -1337\n" + )), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 8589934592\n" + )), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 👾\n" + )), + Err(Error::InvalidPortRequest) + )); + let max_buf_length_no_newline = "CONNECT 1"; + assert_eq!( + max_buf_length_no_newline.len(), + PARTIALLY_READ_COMMAND_BUF_SIZE + ); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + max_buf_length_no_newline + )), + Err(Error::InvalidPortRequest) + )); + let max_buf_length_correct = "CONNECT 1\n"; + assert_eq!( + max_buf_length_correct.len(), + PARTIALLY_READ_COMMAND_BUF_SIZE + ); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + max_buf_length_correct + )), + Ok(1) + )); + + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 0\n" + )), + Ok(0) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337 \n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337 \n" + )), + Ok(1337) + )); + } }