usbip-rs/lib/src/device.rs
Davíð Steinn Geirsson 18a413870a feat: add UAC1 loopback test device and fix endpoint attribute dispatch
Add a simulated USB Audio Class 1 loopback device for testing
isochronous transfers. Audio sent to the playback OUT endpoint
(48kHz/16-bit/stereo) is looped back to the capture IN endpoint.

- Add UsbEndpoint::transfer_type() masking bmAttributes to bits 0-1,
  fixing dispatch for isochronous endpoints with sync-type sub-bits
- Update all endpoint attribute dispatch sites across the library
- Add UacLoopbackBuffer, UacControlHandler, UacStreamOutHandler,
  UacStreamInHandler in lib/src/uac.rs
- Add build_uac_loopback_device() builder function
- Add `test_uac connect` CLI subcommand
- Add 10 unit tests covering buffer, descriptors, and handler behavior
- Add design spec and implementation plan docs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 01:43:31 +00:00

920 lines
41 KiB
Rust

use super::*;
use rusb::Version as rusbVersion;
#[derive(Clone, Default, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct Version {
pub major: u8,
pub minor: u8,
pub patch: u8,
}
impl From<rusbVersion> for Version {
fn from(value: rusbVersion) -> Self {
Self {
major: value.major(),
minor: value.minor(),
patch: value.sub_minor(),
}
}
}
impl From<Version> for rusbVersion {
fn from(val: Version) -> Self {
rusbVersion(val.major, val.minor, val.patch)
}
}
/// bcdDevice
impl From<u16> for Version {
fn from(value: u16) -> Self {
Self {
major: (value >> 8) as u8,
minor: ((value >> 4) & 0xF) as u8,
patch: (value & 0xF) as u8,
}
}
}
impl Version {
/// Reconstruct the 2-byte BCD value (e.g. 0x0200 for USB 2.0, 0x0110 for USB 1.1).
pub fn to_bcd_be(&self) -> [u8; 2] {
[self.major, (self.minor << 4) | self.patch]
}
}
/// Represent a USB device
#[derive(Clone, Default, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize))]
pub struct UsbDevice {
pub path: String,
pub bus_id: String,
pub bus_num: u32,
pub dev_num: u32,
pub speed: u32,
pub vendor_id: u16,
pub product_id: u16,
pub device_bcd: Version,
pub device_class: u8,
pub device_subclass: u8,
pub device_protocol: u8,
pub configuration_value: u8,
pub num_configurations: u8,
#[cfg_attr(feature = "serde", serde(skip))]
pub interface_states: Vec<InterfaceState>,
#[cfg_attr(feature = "serde", serde(skip))]
pub device_handler: Option<Arc<Mutex<Box<dyn UsbDeviceHandler + Send>>>>,
pub usb_version: Version,
pub ep0_in: UsbEndpoint,
pub ep0_out: UsbEndpoint,
// strings
pub string_pool: HashMap<u8, String>,
pub string_configuration: u8,
pub string_manufacturer: u8,
pub string_product: u8,
pub string_serial: u8,
}
impl UsbDevice {
pub fn new(index: u32) -> std::io::Result<Self> {
let mut res = Self {
path: "/sys/bus/0/0/0".to_string(),
bus_id: "0-0-0".to_string(),
dev_num: index,
speed: UsbSpeed::High as u32,
ep0_in: UsbEndpoint {
address: 0x80,
attributes: EndpointAttributes::Control as u8,
max_packet_size: EP0_MAX_PACKET_SIZE,
interval: 0,
},
ep0_out: UsbEndpoint {
address: 0x00,
attributes: EndpointAttributes::Control as u8,
max_packet_size: EP0_MAX_PACKET_SIZE,
interval: 0,
},
// configured by default
configuration_value: 1,
num_configurations: 1,
..Self::default()
};
res.string_configuration = res.new_string("Default Configuration")?;
res.string_manufacturer = res.new_string("Manufacturer")?;
res.string_product = res.new_string("Product")?;
res.string_serial = res.new_string("Serial")?;
Ok(res)
}
/// Returns the old value, if present.
pub fn set_configuration_name(&mut self, name: &str) -> std::io::Result<Option<String>> {
let old = (self.string_configuration != 0)
.then(|| self.string_pool.remove(&self.string_configuration))
.flatten();
self.string_configuration = self.new_string(name)?;
Ok(old)
}
/// Unset configuration name and returns the old value, if present.
pub fn unset_configuration_name(&mut self) -> Option<String> {
let old = (self.string_configuration != 0)
.then(|| self.string_pool.remove(&self.string_configuration))
.flatten();
self.string_configuration = 0;
old
}
/// Returns the old value, if present.
pub fn set_serial_number(&mut self, name: &str) -> std::io::Result<Option<String>> {
let old = (self.string_serial != 0)
.then(|| self.string_pool.remove(&self.string_serial))
.flatten();
self.string_serial = self.new_string(name)?;
Ok(old)
}
/// Unset serial number and returns the old value, if present.
pub fn unset_serial_number(&mut self) -> Option<String> {
let old = (self.string_serial != 0)
.then(|| self.string_pool.remove(&self.string_serial))
.flatten();
self.string_serial = 0;
old
}
/// Returns the old value, if present.
pub fn set_product_name(&mut self, name: &str) -> std::io::Result<Option<String>> {
let old = (self.string_product != 0)
.then(|| self.string_pool.remove(&self.string_product))
.flatten();
self.string_product = self.new_string(name)?;
Ok(old)
}
/// Unset product name and returns the old value, if present.
pub fn unset_product_name(&mut self) -> Option<String> {
let old = (self.string_product != 0)
.then(|| self.string_pool.remove(&self.string_product))
.flatten();
self.string_product = 0;
old
}
/// Returns the old value, if present.
pub fn set_manufacturer_name(&mut self, name: &str) -> std::io::Result<Option<String>> {
let old = (self.string_manufacturer != 0)
.then(|| self.string_pool.remove(&self.string_manufacturer))
.flatten();
self.string_manufacturer = self.new_string(name)?;
Ok(old)
}
/// Unset manufacturer name and returns the old value, if present.
pub fn unset_manufacturer_name(&mut self) -> Option<String> {
let old = (self.string_manufacturer != 0)
.then(|| self.string_pool.remove(&self.string_manufacturer))
.flatten();
self.string_manufacturer = 0;
old
}
pub fn with_interface(
mut self,
interface_class: u8,
interface_subclass: u8,
interface_protocol: u8,
name: Option<&str>,
endpoints: Vec<UsbEndpoint>,
handler: Arc<dyn UsbInterfaceHandler>,
) -> std::io::Result<Self> {
let string_interface = match name {
Some(name) => self.new_string(name)?,
None => 0,
};
let class_specific_descriptor = handler.get_class_specific_descriptor();
self.interface_states.push(InterfaceState::new(UsbInterface {
interface_class,
interface_subclass,
interface_protocol,
endpoints,
string_interface,
class_specific_descriptor,
handler,
}));
Ok(self)
}
pub fn with_device_handler(
mut self,
handler: Arc<Mutex<Box<dyn UsbDeviceHandler + Send>>>,
) -> Self {
self.device_handler = Some(handler);
self
}
pub fn new_string(&mut self, s: &str) -> std::io::Result<u8> {
for i in 1..=u8::MAX {
if let std::collections::hash_map::Entry::Vacant(e) = self.string_pool.entry(i) {
e.insert(s.to_string());
return Ok(i);
}
}
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"string pool exhausted (max 255 entries)",
))
}
pub fn find_ep(&self, ep: u8) -> Option<(UsbEndpoint, Option<usize>)> {
if ep == self.ep0_in.address {
Some((self.ep0_in, None))
} else if ep == self.ep0_out.address {
Some((self.ep0_out, None))
} else {
for (i, state) in self.interface_states.iter().enumerate() {
let inner = match state.inner.try_read() {
Ok(guard) => guard,
Err(_) => {
// Lock contended (e.g. concurrent SET_INTERFACE holds a write lock).
// Skip this interface; the caller handles None gracefully.
continue;
}
};
for endpoint in &inner.active.endpoints {
if endpoint.address == ep {
return Some((*endpoint, Some(i)));
}
}
}
None
}
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut result = Vec::with_capacity(312);
let mut path = self.path.as_bytes().to_vec();
path.truncate(256);
path.resize(256, 0);
result.extend_from_slice(path.as_slice());
let mut bus_id = self.bus_id.as_bytes().to_vec();
bus_id.truncate(32);
bus_id.resize(32, 0);
result.extend_from_slice(bus_id.as_slice());
result.extend_from_slice(&self.bus_num.to_be_bytes());
result.extend_from_slice(&self.dev_num.to_be_bytes());
result.extend_from_slice(&self.speed.to_be_bytes());
result.extend_from_slice(&self.vendor_id.to_be_bytes());
result.extend_from_slice(&self.product_id.to_be_bytes());
result.extend_from_slice(&self.device_bcd.to_bcd_be());
result.push(self.device_class);
result.push(self.device_subclass);
result.push(self.device_protocol);
result.push(self.configuration_value);
result.push(self.num_configurations);
result.push(self.interface_states.len() as u8);
result
}
pub fn to_bytes_with_interfaces(&self) -> Vec<u8> {
let mut result = self.to_bytes();
result.reserve(4 * self.interface_states.len());
for state in &self.interface_states {
match state.inner.try_read() {
Ok(inner) => {
result.push(inner.active.interface_class);
result.push(inner.active.interface_subclass);
result.push(inner.active.interface_protocol);
result.push(0); // padding
}
Err(_) => {
// Lock contended (e.g. concurrent SET_INTERFACE). Emit zeroed
// interface descriptor — this is the informational device-list
// response and the client can re-query later.
result.extend_from_slice(&[0u8; 4]);
}
}
}
result
}
/// Parse a 312-byte USB/IP device descriptor.
/// This is the inverse of `to_bytes()`. Used by the client to extract
/// device metadata from the simplified handshake.
pub fn from_bytes(bytes: &[u8]) -> Self {
assert!(
bytes.len() >= 312,
"device descriptor must be at least 312 bytes"
);
let path = std::str::from_utf8(&bytes[0..256])
.unwrap_or("")
.trim_end_matches('\0')
.to_string();
let bus_id = std::str::from_utf8(&bytes[256..288])
.unwrap_or("")
.trim_end_matches('\0')
.to_string();
let bus_num = u32::from_be_bytes(bytes[288..292].try_into().unwrap());
let dev_num = u32::from_be_bytes(bytes[292..296].try_into().unwrap());
let speed = u32::from_be_bytes(bytes[296..300].try_into().unwrap());
let vendor_id = u16::from_be_bytes(bytes[300..302].try_into().unwrap());
let product_id = u16::from_be_bytes(bytes[302..304].try_into().unwrap());
let device_bcd = Version {
major: bytes[304],
minor: (bytes[305] >> 4) & 0xF,
patch: bytes[305] & 0xF,
};
let device_class = bytes[306];
let device_subclass = bytes[307];
let device_protocol = bytes[308];
let configuration_value = bytes[309];
let num_configurations = bytes[310];
Self {
path,
bus_id,
bus_num,
dev_num,
speed,
vendor_id,
product_id,
device_bcd,
device_class,
device_subclass,
device_protocol,
configuration_value,
num_configurations,
..Self::default()
}
}
pub async fn handle_urb(
&self,
intf_idx: Option<usize>,
request: UrbRequest,
) -> Result<UrbResponse> {
use DescriptorType::*;
use Direction::*;
use EndpointAttributes::*;
use StandardRequest::*;
let ep = request.ep;
let transfer_buffer_length = request.transfer_buffer_length;
let setup_packet = request.setup.clone();
let out_data = request.data.clone();
match (ep.transfer_type(), ep.direction()) {
(Some(Control), In) => {
// control in
debug!("Control IN setup={setup_packet:x?}");
match (
setup_packet.request_type,
FromPrimitive::from_u8(setup_packet.request),
) {
(0b10000000, Some(GetDescriptor)) => {
// high byte: type
match FromPrimitive::from_u16(setup_packet.value >> 8) {
Some(Device) => {
debug!("Get device descriptor");
// Standard Device Descriptor
let bcd_usb = self.usb_version.to_bcd_be();
let mut desc = vec![
0x12, // bLength
Device as u8, // bDescriptorType: Device
bcd_usb[1],
bcd_usb[0], // bcdUSB (little-endian)
self.device_class, // bDeviceClass
self.device_subclass, // bDeviceSubClass
self.device_protocol, // bDeviceProtocol
self.ep0_in.max_packet_size as u8, // bMaxPacketSize0
self.vendor_id as u8, // idVendor
(self.vendor_id >> 8) as u8,
self.product_id as u8, // idProduct
(self.product_id >> 8) as u8,
self.device_bcd.to_bcd_be()[1], // bcdDevice (little-endian)
self.device_bcd.to_bcd_be()[0],
self.string_manufacturer, // iManufacturer
self.string_product, // iProduct
self.string_serial, // iSerial
self.num_configurations, // bNumConfigurations
];
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
}
Some(BOS) => {
debug!("Get BOS descriptor");
if self.device_handler.is_some() {
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
return handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
});
}
let mut desc = vec![
0x05, // bLength
BOS as u8, // bDescriptorType: BOS
0x05, 0x00, // wTotalLength
0x00, // bNumCapabilities
];
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
}
Some(Configuration) => {
debug!("Get configuration descriptor");
// In passthrough mode, forward to the real device so that
// IADs, class-specific descriptors, etc. are preserved.
if self.device_handler.is_some() {
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
return handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
});
}
// Standard Configuration Descriptor
let mut desc = vec![
0x09, // bLength
Configuration as u8, // bDescriptorType: Configuration
0x00,
0x00, // wTotalLength: to be filled below
self.interface_states.len() as u8, // bNumInterfaces
self.configuration_value, // bConfigurationValue
self.string_configuration, // iConfiguration
0x80, // bmAttributes: Bus Powered
0x32, // bMaxPower: 100mA
];
for (i, state) in self.interface_states.iter().enumerate() {
let inner = state.inner.read().await;
for (alt_idx, intf) in inner.alt_settings.iter().enumerate() {
let mut intf_desc = vec![
0x09, // bLength
Interface as u8, // bDescriptorType: Interface
i as u8, // bInterfaceNum
alt_idx as u8, // bAlternateSetting
intf.endpoints.len() as u8, // bNumEndpoints
intf.interface_class, // bInterfaceClass
intf.interface_subclass, // bInterfaceSubClass
intf.interface_protocol, // bInterfaceProtocol
intf.string_interface, // iInterface
];
// class specific descriptor
let mut specific = intf.class_specific_descriptor.clone();
intf_desc.append(&mut specific);
// endpoint descriptors
for endpoint in &intf.endpoints {
let mut ep_desc = vec![
0x07, // bLength
Endpoint as u8, // bDescriptorType: Endpoint
endpoint.address, // bEndpointAddress
endpoint.attributes, // bmAttributes
endpoint.max_packet_size as u8,
(endpoint.max_packet_size >> 8) as u8, // wMaxPacketSize
endpoint.interval, // bInterval
];
intf_desc.append(&mut ep_desc);
}
desc.append(&mut intf_desc);
}
}
// length
let len = desc.len() as u16;
desc[2] = len as u8;
desc[3] = (len >> 8) as u8;
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
}
Some(String) => {
debug!("Get string descriptor");
let index = setup_packet.value as u8;
if index == 0 {
// String Descriptor Zero, Specifying Languages Supported by the Device
// language ids
let mut desc = vec![
4, // bLength
DescriptorType::String as u8, // bDescriptorType
0x09,
0x04, // wLANGID[0], en-US
];
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
} else if let Some(s) = &self.string_pool.get(&index) {
// UNICODE String Descriptor
let bytes: Vec<u16> = s.encode_utf16().collect();
// Truncate to fit in a u8 bLength: 2 header + 2 per code unit <= 255
let truncated = &bytes[..bytes.len().min(126)];
let b_length = (2 + truncated.len() * 2) as u8;
let mut desc = vec![
b_length, // bLength
DescriptorType::String as u8, // bDescriptorType
];
for &byte in truncated {
desc.push(byte as u8);
desc.push((byte >> 8) as u8);
}
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
} else if self.device_handler.is_some() {
// Forward unknown string indices to the device handler
// (host passthrough: the real device knows its own strings)
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid string index: {index}"),
))
}
}
Some(DeviceQualifier) => {
debug!("Get device qualifier descriptor");
if self.device_handler.is_some() {
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
return handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
});
}
// Device_Qualifier Descriptor
let bcd_usb_qual = self.usb_version.to_bcd_be();
let mut desc = vec![
0x0A, // bLength
DeviceQualifier as u8, // bDescriptorType: Device Qualifier
bcd_usb_qual[1],
bcd_usb_qual[0], // bcdUSB (little-endian)
self.device_class, // bDeviceClass
self.device_subclass, // bDeviceSUbClass
self.device_protocol, // bDeviceProtocol
self.ep0_in.max_packet_size as u8, // bMaxPacketSize0
self.num_configurations, // bNumConfigurations
0x00, // bReserved
];
// requested len too short: wLength < real length
if setup_packet.length < desc.len() as u16 {
desc.resize(setup_packet.length as usize, 0);
}
Ok(UrbResponse { data: desc, ..Default::default() })
}
_ => {
warn!("unknown desc type: {setup_packet:x?}");
Ok(UrbResponse::default())
}
}
}
(0b10000000, Some(GetConfiguration)) => {
debug!("Get configuration value");
Ok(UrbResponse { data: vec![self.configuration_value], ..Default::default() })
}
(0b10000000, Some(GetStatus)) => {
// Device recipient: self-powered=0, remote-wakeup=0
debug!("Get status (device)");
Ok(UrbResponse { data: vec![0x00, 0x00], ..Default::default() })
}
(0b10000001, Some(GetStatus)) => {
// Interface recipient: reserved, always zero
debug!("Get status (interface)");
Ok(UrbResponse { data: vec![0x00, 0x00], ..Default::default() })
}
(0b10000010, Some(GetStatus)) => {
// Endpoint recipient: halt=0
debug!("Get status (endpoint)");
Ok(UrbResponse { data: vec![0x00, 0x00], ..Default::default() })
}
(0b10000001, Some(GetInterface)) => {
let intf_index = setup_packet.index as usize & 0xFF;
match self.interface_states.get(intf_index) {
Some(state) => {
let inner = state.inner.read().await;
Ok(UrbResponse { data: vec![inner.current_alt], ..Default::default() })
}
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid interface index: {intf_index}"),
)),
}
}
_ if setup_packet.request_type & 0xF == 1 => {
// to interface
// see https://www.beyondlogic.org/usbnutshell/usb6.shtml
// only low 8 bits are valid
let intf_index = setup_packet.index as usize & 0xFF;
let state = match self.interface_states.get(intf_index) {
Some(state) => state,
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Invalid interface index: {}",
setup_packet.index
),
))
}
};
let inner = state.inner.read().await;
inner.active.handler.handle_urb(&inner.active, UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})
}
_ if setup_packet.request_type & 0xF == 0 && self.device_handler.is_some() => {
// to device
// see https://www.beyondlogic.org/usbnutshell/usb6.shtml
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})
}
_ => {
warn!("Unhandled control IN: {setup_packet:x?}");
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!("Unhandled control IN: {setup_packet:x?}"),
))
}
}
}
(Some(Control), Out) => {
// control out
debug!("Control OUT setup={setup_packet:x?}");
match (
setup_packet.request_type,
FromPrimitive::from_u8(setup_packet.request),
) {
(0b00000000, Some(SetConfiguration)) => {
// Forward to physical device handler if present,
// so endpoints are properly reset on the device
if let Some(ref handler_lock) = self.device_handler {
let mut handler = handler_lock.lock().unwrap_or_else(|e| e.into_inner());
handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})?;
}
Ok(UrbResponse::default())
}
(0b00000001, Some(SetInterface)) => {
let intf_index = setup_packet.index as usize & 0xFF;
let alt = setup_packet.value as u8;
match self.interface_states.get(intf_index) {
Some(state) => {
let mut inner = state.inner.write().await;
if (alt as usize) < inner.alt_settings.len() {
// Notify the handler so it can update the physical device
inner.active.handler.set_alt_setting(alt)?;
inner.active = inner.alt_settings[alt as usize].clone();
inner.current_alt = alt;
info!("SET_INTERFACE: intf={intf_index} alt={alt}");
Ok(UrbResponse::default())
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid alt setting {alt} for interface {intf_index}"),
))
}
}
None => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid interface index: {intf_index}"),
)),
}
}
(0b00000010, Some(ClearFeature)) => {
// Endpoint recipient: no-op (simulated device doesn't stall)
debug!("Clear feature (endpoint)");
Ok(UrbResponse::default())
}
(0b00000010, Some(SetFeature)) => {
// Endpoint recipient: no-op
debug!("Set feature (endpoint)");
Ok(UrbResponse::default())
}
(0b00000000, Some(SetAddress)) => {
// No-op: address already assigned by bus
debug!("Set address (no-op)");
Ok(UrbResponse::default())
}
_ if setup_packet.request_type & 0xF == 1 => {
// to interface
// see https://www.beyondlogic.org/usbnutshell/usb6.shtml
// only low 8 bits are valid
let intf_index = setup_packet.index as usize & 0xFF;
let state = match self.interface_states.get(intf_index) {
Some(state) => state,
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!(
"Invalid interface index: {}",
setup_packet.index
),
))
}
};
let inner = state.inner.read().await;
inner.active.handler.handle_urb(&inner.active, UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})
}
_ if setup_packet.request_type & 0xF == 0 && self.device_handler.is_some() => {
// to device
// see https://www.beyondlogic.org/usbnutshell/usb6.shtml
let lock = self.device_handler.as_ref().unwrap();
let mut handler = lock.lock().unwrap_or_else(|e| e.into_inner());
handler.handle_urb(UrbRequest {
ep,
transfer_buffer_length,
setup: setup_packet,
data: out_data,
..Default::default()
})
}
_ => {
warn!("Unhandled control OUT: {setup_packet:x?}");
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!("Unhandled control OUT: {setup_packet:x?}"),
))
}
}
}
(Some(_), _) => {
// others
match intf_idx {
Some(idx) => {
let state = &self.interface_states[idx];
let inner = state.inner.read().await;
inner.active.handler.handle_urb(
&inner.active,
request,
)
}
None => {
warn!("No interface for endpoint {:02x?}", ep);
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!("No interface for endpoint {ep:02x?}"),
))
}
}
}
_ => {
warn!("Unsupported transfer to {:?}", ep);
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
format!("Unsupported transfer to {ep:?}"),
))
}
}
}
}
/// A handler for URB targeting the device
pub trait UsbDeviceHandler: std::fmt::Debug {
/// Handle a URB(USB Request Block) targeting at this device
///
/// When the lower 4 bits of `bmRequestType` is zero and the URB is not handled by the library, this function is called.
fn handle_urb(&mut self, request: UrbRequest) -> Result<UrbResponse>;
/// Helper to downcast to actual struct
fn as_any(&self) -> &dyn Any;
}
#[cfg(test)]
mod test {
use crate::util::tests::*;
use super::*;
#[test]
fn test_set_string_descriptors() {
setup_test_logger();
let mut device = UsbDevice::new(0).unwrap();
assert_eq!(device.string_pool.len(), 4);
assert!(device.set_configuration_name("test").unwrap().is_some());
assert!(device.set_manufacturer_name("test").unwrap().is_some());
assert!(device.set_product_name("test").unwrap().is_some());
assert!(device.set_serial_number("test").unwrap().is_some());
assert_eq!(device.string_pool.len(), 4);
assert_eq!(device.string_pool[&1], "test");
assert_eq!(device.string_pool[&2], "test");
assert_eq!(device.string_pool[&3], "test");
assert_eq!(device.string_pool[&4], "test");
}
#[tokio::test]
async fn test_invalid_string_index() {
setup_test_logger();
let device = UsbDevice::new(0).unwrap();
let res = device
.handle_urb(
None,
UrbRequest {
ep: UsbEndpoint {
address: 0x80, // IN
attributes: EndpointAttributes::Control as u8,
max_packet_size: EP0_MAX_PACKET_SIZE,
interval: 0,
},
transfer_buffer_length: 0,
setup: SetupPacket {
request_type: 0b10000000,
request: StandardRequest::GetDescriptor as u8,
// string pool only contains 4 strings, 5 should be invalid
value: ((DescriptorType::String as u16) << 8) | 5,
index: 0,
length: 0,
},
data: vec![],
..Default::default()
},
)
.await;
assert!(res.is_err());
}
#[test]
fn test_from_bytes_round_trip() {
setup_test_logger();
let device = UsbDevice::new(0).unwrap();
let bytes = device.to_bytes();
assert_eq!(bytes.len(), 312);
let parsed = UsbDevice::from_bytes(&bytes);
assert_eq!(parsed.bus_num, device.bus_num);
assert_eq!(parsed.dev_num, device.dev_num);
assert_eq!(parsed.speed, device.speed);
assert_eq!(parsed.vendor_id, device.vendor_id);
assert_eq!(parsed.product_id, device.product_id);
assert_eq!(parsed.device_class, device.device_class);
assert_eq!(parsed.num_configurations, device.num_configurations);
assert_eq!(parsed.path, device.path);
assert_eq!(parsed.bus_id, device.bus_id);
}
}