diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fcd7224..09f1dbd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,6 +21,16 @@ jobs: - run: cargo make deny-check - run: cargo make docs test: + services: + aws: + image: public.ecr.aws/localstack/localstack:4 + ports: + - "4566:4566" + env: + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT_URL: http://localhost:4566 + AWS_REGION: eu-west-1 # Avoid duplicate jobs on PR from a branch on the same repo if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name runs-on: ubuntu-latest diff --git a/.rustfmt.toml b/.rustfmt.toml index 2907f0d..1babf79 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1,3 +1,4 @@ +edition = "2021" reorder_imports = true reorder_modules = true max_width = 120 diff --git a/CHANGELOG.md b/CHANGELOG.md index 065d812..91ec771 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,22 @@ and this project adheres to --- +## [0.21.0] - 2025-02-10 + +# Fixed + +- Fixed docs.rs not building the documentation + +# Deprecated + +- Deprecated a lot of old auth0 APIs. See the docs for alternatives to use. + +# Added + +- DynamoDB cache provider + +--- + ## [0.20.0] - 2024-12-02 ### Added @@ -496,7 +512,9 @@ Request::rest(&bridge).send() The old API is still available but deprecated. It will be removed soon. -[Unreleased]: https://github.com/primait/bridge.rs/compare/0.20.0...HEAD + +[Unreleased]: https://github.com/primait/bridge.rs/compare/0.21.0...HEAD +[0.21.0]: https://github.com/primait/bridge.rs/compare/0.20.0...0.21.0 [0.20.0]: https://github.com/primait/bridge.rs/compare/0.19.0...0.20.0 [0.19.0]: https://github.com/primait/bridge.rs/compare/0.18.0...0.19.0 [0.18.0]: https://github.com/primait/bridge.rs/compare/0.17.0...0.18.0 diff --git a/Cargo.toml b/Cargo.toml index 12f51ce..570d0e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,13 +6,17 @@ license = "MIT" name = "prima_bridge" readme = "README.md" repository = "https://github.com/primait/bridge.rs" -version = "0.20.0" +version = "0.21.0" # See https://github.com/rust-lang/rust/issues/107557 -rust-version = "1.72" +rust-version = "1.81" [features] default = ["tracing_opentelemetry"] +# Feature set that should be used +# This exists to avoid compatibility issues with otel version conflicts +_docs = ["auth0", "cache-dynamodb", "grpc", "gzip", "redis-tls", "tracing_opentelemetry"] + auth0 = [ "rand", "redis", @@ -22,9 +26,12 @@ auth0 = [ "dashmap", "tracing", ] -grpc = ["tonic"] +grpc = [ "_any_otel_version", "tonic"] gzip = ["reqwest/gzip"] -redis-tls = ["redis/tls", "redis/tokio-native-tls-comp"] + +redis-tls = [ "redis", "redis/tls", "redis/tokio-native-tls-comp"] +cache-dynamodb = [ "aws-sdk-dynamodb" ] + tracing_opentelemetry = ["tracing_opentelemetry_0_27"] tracing_opentelemetry_0_21 = [ @@ -102,6 +109,7 @@ tonic = { version = "0.12", default-features = false, optional = true } tracing = { version = "0.1", optional = true } uuid = { version = ">=0.7.0, <2.0.0", features = ["serde", "v4"] } chacha20poly1305 = { version = "0.10.1", features = ["std"], optional = true } +aws-sdk-dynamodb = { version = "1.63", optional = true } reqwest-middleware = { version = "0.4.0", features = ["json", "multipart"] } http = "1.0.0" @@ -129,6 +137,7 @@ tracing-opentelemetry_0_27_pkg = { package = "tracing-opentelemetry", version = tracing-opentelemetry_0_28_pkg = { package = "tracing-opentelemetry", version = "0.28", optional = true } [dev-dependencies] +aws-config = { version = "1.5.16", features = ["behavior-version-latest"] } flate2 = "1.0" mockito = "1.0" tokio = { version = "1.16", features = ["macros", "rt-multi-thread"] } @@ -139,7 +148,11 @@ codegen-units = 1 lto = "thin" [package.metadata.docs.rs] -all-features = true +all-features = false +# Avoid opentelemetry version conflicts +features = [ + "_docs" +] rustdoc-args = ["--cfg", "docsrs"] [[example]] diff --git a/Dockerfile b/Dockerfile index f337834..3b209d6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM rust:1.72 +FROM rust:1.81 WORKDIR /code diff --git a/Makefile.toml b/Makefile.toml index 406290f..cf3ff35 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -42,7 +42,7 @@ dependencies = ["build"] [tasks.test-auth0] command = "cargo" -args = ["test", "--features=auth0,gzip", "${@}"] +args = ["test", "--features=auth0,gzip,cache-dynamodb", "${@}"] dependencies = ["build"] [tasks.test-all-otel-versions] @@ -80,7 +80,7 @@ dependencies = ["build"] command = "cargo" args = [ "clippy", - "--features=auth0,gzip", + "--features=auth0,cache-dynamodb,gzip", "--all-targets", "--", "-D", diff --git a/docker-compose.yml b/docker-compose.yml index 186f912..a4190db 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,6 +8,10 @@ services: CARGO_HOME: /home/app/.cargo CARGO_TARGET_DIR: /home/app/target CARGO_MAKE_DISABLE_UPDATE_CHECK: 1 + AWS_ACCESS_KEY_ID: test + AWS_SECRET_ACCESS_KEY: test + AWS_ENDPOINT_URL: http://aws:4566 + AWS_REGION: eu-west-1 volumes: - ".:/code" - "app:/home/app/" @@ -17,6 +21,12 @@ services: - "~/.gitignore:/home/app/.gitignore" tty: true stdin_open: true + depends_on: + - aws + aws: + image: public.ecr.aws/localstack/localstack:4 + ports: + - "4566:4566" volumes: app: diff --git a/examples/auth0.rs b/examples/auth0.rs index feb1921..9b0d0ca 100644 --- a/examples/auth0.rs +++ b/examples/auth0.rs @@ -35,7 +35,7 @@ const QUERY: &str = "query($input:JobsInput!){jobs(input:$input) {\nid\n title\n #[tokio::main] async fn main() { let bridge: Bridge = Bridge::builder() - .with_auth0(auth0::config()) + .with_refreshing_token(auth0::refreshing_token().await) .await .build(URL.parse().unwrap()); @@ -113,30 +113,30 @@ pub struct Job { mod auth0 { use std::time::Duration; - use prima_bridge::auth0::{CacheType, Config, StalenessCheckPercentage}; - - pub fn config() -> Config { - use reqwest::Url; - use std::str::FromStr; + use prima_bridge::auth0::{cache::InMemoryCache, Auth0Client, RefreshingToken, StalenessCheckPercentage}; + pub async fn refreshing_token() -> RefreshingToken { let token_url: String = std::env::var("TOKEN_URL").unwrap(); - let jwks_url: String = std::env::var("JWKS_URL").unwrap(); let client_id: String = std::env::var("CLIENT_ID").unwrap(); let client_secret: String = std::env::var("CLIENT_SECRET").unwrap(); let audience: String = std::env::var("AUDIENCE").unwrap(); - Config { - token_url: Url::from_str(token_url.as_str()).unwrap(), - jwks_url: Url::from_str(jwks_url.as_str()).unwrap(), - caller: "paperboy".to_string(), - audience, - cache_type: CacheType::Inmemory, - token_encryption_key: "32char_long_token_encryption_key".to_string(), - check_interval: Duration::from_secs(2), - staleness_check_percentage: StalenessCheckPercentage::new(0.1, 0.5), + let auth0_client = Auth0Client::new( + token_url.parse().unwrap(), + reqwest::Client::default(), client_id, client_secret, - scope: Some("profile email".to_string()), - } + ); + + RefreshingToken::new( + auth0_client, + Duration::from_secs(10), + StalenessCheckPercentage::default(), + Box::new(InMemoryCache::default()), + audience, + None, + ) + .await + .unwrap() } } diff --git a/src/auth0/cache/crypto.rs b/src/auth0/cache/crypto.rs index 343c4de..7e51de8 100644 --- a/src/auth0/cache/crypto.rs +++ b/src/auth0/cache/crypto.rs @@ -2,11 +2,17 @@ use chacha20poly1305::{aead::Aead, AeadCore, KeyInit, XChaCha20Poly1305}; use rand::thread_rng; use serde::{Deserialize, Serialize}; -use crate::auth0::errors::Auth0Error; - const NONCE_SIZE: usize = 24; -pub fn encrypt(value_ref: &T, token_encryption_key_str: &str) -> Result, Auth0Error> { +#[derive(thiserror::Error, Debug)] +pub enum CryptoError { + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error(transparent)] + ChaCha20Poly1305(#[from] chacha20poly1305::Error), +} + +pub fn encrypt(value_ref: &T, token_encryption_key_str: &str) -> Result, CryptoError> { let json: String = serde_json::to_string(value_ref)?; let enc = XChaCha20Poly1305::new_from_slice(token_encryption_key_str.as_bytes()).unwrap(); @@ -18,7 +24,7 @@ pub fn encrypt(value_ref: &T, token_encryption_key_str: &str) -> R Ok(ciphertext) } -pub fn decrypt(token_encryption_key_str: &str, encrypted: &[u8]) -> Result +pub fn decrypt(token_encryption_key_str: &str, encrypted: &[u8]) -> Result where for<'de> T: Deserialize<'de>, { @@ -28,7 +34,7 @@ where let nonce = encrypted.get(encrypted.len() - NONCE_SIZE..); let (Some(ciphertext), Some(nonce)) = (ciphertext, nonce) else { - return Err(Auth0Error::CryptoError(chacha20poly1305::Error)); + return Err(CryptoError::ChaCha20Poly1305(chacha20poly1305::Error)); }; let nonce = chacha20poly1305::XNonce::from_slice(nonce); diff --git a/src/auth0/cache/dynamodb.rs b/src/auth0/cache/dynamodb.rs new file mode 100644 index 0000000..bda83d2 --- /dev/null +++ b/src/auth0/cache/dynamodb.rs @@ -0,0 +1,207 @@ +use std::{error::Error, time::Duration}; + +pub use aws_sdk_dynamodb; +use aws_sdk_dynamodb::{ + client::Waiters, + operation::describe_table::DescribeTableError, + types::{ + AttributeDefinition, AttributeValue, KeySchemaElement, KeyType, ProvisionedThroughput, TimeToLiveSpecification, + }, +}; + +use crate::auth0::Token; + +use super::{Cache, CacheError}; + +#[derive(thiserror::Error, Debug)] +pub enum DynamoDBCacheError { + #[error("AWS error when interacting with dynamo cache: {0}")] + Aws(Box), + #[error("Data in database is wrong. Key: {0}")] + SchemaError(String), +} + +impl From for super::CacheError { + fn from(val: DynamoDBCacheError) -> Self { + CacheError(Box::new(val)) + } +} + +/// A cache using the AWS DynamoDB +#[derive(Debug)] +pub struct DynamoDBCache { + table_name: String, + client: aws_sdk_dynamodb::Client, +} + +impl DynamoDBCache { + /// Construct a DynamoDBCache instance which uses a given table name and client + /// + /// Note: this method doesn't correctly check whether a table with the given name exists during creation. + /// If needed you can call [DynamoDBCache::create_update_dynamo_table]. + /// DynamoDBCache expects client to have full aws permissions on the table_name table. + /// + /// To ensure the table is setup properly most users will want to call the + /// [DynamoDBCache::create_update_dynamo_table] function and let the library + /// do it for you. + /// + /// If you want to create the table yourself keep in mind that while schema changes + /// will be documented in the changelog, we do not consider the schema a part of semver's + /// guarantees and might alter it in patch/minor releases. If you disagree with this policy, + /// we are open to discussing it, open an issue. + /// + /// Currently bridge.rs expects a table with: + /// - one string key attribute, named `key` of type hash + /// - a time to live attribute named `expiration` + pub fn new(client: aws_sdk_dynamodb::Client, table_name: String) -> Self { + Self { client, table_name } + } + + /// Create table or update the schema for a table created by a previous bridge.rs release. + pub async fn create_update_dynamo_table(&self) -> Result<(), DynamoDBCacheError> { + match self + .client + .describe_table() + .table_name(&self.table_name) + .send() + .await + .map_err(|e| e.into_service_error()) + { + Ok(_) => return Ok(()), + Err(DescribeTableError::ResourceNotFoundException(_)) => (), + Err(e) => return Err(DynamoDBCacheError::Aws(Box::new(e))), + }; + + self.client + .create_table() + .table_name(self.table_name.clone()) + .attribute_definitions( + AttributeDefinition::builder() + .attribute_name("key".to_string()) + .attribute_type(aws_sdk_dynamodb::types::ScalarAttributeType::S) + .build() + // Unwraps here are fine, will be caught by tests + .unwrap(), + ) + .key_schema( + KeySchemaElement::builder() + .attribute_name("key") + .key_type(KeyType::Hash) + .build() + .unwrap(), + ) + .provisioned_throughput( + ProvisionedThroughput::builder() + .read_capacity_units(4) + .write_capacity_units(1) + .build() + .unwrap(), + ) + .send() + .await + .map_err(|e| DynamoDBCacheError::Aws(Box::new(e)))?; + + self.client + .wait_until_table_exists() + .table_name(&self.table_name) + .wait(Duration::from_secs(5)) + .await + .map_err(|e| DynamoDBCacheError::Aws(Box::new(e)))?; + + self.client + .update_time_to_live() + .table_name(self.table_name.clone()) + .time_to_live_specification( + TimeToLiveSpecification::builder() + .enabled(true) + .attribute_name("expiration") + .build() + .unwrap(), + ) + .send() + .await + .map_err(|e| DynamoDBCacheError::Aws(Box::new(e)))?; + + Ok(()) + } +} + +#[async_trait::async_trait] +impl Cache for DynamoDBCache { + async fn get_token(&self, client_id: &str, aud: &str) -> Result, CacheError> { + let key = super::token_key(client_id, aud); + let Some(attrs) = self + .client + .get_item() + .table_name(&self.table_name) + .key("key", AttributeValue::S(key.clone())) + .send() + .await + .map_err(|e| DynamoDBCacheError::Aws(Box::new(e)))? + .item + else { + return Ok(None); + }; + + let token = attrs + .get("token") + .and_then(|t| t.as_s().ok()) + .ok_or(DynamoDBCacheError::SchemaError(key.clone()))?; + + let token: Token = serde_json::from_str(token).unwrap(); + + Ok(Some(token)) + } + + async fn put_token(&self, client_id: &str, aud: &str, token: &Token) -> Result<(), CacheError> { + let key = super::token_key(client_id, aud); + let encoded = serde_json::to_string(token).unwrap(); + self.client + .put_item() + .table_name(&self.table_name) + .item("key", AttributeValue::S(key)) + .item("token", AttributeValue::S(encoded)) + .item( + "expiration", + AttributeValue::N(token.expire_date().timestamp().to_string()), + ) + .send() + .await + .map_err(|e| DynamoDBCacheError::Aws(Box::new(e)))?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use chrono::Utc; + + use super::*; + + #[tokio::test] + async fn dynamodb_cache_get_set_values() { + let aws_config = aws_config::from_env().load().await; + let client = aws_sdk_dynamodb::Client::new(&aws_config); + let table = "test_table".to_string(); + + client.delete_table().table_name(table.clone()).send().await.ok(); + + let cache = DynamoDBCache::new(client, table); + cache.create_update_dynamo_table().await.unwrap(); + + let client_id = "caller".to_string(); + let audience = "audience".to_string(); + + let result: Option = cache.get_token(&client_id, &audience).await.unwrap(); + assert!(result.is_none()); + + let token_str: &str = "token"; + let token: Token = Token::new(token_str.to_string(), Utc::now(), Utc::now()); + cache.put_token(&client_id, &audience, &token).await.unwrap(); + + let result: Option = cache.get_token(&client_id, &audience).await.unwrap(); + assert!(result.is_some()); + assert_eq!(result.unwrap().as_str(), token_str); + } +} diff --git a/src/auth0/cache/inmemory.rs b/src/auth0/cache/inmemory.rs index 22250c5..4486f58 100644 --- a/src/auth0/cache/inmemory.rs +++ b/src/auth0/cache/inmemory.rs @@ -1,42 +1,29 @@ use dashmap::DashMap; -use crate::auth0::cache::{self, crypto}; -use crate::auth0::errors::Auth0Error; +use crate::auth0::cache; +use crate::auth0::cache::Cache; use crate::auth0::token::Token; -use crate::auth0::{cache::Cache, Config}; -#[derive(Clone, Debug)] -pub struct InMemoryCache { - key_value: DashMap>, - encryption_key: String, - caller: String, - audience: String, -} +use super::CacheError; -impl InMemoryCache { - pub async fn new(config_ref: &Config) -> Result { - Ok(InMemoryCache { - key_value: DashMap::new(), - encryption_key: config_ref.token_encryption_key().to_string(), - caller: config_ref.caller().to_string(), - audience: config_ref.audience().to_string(), - }) - } +#[derive(Default, Clone, Debug)] +pub struct InMemoryCache { + key_value: DashMap, } #[async_trait::async_trait] impl Cache for InMemoryCache { - async fn get_token(&self) -> Result, Auth0Error> { - self.key_value - .get(&cache::token_key(&self.caller, &self.audience)) - .map(|value| crypto::decrypt(self.encryption_key.as_str(), value.as_slice())) - .transpose() + async fn get_token(&self, client_id: &str, aud: &str) -> Result, CacheError> { + let token = self + .key_value + .get(&cache::token_key(client_id, aud)) + .map(|v| v.to_owned()); + Ok(token) } - async fn put_token(&self, value_ref: &Token) -> Result<(), Auth0Error> { - let key: String = cache::token_key(&self.caller, &self.audience); - let encrypted_value: Vec = crypto::encrypt(value_ref, self.encryption_key.as_str())?; - let _ = self.key_value.insert(key, encrypted_value); + async fn put_token(&self, client_id: &str, aud: &str, token: &Token) -> Result<(), CacheError> { + let key: String = cache::token_key(client_id, aud); + let _ = self.key_value.insert(key, token.clone()); Ok(()) } } @@ -49,17 +36,19 @@ mod tests { #[tokio::test] async fn inmemory_cache_get_set_values() { - let server = mockito::Server::new_async().await; - let cache = InMemoryCache::new(&Config::test_config(&server)).await.unwrap(); + let client_id = "caller".to_string(); + let audience = "audience".to_string(); + + let cache = InMemoryCache::default(); - let result: Option = cache.get_token().await.unwrap(); + let result: Option = cache.get_token(&client_id, &audience).await.unwrap(); assert!(result.is_none()); let token_str: &str = "token"; let token: Token = Token::new(token_str.to_string(), Utc::now(), Utc::now()); - cache.put_token(&token).await.unwrap(); + cache.put_token(&client_id, &audience, &token).await.unwrap(); - let result: Option = cache.get_token().await.unwrap(); + let result: Option = cache.get_token(&client_id, &audience).await.unwrap(); assert!(result.is_some()); assert_eq!(result.unwrap().as_str(), token_str); } diff --git a/src/auth0/cache/mod.rs b/src/auth0/cache/mod.rs index 6160a51..4866067 100644 --- a/src/auth0/cache/mod.rs +++ b/src/auth0/cache/mod.rs @@ -1,22 +1,34 @@ +#[cfg(feature = "cache-dynamodb")] +#[cfg_attr(docsrs, doc(cfg(feature = "cache-dynamodb")))] +pub use dynamodb::DynamoDBCache; + pub use inmemory::InMemoryCache; pub use redis_impl::RedisCache; +use std::error::Error; -use crate::auth0::errors::Auth0Error; use crate::auth0::Token; mod crypto; -mod inmemory; -mod redis_impl; + +#[cfg(feature = "cache-dynamodb")] +#[cfg_attr(docsrs, doc(cfg(feature = "cache-dynamodb")))] +pub mod dynamodb; +pub mod inmemory; +pub mod redis_impl; const TOKEN_PREFIX: &str = "auth0rs_tokens"; // The version of the token for backwards incompatible changes const TOKEN_VERSION: &str = "2"; +#[derive(thiserror::Error, Debug)] +#[error(transparent)] +pub struct CacheError(pub Box); + #[async_trait::async_trait] pub trait Cache: Send + Sync + std::fmt::Debug { - async fn get_token(&self) -> Result, Auth0Error>; + async fn get_token(&self, client_id: &str, aud: &str) -> Result, CacheError>; - async fn put_token(&self, value_ref: &Token) -> Result<(), Auth0Error>; + async fn put_token(&self, client_id: &str, aud: &str, token: &Token) -> Result<(), CacheError>; } pub(in crate::auth0::cache) fn token_key(caller: &str, audience: &str) -> String { diff --git a/src/auth0/cache/redis_impl.rs b/src/auth0/cache/redis_impl.rs index be42bc7..c09c6fa 100644 --- a/src/auth0/cache/redis_impl.rs +++ b/src/auth0/cache/redis_impl.rs @@ -1,33 +1,46 @@ use redis::AsyncCommands; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; +use super::CacheError; use crate::auth0::cache::{self, crypto, Cache}; use crate::auth0::token::Token; -use crate::auth0::{Auth0Error, Config}; + +#[derive(Debug, thiserror::Error)] +pub enum RedisCacheError { + #[error(transparent)] + Serde(#[from] serde_json::Error), + #[error("redis error: {0}")] + Redis(#[from] redis::RedisError), + #[error("couldn't decrypt stored token: {0}")] + Crypto(#[from] crypto::CryptoError), +} + +impl From for super::CacheError { + fn from(val: RedisCacheError) -> Self { + CacheError(Box::new(val)) + } +} #[derive(Clone, Debug)] pub struct RedisCache { client: redis::Client, encryption_key: String, - caller: String, - audience: String, } impl RedisCache { - pub async fn new(config_ref: &Config) -> Result { - let client: redis::Client = redis::Client::open(config_ref.cache_type().redis_connection_url())?; + /// Redis connection string(eg. `"redis://{host}:{port}?{ParamKey1}={ParamKey2}"`) + pub async fn new(redis_connection_url: String, token_encryption_key: String) -> Result { + let client: redis::Client = redis::Client::open(redis_connection_url)?; // Ensure connection is fine. Should fail otherwise let _ = client.get_multiplexed_async_connection().await?; Ok(RedisCache { client, - encryption_key: config_ref.token_encryption_key().to_string(), - caller: config_ref.caller().to_string(), - audience: config_ref.audience().to_string(), + encryption_key: token_encryption_key, }) } - async fn get(&self, key: &str) -> Result, Auth0Error> + async fn get(&self, key: &str) -> Result, RedisCacheError> where for<'de> T: Deserialize<'de>, { @@ -38,23 +51,30 @@ impl RedisCache { .await? .map(|value| crypto::decrypt(self.encryption_key.as_str(), value.as_slice())) .transpose() + .map_err(Into::into) + } + + async fn put(&self, key: &str, lifetime_in_seconds: u64, v: T) -> Result<(), RedisCacheError> { + let mut connection = self.client.get_multiplexed_async_connection().await?; + + let encrypted_value: Vec = crypto::encrypt(&v, self.encryption_key.as_str())?; + let _: () = connection.set_ex(key, encrypted_value, lifetime_in_seconds).await?; + Ok(()) } } #[async_trait::async_trait] impl Cache for RedisCache { - async fn get_token(&self) -> Result, Auth0Error> { - let key: &str = &cache::token_key(&self.caller, &self.audience); - self.get(key).await + async fn get_token(&self, client_id: &str, audience: &str) -> Result, CacheError> { + let key: &str = &cache::token_key(client_id, audience); + self.get(key).await.map_err(Into::into) } - async fn put_token(&self, value_ref: &Token) -> Result<(), Auth0Error> { - let key: &str = &cache::token_key(&self.caller, &self.audience); - let mut connection = self.client.get_multiplexed_async_connection().await?; - let encrypted_value: Vec = crypto::encrypt(value_ref, self.encryption_key.as_str())?; - let expiration = value_ref.lifetime_in_seconds(); - let _: () = connection.set_ex(key, encrypted_value, expiration).await?; - Ok(()) + async fn put_token(&self, client_id: &str, audience: &str, value_ref: &Token) -> Result<(), CacheError> { + let key: &str = &cache::token_key(client_id, audience); + self.put(key, value_ref.lifetime_in_seconds(), value_ref) + .await + .map_err(Into::into) } } diff --git a/src/auth0/client.rs b/src/auth0/client.rs new file mode 100644 index 0000000..2a5622c --- /dev/null +++ b/src/auth0/client.rs @@ -0,0 +1,116 @@ +use crate::auth0::Auth0Error; +use chrono::{DateTime, Utc}; +use reqwest::{Client, Response, Url}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +use super::{Config, Token}; + +/// The successful response received from the authorization server containing the access token. +/// Related [RFC](https://www.rfc-editor.org/rfc/rfc6749#section-5.1) +#[derive(Deserialize, Serialize, Debug)] +struct FetchTokenResponse { + access_token: String, + scope: Option, + expires_in: i32, + token_type: String, +} + +#[derive(Serialize, Debug)] +struct FetchTokenRequest { + client_id: String, + client_secret: String, + audience: String, + grant_type: String, + scope: Option, +} + +#[derive(Deserialize, Debug)] +pub struct Claims { + #[allow(dead_code)] + #[serde(default)] + pub permissions: Vec, +} + +// Client for talking to auth0 +#[derive(Clone)] +pub struct Auth0Client { + token_url: Url, + client: Client, + client_id: String, + client_secret: String, +} + +impl Auth0Client { + pub fn new(token_url: Url, client: Client, client_id: String, client_secret: String) -> Self { + Self { + token_url, + client, + client_id, + client_secret, + } + } + + pub(super) fn from_config(config: &Config, client: &Client) -> Self { + Auth0Client::new( + config.token_url().clone(), + client.clone(), + config.client_id().to_string(), + config.client_secret().to_string(), + ) + } + + pub async fn fetch_token(&self, audience: &str, scope: Option<&str>) -> Result { + let request: FetchTokenRequest = FetchTokenRequest { + client_id: self.client_id.clone(), + client_secret: self.client_secret.clone(), + audience: audience.to_string(), + grant_type: "client_credentials".to_string(), + scope: scope.map(str::to_string), + }; + + let response: Response = self + .client + .post(self.token_url.clone()) + .json(&request) + .send() + .await + .map_err(|e| { + Auth0Error::JwtFetchError( + e.status().map(|v| v.as_u16()).unwrap_or_default(), + self.token_url.to_string(), + e, + ) + })?; + + let status_code: u16 = response.status().as_u16(); + if status_code != 200 { + return Err(Auth0Error::JwtFetchAuthError(status_code)); + } + + let response: FetchTokenResponse = response + .json() + .await + .map_err(|e| Auth0Error::JwtFetchDeserializationError(self.token_url.as_str().to_string(), e))?; + + let FetchTokenResponse { + access_token, + expires_in, + .. + } = response; + + // this is not the exact issue_date, nor the exact expire_date. But is a good approximation + // as long as we need it just to remove the key from the cache, and calculate the approximation + // of the token lifetime. If we need more correctness we can decrypt the token and get + // the exact issued_at (iat) and expiration (exp) + // reference: https://www.iana.org/assignments/jwt/jwt.xhtml + let issue_date: DateTime = Utc::now(); + let expire_date: DateTime = Utc::now() + Duration::from_secs(expires_in as u64); + + Ok(Token::new(access_token, issue_date, expire_date)) + } + + pub fn client_id(&self) -> &str { + &self.client_id + } +} diff --git a/src/auth0/config.rs b/src/auth0/config.rs index fbb1087..3158596 100644 --- a/src/auth0/config.rs +++ b/src/auth0/config.rs @@ -1,4 +1,4 @@ -use std::ops::RangeInclusive; +pub use super::StalenessCheckPercentage; use std::time::Duration; use reqwest::Url; @@ -116,33 +116,3 @@ impl CacheType { } } } - -#[derive(Clone)] -pub struct StalenessCheckPercentage(RangeInclusive); - -impl StalenessCheckPercentage { - pub fn new(min: f64, max: f64) -> Self { - assert!((0.0..=1.0).contains(&min)); - assert!((0.0..=1.0).contains(&max)); - assert!(min <= max); - - Self(min..=max) - } - - pub fn random_value_between(&self) -> f64 { - use rand::Rng; - rand::thread_rng().gen_range(self.0.clone()) - } -} - -impl Default for StalenessCheckPercentage { - fn default() -> Self { - Self(0.6..=0.9) - } -} - -impl From> for StalenessCheckPercentage { - fn from(range: RangeInclusive) -> Self { - Self::new(*range.start(), *range.end()) - } -} diff --git a/src/auth0/errors.rs b/src/auth0/errors.rs index 54b3333..feff6b3 100644 --- a/src/auth0/errors.rs +++ b/src/auth0/errors.rs @@ -1,5 +1,7 @@ use thiserror::Error; +use super::cache; + #[derive(Debug, Error)] pub enum Auth0Error { #[error(transparent)] @@ -12,8 +14,8 @@ pub enum Auth0Error { JwtFetchDeserializationError(String, reqwest::Error), #[error("failed to fetch jwt from {0}. Status code: {0}; error: {1}")] JwksHttpError(String, reqwest::Error), - #[error("redis error: {0}")] - RedisError(#[from] redis::RedisError), + #[error("cache error: {0}")] + CacheError(#[from] cache::CacheError), #[error(transparent)] CryptoError(#[from] chacha20poly1305::Error), } diff --git a/src/auth0/mod.rs b/src/auth0/mod.rs index c8c8984..fb135af 100644 --- a/src/auth0/mod.rs +++ b/src/auth0/mod.rs @@ -1,123 +1,59 @@ //! Stuff used to provide JWT authentication via Auth0 -use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; - use reqwest::Client; -use tokio::task::JoinHandle; -use tokio::time::Interval; - -pub use config::{CacheType, Config, StalenessCheckPercentage}; -pub use errors::Auth0Error; -use util::ResultExt; -use crate::auth0::cache::Cache; -pub use crate::auth0::token::Token; - -mod cache; +pub mod cache; +mod client; mod config; mod errors; +mod refresh; mod token; mod util; -#[derive(Debug)] -pub struct Auth0 { - token_lock: Arc>, -} +use cache::{Cache, CacheError}; +pub use client::Auth0Client; +pub use config::{CacheType, Config}; +pub use errors::Auth0Error; +pub use refresh::RefreshingToken; +pub use token::Token; +pub use util::StalenessCheckPercentage; + +#[derive(Clone, Debug)] +pub struct Auth0(RefreshingToken); impl Auth0 { - pub async fn new(client_ref: &Client, config: Config) -> Result { - let cache: Arc = if config.is_inmemory_cache() { - Arc::new(cache::InMemoryCache::new(&config).await?) + #[deprecated(since = "0.21.0", note = "please use refreshing token")] + pub async fn new(client: &Client, config: Config) -> Result { + let cache: Box = if config.is_inmemory_cache() { + Box::new(cache::InMemoryCache::default()) } else { - Arc::new(cache::RedisCache::new(&config).await?) + let redis_conn = config.cache_type().redis_connection_url().to_string(); + let encryption_key = config.token_encryption_key().to_string(); + Box::new( + cache::RedisCache::new(redis_conn, encryption_key) + .await + .map_err(Into::::into)?, + ) }; - let token: Token = get_token(client_ref, &cache, &config).await?; - let token_lock: Arc> = Arc::new(RwLock::new(token)); - - start(token_lock.clone(), client_ref.clone(), cache.clone(), config).await; - - Ok(Self { token_lock }) + let client = client::Auth0Client::from_config(&config, client); + RefreshingToken::new( + client, + config.check_interval, + config.staleness_check_percentage, + cache, + config.audience, + config.scope, + ) + .await + .map(Self) } pub fn token(&self) -> Token { - read(&self.token_lock) + self.0.token().clone() } -} - -async fn start( - token_lock: Arc>, - client: Client, - cache: Arc, - config: Config, -) -> JoinHandle<()> { - tokio::spawn(async move { - let mut ticker: Interval = tokio::time::interval(*config.check_interval()); - - loop { - ticker.tick().await; - let token = match cache.get_token().await { - Ok(Some(token)) => token, - Ok(None) => read(&token_lock), - Err(e) => { - tracing::error!("Error reading cached JWT. Reason: {:?}", e); - read(&token_lock) - } - }; - - if token.needs_refresh(&config) { - tracing::info!("Refreshing JWT and JWKS"); - - match fetch_and_update_token(&client, &cache, &config).await { - Ok(token) => { - write(&token_lock, token); - } - Err(error) => tracing::error!("Failed to fetch JWT. Reason: {:?}", error), - } - } else if token.expire_date() > read(&token_lock).expire_date() { - write(&token_lock, token); - } - } - }) -} - -// Try to fetch the token from cache. If it's found return it; fetch from auth0 and put in cache otherwise -async fn get_token(client_ref: &Client, cache_ref: &Arc, config_ref: &Config) -> Result { - match cache_ref.get_token().await { - Ok(Some(token)) => Ok(token), - Ok(None) => fetch_and_update_token(client_ref, cache_ref, config_ref).await, - Err(Auth0Error::CryptoError(e)) => { - tracing::warn!("Crypto error({}) when attempting to decrypt cached token. Ignoring", e); - fetch_and_update_token(client_ref, cache_ref, config_ref).await - } - Err(e) => Err(e), + pub fn refreshing_token(self) -> RefreshingToken { + self.0 } } - -// Unconditionally fetch a new token and update the cache -async fn fetch_and_update_token( - client_ref: &Client, - cache_ref: &Arc, - config_ref: &Config, -) -> Result { - let token: Token = Token::fetch(client_ref, config_ref).await?; - let _ = cache_ref.put_token(&token).await.log_err("JWT cache set failed"); - - Ok(token) -} - -fn read(lock_ref: &Arc>) -> T { - let lock_guard: RwLockReadGuard = lock_ref.read().unwrap_or_else(|poison_error| poison_error.into_inner()); - (*lock_guard).clone() -} - -fn write(lock_ref: &Arc>, token: T) { - let mut lock_guard: RwLockWriteGuard = lock_ref - .write() - .unwrap_or_else(|poison_error| poison_error.into_inner()); - *lock_guard = token; -} - -#[cfg(test)] -mod tests {} diff --git a/src/auth0/refresh.rs b/src/auth0/refresh.rs new file mode 100644 index 0000000..c28f14d --- /dev/null +++ b/src/auth0/refresh.rs @@ -0,0 +1,124 @@ +use std::error::Error; +use std::sync::Arc; +use std::sync::Weak; +use std::time::Duration; + +use super::util::UnpoisonableRwLock; +use super::StalenessCheckPercentage; +use super::{client::Auth0Client, Auth0Error, Cache, Token}; + +#[derive(Debug, Clone)] +pub struct RefreshingToken(Arc>); + +impl RefreshingToken { + pub async fn new( + client: Auth0Client, + + check_interval: Duration, + staleness: StalenessCheckPercentage, + cache: Box, + + audience: String, + scope: Option, + ) -> Result { + let token = client.fetch_token(&audience, scope.as_deref()).await?; + let token = Arc::new(UnpoisonableRwLock::new(token)); + + let worker = RefreshWorker { + check_interval, + client, + cache, + audience, + staleness, + scope, + }; + + tokio::spawn(worker.refresh_loop(Arc::downgrade(&token))); + + Ok(Self(token)) + } + + pub fn token(&self) -> std::sync::RwLockReadGuard { + self.0.read() + } +} + +struct RefreshWorker { + check_interval: Duration, + client: Auth0Client, + audience: String, + scope: Option, + staleness: StalenessCheckPercentage, + + cache: Box, +} + +impl RefreshWorker { + async fn refresh_loop(self, weak_token: Weak>) { + let mut ticker = tokio::time::interval(self.check_interval); + + loop { + ticker.tick().await; + + let Some(token_rc) = weak_token.upgrade() else { + tracing::debug!("All references to auth0 token dropped. Stopping refresh thread"); + return; + }; + + let token = token_rc.read().clone(); + match self.check_refresh_token(token).await { + Ok(token) => token_rc.write(token), + Err(e) => { + tracing::warn!( + error = &e as &dyn Error, + "Failed to refresh auth0 token: {e}! Will retry in {}s", + self.check_interval.as_secs() + ) + } + }; + } + } + + async fn check_refresh_token(&self, cur_token: Token) -> Result { + let cached_token = match self.cache.get_token(self.client.client_id(), &self.audience).await { + Ok(v) => v, + Err(e) => { + tracing::error!("Error reading cached JWT. Reason: {:?}", e); + None + } + }; + + match cached_token { + Some(cached_token) if cached_token.is_stale(&self.staleness) => { + tracing::info!("Refreshing JWT and JWKS"); + + let token = self.fetch_and_update_token().await?; + Ok(token) + } + Some(cached_token) if cached_token.expire_date() > cur_token.expire_date() => Ok(cached_token), + None => { + tracing::debug!("New token expiry_date is lower current token. Not refreshing and trying to replace"); + self.cache + .put_token(self.client.client_id(), &self.audience, &cur_token) + .await?; + Ok(cur_token) + } + _ => Ok(cur_token), + } + } + + // Unconditionally fetch a new token and update the cache + async fn fetch_and_update_token(&self) -> Result { + let token: Token = self.client.fetch_token(&self.audience, self.scope.as_deref()).await?; + + if let Err(e) = self + .cache + .put_token(self.client.client_id(), &self.audience, &token) + .await + { + tracing::error!(error = ?e, "JWT cache set failed: {e}"); + }; + + Ok(token) + } +} diff --git a/src/auth0/token.rs b/src/auth0/token.rs index a248126..eaa4e68 100644 --- a/src/auth0/token.rs +++ b/src/auth0/token.rs @@ -1,11 +1,13 @@ use chrono::{DateTime, Utc}; -use reqwest::{Client, Response}; +use reqwest::Client; use serde::{Deserialize, Serialize}; -use std::time::Duration; +use crate::auth0::client::Auth0Client; use crate::auth0::errors::Auth0Error; use crate::auth0::Config; +use super::StalenessCheckPercentage; + #[derive(Serialize, Deserialize, Clone, Debug)] pub struct Token { token: String, @@ -22,46 +24,16 @@ impl Token { } } - pub async fn fetch(client_ref: &Client, config_ref: &Config) -> Result { - let request: FetchTokenRequest = FetchTokenRequest::from(config_ref); - let response: Response = client_ref - .post(config_ref.token_url().clone()) - .json(&request) - .send() - .await - .map_err(|e| { - Auth0Error::JwtFetchError( - e.status().map(|v| v.as_u16()).unwrap_or_default(), - config_ref.token_url().as_str().to_string(), - e, - ) - })?; - - let status_code: u16 = response.status().as_u16(); - if status_code != 200 { - return Err(Auth0Error::JwtFetchAuthError(status_code)); - } - - let response: FetchTokenResponse = response - .json() - .await - .map_err(|e| Auth0Error::JwtFetchDeserializationError(config_ref.token_url().as_str().to_string(), e))?; - - let access_token: String = response.access_token.clone(); - - // this is not the exact issue_date, nor the exact expire_date. But is a good approximation - // as long as we need it just to removes the key from the cache, and calculate the approximation - // of the token lifetime. If we need more correctness we can decrypt the token and get - // the exact issued_at (iat) and expiration (exp) - // reference: https://www.iana.org/assignments/jwt/jwt.xhtml - let issue_date: DateTime = Utc::now(); - let expire_date: DateTime = Utc::now() + Duration::from_secs(response.expires_in as u64); - - Ok(Self { - token: access_token, - issue_date, - expire_date, - }) + #[deprecated(since = "0.21.0", note = "please use Auth0Client instead")] + pub async fn fetch(client: &Client, config: &Config) -> Result { + Auth0Client::new( + config.token_url().clone(), + client.clone(), + config.client_id().to_string(), + config.client_secret().to_string(), + ) + .fetch_token(config.audience(), config.scope.as_deref()) + .await } pub fn as_str(&self) -> &str { @@ -95,49 +67,16 @@ impl Token { // Check if the token remaining lifetime it's less than a randomized percentage that is between // `max_token_remaining_life_percentage` and `min_token_remaining_life_percentage` - pub fn needs_refresh(&self, config_ref: &Config) -> bool { - self.remaining_life_percentage() < config_ref.staleness_check_percentage().random_value_between() + #[deprecated(since = "0.21.0", note = "please use the is_stale function instead")] + pub fn needs_refresh(&self, config: &Config) -> bool { + self.is_stale(config.staleness_check_percentage()) } - pub fn to_bearer(&self) -> String { - format!("Bearer {}", self.token.as_str()) + pub fn is_stale(&self, staleness: &StalenessCheckPercentage) -> bool { + self.remaining_life_percentage() < staleness.random_value_between() } -} - -/// The successful response received from the authorization server containing the access token. -/// Related [RFC](https://www.rfc-editor.org/rfc/rfc6749#section-5.1) -#[derive(Deserialize, Serialize, Debug)] -struct FetchTokenResponse { - access_token: String, - scope: Option, - expires_in: i32, - token_type: String, -} -#[derive(Serialize, Debug)] -struct FetchTokenRequest { - client_id: String, - client_secret: String, - audience: String, - grant_type: String, - scope: Option, -} - -impl From<&Config> for FetchTokenRequest { - fn from(config: &Config) -> Self { - Self { - client_id: config.client_id().to_string(), - client_secret: config.client_secret().to_string(), - audience: config.audience().to_string(), - grant_type: "client_credentials".to_string(), - scope: config.scope.clone(), - } + pub fn to_bearer(&self) -> String { + format!("Bearer {}", self.token.as_str()) } } - -#[derive(Deserialize, Debug)] -pub struct Claims { - #[allow(dead_code)] - #[serde(default)] - pub permissions: Vec, -} diff --git a/src/auth0/util.rs b/src/auth0/util.rs index fb6fe6f..e93ca96 100644 --- a/src/auth0/util.rs +++ b/src/auth0/util.rs @@ -1,15 +1,54 @@ -pub trait ResultExt { - fn log_err(self, label: &str) -> Result; -} - -impl ResultExt for Result { - fn log_err(self, label: &str) -> Result { - match self { - Ok(t) => Ok(t), - Err(e) => { - tracing::error!("{}: {}", label, e.to_string()); - Err(e) - } - } +use std::ops::RangeInclusive; +use std::sync::RwLock; + +#[derive(Clone)] +pub struct StalenessCheckPercentage(RangeInclusive); + +impl StalenessCheckPercentage { + pub fn new(min: f64, max: f64) -> Self { + assert!((0.0..=1.0).contains(&min)); + assert!((0.0..=1.0).contains(&max)); + assert!(min <= max); + + Self(min..=max) + } + + pub fn random_value_between(&self) -> f64 { + use rand::Rng; + rand::thread_rng().gen_range(self.0.clone()) + } +} + +impl Default for StalenessCheckPercentage { + fn default() -> Self { + Self(0.6..=0.9) + } +} + +impl From> for StalenessCheckPercentage { + fn from(range: RangeInclusive) -> Self { + Self::new(*range.start(), *range.end()) + } +} + +/// A wrapper around RwLock that cannot be poisoned, +/// by requiring the value inside to be completely replaced, +/// ensuring a panic can never occur while the lock is held +#[derive(Debug)] +pub struct UnpoisonableRwLock(RwLock); + +impl UnpoisonableRwLock { + pub fn new(v: T) -> Self { + Self(RwLock::new(v)) + } + + pub fn read(&self) -> std::sync::RwLockReadGuard { + // Unwrapping here is fine since a panic can never occur while we hold the lock + self.0.read().unwrap() + } + + pub fn write(&self, v: T) { + // Unwrapping here is fine since a panic can never occur while we hold the lock + *self.0.write().unwrap() = v; } } diff --git a/src/builder.rs b/src/builder.rs index a83d55c..206116b 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -5,6 +5,8 @@ use reqwest_middleware::Middleware; #[cfg(feature = "auth0")] use crate::auth0; +#[cfg(feature = "auth0")] +use crate::auth0::RefreshingToken; use crate::{Bridge, BridgeImpl, RedirectPolicy}; pub type BridgeBuilder = BridgeBuilderInner; @@ -13,21 +15,33 @@ pub type BridgeBuilder = BridgeBuilderInner; pub struct BridgeBuilderInner { inner: T, #[cfg(feature = "auth0")] - auth0: Option, + auth0: Option, } impl BridgeBuilderInner { /// Adds Auth0 JWT authentication to the requests made by the [Bridge]. #[cfg_attr(docsrs, doc(cfg(feature = "auth0")))] #[cfg(feature = "auth0")] + #[deprecated(since = "0.21.0", note = "please use with_refreshing_token instead")] pub async fn with_auth0(self, config: auth0::Config) -> Self { let client: reqwest::Client = reqwest::Client::new(); + #[allow(deprecated)] + let auth0 = auth0::Auth0::new(&client, config) + .await + .expect("Failed to create auth0 bridge") + .refreshing_token(); + + Self { + auth0: Some(auth0), + ..self + } + } + + #[cfg_attr(docsrs, doc(cfg(feature = "auth0")))] + #[cfg(feature = "auth0")] + pub async fn with_refreshing_token(self, refreshing_token: RefreshingToken) -> Self { Self { - auth0: Some( - auth0::Auth0::new(&client, config) - .await - .expect("Failed to create auth0 bridge"), - ), + auth0: Some(refreshing_token), ..self } } diff --git a/src/lib.rs b/src/lib.rs index 4b4fba0..90d35af 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -23,10 +23,13 @@ //! * `redis-tls` - add support for connecting to redis with tls //! * `grpc` - provides the [GrpcOtelInterceptor] for adding the opentelemetry context to the gRPC requests //! * `tracing_opentelemetry` - adds support for integration with opentelemetry. -//! This feature is an alias for the latest `tracing_opentelemetry_x_xx` feature. +//! This feature is an alias for the latest `tracing_opentelemetry_x_xx` feature. //! * `tracing_opentelemetry_x_xx` (e.g. `tracing_opentelemetry_0_27`) - adds support for integration with a particular opentelemetry version. -//! We are going to support at least the last 3 versions of opentelemetry. After that we might remove support for older otel version without it being a breaking change. +//! We are going to support at least the last 3 versions of opentelemetry. After that we might remove support for older otel version without it being a breaking change. +#[cfg(feature = "auth0")] +#[cfg_attr(docsrs, doc(cfg(feature = "auth0")))] +use auth0::RefreshingToken; use errors::PrimaBridgeError; use http::{header::HeaderName, HeaderValue, Method}; use reqwest::{multipart::Form, Url}; @@ -43,9 +46,10 @@ pub use self::{ response::Response, }; #[cfg(all(feature = "grpc", feature = "_any_otel_version"))] +#[cfg_attr(docsrs, doc(cfg(feature = "grpc")))] pub use request::grpc::{GrpcOtelInterceptedService, GrpcOtelInterceptor}; -mod builder; +pub mod builder; mod errors; pub mod prelude; mod redirect; @@ -61,18 +65,18 @@ pub type Bridge = BridgeImpl; /// A Bridge instance that's generic across the client. If the [BridgeBuilder] is used /// to construct a bridge with middleware, this type will be used to wrap the [reqwest_middleware::ClientWithMiddleware]. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BridgeImpl { inner_client: T, endpoint: Url, #[cfg(feature = "auth0")] - auth0_opt: Option, + auth0_opt: Option, } /// A trait that abstracts the client used by the [BridgeImpl], such that both reqwest clients and reqwest /// clients with middleware can be used, more or less interchangeably. #[doc(hidden)] -pub trait BridgeClient: Sealed { +pub trait BridgeClient: Sealed + Clone { type Builder: PrimaRequestBuilderInner; fn request(&self, method: Method, url: Url) -> PrimaRequestBuilder; } @@ -251,7 +255,7 @@ impl Bridge { #[cfg_attr(docsrs, doc(cfg(feature = "auth0")))] /// Gets the JWT token used by the Bridge, if it has been configured with Auth0 authentication via [BridgeBuilder.with_auth0](BridgeBuilder#with_auth0). pub fn token(&self) -> Option { - self.auth0_opt.as_ref().map(|auth0| auth0.token()) + self.auth0_opt.as_ref().map(|auth0| auth0.token().clone()) } } diff --git a/src/request/mod.rs b/src/request/mod.rs index 38e42c9..520e5c5 100644 --- a/src/request/mod.rs +++ b/src/request/mod.rs @@ -120,7 +120,7 @@ pub trait DeliverableRequest<'a>: Sized + Sealed + 'a { #[cfg(feature = "auth0")] #[doc(hidden)] - fn get_auth0(&self) -> &Option; + fn get_auth0(&self) -> &Option; #[cfg(feature = "auth0")] #[doc(hidden)] diff --git a/src/request/request_type/graphql.rs b/src/request/request_type/graphql.rs index 0652a5d..660cb3c 100644 --- a/src/request/request_type/graphql.rs +++ b/src/request/request_type/graphql.rs @@ -176,7 +176,7 @@ impl<'a, Client: BridgeClient> DeliverableRequest<'a> for GraphQLRequest<'a, Cli } #[cfg(feature = "auth0")] - fn get_auth0(&self) -> &Option { + fn get_auth0(&self) -> &Option { &self.bridge.auth0_opt } diff --git a/src/request/request_type/rest.rs b/src/request/request_type/rest.rs index 109e4b8..706341e 100644 --- a/src/request/request_type/rest.rs +++ b/src/request/request_type/rest.rs @@ -144,7 +144,7 @@ impl<'a, Client: BridgeClient> DeliverableRequest<'a> for RestRequest<'a, Client } #[cfg(feature = "auth0")] - fn get_auth0(&self) -> &Option { + fn get_auth0(&self) -> &Option { &self.bridge.auth0_opt } diff --git a/tests/async_auth0/builder.rs b/tests/async_auth0/builder.rs index db0e13c..e3a168c 100644 --- a/tests/async_auth0/builder.rs +++ b/tests/async_auth0/builder.rs @@ -2,10 +2,12 @@ use async_trait::async_trait; use mockito::{Matcher, Server}; use reqwest::Url; -use crate::async_auth0::{config, Auth0Mocks}; -use prima_bridge::auth0::Config; +use crate::async_auth0::Auth0Mocks; +use prima_bridge::auth0::RefreshingToken; use prima_bridge::prelude::*; +use super::refreshing_token; + #[async_trait] pub(in crate::async_auth0) trait Auth0MocksExt { async fn create_bridge(&mut self, status_code: usize, body: &str) -> (Auth0Mocks, Bridge); @@ -34,9 +36,9 @@ pub(in crate::async_auth0) trait Auth0MocksExt { async fn create_bridge_with_binary_body_matcher(&mut self, body: &[u8]) -> (Auth0Mocks, Bridge); async fn create_bridge_with_auth0_get_token_match_body( &mut self, - body: &str, + body: serde_json::Value, req_token_body: serde_json::Value, - config: Config, + token: RefreshingToken, ) -> (Auth0Mocks, Bridge); } @@ -56,7 +58,10 @@ impl Auth0MocksExt for Server { let base_url = format!("{}/{}", self.url(), base); let url = Url::parse(base_url.as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", format!("/{}/{}", base, path).as_str()) @@ -80,13 +85,13 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_auth0_get_token_match_body( &mut self, - body: &str, + body: serde_json::Value, req_token_body: serde_json::Value, - config: Config, + token: RefreshingToken, ) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new_with_req_token_body_match(self, req_token_body).await; - let bridge = Bridge::builder().with_auth0(config).await.build(url); + let bridge = Bridge::builder().with_refreshing_token(token).await.build(url); let mock = self .mock("GET", "/") @@ -99,7 +104,7 @@ impl Auth0MocksExt for Server { bridge.token().unwrap().to_bearer().as_str(), ) .with_status(200) - .with_body(body) + .with_body(body.to_string()) .create_async() .await; @@ -111,7 +116,10 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_path(&mut self, status_code: usize, body: &str, path: &str) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", path) @@ -142,7 +150,10 @@ impl Auth0MocksExt for Server { ) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", path) @@ -168,7 +179,10 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_raw_body_matcher(&mut self, body: &str) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", "/") @@ -193,7 +207,10 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_header_matcher(&mut self, (name, value): (&str, &str)) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", "/") @@ -219,7 +236,7 @@ impl Auth0MocksExt for Server { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; let bridge = Bridge::builder() - .with_auth0(config(self)) + .with_refreshing_token(refreshing_token(self).await) .await .with_user_agent(user_agent) .build(url); @@ -247,7 +264,10 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_json_body_matcher(&mut self, json: serde_json::Value) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", "/") @@ -272,7 +292,10 @@ impl Auth0MocksExt for Server { async fn create_bridge_with_binary_body_matcher(&mut self, body: &[u8]) -> (Auth0Mocks, Bridge) { let url = Url::parse(self.url().as_str()).unwrap(); let mut mocks: Auth0Mocks = Auth0Mocks::new(self).await; - let bridge = Bridge::builder().with_auth0(config(self)).await.build(url); + let bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(self).await) + .await + .build(url); let mock = self .mock("GET", "/") diff --git a/tests/async_auth0/graphql.rs b/tests/async_auth0/graphql.rs index c58999a..044124d 100644 --- a/tests/async_auth0/graphql.rs +++ b/tests/async_auth0/graphql.rs @@ -9,7 +9,9 @@ use serde_json::json; use prima_bridge::prelude::*; use prima_bridge::ParsedGraphqlResponseExt; -use crate::async_auth0::{config, Auth0Mocks}; +use crate::async_auth0::Auth0Mocks; + +use super::refreshing_token; #[derive(Deserialize, Clone, Debug, PartialEq)] struct Person { @@ -171,7 +173,10 @@ async fn create_gql_bridge(server: &mut Server, status_code: usize, query: &str, let mut mocks = Auth0Mocks::new(server).await; let url = Url::parse(&server.url()).unwrap(); - let bridge: Bridge = Bridge::builder().with_auth0(config(server)).await.build(url); + let bridge: Bridge = Bridge::builder() + .with_refreshing_token(refreshing_token(server).await) + .await + .build(url); let graphql_mock: Mock = server .mock("POST", "/") diff --git a/tests/async_auth0/mod.rs b/tests/async_auth0/mod.rs index 54ec80f..f6c2c78 100644 --- a/tests/async_auth0/mod.rs +++ b/tests/async_auth0/mod.rs @@ -1,30 +1,33 @@ -use std::str::FromStr; use std::time::Duration; use jsonwebtoken::Algorithm; use mockito::{Matcher, Mock, Server}; -use reqwest::Url; -use prima_bridge::auth0::{CacheType, Config, StalenessCheckPercentage}; +use prima_bridge::auth0::{cache::InMemoryCache, Auth0Client, RefreshingToken, StalenessCheckPercentage}; mod builder; mod graphql; mod rest; -fn config(server: &Server) -> Config { - Config { - token_url: Url::from_str(&format!("{}/{}", server.url().as_str(), "token")).unwrap(), - jwks_url: Url::from_str(&format!("{}/{}", server.url().as_str(), "jwks")).unwrap(), - caller: "caller".to_string(), - audience: "audience".to_string(), - cache_type: CacheType::Inmemory, - token_encryption_key: "32char_long_token_encryption_key".to_string(), - check_interval: Duration::from_secs(10), - staleness_check_percentage: StalenessCheckPercentage::default(), - client_id: "client_id".to_string(), - client_secret: "client_secret".to_string(), - scope: None, - } +async fn refreshing_token(server: &Server) -> RefreshingToken { + let token_url = format!("{}/token", server.url().as_str()).parse().unwrap(); + let auth0_client = Auth0Client::new( + token_url, + reqwest::Client::default(), + "caller".to_string(), + "client_secret".to_string(), + ); + + RefreshingToken::new( + auth0_client, + Duration::from_secs(10), + StalenessCheckPercentage::default(), + Box::new(InMemoryCache::default()), + "audience".to_string(), + None, + ) + .await + .unwrap() } struct Auth0Mocks { diff --git a/tests/async_auth0/rest.rs b/tests/async_auth0/rest.rs index 40ce5a9..2dd5096 100644 --- a/tests/async_auth0/rest.rs +++ b/tests/async_auth0/rest.rs @@ -1,6 +1,11 @@ use std::error::Error; +use std::time::Duration; use mockito::Server; +use prima_bridge::auth0::cache::InMemoryCache; +use prima_bridge::auth0::Auth0Client; +use prima_bridge::auth0::RefreshingToken; +use prima_bridge::auth0::StalenessCheckPercentage; use reqwest::header::{HeaderName, HeaderValue}; use serde::{Deserialize, Serialize}; use serde_json::json; @@ -8,7 +13,6 @@ use serde_json::json; use prima_bridge::prelude::*; use crate::async_auth0::builder::*; -use crate::async_auth0::config; #[derive(Deserialize, Clone, Debug, PartialEq, Serialize)] struct Data { @@ -28,18 +32,46 @@ async fn simple_request() -> Result<(), Box> { #[tokio::test] async fn simple_request_with_auth0_scope() -> Result<(), Box> { + let mut auth0_server = Server::new_async().await; + auth0_server.create_bridge(200, "{}").await; + let mut server = Server::new_async().await; - let mut cfg = config(&server); - cfg.scope = Some("profile email".to_string()); + + let client_id = "client_id"; + let client_secret = "client_secret"; + let audience = "audience"; + let scope = "profile email"; + + let token_url = format!("{}/token", auth0_server.url().as_str()).parse().unwrap(); + + let auth0_client = Auth0Client::new( + token_url, + reqwest::Client::default(), + client_id.to_string(), + client_secret.to_string(), + ); + + let token = RefreshingToken::new( + auth0_client, + Duration::from_secs(10), + StalenessCheckPercentage::default(), + Box::new(InMemoryCache::default()), + audience.to_string(), + Some(scope.to_string()), + ) + .await + .unwrap(); + let req_token_body = json!({ - "client_id": cfg.client_id.clone(), - "client_secret": cfg.client_secret.clone(), - "audience": cfg.audience.clone(), + "client_id": client_id, + "client_secret": client_secret, + "audience": audience, "grant_type": "client_credentials", - "scope": "profile email" + "scope": scope, }); + let (_m, bridge) = server - .create_bridge_with_auth0_get_token_match_body("{\"hello\": \"world!\"}", req_token_body, cfg) + .create_bridge_with_auth0_get_token_match_body(json!({"hello": "world!"}), req_token_body, token) .await; let result: String = RestRequest::new(&bridge).send().await?.get_data(&["hello"])?; @@ -222,9 +254,10 @@ async fn decompresses_gzip_responses() -> Result<(), Box> { .await; let _mocks = crate::async_auth0::Auth0Mocks::new(&mut server).await; + let token = crate::async_auth0::refreshing_token(&server).await; let bridge = Bridge::builder() - .with_auth0(crate::async_auth0::config(&server)) + .with_refreshing_token(token) .await .build(server.url().parse().unwrap());