Skip to content

Commit

Permalink
Remove async-trait
Browse files Browse the repository at this point in the history
  • Loading branch information
jacob-pro committed Jan 21, 2024
1 parent eda8261 commit 8fe3403
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 92 deletions.
4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "actix-extensible-rate-limit"
version = "0.2.3"
version = "0.3.0"
edition = "2021"
license = "MIT OR Apache-2.0"
description = "Rate limiting middleware for actix-web"
Expand All @@ -9,11 +9,9 @@ homepage = "https://github.com/jacob-pro/actix-extensible-rate-limit"

[dependencies]
actix-web = { version = "4", default-features = false, features = ["macros"] }
async-trait = "0.1.68"
dashmap = { version = "5.4.0", optional = true }
futures = "0.3.28"
log = "0.4.19"
once_cell = "1.17.1"
redis = { version = "0.23.0", default-features = false, features = ["tokio-comp", "aio", "connection-manager"], optional = true }
thiserror = "1.0.40"

Expand Down
49 changes: 23 additions & 26 deletions src/backend/memory.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::backend::{Backend, SimpleBackend, SimpleInput, SimpleOutput};
use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
use actix_web::rt::task::JoinHandle;
use actix_web::rt::time::Instant;
use async_trait::async_trait;
use dashmap::DashMap;
use std::convert::Infallible;
use std::sync::Arc;
Expand Down Expand Up @@ -68,7 +67,6 @@ impl Builder {
}
}

#[async_trait(?Send)]
impl Backend<SimpleInput> for InMemoryBackend {
type Output = SimpleOutput;
type RollbackToken = String;
Expand All @@ -77,7 +75,7 @@ impl Backend<SimpleInput> for InMemoryBackend {
async fn request(
&self,
input: SimpleInput,
) -> Result<(bool, Self::Output, Self::RollbackToken), Self::Error> {
) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
let now = Instant::now();
let mut count = 1;
let mut expiry = now
Expand Down Expand Up @@ -108,7 +106,7 @@ impl Backend<SimpleInput> for InMemoryBackend {
remaining: input.max_requests.saturating_sub(count),
reset: expiry,
};
Ok((allow, output, input.key))
Ok((Decision::from_allowed(allow), output, input.key))
}

async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
Expand All @@ -119,7 +117,6 @@ impl Backend<SimpleInput> for InMemoryBackend {
}
}

#[async_trait(?Send)]
impl SimpleBackend for InMemoryBackend {
async fn remove_key(&self, key: &str) -> Result<(), Self::Error> {
self.map.remove(key);
Expand Down Expand Up @@ -153,11 +150,11 @@ mod tests {
for _ in 0..5 {
// First 5 should be allowed
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
assert!(allow.is_allowed());
}
// Sixth should be denied
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
assert!(!allow.is_allowed());
}

#[actix_web::test]
Expand All @@ -170,17 +167,17 @@ mod tests {
key: "KEY1".to_string(),
};
// Make first request, should be allowed
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
// Request again, should be denied
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());
// Advance time and try again, should now be allowed
tokio::time::advance(MINUTE).await;
// We want to be sure the key hasn't been garbage collected, and we are testing the expiry logic
assert!(backend.map.contains_key("KEY1"));
let (allow, _, _) = backend.request(input).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input).await.unwrap();
assert!(decision.is_allowed());
}

#[actix_web::test]
Expand Down Expand Up @@ -224,20 +221,20 @@ mod tests {
key: "KEY1".to_string(),
};
// First of 2 should be allowed.
let (allow, output, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
assert_eq!(output.remaining, 1);
assert_eq!(output.limit, 2);
assert_eq!(output.reset, Instant::now() + MINUTE);
// Second of 2 should be allowed.
let (allow, output, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 2);
assert_eq!(output.reset, Instant::now() + MINUTE);
// Should be denied
let (allow, output, _) = backend.request(input).await.unwrap();
assert!(!allow);
let (decision, output, _) = backend.request(input).await.unwrap();
assert!(decision.is_denied());
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 2);
assert_eq!(output.reset, Instant::now() + MINUTE);
Expand Down Expand Up @@ -269,13 +266,13 @@ mod tests {
max_requests: 1,
key: "KEY1".to_string(),
};
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());
backend.remove_key("KEY1").await.unwrap();
// Counter should have been reset
let (allow, _, _) = backend.request(input).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input).await.unwrap();
assert!(decision.is_allowed());
}
}
41 changes: 30 additions & 11 deletions src/backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,41 @@ pub mod memory;
pub mod redis;

pub use input_builder::{SimpleInputFunctionBuilder, SimpleInputFuture};
use std::future::Future;

use crate::HeaderCompatibleOutput;
use actix_web::rt::time::Instant;
use async_trait::async_trait;
use std::time::Duration;

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Decision {
Allowed,
Denied,
}

impl Decision {
pub fn from_allowed(allowed: bool) -> Self {
if allowed {
Self::Allowed
} else {
Self::Denied
}
}

pub fn is_allowed(self) -> bool {
self == Self::Allowed
}

pub fn is_denied(self) -> bool {
matches!(self, Self::Denied)
}
}

/// Describes an implementation of a rate limiting store and algorithm.
///
/// To implement your own rate limiting backend it is recommended to use
/// [async_trait](https://github.com/dtolnay/async-trait), and add the `#[async_trait(?Send)]`
/// attribute onto your trait implementation.
///
/// A Backend is required to implement [Clone], usually this means wrapping your data store within
/// an [Arc](std::sync::Arc), although many connection pools already do so internally; there is no
/// need to wrap it twice.
#[async_trait(?Send)]
pub trait Backend<I: 'static = SimpleInput>: Clone {
type Output;
type RollbackToken;
Expand All @@ -38,10 +57,10 @@ pub trait Backend<I: 'static = SimpleInput>: Clone {
/// Returns a boolean of whether to allow or deny the request, arbitrary output that can be used
/// to transform the allowed and denied responses, and a token to allow the rate limit counter
/// to be rolled back in certain conditions.
async fn request(
fn request(
&self,
input: I,
) -> Result<(bool, Self::Output, Self::RollbackToken), Self::Error>;
) -> impl Future<Output = Result<(Decision, Self::Output, Self::RollbackToken), Self::Error>>;

/// Under certain conditions we may not want to rollback the request operation.
///
Expand All @@ -55,7 +74,8 @@ pub trait Backend<I: 'static = SimpleInput>: Clone {
/// # Arguments
///
/// * `token`: The token returned from the initial call to [Backend::request()].
async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error>;
fn rollback(&self, token: Self::RollbackToken)
-> impl Future<Output = Result<(), Self::Error>>;
}

/// A default [Backend] Input structure.
Expand Down Expand Up @@ -85,12 +105,11 @@ pub struct SimpleOutput {
}

/// Additional functions for a [Backend] that uses [SimpleInput] and [SimpleOutput].
#[async_trait(?Send)]
pub trait SimpleBackend: Backend<SimpleInput, Output = SimpleOutput> {
/// Removes the bucket for a given rate limit key.
///
/// Intended to be used to reset a key before changing the interval.
async fn remove_key(&self, key: &str) -> Result<(), Self::Error>;
fn remove_key(&self, key: &str) -> impl Future<Output = Result<(), Self::Error>>;
}

impl HeaderCompatibleOutput for SimpleOutput {
Expand Down
53 changes: 25 additions & 28 deletions src/backend/redis.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::backend::{Backend, SimpleBackend, SimpleInput, SimpleOutput};
use crate::backend::{Backend, Decision, SimpleBackend, SimpleInput, SimpleOutput};
use actix_web::rt::time::Instant;
use actix_web::{HttpResponse, ResponseError};
use async_trait::async_trait;
use redis::aio::ConnectionManager;
use redis::AsyncCommands;
use std::borrow::Cow;
Expand Down Expand Up @@ -103,7 +102,6 @@ impl Builder {
}
}

#[async_trait(?Send)]
impl Backend<SimpleInput> for RedisBackend {
type Output = SimpleOutput;
type RollbackToken = String;
Expand All @@ -112,7 +110,7 @@ impl Backend<SimpleInput> for RedisBackend {
async fn request(
&self,
input: SimpleInput,
) -> Result<(bool, Self::Output, Self::RollbackToken), Self::Error> {
) -> Result<(Decision, Self::Output, Self::RollbackToken), Self::Error> {
let key = self.make_key(&input.key);
// https://github.com/actix/actix-extras/blob/master/actix-limitation/src/lib.rs#L123
let mut pipe = redis::pipe();
Expand Down Expand Up @@ -141,7 +139,7 @@ impl Backend<SimpleInput> for RedisBackend {
remaining: input.max_requests.saturating_sub(count),
reset: Instant::now() + Duration::from_secs(ttl as u64),
};
Ok((allow, output, input.key))
Ok((Decision::from_allowed(allow), output, input.key))
}

async fn rollback(&self, token: Self::RollbackToken) -> Result<(), Self::Error> {
Expand All @@ -168,7 +166,6 @@ impl Backend<SimpleInput> for RedisBackend {
}
}

#[async_trait(?Send)]
impl SimpleBackend for RedisBackend {
/// Note that the key prefix (if set) is automatically included, you do not need to prepend
/// it yourself.
Expand Down Expand Up @@ -208,12 +205,12 @@ mod tests {
};
for _ in 0..5 {
// First 5 should be allowed
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
}
// Sixth should be denied
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());
}

#[actix_web::test]
Expand All @@ -225,15 +222,15 @@ mod tests {
key: "test_reset".to_string(),
};
// Make first request, should be allowed
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
// Request again, should be denied
let (allow, out, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
let (decision, out, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());
// Sleep until reset, should now be allowed
tokio::time::sleep(Duration::from_secs(out.seconds_until_reset())).await;
let (allow, _, _) = backend.request(input).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input).await.unwrap();
assert!(decision.is_allowed());
}

#[actix_web::test]
Expand All @@ -245,20 +242,20 @@ mod tests {
key: "test_output".to_string(),
};
// First of 2 should be allowed.
let (allow, output, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
assert_eq!(output.remaining, 1);
assert_eq!(output.limit, 2);
assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
// Second of 2 should be allowed.
let (allow, output, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (decision, output, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 2);
assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
// Should be denied
let (allow, output, _) = backend.request(input).await.unwrap();
assert!(!allow);
let (decision, output, _) = backend.request(input).await.unwrap();
assert!(decision.is_denied());
assert_eq!(output.remaining, 0);
assert_eq!(output.limit, 2);
assert!(output.seconds_until_reset() > 0 && output.seconds_until_reset() <= 60);
Expand Down Expand Up @@ -306,14 +303,14 @@ mod tests {
max_requests: 1,
key: "test_remove_key".to_string(),
};
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(allow);
let (allow, _, _) = backend.request(input.clone()).await.unwrap();
assert!(!allow);
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_allowed());
let (decision, _, _) = backend.request(input.clone()).await.unwrap();
assert!(decision.is_denied());
backend.remove_key("test_remove_key").await.unwrap();
// Counter should have been reset
let (allow, _, _) = backend.request(input).await.unwrap();
assert!(allow);
let (decision, _, _) = backend.request(input).await.unwrap();
assert!(decision.is_allowed());
}

#[actix_web::test]
Expand Down
Loading

0 comments on commit 8fe3403

Please sign in to comment.