From 8d2e5908860d97858fb356128dec8ca71baf3277 Mon Sep 17 00:00:00 2001 From: Thomas Barrett Date: Mon, 4 Sep 2023 14:10:45 -0700 Subject: [PATCH] rate-limiter: Allow RateLimiter to be shared between threads Signed-off-by: Thomas Barrett --- rate_limiter/src/lib.rs | 195 ++++++++++++++++++++++------------------ 1 file changed, 106 insertions(+), 89 deletions(-) diff --git a/rate_limiter/src/lib.rs b/rate_limiter/src/lib.rs index 977704668..922c402af 100644 --- a/rate_limiter/src/lib.rs +++ b/rate_limiter/src/lib.rs @@ -46,9 +46,11 @@ #[macro_use] extern crate log; +use std::io; use std::os::unix::io::{AsRawFd, RawFd}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; use std::time::{Duration, Instant}; -use std::{fmt, io}; use vmm_sys_util::timerfd::TimerFd; #[derive(Debug)] @@ -271,27 +273,27 @@ pub enum BucketUpdate { /// implementation. These events are meant to be consumed by the user of this struct. /// On each such event, the user must call the `event_handler()` method. pub struct RateLimiter { + inner: Mutex, + + // Internal flag that quickly determines timer state. + timer_active: AtomicBool, +} + +struct RateLimiterInner { bandwidth: Option, ops: Option, timer_fd: TimerFd, - // Internal flag that quickly determines timer state. - timer_active: bool, } -impl PartialEq for RateLimiter { - fn eq(&self, other: &RateLimiter) -> bool { - self.bandwidth == other.bandwidth && self.ops == other.ops - } -} - -impl fmt::Debug for RateLimiter { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "RateLimiter {{ bandwidth: {:?}, ops: {:?} }}", - self.bandwidth, self.ops - ) +impl RateLimiterInner { + // Arm the timer of the rate limiter with the provided `Duration` (which will fire only once). + fn activate_timer(&mut self, dur: Duration, flag: &AtomicBool) { + // Panic when failing to arm the timer (same handling in crate TimerFd::set_state()) + self.timer_fd + .reset(dur, None) + .expect("Can't arm the timer (unexpected 'timerfd_settime' failure)."); + flag.store(true, Ordering::Relaxed) } } @@ -355,35 +357,28 @@ impl RateLimiter { } Ok(RateLimiter { - bandwidth: bytes_token_bucket, - ops: ops_token_bucket, - timer_fd, - timer_active: false, + inner: Mutex::new(RateLimiterInner { + bandwidth: bytes_token_bucket, + ops: ops_token_bucket, + timer_fd, + }), + timer_active: AtomicBool::new(false), }) } - // Arm the timer of the rate limiter with the provided `Duration` (which will fire only once). - fn activate_timer(&mut self, dur: Duration) { - // Panic when failing to arm the timer (same handling in crate TimerFd::set_state()) - self.timer_fd - .reset(dur, None) - .expect("Can't arm the timer (unexpected 'timerfd_settime' failure)."); - self.timer_active = true; - } - /// Attempts to consume tokens and returns whether that is possible. /// /// If rate limiting is disabled on provided `token_type`, this function will always succeed. - pub fn consume(&mut self, tokens: u64, token_type: TokenType) -> bool { + pub fn consume(&self, tokens: u64, token_type: TokenType) -> bool { // If the timer is active, we can't consume tokens from any bucket and the function fails. - if self.timer_active { + if self.is_blocked() { return false; } - + let mut guard = self.inner.lock().unwrap(); // Identify the required token bucket. let token_bucket = match token_type { - TokenType::Bytes => self.bandwidth.as_mut(), - TokenType::Ops => self.ops.as_mut(), + TokenType::Bytes => guard.bandwidth.as_mut(), + TokenType::Ops => guard.ops.as_mut(), }; // Try to consume from the token bucket. if let Some(bucket) = token_bucket { @@ -393,8 +388,8 @@ impl RateLimiter { // register a timer to replenish the bucket and resume processing; // make sure there is only one running timer for this limiter. BucketReduction::Failure => { - if !self.timer_active { - self.activate_timer(TIMER_REFILL_DUR); + if !self.is_blocked() { + guard.activate_timer(TIMER_REFILL_DUR, &self.timer_active); } false } @@ -409,7 +404,10 @@ impl RateLimiter { // order to enforce the bandwidth limit we need to prevent // further calls to the rate limiter for // `ratio * refill_time` milliseconds. - self.activate_timer(Duration::from_millis((ratio * refill_time as f64) as u64)); + guard.activate_timer( + Duration::from_millis((ratio * refill_time as f64) as u64), + &self.timer_active, + ); true } } @@ -424,11 +422,12 @@ impl RateLimiter { /// /// Can be used to *manually* add tokens to a bucket. Useful for reverting a /// `consume()` if needed. - pub fn manual_replenish(&mut self, tokens: u64, token_type: TokenType) { + pub fn manual_replenish(&self, tokens: u64, token_type: TokenType) { + let mut guard = self.inner.lock().unwrap(); // Identify the required token bucket. let token_bucket = match token_type { - TokenType::Bytes => self.bandwidth.as_mut(), - TokenType::Ops => self.ops.as_mut(), + TokenType::Bytes => guard.bandwidth.as_mut(), + TokenType::Ops => guard.ops.as_mut(), }; // Add tokens to the token bucket. if let Some(bucket) = token_bucket { @@ -442,7 +441,7 @@ impl RateLimiter { /// budget for it. /// An event will be generated on the exported FD when the limiter 'unblocks'. pub fn is_blocked(&self) -> bool { - self.timer_active + self.timer_active.load(Ordering::Relaxed) } /// This function needs to be called every time there is an event on the @@ -451,11 +450,12 @@ impl RateLimiter { /// # Errors /// /// If the rate limiter is disabled or is not blocked, an error is returned. - pub fn event_handler(&mut self) -> Result<(), Error> { + pub fn event_handler(&self) -> Result<(), Error> { + let mut guard = self.inner.lock().unwrap(); loop { // Note: As we manually added the `O_NONBLOCK` flag to the FD, the following // `timer_fd::wait()` won't block (which is different from its default behavior.) - match self.timer_fd.wait() { + match guard.timer_fd.wait() { Err(e) => { let err: std::io::Error = e.into(); match err.kind() { @@ -469,7 +469,7 @@ impl RateLimiter { } } _ => { - self.timer_active = false; + self.timer_active.store(false, Ordering::Relaxed); return Ok(()); } } @@ -479,27 +479,18 @@ impl RateLimiter { /// Updates the parameters of the token buckets associated with this RateLimiter. // TODO: Please note that, right now, the buckets become full after being updated. pub fn update_buckets(&mut self, bytes: BucketUpdate, ops: BucketUpdate) { + let mut guard = self.inner.lock().unwrap(); match bytes { - BucketUpdate::Disabled => self.bandwidth = None, - BucketUpdate::Update(tb) => self.bandwidth = Some(tb), + BucketUpdate::Disabled => guard.bandwidth = None, + BucketUpdate::Update(tb) => guard.bandwidth = Some(tb), BucketUpdate::None => (), }; match ops { - BucketUpdate::Disabled => self.ops = None, - BucketUpdate::Update(tb) => self.ops = Some(tb), + BucketUpdate::Disabled => guard.ops = None, + BucketUpdate::Update(tb) => guard.ops = Some(tb), BucketUpdate::None => (), }; } - - /// Returns an immutable view of the inner bandwidth token bucket. - pub fn bandwidth(&self) -> Option<&TokenBucket> { - self.bandwidth.as_ref() - } - - /// Returns an immutable view of the inner ops token bucket. - pub fn ops(&self) -> Option<&TokenBucket> { - self.ops.as_ref() - } } impl AsRawFd for RateLimiter { @@ -510,7 +501,8 @@ impl AsRawFd for RateLimiter { /// Will return a negative value if rate limiting is disabled on both /// token types. fn as_raw_fd(&self) -> RawFd { - self.timer_fd.as_raw_fd() + let guard = self.inner.lock().unwrap(); + guard.timer_fd.as_raw_fd() } } @@ -525,6 +517,7 @@ impl Default for RateLimiter { #[cfg(test)] pub(crate) mod tests { use super::*; + use std::fmt; use std::thread; use std::time::Duration; @@ -557,11 +550,33 @@ pub(crate) mod tests { } impl RateLimiter { - fn get_token_bucket(&self, token_type: TokenType) -> Option<&TokenBucket> { - match token_type { - TokenType::Bytes => self.bandwidth.as_ref(), - TokenType::Ops => self.ops.as_ref(), - } + pub fn bandwidth(&self) -> Option { + let guard = self.inner.lock().unwrap(); + guard.bandwidth.clone() + } + + pub fn ops(&self) -> Option { + let guard = self.inner.lock().unwrap(); + guard.ops.clone() + } + } + + impl PartialEq for RateLimiter { + fn eq(&self, other: &RateLimiter) -> bool { + let self_guard = self.inner.lock().unwrap(); + let other_guard = other.inner.lock().unwrap(); + self_guard.bandwidth == other_guard.bandwidth && self_guard.ops == other_guard.ops + } + } + + impl fmt::Debug for RateLimiter { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let guard = self.inner.lock().unwrap(); + write!( + f, + "RateLimiter {{ bandwidth: {:?}, ops: {:?} }}", + guard.bandwidth, guard.ops + ) } } @@ -640,7 +655,7 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_default() { - let mut l = RateLimiter::default(); + let l = RateLimiter::default(); // limiter should not be blocked assert!(!l.is_blocked()); @@ -659,14 +674,13 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_new() { let l = RateLimiter::new(1000, 1001, 1002, 1003, 1004, 1005).unwrap(); - - let bw = l.bandwidth.unwrap(); + let bw = l.bandwidth().unwrap(); assert_eq!(bw.capacity(), 1000); assert_eq!(bw.one_time_burst(), 1001); assert_eq!(bw.refill_time_ms(), 1002); assert_eq!(bw.budget(), 1000); - let ops = l.ops.unwrap(); + let ops = l.ops().unwrap(); assert_eq!(ops.capacity(), 1003); assert_eq!(ops.one_time_burst(), 1004); assert_eq!(ops.refill_time_ms(), 1005); @@ -676,20 +690,20 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_manual_replenish() { // rate limiter with limit of 1000 bytes/s and 1000 ops/s - let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); + let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); // consume 123 bytes assert!(l.consume(123, TokenType::Bytes)); l.manual_replenish(23, TokenType::Bytes); { - let bytes_tb = l.get_token_bucket(TokenType::Bytes).unwrap(); + let bytes_tb = l.bandwidth().unwrap(); assert_eq!(bytes_tb.budget(), 900); } // consume 123 ops assert!(l.consume(123, TokenType::Ops)); l.manual_replenish(23, TokenType::Ops); { - let bytes_tb = l.get_token_bucket(TokenType::Ops).unwrap(); + let bytes_tb = l.ops().unwrap(); assert_eq!(bytes_tb.budget(), 900); } } @@ -697,7 +711,7 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_bandwidth() { // rate limiter with limit of 1000 bytes/s - let mut l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap(); + let l = RateLimiter::new(1000, 0, 1000, 0, 0, 0).unwrap(); // limiter should not be blocked assert!(!l.is_blocked()); @@ -730,7 +744,7 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_ops() { // rate limiter with limit of 1000 ops/s - let mut l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap(); + let l = RateLimiter::new(0, 0, 0, 1000, 0, 1000).unwrap(); // limiter should not be blocked assert!(!l.is_blocked()); @@ -763,7 +777,7 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_full() { // rate limiter with limit of 1000 bytes/s and 1000 ops/s - let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); + let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); // limiter should not be blocked assert!(!l.is_blocked()); @@ -799,7 +813,7 @@ pub(crate) mod tests { #[test] fn test_rate_limiter_overconsumption() { // initialize the rate limiter - let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); + let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); // try to consume 2.5x the bucket size // we are "borrowing" 1.5x the bucket size in tokens since // the bucket is full @@ -818,7 +832,7 @@ pub(crate) mod tests { assert!(!l.is_blocked()); // reset the rate limiter - let mut l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); + let l = RateLimiter::new(1000, 0, 1000, 1000, 0, 1000).unwrap(); // try to consume 1.5x the bucket size // we are "borrowing" 1.5x the bucket size in tokens since // the bucket is full, should arm the timer to 0.5x replenish @@ -857,12 +871,12 @@ pub(crate) mod tests { fn test_update_buckets() { let mut x = RateLimiter::new(1000, 2000, 1000, 10, 20, 1000).unwrap(); - let initial_bw = x.bandwidth.clone(); - let initial_ops = x.ops.clone(); + let initial_bw = x.bandwidth(); + let initial_ops = x.ops(); x.update_buckets(BucketUpdate::None, BucketUpdate::None); - assert_eq!(x.bandwidth, initial_bw); - assert_eq!(x.ops, initial_ops); + assert_eq!(x.bandwidth(), initial_bw); + assert_eq!(x.ops(), initial_ops); let new_bw = TokenBucket::new(123, 0, 57).unwrap(); let new_ops = TokenBucket::new(321, 12346, 89).unwrap(); @@ -871,18 +885,21 @@ pub(crate) mod tests { BucketUpdate::Update(new_ops.clone()), ); - // We have manually adjust the last_update field, because it changes when update_buckets() - // constructs new buckets (and thus gets a different value for last_update). We do this so - // it makes sense to test the following assertions. - x.bandwidth.as_mut().unwrap().last_update = new_bw.last_update; - x.ops.as_mut().unwrap().last_update = new_ops.last_update; + { + let mut guard = x.inner.lock().unwrap(); + // We have manually adjust the last_update field, because it changes when update_buckets() + // constructs new buckets (and thus gets a different value for last_update). We do this so + // it makes sense to test the following assertions. + guard.bandwidth.as_mut().unwrap().last_update = new_bw.last_update; + guard.ops.as_mut().unwrap().last_update = new_ops.last_update; + } - assert_eq!(x.bandwidth, Some(new_bw)); - assert_eq!(x.ops, Some(new_ops)); + assert_eq!(x.bandwidth(), Some(new_bw)); + assert_eq!(x.ops(), Some(new_ops)); x.update_buckets(BucketUpdate::Disabled, BucketUpdate::Disabled); - assert_eq!(x.bandwidth, None); - assert_eq!(x.ops, None); + assert_eq!(x.bandwidth(), None); + assert_eq!(x.ops(), None); } #[test]