Add support for modifying devicelist during execution

This commit is contained in:
h7x4 2023-08-23 15:51:48 +02:00
parent 8506ede125
commit 6ed02934e3
No known key found for this signature in database
GPG key ID: 9F2F7D8250F35146
6 changed files with 353 additions and 92 deletions

View file

@ -10,7 +10,7 @@ description = "A library to run USB/IP server"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
tokio = { version = "1.22.0", features = ["rt", "net", "io-util"] }
tokio = { version = "1.22.0", features = ["rt", "net", "io-util", "sync"] }
log = "0.4.17"
num-traits = "0.2.15"
num-derive = "0.3.3"

View file

@ -11,14 +11,16 @@ async fn main() {
let handler =
Arc::new(Mutex::new(Box::new(usbip::cdc::UsbCdcAcmHandler::new())
as Box<dyn usbip::UsbInterfaceHandler + Send>));
let server = usbip::UsbIpServer::new_simulated(vec![usbip::UsbDevice::new(0).with_interface(
usbip::ClassCode::CDC as u8,
usbip::cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
usbip::cdc::UsbCdcAcmHandler::endpoints(),
handler.clone(),
)]);
let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![
usbip::UsbDevice::new(0).with_interface(
usbip::ClassCode::CDC as u8,
usbip::cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
usbip::cdc::UsbCdcAcmHandler::endpoints(),
handler.clone(),
),
]));
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240);
tokio::spawn(usbip::server(addr, server));

View file

@ -12,19 +12,21 @@ async fn main() {
Box::new(usbip::hid::UsbHidKeyboardHandler::new_keyboard())
as Box<dyn usbip::UsbInterfaceHandler + Send>,
));
let server = usbip::UsbIpServer::new_simulated(vec![usbip::UsbDevice::new(0).with_interface(
usbip::ClassCode::HID as u8,
0x00,
0x00,
"Test HID",
vec![usbip::UsbEndpoint {
address: 0x81, // IN
attributes: 0x03, // Interrupt
max_packet_size: 0x08, // 8 bytes
interval: 10,
}],
handler.clone(),
)]);
let server = Arc::new(usbip::UsbIpServer::new_simulated(vec![
usbip::UsbDevice::new(0).with_interface(
usbip::ClassCode::HID as u8,
0x00,
0x00,
"Test HID",
vec![usbip::UsbEndpoint {
address: 0x81, // IN
attributes: 0x03, // Interrupt
max_packet_size: 0x08, // 8 bytes
interval: 10,
}],
handler.clone(),
),
]));
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240);
tokio::spawn(usbip::server(addr, server));

View file

@ -1,12 +1,13 @@
use env_logger;
use std::net::*;
use std::sync::Arc;
use std::time::Duration;
use usbip;
#[tokio::main]
async fn main() {
env_logger::init();
let server = usbip::UsbIpServer::new_from_host();
let server = Arc::new(usbip::UsbIpServer::new_from_host());
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 3240);
tokio::spawn(usbip::server(addr, server));

View file

@ -12,6 +12,7 @@ use std::sync::{Arc, Mutex};
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::net::TcpListener;
use tokio::sync::RwLock;
pub mod cdc;
mod consts;
@ -31,14 +32,19 @@ pub use setup::*;
pub use util::*;
/// Main struct of a USB/IP server
#[derive(Default)]
pub struct UsbIpServer {
devices: Vec<UsbDevice>,
available_devices: RwLock<Vec<UsbDevice>>,
used_devices: RwLock<HashMap<String, UsbDevice>>,
}
impl UsbIpServer {
/// Create a [UsbIpServer] with simulated devices
pub fn new_simulated(devices: Vec<UsbDevice>) -> Self {
Self { devices }
Self {
available_devices: RwLock::new(devices),
used_devices: RwLock::new(HashMap::new()),
}
}
fn with_devices(device_list: Vec<Device<GlobalContext>>) -> Vec<UsbDevice> {
@ -48,7 +54,7 @@ impl UsbIpServer {
let open_device = match dev.open() {
Ok(dev) => dev,
Err(err) => {
println!("Impossible to share {:?}: {}", dev, err);
warn!("Impossible to share {:?}: {}", dev, err);
continue;
}
};
@ -180,10 +186,11 @@ impl UsbIpServer {
devs.push(d)
}
Self {
devices: Self::with_devices(devs),
available_devices: RwLock::new(Self::with_devices(devs)),
..Default::default()
}
}
Err(_) => Self { devices: vec![] },
Err(_) => Default::default(),
}
}
@ -198,22 +205,61 @@ impl UsbIpServer {
devs.push(d)
}
Self {
devices: Self::with_devices(devs),
available_devices: RwLock::new(Self::with_devices(devs)),
..Default::default()
}
}
Err(_) => Self { devices: vec![] },
Err(_) => Default::default(),
}
}
pub async fn add_device(&self, device: UsbDevice) {
self.available_devices.write().await.push(device);
}
pub async fn remove_device(&self, bus_id: &str) -> Result<()> {
let mut available_devices = self.available_devices.write().await;
if let Some(device) = available_devices.iter().position(|d| d.bus_id == bus_id) {
available_devices.remove(device);
Ok(())
} else if let Some(device) = self
.used_devices
.read()
.await
.values()
.find(|d| d.bus_id == bus_id)
{
Err(std::io::Error::new(
ErrorKind::Other,
format!("Device {} is in use", device.bus_id),
))
} else {
Err(std::io::Error::new(
ErrorKind::NotFound,
format!("Device {} not found", bus_id),
))
}
}
}
async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
pub async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
mut socket: &mut T,
server: Arc<UsbIpServer>,
) -> Result<()> {
let mut current_import_device = None;
let mut current_import_device_id: Option<String> = None;
loop {
let mut command = [0u8; 4];
if let Err(err) = socket.read_exact(&mut command).await {
if let Some(dev_id) = current_import_device_id {
let mut used_devices = server.used_devices.write().await;
let mut available_devices = server.available_devices.write().await;
match used_devices.remove(&dev_id) {
Some(dev) => available_devices.push(dev),
None => unreachable!(),
}
}
if err.kind() == ErrorKind::UnexpectedEof {
info!("Remote closed the connection");
return Ok(());
@ -221,16 +267,24 @@ async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
return Err(err);
}
}
let used_devices = server.used_devices.read().await;
let mut current_import_device = current_import_device_id
.clone()
.and_then(|ref id| used_devices.get(id));
match command {
[0x01, 0x11, 0x80, 0x05] => {
trace!("Got OP_REQ_DEVLIST");
let _status = socket.read_u32().await?;
let devices = server.available_devices.read().await;
// OP_REP_DEVLIST
socket.write_u32(0x01110005).await?;
socket.write_u32(0).await?;
socket.write_u32(server.devices.len() as u32).await?;
for dev in &server.devices {
socket.write_u32(devices.len() as u32).await?;
for dev in devices.iter() {
dev.write_dev_with_interfaces(&mut socket).await?;
}
trace!("Sent OP_REP_DEVLIST");
@ -240,13 +294,22 @@ async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
let _status = socket.read_u32().await?;
let mut bus_id = [0u8; 32];
socket.read_exact(&mut bus_id).await?;
current_import_device_id = None;
current_import_device = None;
for device in &server.devices {
let mut expected = device.bus_id.as_bytes().to_vec();
std::mem::drop(used_devices);
let mut used_devices = server.used_devices.write().await;
let mut available_devices = server.available_devices.write().await;
for (i, dev) in available_devices.iter().enumerate() {
let mut expected = dev.bus_id.as_bytes().to_vec();
expected.resize(32, 0);
if expected == bus_id {
current_import_device = Some(device);
info!("Found device {:?}", device.path);
let dev = available_devices.remove(i);
let dev_id = dev.bus_id.clone();
used_devices.insert(dev.bus_id.clone(), dev);
current_import_device_id = dev_id.clone().into();
current_import_device = Some(used_devices.get(&dev_id).unwrap());
break;
}
}
@ -347,6 +410,22 @@ async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
let mut padding = [0u8; 6 * 4];
socket.read_exact(&mut padding).await?;
std::mem::drop(used_devices);
let mut used_devices = server.used_devices.write().await;
let mut available_devices = server.available_devices.write().await;
let dev = match current_import_device_id
.clone()
.and_then(|ref k| used_devices.remove(k))
{
Some(dev) => dev,
None => unreachable!(),
};
available_devices.push(dev);
current_import_device_id = None;
// USBIP_RET_UNLINK
// command
socket.write_u32(0x4).await?;
@ -356,7 +435,7 @@ async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
socket.write_u32(0).await?;
// status
socket.write_u32(0).await?;
socket.write_all(&mut padding).await?;
socket.write_all(&padding).await?;
}
_ => warn!("Got unknown command {:?}", command),
}
@ -364,16 +443,15 @@ async fn handler<T: AsyncReadExt + AsyncWriteExt + Unpin>(
}
/// Spawn a USB/IP server at `addr` using [TcpListener]
pub async fn server(addr: SocketAddr, server: UsbIpServer) {
pub async fn server(addr: SocketAddr, server: Arc<UsbIpServer>) {
let listener = TcpListener::bind(addr).await.expect("bind to addr");
let server = async move {
let usbip_server = Arc::new(server);
loop {
match listener.accept().await {
Ok((mut socket, _addr)) => {
info!("Got connection from {:?}", socket.peer_addr());
let new_server = usbip_server.clone();
let new_server = server.clone();
tokio::spawn(async move {
let res = handler(&mut socket, new_server).await;
info!("Handler ended with {:?}", res);
@ -391,12 +469,46 @@ pub async fn server(addr: SocketAddr, server: UsbIpServer) {
#[cfg(test)]
mod test {
use tokio::{net::TcpStream, task::JoinSet};
use super::*;
use crate::util::tests::*;
fn new_server_with_single_device() -> UsbIpServer {
UsbIpServer::new_simulated(vec![UsbDevice::new(0).with_interface(
ClassCode::CDC as u8,
cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
cdc::UsbCdcAcmHandler::endpoints(),
Arc::new(Mutex::new(
Box::new(cdc::UsbCdcAcmHandler::new()) as Box<dyn UsbInterfaceHandler + Send>
)),
)])
}
fn op_req_import(bus_id: u32) -> Vec<u8> {
let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00];
let mut path = bus_id.to_string().as_bytes().to_vec();
path.resize(32, 0);
req.extend(path);
req
}
async fn attach_device(connection: &mut TcpStream, bus_id: u32) -> u32 {
let req = op_req_import(bus_id);
connection.write_all(req.as_slice()).await.unwrap();
connection.read_u32().await.unwrap();
let result = connection.read_u32().await.unwrap();
if result == 0 {
connection.read_exact(&mut vec![0; 0x138]).await.unwrap();
}
return result;
}
#[tokio::test]
async fn req_empty_devlist() {
let server = UsbIpServer { devices: vec![] };
let server = UsbIpServer::new_simulated(vec![]);
// OP_REQ_DEVLIST
let mut mock_socket = MockSocket::new(vec![0x01, 0x11, 0x80, 0x05, 0x00, 0x00, 0x00, 0x00]);
@ -410,20 +522,7 @@ mod test {
#[tokio::test]
async fn req_sample_devlist() {
let intf_handler = Arc::new(Mutex::new(
Box::new(cdc::UsbCdcAcmHandler::new()) as Box<dyn UsbInterfaceHandler + Send>
));
let server = UsbIpServer {
devices: vec![UsbDevice::new(0).with_interface(
ClassCode::CDC as u8,
cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
cdc::UsbCdcAcmHandler::endpoints(),
intf_handler.clone(),
)],
};
let server = new_server_with_single_device();
// OP_REQ_DEVLIST
let mut mock_socket = MockSocket::new(vec![0x01, 0x11, 0x80, 0x05, 0x00, 0x00, 0x00, 0x00]);
handler(&mut mock_socket, Arc::new(server)).await.ok();
@ -436,52 +535,192 @@ mod test {
#[tokio::test]
async fn req_import() {
let intf_handler = Arc::new(Mutex::new(
Box::new(cdc::UsbCdcAcmHandler::new()) as Box<dyn UsbInterfaceHandler + Send>
));
let server = UsbIpServer {
devices: vec![UsbDevice::new(0).with_interface(
ClassCode::CDC as u8,
cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
cdc::UsbCdcAcmHandler::endpoints(),
intf_handler.clone(),
)],
};
let server = new_server_with_single_device();
// OP_REQ_IMPORT
let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00];
let mut path = "0".as_bytes().to_vec();
path.resize(32, 0);
req.extend(path);
let req = op_req_import(0);
let mut mock_socket = MockSocket::new(req);
handler(&mut mock_socket, Arc::new(server)).await.ok();
// OP_REQ_IMPORT
assert_eq!(mock_socket.output.len(), 0x140);
}
#[tokio::test]
async fn add_and_remove_10_devices() {
let server_ = Arc::new(UsbIpServer::new_simulated(vec![]));
let addr = get_free_address().await;
tokio::spawn(server(addr, server_.clone()));
let mut join_set = JoinSet::new();
let devices = (0..10).map(UsbDevice::new).collect::<Vec<_>>();
for device in devices.iter() {
let new_server = server_.clone();
let new_device = device.clone();
join_set.spawn(async move {
new_server.add_device(new_device).await;
});
}
for device in devices.iter() {
let new_server = server_.clone();
let new_device = device.clone();
join_set.spawn(async move {
new_server.remove_device(&new_device.bus_id).await.unwrap();
});
}
while join_set.join_next().await.is_some() {}
let device_len = server_.clone().available_devices.read().await.len();
assert_eq!(device_len, 0);
}
#[tokio::test]
async fn send_usb_traffic_while_adding_and_removing_devices() {
let server_ = Arc::new(new_server_with_single_device());
let addr = get_free_address().await;
tokio::spawn(server(addr, server_.clone()));
let cmd_loop_handle = tokio::spawn(async move {
let mut connection = poll_connect(addr).await;
let result = attach_device(&mut connection, 0).await;
assert_eq!(result, 0);
let cdc_loopback_bulk_cmd = vec![
0x00, 0x00, 0x00, 0x01, // command
0x00, 0x00, 0x00, 0x01, // seq num
0x00, 0x00, 0x00, 0x00, // dev id
0x00, 0x00, 0x00, 0x00, // OUT
0x00, 0x00, 0x00, 0x02, // ep 2
0x00, 0x00, 0x00, 0x00, // transfer flags
0x00, 0x00, 0x00, 0x08, // transfer buffer length 8
0x00, 0x00, 0x00, 0x00, // start frame
0x00, 0x00, 0x00, 0x00, // number of packets
0x00, 0x00, 0x00, 0x00, // interval
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Empty setup packet
0x01, 0x02, 0x03, 0x04, // data
0x05, 0x06, 0x07, 0x08, // data
];
loop {
connection
.write_all(cdc_loopback_bulk_cmd.as_slice())
.await
.unwrap();
let mut result = vec![0; 4 * 12];
connection.read_exact(&mut result).await.unwrap();
}
});
let add_and_remove_device_handle = tokio::spawn(async move {
let mut join_set = JoinSet::new();
let devices = (1..4).map(UsbDevice::new).collect::<Vec<_>>();
loop {
for device in devices.iter() {
let new_server = server_.clone();
let new_device = device.clone();
join_set.spawn(async move {
new_server.add_device(new_device).await;
});
}
for device in devices.iter() {
let new_server = server_.clone();
let new_device = device.clone();
join_set.spawn(async move {
new_server.remove_device(&new_device.bus_id).await.unwrap();
});
}
while join_set.join_next().await.is_some() {}
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
}
});
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
cmd_loop_handle.abort();
add_and_remove_device_handle.abort();
}
#[tokio::test]
async fn only_single_connection_allowed_to_device() {
let server_ = Arc::new(new_server_with_single_device());
let addr = get_free_address().await;
tokio::spawn(server(addr, server_.clone()));
let mut first_connection = poll_connect(addr).await;
let mut second_connection = TcpStream::connect(addr).await.unwrap();
let result = attach_device(&mut first_connection, 0).await;
assert_eq!(result, 0);
let result = attach_device(&mut second_connection, 0).await;
assert_eq!(result, 1);
}
#[tokio::test]
async fn device_gets_released_on_cmd_unlink() {
let server_ = Arc::new(new_server_with_single_device());
let addr = get_free_address().await;
tokio::spawn(server(addr, server_.clone()));
let mut connection = poll_connect(addr).await;
let result = attach_device(&mut connection, 0).await;
assert_eq!(result, 0);
let unlink_req = vec![
0x00, 0x00, 0x00, 0x02, // cmd
0x00, 0x00, 0x00, 0x01, // seq_num
0x00, 0x00, 0x00, 0x00, // dev_id
0x00, 0x00, 0x00, 0x00, // direction
0x00, 0x00, 0x00, 0x00, // ep
0x00, 0x00, 0x00, 0x00, // seq_num_submit
0x00, 0x00, 0x00, 0x00, // padding
0x00, 0x00, 0x00, 0x00, // padding
0x00, 0x00, 0x00, 0x00, // padding
0x00, 0x00, 0x00, 0x00, // padding
0x00, 0x00, 0x00, 0x00, // padding
0x00, 0x00, 0x00, 0x00, // padding
];
connection.write_all(unlink_req.as_slice()).await.unwrap();
connection.read_exact(&mut vec![0; 4 * 5]).await.unwrap();
let result = connection.read_u32().await.unwrap();
connection.read_exact(&mut vec![0; 4 * 6]).await.unwrap();
assert_eq!(result, 0);
let result = attach_device(&mut connection, 0).await;
assert_eq!(result, 0);
}
#[tokio::test]
async fn device_gets_released_on_closed_socket() {
let server_ = Arc::new(new_server_with_single_device());
let addr = get_free_address().await;
tokio::spawn(server(addr, server_.clone()));
let mut connection = poll_connect(addr).await;
let result = attach_device(&mut connection, 0).await;
assert_eq!(result, 0);
std::mem::drop(connection);
let mut connection = TcpStream::connect(addr).await.unwrap();
let result = attach_device(&mut connection, 0).await;
assert_eq!(result, 0);
}
#[tokio::test]
async fn req_import_get_device_desc() {
let intf_handler = Arc::new(Mutex::new(
Box::new(cdc::UsbCdcAcmHandler::new()) as Box<dyn UsbInterfaceHandler + Send>
));
let server = UsbIpServer {
devices: vec![UsbDevice::new(0).with_interface(
ClassCode::CDC as u8,
cdc::CDC_ACM_SUBCLASS,
0x00,
"Test CDC ACM",
cdc::UsbCdcAcmHandler::endpoints(),
intf_handler.clone(),
)],
};
let server = new_server_with_single_device();
// OP_REQ_IMPORT
let mut req = vec![0x01, 0x11, 0x80, 0x03, 0x00, 0x00, 0x00, 0x00];
let mut path = "0".as_bytes().to_vec();
path.resize(32, 0);
req.extend(path);
let mut req = op_req_import(0);
// USBIP_CMD_SUBMIT
req.extend(vec![
0x00, 0x00, 0x00, 0x01, // command

View file

@ -24,10 +24,14 @@ pub fn verify_descriptor(desc: &[u8]) {
pub(crate) mod tests {
use std::{
io::*,
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::{TcpListener, TcpStream},
};
pub(crate) struct MockSocket {
pub input: Cursor<Vec<u8>>,
@ -73,4 +77,17 @@ pub(crate) mod tests {
Poll::Ready(Ok(()))
}
}
pub(crate) async fn get_free_address() -> SocketAddr {
let stream = TcpListener::bind("127.0.0.1:0").await.unwrap();
stream.local_addr().unwrap()
}
pub(crate) async fn poll_connect(addr: SocketAddr) -> TcpStream {
loop {
if let Ok(stream) = TcpStream::connect(addr).await {
return stream;
}
}
}
}