virtio-devices: refactor VSOCK "connect" parsing

The function `read_local_stream_port` had no proper handling for
unexpected or incomplete input.
When the control socket of the VSOCK device was closed without sending
the expected `CONNECT <PORT>\n` statement completely, the thread
got stuck in an infinite loop as it attempted to read from a closed
socket over and over again which never returned any data.

This resulted in the thread responsible for `epoll` being completely
blocked. New VSOCK connections could not be established and existing
ones became defunct, effectively leading to a Denial of Service of
the entire VSOCK device.

The issue can be reproduced by opening a socket and immediately
closing it.

```
socat - UNIX-CONNECT:/socket.vsock
<Ctrl-C>
```

Instead of applying a quick fix by handling the `EPOLLHUP` event before
reading, the function is refactored to remove the error-prone `while`
loop and multiple `read`s.
Notably, we now check if the number of bytes read is zero, which occurs
when `event_set == EPOLLHUP | EPOLLIN`, indicating that the socket has
been closed by the client.

Additionally, the actual parsing code is now extracted into a dedicated
function that is tested.

Fixes: #6798
Signed-off-by: Maximilian Güntner <code@mguentner.de>
This commit is contained in:
Maximilian Güntner 2025-09-14 00:29:14 +02:00 committed by Bo Chen
parent 1e8996f94f
commit d28d9eb34e

View file

@ -38,6 +38,7 @@
//! To route all these events to their handlers, the muxer uses another `HashMap` object,
//! mapping `RawFd`s to `EpollListener`s.
use std::cmp::max;
use std::collections::{HashMap, HashSet};
use std::fs::File;
use std::io::{self, ErrorKind, Read};
@ -90,11 +91,13 @@ enum EpollListener {
LocalStream(UnixStream),
}
const PARTIALLY_READ_COMMAND_BUF_SIZE: usize = 32;
/// A partially read "CONNECT" command.
#[derive(Default)]
struct PartiallyReadCommand {
/// The bytes of the command that have been read so far.
buf: [u8; 32],
buf: [u8; PARTIALLY_READ_COMMAND_BUF_SIZE],
/// How much of `buf` has been used.
len: usize,
}
@ -435,7 +438,11 @@ impl VsockMuxer {
// "connect" command that we're expecting.
Some(EpollListener::LocalStream(_)) => {
if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) {
let port = Self::read_local_stream_port(&mut self.partial_command_map, stream);
let command = self
.partial_command_map
.entry(stream.as_raw_fd())
.or_default();
let port = Self::read_local_stream_port(command, stream);
if let Err(Error::UnixRead(ref e)) = port
&& e.kind() == ErrorKind::WouldBlock
@ -443,6 +450,11 @@ impl VsockMuxer {
return;
}
// either we have `Ok(port)` or a fatal Error such as
// Error::InvalidPortRequest, either way we must remove
// the command from the map
self.partial_command_map.remove(&stream.as_raw_fd());
let stream = match self.remove_listener(fd) {
Some(EpollListener::LocalStream(s)) => s,
_ => unreachable!(),
@ -480,55 +492,73 @@ impl VsockMuxer {
}
}
fn parse_port_from_read_command(command: &PartiallyReadCommand) -> Result<u32> {
// normally followed by the port and a `\n`
let connect_prefix: &str = "connect ";
let opt_new_line_position = command.buf[..command.len].iter().position(|x| *x == b'\n');
// we need to read more to get a `connect ` statement
if command.len < connect_prefix.len() {
return match opt_new_line_position {
Some(_) => Err(Error::InvalidPortRequest),
None => Err(Error::UnixRead(std::io::ErrorKind::WouldBlock.into())),
};
}
// check for both upper and lower case connect statements
if !command.buf[..connect_prefix.len()].eq_ignore_ascii_case(connect_prefix.as_bytes()) {
return Err(Error::InvalidPortRequest);
}
// we filled our buffer
if command.buf.len() == command.len && opt_new_line_position.is_none() {
return Err(Error::InvalidPortRequest);
}
// we parsed correctly `connect ` but need to wait for `\n`
let new_line_position =
opt_new_line_position.ok_or(Error::UnixRead(std::io::ErrorKind::WouldBlock.into()))?;
// we now have the newline, we will treat everything in between as the port
let port_string_as_bytes = &command.buf[connect_prefix.len()..new_line_position];
std::str::from_utf8(port_string_as_bytes)
.map_err(|_| Error::InvalidPortRequest)?
.trim()
.parse::<u32>()
.map_err(|_| Error::InvalidPortRequest)
}
/// Parse a host "connect" command, and extract the destination vsock port.
///
fn read_local_stream_port(
partial_command_map: &mut HashMap<RawFd, PartiallyReadCommand>,
command: &mut PartiallyReadCommand,
stream: &mut UnixStream,
) -> Result<u32> {
let command = partial_command_map.entry(stream.as_raw_fd()).or_default();
// the minimum connect statement that is still valid
let connect_min_statement: &str = "connect 0\n";
// This is the minimum number of bytes that we should be able to read, when parsing a
// valid connection request. I.e. `b"connect 0\n".len()`.
const MIN_COMMAND_LEN: usize = 10;
// read the amount of bytes that are required for a valid connect
// with the minimum length (`connect_min_statement`).
// Then, continue with reading a single byte at a time, this is
// really inefficient but prevents us to read past the `\n` character
// which might swallow actual application data
// alternative: the bytes that might have been read beyond `\n` would need
// to be sent somehow via `MuxerConnection` prior to reading from `stream` again
// Another, currently unstable alternative: use UnixStream::peak to read the
// data without removing it from the queue.
// Issue: https://github.com/rust-lang/rust/issues/76923
let read_bytes = stream
.read(&mut command.buf[command.len..max(connect_min_statement.len(), command.len + 1)])
.map_err(Error::UnixRead)?;
// Bring in the minimum number of bytes that we should be able to read.
if command.len < MIN_COMMAND_LEN {
command.len += stream
.read(&mut command.buf[command.len..MIN_COMMAND_LEN])
.map_err(Error::UnixRead)?;
if read_bytes == 0 {
return Err(Error::InvalidPortRequest);
}
// Now, finish reading the destination port number, by bringing in one byte at a time,
// until we reach an EOL terminator (or our buffer space runs out). Yeah, not
// particularly proud of this approach, but it will have to do for now.
while command.len.checked_sub(1).map(|n| command.buf[n]) != Some(b'\n')
&& command.len < command.buf.len()
{
command.len += stream
.read(&mut command.buf[command.len..=command.len])
.map_err(Error::UnixRead)?;
}
let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap();
let mut word_iter = std::str::from_utf8(&command.buf[..command.len])
.map_err(Error::ConvertFromUtf8)?
.split_whitespace();
word_iter
.next()
.ok_or(Error::InvalidPortRequest)
.and_then(|word| {
if word.to_lowercase() == "connect" {
Ok(())
} else {
Err(Error::InvalidPortRequest)
}
})
.and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
.and_then(|word| word.parse::<u32>().map_err(Error::ParseInteger))
.map_err(|e| Error::ReadStreamPort(Box::new(e)))
command.len += read_bytes;
Self::parse_port_from_read_command(command)
}
/// Add a new connection to the active connection pool.
@ -850,6 +880,7 @@ impl VsockMuxer {
#[cfg(test)]
mod tests {
use std::cmp::min;
use std::io::Write;
use std::path::{Path, PathBuf};
@ -859,6 +890,18 @@ mod tests {
use super::super::super::tests::TestContext as VsockTestContext;
use super::*;
impl PartiallyReadCommand {
/// used to construct `PartiallyReadCommand` for tests
fn from_str(s: &str) -> Self {
let input_bytes = s.as_bytes();
let mut command = PartiallyReadCommand::default();
let len_to_copy = min(input_bytes.len(), PARTIALLY_READ_COMMAND_BUF_SIZE);
command.buf[..len_to_copy].copy_from_slice(&input_bytes[..len_to_copy]);
command.len = len_to_copy;
command
}
}
const PEER_CID: u32 = 3;
const PEER_BUF_ALLOC: u32 = 64 * 1024;
@ -971,7 +1014,10 @@ mod tests {
stream.write_all(buf.as_bytes()).unwrap();
// The muxer would now get notified that data is available for reading from the locally
// initiated connection.
self.notify_muxer();
// this needs to happen multiple times because the command may not be read at once
for _ in 0..buf.len() {
self.notify_muxer();
}
// Successfully reading and parsing the connection request should have removed the
// LocalStream epoll listener and added a Connection epoll listener.
@ -1454,4 +1500,123 @@ mod tests {
// not be any pending RX in the muxer.
assert!(!ctx.muxer.has_pending_rx());
}
#[test]
fn test_parse_command() {
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("")),
Err(Error::UnixRead(_))
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("\n")),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONN\n")),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOO ")),
Err(Error::UnixRead(_))
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("FOOFOOX ")),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("CONNECT ")),
Err(Error::UnixRead(_))
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect ")),
Err(Error::UnixRead(_))
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str("connect \n")),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"connect 1337"
)),
Err(Error::UnixRead(_))
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"connect -1337\n"
)),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"connect 8589934592\n"
)),
Err(Error::InvalidPortRequest)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"CONNECT 👾\n"
)),
Err(Error::InvalidPortRequest)
));
let max_buf_length_no_newline = "CONNECT 1";
assert_eq!(
max_buf_length_no_newline.len(),
PARTIALLY_READ_COMMAND_BUF_SIZE
);
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
max_buf_length_no_newline
)),
Err(Error::InvalidPortRequest)
));
let max_buf_length_correct = "CONNECT 1\n";
assert_eq!(
max_buf_length_correct.len(),
PARTIALLY_READ_COMMAND_BUF_SIZE
);
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
max_buf_length_correct
)),
Ok(1)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"connect 0\n"
)),
Ok(0)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"connect 1337\n"
)),
Ok(1337)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"CONNECT 1337\n"
)),
Ok(1337)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"CONNECT 1337\n"
)),
Ok(1337)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"CONNECT 1337 \n"
)),
Ok(1337)
));
assert!(matches!(
VsockMuxer::parse_port_from_read_command(&PartiallyReadCommand::from_str(
"CONNECT 1337 \n"
)),
Ok(1337)
));
}
}