From d28d9eb34e071d30bddf4b9bba6b8772a70df87d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20G=C3=BCntner?= Date: Sun, 14 Sep 2025 00:29:14 +0200 Subject: [PATCH] virtio-devices: refactor VSOCK "connect" parsing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The function `read_local_stream_port` had no proper handling for unexpected or incomplete input. When the control socket of the VSOCK device was closed without sending the expected `CONNECT \n` statement completely, the thread got stuck in an infinite loop as it attempted to read from a closed socket over and over again which never returned any data. This resulted in the thread responsible for `epoll` being completely blocked. New VSOCK connections could not be established and existing ones became defunct, effectively leading to a Denial of Service of the entire VSOCK device. The issue can be reproduced by opening a socket and immediately closing it. ``` socat - UNIX-CONNECT:/socket.vsock ``` Instead of applying a quick fix by handling the `EPOLLHUP` event before reading, the function is refactored to remove the error-prone `while` loop and multiple `read`s. Notably, we now check if the number of bytes read is zero, which occurs when `event_set == EPOLLHUP | EPOLLIN`, indicating that the socket has been closed by the client. Additionally, the actual parsing code is now extracted into a dedicated function that is tested. Fixes: #6798 Signed-off-by: Maximilian Güntner --- virtio-devices/src/vsock/unix/muxer.rs | 251 ++++++++++++++++++++----- 1 file changed, 208 insertions(+), 43 deletions(-) diff --git a/virtio-devices/src/vsock/unix/muxer.rs b/virtio-devices/src/vsock/unix/muxer.rs index 842e02677..5ea8efb9b 100644 --- a/virtio-devices/src/vsock/unix/muxer.rs +++ b/virtio-devices/src/vsock/unix/muxer.rs @@ -38,6 +38,7 @@ //! To route all these events to their handlers, the muxer uses another `HashMap` object, //! mapping `RawFd`s to `EpollListener`s. +use std::cmp::max; use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::{self, ErrorKind, Read}; @@ -90,11 +91,13 @@ enum EpollListener { LocalStream(UnixStream), } +const PARTIALLY_READ_COMMAND_BUF_SIZE: usize = 32; + /// A partially read "CONNECT" command. #[derive(Default)] struct PartiallyReadCommand { /// The bytes of the command that have been read so far. - buf: [u8; 32], + buf: [u8; PARTIALLY_READ_COMMAND_BUF_SIZE], /// How much of `buf` has been used. len: usize, } @@ -435,7 +438,11 @@ impl VsockMuxer { // "connect" command that we're expecting. Some(EpollListener::LocalStream(_)) => { if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) { - let port = Self::read_local_stream_port(&mut self.partial_command_map, stream); + let command = self + .partial_command_map + .entry(stream.as_raw_fd()) + .or_default(); + let port = Self::read_local_stream_port(command, stream); if let Err(Error::UnixRead(ref e)) = port && e.kind() == ErrorKind::WouldBlock @@ -443,6 +450,11 @@ impl VsockMuxer { return; } + // either we have `Ok(port)` or a fatal Error such as + // Error::InvalidPortRequest, either way we must remove + // the command from the map + self.partial_command_map.remove(&stream.as_raw_fd()); + let stream = match self.remove_listener(fd) { Some(EpollListener::LocalStream(s)) => s, _ => unreachable!(), @@ -480,55 +492,73 @@ impl VsockMuxer { } } + fn parse_port_from_read_command(command: &PartiallyReadCommand) -> Result { + // normally followed by the port and a `\n` + let connect_prefix: &str = "connect "; + + let opt_new_line_position = command.buf[..command.len].iter().position(|x| *x == b'\n'); + + // we need to read more to get a `connect ` statement + if command.len < connect_prefix.len() { + return match opt_new_line_position { + Some(_) => Err(Error::InvalidPortRequest), + None => Err(Error::UnixRead(std::io::ErrorKind::WouldBlock.into())), + }; + } + + // check for both upper and lower case connect statements + if !command.buf[..connect_prefix.len()].eq_ignore_ascii_case(connect_prefix.as_bytes()) { + return Err(Error::InvalidPortRequest); + } + + // we filled our buffer + if command.buf.len() == command.len && opt_new_line_position.is_none() { + return Err(Error::InvalidPortRequest); + } + + // we parsed correctly `connect ` but need to wait for `\n` + let new_line_position = + opt_new_line_position.ok_or(Error::UnixRead(std::io::ErrorKind::WouldBlock.into()))?; + + // we now have the newline, we will treat everything in between as the port + let port_string_as_bytes = &command.buf[connect_prefix.len()..new_line_position]; + + std::str::from_utf8(port_string_as_bytes) + .map_err(|_| Error::InvalidPortRequest)? + .trim() + .parse::() + .map_err(|_| Error::InvalidPortRequest) + } + /// Parse a host "connect" command, and extract the destination vsock port. /// fn read_local_stream_port( - partial_command_map: &mut HashMap, + command: &mut PartiallyReadCommand, stream: &mut UnixStream, ) -> Result { - let command = partial_command_map.entry(stream.as_raw_fd()).or_default(); + // the minimum connect statement that is still valid + let connect_min_statement: &str = "connect 0\n"; - // This is the minimum number of bytes that we should be able to read, when parsing a - // valid connection request. I.e. `b"connect 0\n".len()`. - const MIN_COMMAND_LEN: usize = 10; + // read the amount of bytes that are required for a valid connect + // with the minimum length (`connect_min_statement`). + // Then, continue with reading a single byte at a time, this is + // really inefficient but prevents us to read past the `\n` character + // which might swallow actual application data + // alternative: the bytes that might have been read beyond `\n` would need + // to be sent somehow via `MuxerConnection` prior to reading from `stream` again + // Another, currently unstable alternative: use UnixStream::peak to read the + // data without removing it from the queue. + // Issue: https://github.com/rust-lang/rust/issues/76923 + let read_bytes = stream + .read(&mut command.buf[command.len..max(connect_min_statement.len(), command.len + 1)]) + .map_err(Error::UnixRead)?; - // Bring in the minimum number of bytes that we should be able to read. - if command.len < MIN_COMMAND_LEN { - command.len += stream - .read(&mut command.buf[command.len..MIN_COMMAND_LEN]) - .map_err(Error::UnixRead)?; + if read_bytes == 0 { + return Err(Error::InvalidPortRequest); } - // Now, finish reading the destination port number, by bringing in one byte at a time, - // until we reach an EOL terminator (or our buffer space runs out). Yeah, not - // particularly proud of this approach, but it will have to do for now. - while command.len.checked_sub(1).map(|n| command.buf[n]) != Some(b'\n') - && command.len < command.buf.len() - { - command.len += stream - .read(&mut command.buf[command.len..=command.len]) - .map_err(Error::UnixRead)?; - } - - let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap(); - - let mut word_iter = std::str::from_utf8(&command.buf[..command.len]) - .map_err(Error::ConvertFromUtf8)? - .split_whitespace(); - - word_iter - .next() - .ok_or(Error::InvalidPortRequest) - .and_then(|word| { - if word.to_lowercase() == "connect" { - Ok(()) - } else { - Err(Error::InvalidPortRequest) - } - }) - .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest)) - .and_then(|word| word.parse::().map_err(Error::ParseInteger)) - .map_err(|e| Error::ReadStreamPort(Box::new(e))) + command.len += read_bytes; + Self::parse_port_from_read_command(command) } /// Add a new connection to the active connection pool. @@ -850,6 +880,7 @@ impl VsockMuxer { #[cfg(test)] mod tests { + use std::cmp::min; use std::io::Write; use std::path::{Path, PathBuf}; @@ -859,6 +890,18 @@ mod tests { use super::super::super::tests::TestContext as VsockTestContext; use super::*; + impl PartiallyReadCommand { + /// used to construct `PartiallyReadCommand` for tests + fn from_str(s: &str) -> Self { + let input_bytes = s.as_bytes(); + let mut command = PartiallyReadCommand::default(); + let len_to_copy = min(input_bytes.len(), PARTIALLY_READ_COMMAND_BUF_SIZE); + command.buf[..len_to_copy].copy_from_slice(&input_bytes[..len_to_copy]); + command.len = len_to_copy; + command + } + } + const PEER_CID: u32 = 3; const PEER_BUF_ALLOC: u32 = 64 * 1024; @@ -971,7 +1014,10 @@ mod tests { stream.write_all(buf.as_bytes()).unwrap(); // The muxer would now get notified that data is available for reading from the locally // initiated connection. - self.notify_muxer(); + // this needs to happen multiple times because the command may not be read at once + for _ in 0..buf.len() { + self.notify_muxer(); + } // Successfully reading and parsing the connection request should have removed the // LocalStream epoll listener and added a Connection epoll listener. @@ -1454,4 +1500,123 @@ mod tests { // not be any pending RX in the muxer. assert!(!ctx.muxer.has_pending_rx()); } + + #[test] + fn test_parse_command() { + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("\n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONN\n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOO ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOOFOOX ")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONNECT ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect ")), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect \n")), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 1337" + )), + Err(Error::UnixRead(_)) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect -1337\n" + )), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 8589934592\n" + )), + Err(Error::InvalidPortRequest) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 👾\n" + )), + Err(Error::InvalidPortRequest) + )); + let max_buf_length_no_newline = "CONNECT 1"; + assert_eq!( + max_buf_length_no_newline.len(), + PARTIALLY_READ_COMMAND_BUF_SIZE + ); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + max_buf_length_no_newline + )), + Err(Error::InvalidPortRequest) + )); + let max_buf_length_correct = "CONNECT 1\n"; + assert_eq!( + max_buf_length_correct.len(), + PARTIALLY_READ_COMMAND_BUF_SIZE + ); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + max_buf_length_correct + )), + Ok(1) + )); + + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 0\n" + )), + Ok(0) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "connect 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337\n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337 \n" + )), + Ok(1337) + )); + assert!(matches!( + VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str( + "CONNECT 1337 \n" + )), + Ok(1337) + )); + } }