From 8f64b48a3c345520397761b717abbaca87ab4819 Mon Sep 17 00:00:00 2001 From: Jacob Halsey Date: Sun, 21 Jan 2024 18:01:58 +0000 Subject: [PATCH] Redis backend now uses BITFIELD to store counts` --- CHANGES.md | 5 ++ Cargo.toml | 2 +- README.md | 22 ++++++++ src/backend/mod.rs | 2 +- src/backend/redis.rs | 124 +++++++++++++++++++++++-------------------- 5 files changed, 94 insertions(+), 61 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 5f399f9..4f04790 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,10 @@ # Changes +## 0.3.0 2024-01-21 + +- Breaking: Removes async-trait dependency +- Breaking: Redis backend now uses BITFIELD to store counts + ## 0.2.2 2022-04-19 - Improve documentation. diff --git a/Cargo.toml b/Cargo.toml index b8c677c..f2eade2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ actix-web = { version = "4", default-features = false, features = ["macros"] } dashmap = { version = "5.4.0", optional = true } futures = "0.3.28" log = "0.4.19" -redis = { version = "0.23.0", default-features = false, features = ["tokio-comp", "aio", "connection-manager"], optional = true } +redis = { version = "0.24.0", default-features = false, features = ["tokio-comp", "aio", "connection-manager"], optional = true } thiserror = "1.0.40" [features] diff --git a/README.md b/README.md index 9793ada..70bc4d0 100644 --- a/README.md +++ b/README.md @@ -52,3 +52,25 @@ async fn main() -> std::io::Result<()> { .await } ``` + +Try it out: + +``` +$ curl -v http://127.0.0.1:8080 +* Trying 127.0.0.1:8080... +* Connected to 127.0.0.1 (127.0.0.1) port 8080 (#0) +> GET / HTTP/1.1 +> Host: 127.0.0.1:8080 +> User-Agent: curl/7.83.1 +> Accept: */* +> +* Mark bundle as not supporting multiuse +< HTTP/1.1 404 Not Found +< content-length: 0 +< x-ratelimit-limit: 5 +< x-ratelimit-reset: 60 +< x-ratelimit-remaining: 4 +< date: Sun, 21 Jan 2024 16:52:27 GMT +< +* Connection #0 to host 127.0.0.1 left intact +``` \ No newline at end of file diff --git a/src/backend/mod.rs b/src/backend/mod.rs index fee8546..c191e0c 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -31,7 +31,7 @@ impl Decision { } pub fn is_allowed(self) -> bool { - self == Self::Allowed + matches!(self, Self::Allowed) } pub fn is_denied(self) -> bool { diff --git a/src/backend/redis.rs b/src/backend/redis.rs index e67d325..29e52e1 100644 --- a/src/backend/redis.rs +++ b/src/backend/redis.rs @@ -2,24 +2,13 @@ use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput use actix_web::rt::time::Instant; use actix_web::{HttpResponse, ResponseError}; use redis::aio::ConnectionManager; -use redis::AsyncCommands; +use redis::{AsyncCommands, Cmd}; use std::borrow::Cow; use std::time::Duration; use thiserror::Error; -// https://github.com/mitsuhiko/redis-rs/issues/353 -macro_rules! async_transaction { - ($conn:expr, $keys:expr, $body:expr) => { - loop { - redis::cmd("WATCH").arg($keys).query_async($conn).await?; - - if let Some(response) = $body { - redis::cmd("UNWATCH").query_async($conn).await?; - break response; - } - } - }; -} +const BITFIELD_ENCODING: &str = "u63"; +const BITFIELD_OFFSET: u8 = 0; #[derive(Debug, Error)] pub enum Error { @@ -29,7 +18,7 @@ pub enum Error { #[from] redis::RedisError, ), - #[error("Unexpected negative TTL response")] + #[error("Unexpected negative TTL response for the rate limit key")] NegativeTtl, } @@ -58,7 +47,7 @@ impl RedisBackend { /// ```no_run /// # use actix_extensible_rate_limit::backend::redis::RedisBackend; /// # use redis::aio::ConnectionManager; - /// # async { + /// # async fn example() { /// let client = redis::Client::open("redis://127.0.0.1/").unwrap(); /// let manager = ConnectionManager::new(client).await.unwrap(); /// let backend = RedisBackend::builder(manager).build(); @@ -112,26 +101,37 @@ impl Backend for RedisBackend { input: SimpleInput, ) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> { let key = self.make_key(&input.key); - // https://github.com/actix/actix-extras/blob/master/actix-limitation/src/lib.rs#L123 + let mut pipe = redis::pipe(); pipe.atomic() - .cmd("SET") // Set key and value + // Increment the rate limit count + .cmd("BITFIELD") .arg(key.as_ref()) - .arg(0i64) - .arg("EX") // Set the specified expire time, in seconds. - .arg(input.interval.as_secs()) - .arg("NX") // Only set the key if it does not already exist. - .ignore() // --- ignore returned value of SET command --- - .cmd("INCR") // Increment key + .arg("OVERFLOW") + .arg("SAT") + .arg("INCRBY") + .arg(BITFIELD_ENCODING) + .arg(BITFIELD_OFFSET) + .arg(1) + .arg("GET") + .arg(BITFIELD_ENCODING) + .arg(BITFIELD_OFFSET) + // Set the key to expire (only if it doesn't already have an expiry) + .cmd("EXPIRE") .arg(key.as_ref()) - .cmd("TTL") // Return time-to-live of key + .arg(input.interval.as_secs()) + .arg("NX") + .ignore() + // Return time-to-live of key + .cmd("TTL") .arg(key.as_ref()); let mut con = self.connection.clone(); - let (count, ttl): (u64, i64) = pipe.query_async(&mut con).await?; + let (counts, ttl): (Vec, i64) = pipe.query_async(&mut con).await?; if ttl < 0 { - return Err(Self::Error::NegativeTtl); + return Err(Error::NegativeTtl); } + let count = *counts.first().expect("BITFIELD should return one value"); let allow = count <= input.max_requests; let output = SimpleOutput { @@ -144,24 +144,20 @@ impl Backend for RedisBackend { async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> { let key = self.make_key(&token); + let mut con = self.connection.clone(); - async_transaction!(&mut con, &[key.as_ref()], { - let old_val: Option = con.get(key.as_ref()).await?; - if let Some(old_val) = old_val { - if old_val >= 1 { - redis::pipe() - .atomic() - .decr::<_, u64>(key.as_ref(), 1) - .ignore() - .query_async::<_, Option<()>>(&mut con) - .await? - } else { - Some(()) - } - } else { - Some(()) - } - }); + let mut cmd = Cmd::new(); + cmd.arg("BITFIELD") + .arg(key.as_ref()) + .arg("OVERFLOW") + .arg("SAT") + .arg("INCRBY") + .arg(BITFIELD_ENCODING) + .arg(BITFIELD_OFFSET) + .arg(-1); + + cmd.query_async(&mut con).await?; + Ok(()) } } @@ -203,13 +199,17 @@ mod tests { max_requests: 5, key: "test_allow_deny".to_string(), }; - for _ in 0..5 { + for i in (0..5).rev() { // First 5 should be allowed - let (decision, _, _) = backend.request(input.clone()).await.unwrap(); + let (decision, output, _) = backend.request(input.clone()).await.unwrap(); + assert_eq!(output.remaining, i); + assert_eq!(output.limit, 5); assert!(decision.is_allowed()); } // Sixth should be denied - let (decision, _, _) = backend.request(input.clone()).await.unwrap(); + let (decision, output, _) = backend.request(input.clone()).await.unwrap(); + assert_eq!(output.remaining, 0); + assert_eq!(output.limit, 5); assert!(decision.is_denied()); } @@ -224,9 +224,11 @@ mod tests { // Make first request, should be allowed let (decision, _, _) = backend.request(input.clone()).await.unwrap(); assert!(decision.is_allowed()); - // Request again, should be denied + + // Request again immediately afterwards, should now be denied let (decision, out, _) = backend.request(input.clone()).await.unwrap(); assert!(decision.is_denied()); + // Sleep until reset, should now be allowed tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await; let (decision, _, _) = backend.request(input).await.unwrap(); @@ -247,12 +249,14 @@ mod tests { assert_eq!(output.remaining, 1); assert_eq!(output.limit, 2); assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); + // Second of 2 should be allowed. let (decision, output, _) = backend.request(input.clone()).await.unwrap(); assert!(decision.is_allowed()); assert_eq!(output.remaining, 0); assert_eq!(output.limit, 2); assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60); + // Should be denied let (decision, output, _) = backend.request(input).await.unwrap(); assert!(decision.is_denied()); @@ -281,18 +285,20 @@ mod tests { #[actix_web::test] async fn test_rollback_key_gone() { - let backend = make_backend("test_rollback_key_gone").await.build(); + let key = "test_rollback_key_gone"; + let backend = make_backend(key).await.build(); let mut con = backend.connection.clone(); - // The rollback could happen after the key has already expired - backend - .rollback("test_rollback_key_gone".to_string()) - .await - .unwrap(); - // In which case nothing should happen - assert!(!con - .exists::<_, bool>("test_rollback_key_gone") - .await - .unwrap()); + // The rollback could happen after the key has already expired / gone + backend.rollback(key.to_string()).await.unwrap(); + // In which case the count should remain at 0 + let mut cmd = Cmd::new(); + cmd.arg("BITFIELD") + .arg(key) + .arg("GET") + .arg(BITFIELD_ENCODING) + .arg(BITFIELD_OFFSET); + let value: Vec = cmd.query_async(&mut con).await.unwrap(); + assert_eq!(value[0], 0u64); } #[actix_web::test]