From cecdf88c8633748583c8a2d6201f95e150e6eb03 Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Sun, 28 Jan 2024 15:42:50 -0700 Subject: [PATCH] linux: hotplug --- Cargo.toml | 5 +- src/platform/linux_usbfs/device.rs | 13 ++- src/platform/linux_usbfs/events.rs | 86 +++++++++++++---- src/platform/linux_usbfs/hotplug.rs | 138 +++++++++++++++++++++++++++- 4 files changed, 219 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0adae90..b4998b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ env_logger = "0.10.0" futures-lite = "1.13.0" [target.'cfg(target_os="linux")'.dependencies] -rustix = { version = "0.38.17", features = ["fs", "event"] } +rustix = { version = "0.38.17", features = ["fs", "event", "net"] } [target.'cfg(target_os="windows")'.dependencies] windows-sys = { version = "0.48.0", features = ["Win32_Devices_Usb", "Win32_Devices_DeviceAndDriverInstallation", "Win32_Foundation", "Win32_Devices_Properties", "Win32_Storage_FileSystem", "Win32_Security", "Win32_System_IO", "Win32_System_Registry", "Win32_System_Com"] } @@ -34,3 +34,6 @@ io-kit-sys = "0.4.0" [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(fuzzing)'] } + +[patch.crates-io] +rustix = { git = "https://github.com/kevinmehall/rustix.git", rev = "9b432db1b4ed6cd8ec58fd88815a785a03300ebe" } diff --git a/src/platform/linux_usbfs/device.rs b/src/platform/linux_usbfs/device.rs index f26e192..e813d99 100644 --- a/src/platform/linux_usbfs/device.rs +++ b/src/platform/linux_usbfs/device.rs @@ -11,6 +11,8 @@ use std::{ }; use log::{debug, error}; +use rustix::event::epoll; +use rustix::fd::AsFd; use rustix::{ fd::{AsRawFd, FromRawFd, OwnedFd}, fs::{Mode, OFlags}, @@ -22,6 +24,7 @@ use super::{ usbfs::{self, Urb}, SysfsPath, }; +use crate::platform::linux_usbfs::events::Watch; use crate::{ descriptors::{parse_concatenated_config_descriptors, DESCRIPTOR_LEN_DEVICE}, transfer::{ @@ -61,7 +64,11 @@ impl LinuxDevice { // because there's no Arc::try_new_cyclic let mut events_err = None; let arc = Arc::new_cyclic(|weak| { - let res = events::register(&fd, weak.clone()); + let res = events::register( + fd.as_fd(), + Watch::Device(weak.clone()), + epoll::EventFlags::OUT, + ); let events_id = *res.as_ref().unwrap_or(&usize::MAX); events_err = res.err(); LinuxDevice { @@ -109,7 +116,7 @@ impl LinuxDevice { // only returns ENODEV after all events are received, so unregister to // keep the event thread from spinning because we won't receive further events. // The drop impl will try to unregister again, but that's ok. - events::unregister_fd(&self.fd); + events::unregister_fd(self.fd.as_fd()); } Err(e) => { error!("Unexpected error {e} from REAPURBNDELAY"); @@ -282,7 +289,7 @@ impl LinuxDevice { impl Drop for LinuxDevice { fn drop(&mut self) { debug!("Closing device {}", self.events_id); - events::unregister(&self.fd, self.events_id) + events::unregister(self.fd.as_fd(), self.events_id) } } diff --git a/src/platform/linux_usbfs/events.rs b/src/platform/linux_usbfs/events.rs index 7d3e0af..7d2780b 100644 --- a/src/platform/linux_usbfs/events.rs +++ b/src/platform/linux_usbfs/events.rs @@ -1,12 +1,27 @@ +use atomic_waker::AtomicWaker; +/// Epoll based event loop for Linux. +/// +/// Launches a thread when opening the first device that polls +/// for events on usbfs devices and arbitrary file descriptors +/// (used for udev hotplug). +/// +/// ### Why not share an event loop with `tokio` or `async-io`? +/// +/// This event loop will call USBFS_REAP_URB on the event thread and +/// dispatch to the transfer's waker directly. Since all USB transfers +/// on a device use the same file descriptor, putting USB-specific +/// dispatch in the event loop avoids additonal synchronization. use once_cell::sync::OnceCell; use rustix::{ - event::epoll::{self, EventData}, - fd::OwnedFd, + event::epoll::{self, EventData, EventFlags}, + fd::{AsFd, BorrowedFd, OwnedFd}, io::retry_on_intr, }; use slab::Slab; use std::{ - sync::{Mutex, Weak}, + io, + sync::{Arc, Mutex, Weak}, + task::Waker, thread, }; @@ -15,9 +30,14 @@ use crate::Error; use super::Device; static EPOLL_FD: OnceCell = OnceCell::new(); -static DEVICES: Mutex>> = Mutex::new(Slab::new()); +static WATCHES: Mutex> = Mutex::new(Slab::new()); -pub(super) fn register(usb_fd: &OwnedFd, weak_device: Weak) -> Result { +pub(super) enum Watch { + Device(Weak), + Fd(Arc), +} + +pub(super) fn register(fd: BorrowedFd, watch: Watch, flags: EventFlags) -> Result { let mut start_thread = false; let epoll_fd = EPOLL_FD.get_or_try_init(|| { start_thread = true; @@ -25,8 +45,8 @@ pub(super) fn register(usb_fd: &OwnedFd, weak_device: Weak) -> Result) -> Result { + if let Some(device) = w.upgrade() { + drop(lock); + device.handle_events(); + // `device` gets dropped here. if it was the last reference, the LinuxDevice will be dropped. + // That will unregister its fd, so it's important that WATCHES is unlocked here, or we'd deadlock. + } + } + Watch::Fd(waker) => waker.wake(), } } } } + +pub(crate) struct Async { + pub(crate) inner: T, + waker: Arc, + id: usize, +} + +impl Async { + pub fn new(inner: T) -> Result { + let waker = Arc::new(AtomicWaker::new()); + let id = register(inner.as_fd(), Watch::Fd(waker.clone()), EventFlags::empty())?; + Ok(Async { inner, id, waker }) + } + + pub fn register(&self, waker: &Waker) -> Result<(), io::Error> { + self.waker.register(waker); + let epoll_fd = EPOLL_FD.get().unwrap(); + epoll::modify( + epoll_fd, + self.inner.as_fd(), + EventData::new_u64(self.id as u64), + EventFlags::ONESHOT | EventFlags::IN, + )?; + Ok(()) + } +} diff --git a/src/platform/linux_usbfs/hotplug.rs b/src/platform/linux_usbfs/hotplug.rs index ba0c6a9..fc10bfd 100644 --- a/src/platform/linux_usbfs/hotplug.rs +++ b/src/platform/linux_usbfs/hotplug.rs @@ -1,15 +1,147 @@ -use std::{io::ErrorKind, task::Poll}; +use log::{debug, error, warn}; +use rustix::{ + fd::{AsFd, OwnedFd}, + net::{ + bind, + netlink::{self, SocketAddrNetlink}, + recvfrom, socket_with, AddressFamily, RecvFlags, SocketAddrAny, SocketFlags, SocketType, + }, +}; +use std::{io::ErrorKind, os::unix::prelude::BorrowedFd, path::Path, task::Poll}; use crate::{hotplug::HotplugEvent, Error}; -pub(crate) struct LinuxHotplugWatch {} +use super::{enumeration::probe_device, events::Async, SysfsPath}; + +const UDEV_MAGIC: &[u8; 12] = b"libudev\0\xfe\xed\xca\xfe"; +const UDEV_MULTICAST_GROUP: u32 = 1 << 1; + +pub(crate) struct LinuxHotplugWatch { + fd: Async, +} impl LinuxHotplugWatch { pub(crate) fn new() -> Result { - Err(Error::new(ErrorKind::Unsupported, "Not implemented.")) + let fd = socket_with( + AddressFamily::NETLINK, + SocketType::RAW, + SocketFlags::CLOEXEC, + Some(netlink::KOBJECT_UEVENT), + )?; + bind(&fd, &SocketAddrNetlink::new(0, UDEV_MULTICAST_GROUP))?; + Ok(LinuxHotplugWatch { + fd: Async::new(fd)?, + }) } pub(crate) fn poll_next(&mut self, cx: &mut std::task::Context<'_>) -> Poll { + if let Some(event) = try_receive_event(self.fd.inner.as_fd()) { + return Poll::Ready(event); + } + + if let Err(e) = self.fd.register(cx.waker()) { + log::error!("failed to register udev socket with epoll: {e}"); + } + Poll::Pending } } + +fn try_receive_event(fd: BorrowedFd) -> Option { + let mut buf = [0; 8192]; + + match recvfrom(fd, &mut buf, RecvFlags::DONTWAIT) { + // udev messages will normally be sent to a multicast group, which only + // root can send to. Reject unicast messages that may be from anywhere. + Ok((size, Some(SocketAddrAny::Netlink(nl)))) if nl.groups() == UDEV_MULTICAST_GROUP => { + parse_packet(&buf[..size]) + } + Ok((_, src)) => { + warn!("udev netlink socket received message from {src:?}"); + None + } + Err(e) if e.kind() == ErrorKind::WouldBlock => None, + Err(e) => { + error!("udev netlink socket recvfrom failed with {e}"); + None + } + } +} + +fn parse_packet(buf: &[u8]) -> Option { + if buf.len() < 24 { + error!("packet too short: {buf:x?}"); + return None; + } + + if !buf.starts_with(UDEV_MAGIC) { + error!("packet does not start with expected header: {buf:x?}"); + return None; + } + + let properties_off = u32::from_ne_bytes(buf[16..20].try_into().unwrap()) as usize; + let properties_len = u32::from_ne_bytes(buf[20..24].try_into().unwrap()) as usize; + let Some(properties_buf) = buf.get(properties_off..properties_off + properties_len) else { + error!("properties offset={properties_off} length={properties_len} exceeds buffer length {len}", len = buf.len()); + return None; + }; + + let mut is_add = None; + let mut busnum = None; + let mut devnum = None; + let mut devpath = None; + + for (k, v) in parse_properties(properties_buf) { + debug!("uevent property {k} = {v}"); + match k { + "SUBSYSTEM" if v != "usb" => return None, + "DEVTYPE" if v != "usb_device" => return None, + "ACTION" => { + is_add = Some(match v { + "add" => true, + "remove" => false, + _ => return None, + }); + } + "BUSNUM" => { + busnum = v.parse::().ok(); + } + "DEVNUM" => { + devnum = v.parse::().ok(); + } + "DEVPATH" => { + devpath = Some(v); + } + _ => {} + } + } + + let is_add = is_add?; + let busnum = busnum?; + let devnum = devnum?; + let devpath = devpath?; + + if is_add { + let path = Path::new("/sys/").join(devpath.trim_start_matches('/')); + match probe_device(SysfsPath(path.clone())) { + Ok(d) => Some(HotplugEvent::Connected(d)), + Err(e) => { + error!("Failed to probe device {path:?}: {e}"); + None + } + } + } else { + Some(HotplugEvent::Disconnected(crate::DeviceId( + super::DeviceId { + bus: busnum, + addr: devnum, + }, + ))) + } +} + +/// Split nul-separated key=value pairs +fn parse_properties(buf: &[u8]) -> impl Iterator + '_ { + buf.split(|b| b == &0) + .filter_map(|entry| std::str::from_utf8(entry).ok()?.split_once('=')) +}