Skip to content

Commit

Permalink
feat(client)!: add support for client timeouts
Browse files Browse the repository at this point in the history
  • Loading branch information
jbr committed Apr 6, 2024
1 parent e4b3a48 commit e80c3bb
Show file tree
Hide file tree
Showing 15 changed files with 208 additions and 15 deletions.
4 changes: 4 additions & 0 deletions async-std/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,8 @@ impl Connector for ClientConfig {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
async_std::task::spawn(fut);
}

async fn delay(&self, duration: std::time::Duration) {
let _ = async_std::future::timeout(duration, std::future::pending::<()>()).await;
}
}
30 changes: 24 additions & 6 deletions client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{Conn, IntoUrl, Pool, USER_AGENT};
use std::{fmt::Debug, sync::Arc};
use std::{fmt::Debug, sync::Arc, time::Duration};
use trillium_http::{
transport::BoxedTransport, HeaderName, HeaderValues, Headers, KnownHeaderName, Method,
ReceivedBodyState,
Expand All @@ -20,6 +20,7 @@ pub struct Client {
pool: Option<Pool<Origin, BoxedTransport>>,
base: Option<Arc<Url>>,
default_headers: Arc<Headers>,
timeout: Option<Duration>,
}

macro_rules! method {
Expand Down Expand Up @@ -74,9 +75,16 @@ impl Client {
pool: None,
base: None,
default_headers: Arc::new(default_request_headers()),
timeout: None,
}
}

method!(get, Get);
method!(post, Post);
method!(put, Put);
method!(delete, Delete);
method!(patch, Patch);

/// chainable method to remove a header from default request headers
pub fn without_default_header(mut self, name: impl Into<HeaderName<'static>>) -> Self {
self.default_headers_mut().remove(name);
Expand Down Expand Up @@ -160,6 +168,7 @@ impl Client {
response_body_state: ReceivedBodyState::Start,
config: self.config.clone(),
headers_finalized: false,
timeout: self.timeout,
}
}

Expand Down Expand Up @@ -209,11 +218,20 @@ impl Client {
Ok(())
}

method!(get, Get);
method!(post, Post);
method!(put, Put);
method!(delete, Delete);
method!(patch, Patch);
/// set the timeout for all conns this client builds
///
/// this can also be set with [`Conn::set_timeout`] and [`Conn::with_timeout`]
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = Some(timeout);
}

/// set the timeout for all conns this client builds
///
/// this can also be set with [`Conn::set_timeout`] and [`Conn::with_timeout`]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.set_timeout(timeout);
self
}
}

impl<T: Connector> From<T> for Client {
Expand Down
33 changes: 30 additions & 3 deletions client/src/conn.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::{pool::PoolEntry, util::encoding, Pool};
use encoding_rs::Encoding;
use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt};
use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt, FutureExt};
use memchr::memmem::Finder;
use size::{Base, Size};
use std::{
Expand All @@ -10,6 +10,7 @@ use std::{
ops::{Deref, DerefMut},
pin::Pin,
str::FromStr,
time::Duration,
};
use trillium_http::{
transport::BoxedTransport,
Expand Down Expand Up @@ -60,6 +61,7 @@ pub struct Conn {
pub(crate) response_body_state: ReceivedBodyState,
pub(crate) config: ArcedConnector,
pub(crate) headers_finalized: bool,
pub(crate) timeout: Option<Duration>,
}

/// default http user-agent header
Expand Down Expand Up @@ -492,6 +494,21 @@ impl Conn {
.and_then(|t| t.peer_addr().ok().flatten())
}

/// set the timeout for this conn
///
/// this can also be set on the client with [`Client::set_timeout`] and [`Client::with_timeout`]
pub fn set_timeout(&mut self, timeout: Duration) {
self.timeout = Some(timeout);
}

/// set the timeout for this conn
///
/// this can also be set on the client with [`Client::set_timeout`] and [`Client::with_timeout`]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.set_timeout(timeout);
self
}

// --- everything below here is private ---

fn finalize_headers(&mut self) -> Result<()> {
Expand Down Expand Up @@ -566,7 +583,7 @@ impl Conn {
}

None => {
let mut transport = Connector::connect(&self.config, &self.url).await?;
let mut transport = self.config.connect(&self.url).await?;
log::debug!("opened new connection to {:?}", transport.peer_addr()?);
transport.write_all(&head).await?;
transport
Expand Down Expand Up @@ -895,7 +912,17 @@ impl IntoFuture for Conn {

fn into_future(mut self) -> Self::IntoFuture {
Box::pin(async move {
self.exec().await?;
if let Some(duration) = self.timeout {
let config = self.config.clone();
self.exec()
.or(async {
config.delay(duration).await;
Err(Error::TimedOut("Conn", duration))
})
.await?
} else {
self.exec().await?;
}
Ok(self)
})
}
Expand Down
8 changes: 6 additions & 2 deletions client/tests/one_hundred_continue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use async_channel::Sender;
use futures_lite::future;
use indoc::{formatdoc, indoc};
use pretty_assertions::assert_eq;
use std::future::Future;
use std::future::{Future, IntoFuture};
use test_harness::test;
use trillium_client::{Client, Conn, Error, Status, USER_AGENT};
use trillium_server_common::{Connector, Url};
Expand Down Expand Up @@ -229,14 +229,18 @@ impl Connector for TestConnector {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
let _ = trillium_testing::spawn(fut);
}

async fn delay(&self, duration: std::time::Duration) {
trillium_testing::delay(duration).await
}
}

async fn test_conn(
setup: impl FnOnce(Client) -> Conn + Send + 'static,
) -> (TestTransport, impl Future<Output = Result<Conn, Error>>) {
let (sender, receiver) = async_channel::unbounded();
let client = Client::new(TestConnector(sender));
let conn_fut = trillium_testing::spawn(async move { setup(client).await });
let conn_fut = trillium_testing::spawn(setup(client).into_future());
let transport = receiver.recv().await.unwrap();
(transport, async move { conn_fut.await.unwrap() })
}
46 changes: 46 additions & 0 deletions client/tests/timeout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
use std::time::Duration;
use trillium_client::Client;
use trillium_testing::ClientConfig;

async fn handler(conn: trillium::Conn) -> trillium::Conn {
if conn.path() == "/slow" {
trillium_testing::delay(Duration::from_secs(5)).await;
}
conn.ok("ok")
}

#[test]
fn timeout_on_conn() {
trillium_testing::with_server(handler, move |url| async move {
let client = Client::new(ClientConfig::new()).with_base(url);
let err = client
.get("/slow")
.with_timeout(Duration::from_millis(100))
.await
.unwrap_err();

assert_eq!(err.to_string(), "Conn took longer than 100ms");

assert!(client
.get("/")
.with_timeout(Duration::from_millis(100))
.await
.is_ok());

Ok(())
})
}

#[test]
fn timeout_on_client() {
trillium_testing::with_server(handler, move |url| async move {
let client = Client::new(ClientConfig::new())
.with_base(url)
.with_timeout(Duration::from_millis(100));
let err = client.get("/slow").await.unwrap_err();
assert_eq!(err.to_string(), "Conn took longer than 100ms");

assert!(client.get("/").await.is_ok());
Ok(())
})
}
6 changes: 5 additions & 1 deletion http/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{HeaderName, Version};
use std::{num::TryFromIntError, str::Utf8Error};
use std::{num::TryFromIntError, str::Utf8Error, time::Duration};
use thiserror::Error;

/// Concrete errors that occur within trillium's HTTP implementation
Expand Down Expand Up @@ -86,6 +86,10 @@ pub enum Error {
/// implementation on `ReceivedBody`
#[error("Received body too long. Maximum {0} bytes")]
ReceivedBodyTooLong(u64),

/// something took longer than was allowed
#[error("{0} took longer than {1:?}")]
TimedOut(&'static str, Duration),
}

/// this crate's result type
Expand Down
4 changes: 4 additions & 0 deletions http/tests/use_cases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,8 @@ where
fn spawn<SpawnFut: Future<Output = ()> + Send + 'static>(&self, fut: SpawnFut) {
trillium_testing::spawn(fut);
}

async fn delay(&self, duration: std::time::Duration) {
trillium_testing::delay(duration).await
}
}
4 changes: 4 additions & 0 deletions native-tls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ impl<T: Connector> Connector for NativeTlsConfig<T> {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
self.tcp_config.spawn(fut)
}

async fn delay(&self, duration: std::time::Duration) {
self.tcp_config.delay(duration).await
}
}

/**
Expand Down
4 changes: 4 additions & 0 deletions rustls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ impl<C: Connector> Connector for RustlsConfig<C> {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
self.tcp_config.spawn(fut)
}

async fn delay(&self, duration: std::time::Duration) {
self.tcp_config.delay(duration).await
}
}

#[derive(Debug)]
Expand Down
32 changes: 29 additions & 3 deletions server-common/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use std::{
io,
pin::Pin,
sync::Arc,
time::Duration,
};
/**
Interface for runtime and tls adapters for the trillium client
Expand All @@ -23,8 +24,11 @@ pub trait Connector: Send + Sync + 'static {
/// Initiate a connection to the provided url
fn connect(&self, url: &Url) -> impl Future<Output = io::Result<Self::Transport>> + Send;

/// spwan a future on the runtime
/// spawn and detach a future on the runtime
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut);

/// wake in this amount of wall time
fn delay(&self, duration: Duration) -> impl Future<Output = ()> + Send;
}

/// An Arced and type-erased [`Connector`]
Expand Down Expand Up @@ -73,9 +77,17 @@ trait ObjectSafeConnector: Send + Sync + 'static {
'url: 'fut,
Self: 'fut;
fn spawn(&self, fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>);
fn delay<'connector, 'fut>(
&'connector self,
duration: Duration,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'fut>>
where
'connector: 'fut,
Self: 'fut;
fn as_any(&self) -> &dyn Any;
fn as_mut_any(&mut self) -> &mut dyn Any;
}

impl<T: Connector> ObjectSafeConnector for T {
fn connect<'connector, 'url, 'fut>(
&'connector self,
Expand All @@ -86,10 +98,10 @@ impl<T: Connector> ObjectSafeConnector for T {
'url: 'fut,
Self: 'fut,
{
Box::pin(async move { T::connect(self, url).await.map(BoxedTransport::new) })
Box::pin(async move { Connector::connect(self, url).await.map(BoxedTransport::new) })
}
fn spawn(&self, fut: Pin<Box<dyn Future<Output = ()> + Send + 'static>>) {
T::spawn(self, fut)
Connector::spawn(self, fut)
}

fn as_any(&self) -> &dyn Any {
Expand All @@ -98,6 +110,16 @@ impl<T: Connector> ObjectSafeConnector for T {
fn as_mut_any(&mut self) -> &mut dyn Any {
self
}
fn delay<'connector, 'fut>(
&'connector self,
duration: Duration,
) -> Pin<Box<dyn Future<Output = ()> + Send + 'fut>>
where
'connector: 'fut,
Self: 'fut,
{
Box::pin(async move { Connector::delay(self, duration).await })
}
}

impl Connector for ArcedConnector {
Expand All @@ -109,4 +131,8 @@ impl Connector for ArcedConnector {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
self.0.spawn(Box::pin(fut))
}

async fn delay(&self, duration: Duration) {
self.0.delay(duration).await
}
}
4 changes: 4 additions & 0 deletions smol/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,8 @@ impl Connector for ClientConfig {
fn spawn<Fut: Future<Output = ()> + Send + 'static>(&self, fut: Fut) {
async_global_executor::spawn(fut).detach();
}

async fn delay(&self, duration: std::time::Duration) {
async_io::Timer::after(duration).await;
}
}
Loading

0 comments on commit e80c3bb

Please sign in to comment.