From 0e145f89a698a6743be3459b72ca87b91f96a971 Mon Sep 17 00:00:00 2001 From: Kevin Mehall Date: Sat, 1 Feb 2025 19:42:37 -0700 Subject: [PATCH] Switch to rustix for netlink bind/recvfrom --- Cargo.toml | 1 - src/platform/linux_usbfs/hotplug.rs | 78 +++++++++-------------------- 2 files changed, 24 insertions(+), 55 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1d57c47..b0afabd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,6 @@ futures-lite = "1.13.0" [target.'cfg(any(target_os="linux", target_os="android"))'.dependencies] rustix = { version = "1.0.1", features = ["fs", "event", "net"] } -libc = "0.2.155" [target.'cfg(target_os="windows")'.dependencies] windows-sys = { version = "0.59.0", features = ["Win32_Devices_Usb", "Win32_Devices_DeviceAndDriverInstallation", "Win32_Foundation", "Win32_Devices_Properties", "Win32_Storage_FileSystem", "Win32_Security", "Win32_System_IO", "Win32_System_Registry", "Win32_System_Com"] } diff --git a/src/platform/linux_usbfs/hotplug.rs b/src/platform/linux_usbfs/hotplug.rs index a2d4814..28ddc04 100644 --- a/src/platform/linux_usbfs/hotplug.rs +++ b/src/platform/linux_usbfs/hotplug.rs @@ -1,16 +1,14 @@ -use libc::{sockaddr, sockaddr_nl, socklen_t, AF_NETLINK, MSG_DONTWAIT}; use log::{error, trace, warn}; use rustix::{ - fd::{AsFd, AsRawFd, OwnedFd}, - net::{netlink, socket_with, AddressFamily, SocketFlags, SocketType}, -}; -use std::{ - io::ErrorKind, - mem, - os::{raw::c_void, unix::prelude::BorrowedFd}, - path::Path, - task::Poll, + fd::{AsFd, OwnedFd}, + io::Errno, + net::{ + bind, + netlink::{self, SocketAddrNetlink}, + recvfrom, socket_with, AddressFamily, RecvFlags, SocketFlags, SocketType, + }, }; +use std::{mem::MaybeUninit, os::unix::prelude::BorrowedFd, path::Path, task::Poll}; use crate::{hotplug::HotplugEvent, Error}; @@ -31,23 +29,7 @@ impl LinuxHotplugWatch { SocketFlags::CLOEXEC, Some(netlink::KOBJECT_UEVENT), )?; - - unsafe { - // rustix doesn't support netlink yet (pending https://github.com/bytecodealliance/rustix/pull/1004) - // so use libc for now. - let mut addr: sockaddr_nl = mem::zeroed(); - addr.nl_family = AF_NETLINK as u16; - addr.nl_groups = UDEV_MULTICAST_GROUP; - let r = libc::bind( - fd.as_raw_fd(), - &addr as *const sockaddr_nl as *const sockaddr, - mem::size_of_val(&addr) as socklen_t, - ); - if r != 0 { - return Err(Error::last_os_error()); - } - } - + bind(&fd, &SocketAddrNetlink::new(0, UDEV_MULTICAST_GROUP))?; Ok(LinuxHotplugWatch { fd: Async::new(fd)?, }) @@ -67,40 +49,28 @@ impl LinuxHotplugWatch { } fn try_receive_event(fd: BorrowedFd) -> Option { - let mut buf = [0; 8192]; + let mut buf = [MaybeUninit::uninit(); 8192]; - let received = unsafe { - let mut addr: sockaddr_nl = mem::zeroed(); - let mut addrlen: socklen_t = mem::size_of_val(&addr) as socklen_t; - let r = libc::recvfrom( - fd.as_raw_fd(), - buf.as_mut_ptr() as *mut c_void, - buf.len(), - MSG_DONTWAIT, - &mut addr as *mut sockaddr_nl as *mut sockaddr, - &mut addrlen, - ); - if r >= 0 { - Ok((r as usize, addr.nl_groups)) - } else { - Err(Error::last_os_error()) + let (data, src) = match recvfrom(fd, &mut buf, RecvFlags::DONTWAIT) { + Ok(((buf, _), _, src)) => (buf, src), + Err(Errno::AGAIN | Errno::INTR) => return None, + Err(e) => { + error!("udev netlink socket recvfrom failed with {e}"); + return None; } }; - match received { - // udev messages will normally be sent to a multicast group, which only - // root can send to. Reject unicast messages that may be from anywhere. - Ok((size, groups)) if groups == UDEV_MULTICAST_GROUP => parse_packet(&buf[..size]), - Ok((_, src)) => { + // udev messages will normally be sent to a multicast group, which only + // root can send to. Reject unicast messages that may be from anywhere. + match src.map(SocketAddrNetlink::try_from).transpose() { + Ok(Some(nl)) if nl.groups() == UDEV_MULTICAST_GROUP => {} + src => { warn!("udev netlink socket received message from {src:?}"); - None - } - Err(e) if e.kind() == ErrorKind::WouldBlock => None, - Err(e) => { - error!("udev netlink socket recvfrom failed with {e}"); - None + return None; } } + + parse_packet(data) } fn parse_packet(buf: &[u8]) -> Option {