From 89d55ab6a2635fb2477333fe69894aa7105459a9 Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Mon, 14 Jul 2025 08:31:41 -0600 Subject: [PATCH 1/4] EndpointRead: improve semantics for num_transfers --- src/io/read.rs | 120 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 79 insertions(+), 41 deletions(-) diff --git a/src/io/read.rs b/src/io/read.rs index 0645148..4b69b0f 100644 --- a/src/io/read.rs +++ b/src/io/read.rs @@ -88,33 +88,41 @@ impl EndpointRead { Self { endpoint, reading: None, - num_transfers: 0, + num_transfers: 1, transfer_size, read_timeout: Duration::MAX, } } - /// Set the number of transfers to maintain pending at all times. + /// Set the number of concurrent transfers. /// - /// A value of 0 means that transfers will only be submitted when calling - /// `read()` or `fill_buf()` and the buffer is empty. To maximize throughput, - /// a value of 2 or more is recommended for applications that stream data - /// continuously. + /// A value of 1 (default) means that transfers will only be submitted when + /// calling `read()` or `fill_buf()` and the buffer is empty. To maximize + /// throughput, a value of 2 or more is recommended for applications that + /// stream data continuously so that the host controller can continue to + /// receive data while the application processes the data from a completed + /// transfer. + /// + /// A value of 0 means no further transfers will be submitted. Existing + /// transfers will complete normally, and subsequent calls to `read()` and + /// `fill_buf()` will return zero bytes (EOF). /// /// This submits more transfers when increasing the number, but does not - /// cancel transfers when decreasing it. + /// [cancel transfers](Self::cancel_all) when decreasing it. pub fn set_num_transfers(&mut self, num_transfers: usize) { self.num_transfers = num_transfers; - while self.endpoint.pending() < num_transfers { + // Leave the last transfer to be submitted by `read` such that + // a value of `1` only has transfers pending within `read` calls. + while self.endpoint.pending() < num_transfers.saturating_sub(1) { let buf = self.endpoint.allocate(self.transfer_size); self.endpoint.submit(buf); } } - /// Set the number of transfers to maintain pending at all times. + /// Set the number of concurrent transfers. /// - /// See [Self::set_num_transfers] -- this is for method chaining with `EndpointRead::new()`. + /// See [Self::set_num_transfers] (this version is for method chaining). pub fn with_num_transfers(mut self, num_transfers: usize) -> Self { self.set_num_transfers(num_transfers); self @@ -141,8 +149,14 @@ impl EndpointRead { /// Cancel all pending transfers. /// - /// They will be re-submitted on the next read. + /// This sets [`num_transfers`](Self::set_num_transfers) to 0, so no further + /// transfers will be submitted. Any data buffered before the transfers are cancelled + /// can be read, and then the read methods will return 0 bytes (EOF). + /// + /// Call [`num_transfers`](Self::set_num_transfers) with a non-zero value + /// to resume receiving data. pub fn cancel_all(&mut self) { + self.num_transfers = 0; self.endpoint.cancel_all(); } @@ -183,15 +197,19 @@ impl EndpointRead { } } - fn start_read(&mut self) { - let t = usize::max(1, self.num_transfers); - if self.endpoint.pending() < t { + fn start_read(&mut self) -> bool { + if self.endpoint.pending() < self.num_transfers { + // Re-use the last completed buffer if available self.resubmit(); - while self.endpoint.pending() < t { + while self.endpoint.pending() < self.num_transfers { + // Allocate more buffers for any remaining transfers let buf = self.endpoint.allocate(self.transfer_size); self.endpoint.submit(buf); } } + + // If num_transfers is 0 and all transfers are complete + self.endpoint.pending() > 0 } #[inline] @@ -208,38 +226,46 @@ impl EndpointRead { } } - fn wait(&mut self) -> Result<(), std::io::Error> { - self.start_read(); - let c = self.endpoint.wait_next_complete(self.read_timeout); - let c = c.ok_or(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "timeout waiting for read", - ))?; - self.reading = Some(ReadBuffer { - pos: 0, - buf: c.buffer, - status: c.status, - }); - Ok(()) + fn wait(&mut self) -> Result { + if self.start_read() { + let c = self.endpoint.wait_next_complete(self.read_timeout); + let c = c.ok_or(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "timeout waiting for read", + ))?; + self.reading = Some(ReadBuffer { + pos: 0, + buf: c.buffer, + status: c.status, + }); + Ok(true) + } else { + Ok(false) + } } #[cfg(any(feature = "tokio", feature = "smol"))] - fn poll(&mut self, cx: &mut Context<'_>) -> Poll<()> { - self.start_read(); - let c = ready!(self.endpoint.poll_next_complete(cx)); - self.reading = Some(ReadBuffer { - pos: 0, - buf: c.buffer, - status: c.status, - }); - Poll::Ready(()) + fn poll(&mut self, cx: &mut Context<'_>) -> Poll { + if self.start_read() { + let c = ready!(self.endpoint.poll_next_complete(cx)); + self.reading = Some(ReadBuffer { + pos: 0, + buf: c.buffer, + status: c.status, + }); + Poll::Ready(true) + } else { + Poll::Ready(false) + } } #[cfg(any(feature = "tokio", feature = "smol"))] #[inline] fn poll_fill_buf(&mut self, cx: &mut Context<'_>) -> Poll> { while !self.has_data() { - ready!(self.poll(cx)); + if !ready!(self.poll(cx)) { + return Poll::Ready(Ok(&[])); + } } Poll::Ready(self.remaining()) } @@ -251,7 +277,12 @@ impl EndpointRead { cx: &mut Context<'_>, ) -> Poll> { while !self.has_data_or_short_end() { - ready!(self.poll(cx)); + if !ready!(self.poll(cx)) { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ended without short packet", + ))); + } } Poll::Ready(self.remaining()) } @@ -271,7 +302,9 @@ impl BufRead for EndpointRead { #[inline] fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> { while !self.has_data() { - self.wait()?; + if !self.wait()? { + return Ok(&[]); + } } self.remaining() } @@ -393,7 +426,12 @@ impl BufRead for EndpointReadUntilShortPacket<'_, EpTyp #[inline] fn fill_buf(&mut self) -> Result<&[u8], std::io::Error> { while !self.reader.has_data_or_short_end() { - self.reader.wait()?; + if !self.reader.wait()? { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "ended without short packet", + )); + } } self.reader.remaining() } From a3c57e9336955a59beccccf499765298c45deb1a Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Tue, 15 Jul 2025 22:12:02 -0600 Subject: [PATCH 2/4] EndpointRead: Error type for EndpointReadUntilShortPacket::consume_end --- src/io/read.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/io/read.rs b/src/io/read.rs index 4b69b0f..d9d01d0 100644 --- a/src/io/read.rs +++ b/src/io/read.rs @@ -1,4 +1,5 @@ use std::{ + error::Error, io::{BufRead, Read}, time::Duration, }; @@ -384,6 +385,19 @@ pub struct EndpointReadUntilShortPacket<'a, EpType: BulkOrInterrupt> { reader: &'a mut EndpointRead, } +/// Error returned by [`EndpointReadUntilShortPacket::consume_end()`] +/// when the reader is not at the end of a short packet. +#[derive(Debug)] +pub struct ExpectedShortPacket; + +impl std::fmt::Display for ExpectedShortPacket { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "expected short packet") + } +} + +impl Error for ExpectedShortPacket {} + impl EndpointReadUntilShortPacket<'_, EpType> { /// Check if the underlying endpoint has reached the end of a short packet. /// @@ -402,12 +416,12 @@ impl EndpointReadUntilShortPacket<'_, EpType> { /// to read the next message. /// /// Returns an error and does nothing if the reader [is not at the end of a short packet](Self::is_end). - pub fn consume_end(&mut self) -> Result<(), ()> { + pub fn consume_end(&mut self) -> Result<(), ExpectedShortPacket> { if self.is_end() { self.reader.reading.as_mut().unwrap().clear_short_packet(); Ok(()) } else { - Err(()) + Err(ExpectedShortPacket) } } } From eb12376bc252e0d50ecc6c62b8d5b6b4b2dd4795 Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Tue, 15 Jul 2025 22:56:32 -0600 Subject: [PATCH 3/4] EndpointWrite: require num_transfers is nonzero --- src/io/write.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/io/write.rs b/src/io/write.rs index d6fad2d..0925ae1 100644 --- a/src/io/write.rs +++ b/src/io/write.rs @@ -49,7 +49,10 @@ impl EndpointWrite { /// If more than `num_transfers` transfers are pending, calls to `write` /// will block or async methods will return `Pending` until a transfer /// completes. + /// + /// Panics if `num_transfers` is zero. pub fn set_num_transfers(&mut self, num_transfers: usize) { + assert!(num_transfers > 0, "num_transfers must be greater than zero"); self.num_transfers = num_transfers; } From 371b91c85fcd2df4940b54e9955b2c0d210d0feb Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Sun, 15 Jun 2025 10:12:21 -0600 Subject: [PATCH 4/4] EndpointRead / EndpointWrite examples with tokio and smol --- Cargo.toml | 10 ++++ examples/bulk_io.rs | 18 ++++++- examples/bulk_io_smol.rs | 99 +++++++++++++++++++++++++++++++++++++++ examples/bulk_io_tokio.rs | 95 +++++++++++++++++++++++++++++++++++++ 4 files changed, 221 insertions(+), 1 deletion(-) create mode 100644 examples/bulk_io_smol.rs create mode 100644 examples/bulk_io_tokio.rs diff --git a/Cargo.toml b/Cargo.toml index 22878b3..a5c61a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ slab = "0.4.9" [dev-dependencies] env_logger = "0.11" futures-lite = "2.0" +tokio = { version = "1", features = ["rt", "macros", "io-util", "rt-multi-thread"] } [target.'cfg(any(target_os="linux", target_os="android"))'.dependencies] rustix = { version = "1.0.1", features = ["fs", "event", "net", "time", "mm"] } @@ -49,3 +50,12 @@ unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } [package.metadata.docs.rs] all-features = true + +[[example]] +name = "bulk_io_smol" +required-features = ["smol"] + +[[example]] +name = "bulk_io_tokio" +required-features = ["tokio"] + diff --git a/examples/bulk_io.rs b/examples/bulk_io.rs index 0a41518..1f91ddc 100644 --- a/examples/bulk_io.rs +++ b/examples/bulk_io.rs @@ -31,7 +31,8 @@ fn main() { let mut reader = main_interface .endpoint::(0x83) .unwrap() - .reader(128); + .reader(128) + .with_num_transfers(8); writer.write_all(&[1; 16]).unwrap(); writer.write_all(&[2; 256]).unwrap(); @@ -47,6 +48,21 @@ fn main() { dbg!(reader.fill_buf().unwrap().len()); + let mut buf = [0; 1000]; + for len in 0..1000 { + reader.read_exact(&mut buf[..len]).unwrap(); + writer.write_all(&buf[..len]).unwrap(); + } + + reader.cancel_all(); + loop { + let n = reader.read(&mut buf).unwrap(); + dbg!(n); + if n == 0 { + break; + } + } + let echo_interface = device.claim_interface(1).wait().unwrap(); echo_interface.set_alt_setting(1).wait().unwrap(); diff --git a/examples/bulk_io_smol.rs b/examples/bulk_io_smol.rs new file mode 100644 index 0000000..7ebedc3 --- /dev/null +++ b/examples/bulk_io_smol.rs @@ -0,0 +1,99 @@ +use std::time::Duration; + +use nusb::{ + transfer::{Bulk, In, Out}, + MaybeFuture, +}; + +use futures_lite::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; + +fn main() { + env_logger::init(); + let di = nusb::list_devices() + .wait() + .unwrap() + .find(|d| d.vendor_id() == 0x59e3 && d.product_id() == 0x00aa) + .expect("device should be connected"); + + println!("Device info: {di:?}"); + + futures_lite::future::block_on(async { + let device = di.open().await.unwrap(); + + let main_interface = device.claim_interface(0).await.unwrap(); + + let mut writer = main_interface + .endpoint::(0x03) + .unwrap() + .writer(128) + .with_num_transfers(8); + + let mut reader = main_interface + .endpoint::(0x83) + .unwrap() + .reader(128) + .with_num_transfers(8); + + writer.write_all(&[1; 16]).await.unwrap(); + writer.write_all(&[2; 256]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[3; 64]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + let mut buf = [0; 16]; + reader.read_exact(&mut buf).await.unwrap(); + + let mut buf = [0; 64]; + reader.read_exact(&mut buf).await.unwrap(); + + dbg!(reader.fill_buf().await.unwrap().len()); + + let mut buf = [0; 1000]; + for len in 0..1000 { + reader.read_exact(&mut buf[..len]).await.unwrap(); + writer.write_all(&buf[..len]).await.unwrap(); + } + + reader.cancel_all(); + loop { + let n = reader.read(&mut buf).await.unwrap(); + dbg!(n); + if n == 0 { + break; + } + } + + let echo_interface = device.claim_interface(1).await.unwrap(); + echo_interface.set_alt_setting(1).await.unwrap(); + + let mut writer = echo_interface + .endpoint::(0x01) + .unwrap() + .writer(64) + .with_num_transfers(1); + let mut reader = echo_interface + .endpoint::(0x81) + .unwrap() + .reader(64) + .with_num_transfers(8) + .with_read_timeout(Duration::from_millis(100)); + + let mut pkt_reader = reader.until_short_packet(); + + writer.write_all(&[1; 16]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + writer.write_all(&[2; 128]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + let mut v = Vec::new(); + pkt_reader.read_to_end(&mut v).await.unwrap(); + assert_eq!(&v[..], &[1; 16]); + pkt_reader.consume_end().unwrap(); + + let mut v = Vec::new(); + pkt_reader.read_to_end(&mut v).await.unwrap(); + assert_eq!(&v[..], &[2; 128]); + pkt_reader.consume_end().unwrap(); + }) +} diff --git a/examples/bulk_io_tokio.rs b/examples/bulk_io_tokio.rs new file mode 100644 index 0000000..19b28e0 --- /dev/null +++ b/examples/bulk_io_tokio.rs @@ -0,0 +1,95 @@ +use std::time::Duration; + +use nusb::transfer::{Bulk, In, Out}; + +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt}; + +#[tokio::main] +async fn main() { + env_logger::init(); + let di = nusb::list_devices() + .await + .unwrap() + .find(|d| d.vendor_id() == 0x59e3 && d.product_id() == 0x00aa) + .expect("device should be connected"); + + println!("Device info: {di:?}"); + + let device = di.open().await.unwrap(); + + let main_interface = device.claim_interface(0).await.unwrap(); + + let mut writer = main_interface + .endpoint::(0x03) + .unwrap() + .writer(128) + .with_num_transfers(8); + + let mut reader = main_interface + .endpoint::(0x83) + .unwrap() + .reader(128) + .with_num_transfers(8); + + writer.write_all(&[1; 16]).await.unwrap(); + writer.write_all(&[2; 256]).await.unwrap(); + writer.flush().await.unwrap(); + writer.write_all(&[3; 64]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + let mut buf = [0; 16]; + reader.read_exact(&mut buf).await.unwrap(); + + let mut buf = [0; 64]; + reader.read_exact(&mut buf).await.unwrap(); + + dbg!(reader.fill_buf().await.unwrap().len()); + + let mut buf = [0; 1000]; + for len in 0..1000 { + reader.read_exact(&mut buf[..len]).await.unwrap(); + writer.write_all(&buf[..len]).await.unwrap(); + } + + reader.cancel_all(); + loop { + let n = reader.read(&mut buf).await.unwrap(); + dbg!(n); + if n == 0 { + break; + } + } + + let echo_interface = device.claim_interface(1).await.unwrap(); + echo_interface.set_alt_setting(1).await.unwrap(); + + let mut writer = echo_interface + .endpoint::(0x01) + .unwrap() + .writer(64) + .with_num_transfers(1); + let mut reader = echo_interface + .endpoint::(0x81) + .unwrap() + .reader(64) + .with_num_transfers(8) + .with_read_timeout(Duration::from_millis(100)); + + let mut pkt_reader = reader.until_short_packet(); + + writer.write_all(&[1; 16]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + writer.write_all(&[2; 128]).await.unwrap(); + writer.flush_end_async().await.unwrap(); + + let mut v = Vec::new(); + pkt_reader.read_to_end(&mut v).await.unwrap(); + assert_eq!(&v[..], &[1; 16]); + pkt_reader.consume_end().unwrap(); + + let mut v = Vec::new(); + pkt_reader.read_to_end(&mut v).await.unwrap(); + assert_eq!(&v[..], &[2; 128]); + pkt_reader.consume_end().unwrap(); +}