diff --git a/async-std/src/client.rs b/async-std/src/client.rs index 3e8595abe6..5d6808661c 100644 --- a/async-std/src/client.rs +++ b/async-std/src/client.rs @@ -83,4 +83,8 @@ impl Connector for ClientConfig { fn spawn + Send + 'static>(&self, fut: Fut) { async_std::task::spawn(fut); } + + async fn delay(&self, duration: std::time::Duration) { + let _ = async_std::future::timeout(duration, std::future::pending::<()>()).await; + } } diff --git a/client/src/client.rs b/client/src/client.rs index 4de2f4b5f9..ce2e7f407f 100644 --- a/client/src/client.rs +++ b/client/src/client.rs @@ -1,5 +1,5 @@ use crate::{Conn, IntoUrl, Pool, USER_AGENT}; -use std::{fmt::Debug, sync::Arc}; +use std::{fmt::Debug, sync::Arc, time::Duration}; use trillium_http::{ transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method, ReceivedBodyState, @@ -20,6 +20,7 @@ pub struct Client { pool: Option>, base: Option>, default_headers: Arc, + timeout: Option, } macro_rules! method { @@ -74,9 +75,16 @@ impl Client { pool: None, base: None, default_headers: Arc::new(default_request_headers()), + timeout: None, } } + method!(get, Get); + method!(post, Post); + method!(put, Put); + method!(delete, Delete); + method!(patch, Patch); + /// chainable method to remove a header from default request headers pub fn without_default_header(mut self, name: impl Into>) -> Self { self.default_headers_mut().remove(name); @@ -160,6 +168,7 @@ impl Client { response_body_state: ReceivedBodyState::Start, config: self.config.clone(), headers_finalized: false, + timeout: self.timeout, } } @@ -209,11 +218,20 @@ impl Client { Ok(()) } - method!(get, Get); - method!(post, Post); - method!(put, Put); - method!(delete, Delete); - method!(patch, Patch); + /// set the timeout for all conns this client builds + /// + /// this can also be set with [`Conn::set_timeout`] and [`Conn::with_timeout`] + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = Some(timeout); + } + + /// set the timeout for all conns this client builds + /// + /// this can also be set with [`Conn::set_timeout`] and [`Conn::with_timeout`] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.set_timeout(timeout); + self + } } impl From for Client { diff --git a/client/src/conn.rs b/client/src/conn.rs index a370e860ea..4ee309c702 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -1,6 +1,6 @@ use crate::{pool::PoolEntry, util::encoding, Pool}; use encoding_rs::Encoding; -use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; +use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt, FutureExt}; use memchr::memmem::Finder; use size::{Base, Size}; use std::{ @@ -10,6 +10,7 @@ use std::{ ops::{Deref, DerefMut}, pin::Pin, str::FromStr, + time::Duration, }; use trillium_http::{ transport::BoxedTransport, @@ -60,6 +61,7 @@ pub struct Conn { pub(crate) response_body_state: ReceivedBodyState, pub(crate) config: ArcedConnector, pub(crate) headers_finalized: bool, + pub(crate) timeout: Option, } /// default http user-agent header @@ -492,6 +494,21 @@ impl Conn { .and_then(|t| t.peer_addr().ok().flatten()) } + /// set the timeout for this conn + /// + /// this can also be set on the client with [`Client::set_timeout`] and [`Client::with_timeout`] + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = Some(timeout); + } + + /// set the timeout for this conn + /// + /// this can also be set on the client with [`Client::set_timeout`] and [`Client::with_timeout`] + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.set_timeout(timeout); + self + } + // --- everything below here is private --- fn finalize_headers(&mut self) -> Result<()> { @@ -566,7 +583,7 @@ impl Conn { } None => { - let mut transport = Connector::connect(&self.config, &self.url).await?; + let mut transport = self.config.connect(&self.url).await?; log::debug!("opened new connection to {:?}", transport.peer_addr()?); transport.write_all(&head).await?; transport @@ -895,7 +912,17 @@ impl IntoFuture for Conn { fn into_future(mut self) -> Self::IntoFuture { Box::pin(async move { - self.exec().await?; + if let Some(duration) = self.timeout { + let config = self.config.clone(); + self.exec() + .or(async { + config.delay(duration).await; + Err(Error::TimedOut("Conn", duration)) + }) + .await? + } else { + self.exec().await?; + } Ok(self) }) } diff --git a/client/tests/one_hundred_continue.rs b/client/tests/one_hundred_continue.rs index 8ed4ee6aa0..ba072cd273 100644 --- a/client/tests/one_hundred_continue.rs +++ b/client/tests/one_hundred_continue.rs @@ -2,7 +2,7 @@ use async_channel::Sender; use futures_lite::future; use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; -use std::future::Future; +use std::future::{Future, IntoFuture}; use test_harness::test; use trillium_client::{Client, Conn, Error, Status, USER_AGENT}; use trillium_server_common::{Connector, Url}; @@ -229,6 +229,10 @@ impl Connector for TestConnector { fn spawn + Send + 'static>(&self, fut: Fut) { let _ = trillium_testing::spawn(fut); } + + async fn delay(&self, duration: std::time::Duration) { + trillium_testing::delay(duration).await + } } async fn test_conn( @@ -236,7 +240,7 @@ async fn test_conn( ) -> (TestTransport, impl Future>) { let (sender, receiver) = async_channel::unbounded(); let client = Client::new(TestConnector(sender)); - let conn_fut = trillium_testing::spawn(async move { setup(client).await }); + let conn_fut = trillium_testing::spawn(setup(client).into_future()); let transport = receiver.recv().await.unwrap(); (transport, async move { conn_fut.await.unwrap() }) } diff --git a/client/tests/timeout.rs b/client/tests/timeout.rs new file mode 100644 index 0000000000..9a5c24cd8f --- /dev/null +++ b/client/tests/timeout.rs @@ -0,0 +1,46 @@ +use std::time::Duration; +use trillium_client::Client; +use trillium_testing::ClientConfig; + +async fn handler(conn: trillium::Conn) -> trillium::Conn { + if conn.path() == "/slow" { + trillium_testing::delay(Duration::from_secs(5)).await; + } + conn.ok("ok") +} + +#[test] +fn timeout_on_conn() { + trillium_testing::with_server(handler, move |url| async move { + let client = Client::new(ClientConfig::new()).with_base(url); + let err = client + .get("/slow") + .with_timeout(Duration::from_millis(100)) + .await + .unwrap_err(); + + assert_eq!(err.to_string(), "Conn took longer than 100ms"); + + assert!(client + .get("/") + .with_timeout(Duration::from_millis(100)) + .await + .is_ok()); + + Ok(()) + }) +} + +#[test] +fn timeout_on_client() { + trillium_testing::with_server(handler, move |url| async move { + let client = Client::new(ClientConfig::new()) + .with_base(url) + .with_timeout(Duration::from_millis(100)); + let err = client.get("/slow").await.unwrap_err(); + assert_eq!(err.to_string(), "Conn took longer than 100ms"); + + assert!(client.get("/").await.is_ok()); + Ok(()) + }) +} diff --git a/http/src/error.rs b/http/src/error.rs index 4e483977fb..03e5cbca37 100644 --- a/http/src/error.rs +++ b/http/src/error.rs @@ -1,5 +1,5 @@ use crate::{HeaderName, Version}; -use std::{num::TryFromIntError, str::Utf8Error}; +use std::{num::TryFromIntError, str::Utf8Error, time::Duration}; use thiserror::Error; /// Concrete errors that occur within trillium's HTTP implementation @@ -86,6 +86,10 @@ pub enum Error { /// implementation on `ReceivedBody` #[error("Received body too long. Maximum {0} bytes")] ReceivedBodyTooLong(u64), + + /// something took longer than was allowed + #[error("{0} took longer than {1:?}")] + TimedOut(&'static str, Duration), } /// this crate's result type diff --git a/http/tests/use_cases.rs b/http/tests/use_cases.rs index 6135ccac66..a6a23a059d 100644 --- a/http/tests/use_cases.rs +++ b/http/tests/use_cases.rs @@ -60,4 +60,8 @@ where fn spawn + Send + 'static>(&self, fut: SpawnFut) { trillium_testing::spawn(fut); } + + async fn delay(&self, duration: std::time::Duration) { + trillium_testing::delay(duration).await + } } diff --git a/native-tls/src/client.rs b/native-tls/src/client.rs index 688356674f..b5094e3f60 100644 --- a/native-tls/src/client.rs +++ b/native-tls/src/client.rs @@ -102,6 +102,10 @@ impl Connector for NativeTlsConfig { fn spawn + Send + 'static>(&self, fut: Fut) { self.tcp_config.spawn(fut) } + + async fn delay(&self, duration: std::time::Duration) { + self.tcp_config.delay(duration).await + } } /** diff --git a/rustls/src/client.rs b/rustls/src/client.rs index 942992cf80..e24448d63c 100644 --- a/rustls/src/client.rs +++ b/rustls/src/client.rs @@ -125,6 +125,10 @@ impl Connector for RustlsConfig { fn spawn + Send + 'static>(&self, fut: Fut) { self.tcp_config.spawn(fut) } + + async fn delay(&self, duration: std::time::Duration) { + self.tcp_config.delay(duration).await + } } #[derive(Debug)] diff --git a/server-common/src/client.rs b/server-common/src/client.rs index bbb278c826..af5ca864df 100644 --- a/server-common/src/client.rs +++ b/server-common/src/client.rs @@ -8,6 +8,7 @@ use std::{ io, pin::Pin, sync::Arc, + time::Duration, }; /** Interface for runtime and tls adapters for the trillium client @@ -23,8 +24,11 @@ pub trait Connector: Send + Sync + 'static { /// Initiate a connection to the provided url fn connect(&self, url: &Url) -> impl Future> + Send; - /// spwan a future on the runtime + /// spawn and detach a future on the runtime fn spawn + Send + 'static>(&self, fut: Fut); + + /// wake in this amount of wall time + fn delay(&self, duration: Duration) -> impl Future + Send; } /// An Arced and type-erased [`Connector`] @@ -73,9 +77,17 @@ trait ObjectSafeConnector: Send + Sync + 'static { 'url: 'fut, Self: 'fut; fn spawn(&self, fut: Pin + Send + 'static>>); + fn delay<'connector, 'fut>( + &'connector self, + duration: Duration, + ) -> Pin + Send + 'fut>> + where + 'connector: 'fut, + Self: 'fut; fn as_any(&self) -> &dyn Any; fn as_mut_any(&mut self) -> &mut dyn Any; } + impl ObjectSafeConnector for T { fn connect<'connector, 'url, 'fut>( &'connector self, @@ -86,10 +98,10 @@ impl ObjectSafeConnector for T { 'url: 'fut, Self: 'fut, { - Box::pin(async move { T::connect(self, url).await.map(BoxedTransport::new) }) + Box::pin(async move { Connector::connect(self, url).await.map(BoxedTransport::new) }) } fn spawn(&self, fut: Pin + Send + 'static>>) { - T::spawn(self, fut) + Connector::spawn(self, fut) } fn as_any(&self) -> &dyn Any { @@ -98,6 +110,16 @@ impl ObjectSafeConnector for T { fn as_mut_any(&mut self) -> &mut dyn Any { self } + fn delay<'connector, 'fut>( + &'connector self, + duration: Duration, + ) -> Pin + Send + 'fut>> + where + 'connector: 'fut, + Self: 'fut, + { + Box::pin(async move { Connector::delay(self, duration).await }) + } } impl Connector for ArcedConnector { @@ -109,4 +131,8 @@ impl Connector for ArcedConnector { fn spawn + Send + 'static>(&self, fut: Fut) { self.0.spawn(Box::pin(fut)) } + + async fn delay(&self, duration: Duration) { + self.0.delay(duration).await + } } diff --git a/smol/src/client.rs b/smol/src/client.rs index ec8ca415a7..2414959621 100644 --- a/smol/src/client.rs +++ b/smol/src/client.rs @@ -83,4 +83,8 @@ impl Connector for ClientConfig { fn spawn + Send + 'static>(&self, fut: Fut) { async_global_executor::spawn(fut).detach(); } + + async fn delay(&self, duration: std::time::Duration) { + async_io::Timer::after(duration).await; + } } diff --git a/testing/src/lib.rs b/testing/src/lib.rs index e1650178b2..4128b7848a 100644 --- a/testing/src/lib.rs +++ b/testing/src/lib.rs @@ -150,6 +150,10 @@ cfg_if::cfg_if! { } pub use trillium_smol::async_global_executor::block_on; pub use trillium_smol::ClientConfig; + /// a future that wakes after this amount of time + pub async fn delay(duration: std::time::Duration) { + trillium_smol::async_io::Timer::after(duration).await; + } } else if #[cfg(feature = "async-std")] { /// runtime server config @@ -176,6 +180,15 @@ cfg_if::cfg_if! { pub fn client_config() -> impl Connector { ClientConfig::default() } + + /// a future that wakes after this amount of time + pub async fn delay(duration: std::time::Duration) { + let _ = trillium_async_std::async_std::future::timeout( + duration, + std::future::pending::<()>() + ).await; + } + } else if #[cfg(feature = "tokio")] { /// runtime server config pub fn config() -> Config { @@ -200,6 +213,12 @@ cfg_if::cfg_if! { pub fn client_config() -> impl Connector { ClientConfig::default() } + + /// a future that wakes after this amount of time + pub async fn delay(duration: std::time::Duration) { + trillium_tokio::tokio::time::sleep(duration).await; + } + } else { /// runtime server config pub fn config() -> Config { @@ -228,6 +247,17 @@ cfg_if::cfg_if! { rx.recv().await.ok() }) } + + /// a future that wakes after this amount of time + pub async fn delay(duration: std::time::Duration) { + let (sender, receiver) = async_channel::bounded::<()>(1); + std::thread::spawn(move || { + std::thread::sleep(duration); + let _ = sender.send_blocking(()); + }); + + let _ = receiver.recv().await; + } } } diff --git a/testing/src/runtimeless.rs b/testing/src/runtimeless.rs index 0f7dc8f3be..f9ddcd4b7d 100644 --- a/testing/src/runtimeless.rs +++ b/testing/src/runtimeless.rs @@ -109,6 +109,16 @@ impl Connector for RuntimelessClientConfig { fn spawn + Send + 'static>(&self, fut: Fut) { spawn(fut); } + + async fn delay(&self, duration: std::time::Duration) { + let (sender, receiver) = async_channel::bounded::<()>(1); + std::thread::spawn(move || { + std::thread::sleep(duration); + let _ = sender.send_blocking(()); + }); + + let _ = receiver.recv().await; + } } #[cfg(test)] diff --git a/testing/src/server_connector.rs b/testing/src/server_connector.rs index b272e6142b..916975d781 100644 --- a/testing/src/server_connector.rs +++ b/testing/src/server_connector.rs @@ -52,6 +52,10 @@ impl trillium_server_common::Connector for ServerConnector fn spawn + Send + 'static>(&self, fut: Fut) { crate::spawn(fut); } + + async fn delay(&self, duration: std::time::Duration) { + crate::delay(duration).await + } } /// build a connector from this handler diff --git a/tokio/src/client.rs b/tokio/src/client.rs index 5bea499d28..dbfabe0c2b 100644 --- a/tokio/src/client.rs +++ b/tokio/src/client.rs @@ -101,4 +101,8 @@ impl Connector for ClientConfig { fn spawn + Send + 'static>(&self, fut: Fut) { tokio::task::spawn(fut); } + + async fn delay(&self, duration: Duration) { + tokio::time::sleep(duration).await + } }