Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add peer certificates for mtls to Warp #1108

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
Cargo.lock
.idea/
warp.iml

*.swp
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ percent-encoding = "2.1"
pin-project = "1.0"
tokio-rustls = { version = "0.25", optional = true }
rustls-pemfile = { version = "2.0", optional = true }
rustls-pki-types = { version = "1.9.0", optional = true }

[dev-dependencies]
pretty_env_logger = "0.5"
Expand All @@ -56,7 +57,7 @@ listenfd = "1.0"
default = ["multipart", "websocket"]
multipart = ["multer"]
websocket = ["tokio-tungstenite"]
tls = ["tokio-rustls", "rustls-pemfile"]
tls = ["tokio-rustls", "rustls-pemfile", "rustls-pki-types"]

# Enable compression-related filters
compression = ["compression-brotli", "compression-gzip"]
Expand Down
10 changes: 5 additions & 5 deletions src/filter/service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::convert::Infallible;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};

Expand All @@ -11,6 +10,7 @@ use pin_project::pin_project;
use crate::reject::IsReject;
use crate::reply::{Reply, Response};
use crate::route::{self, Route};
use crate::transport::PeerInfo;
use crate::{Filter, Request};

/// Convert a `Filter` into a `Service`.
Expand Down Expand Up @@ -70,14 +70,14 @@ where
<F::Future as TryFuture>::Error: IsReject,
{
#[inline]
pub(crate) fn call_with_addr(
pub(crate) fn call_with_peer_info(
&self,
req: Request,
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
) -> FilteredFuture<F::Future> {
debug_assert!(!route::is_set(), "nested route::set calls");

let route = Route::new(req, remote_addr);
let route = Route::new(req, peer_info);
let fut = route::set(&route, || self.filter.filter(super::Internal));
FilteredFuture { future: fut, route }
}
Expand All @@ -99,7 +99,7 @@ where

#[inline]
fn call(&mut self, req: Request) -> Self::Future {
self.call_with_addr(req, None)
self.call_with_peer_info(req, Default::default())
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/filters/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ pub mod path;
pub mod query;
pub mod reply;
pub mod sse;
#[cfg(feature = "tls")]
pub mod mtls;
pub mod trace;
#[cfg(feature = "websocket")]
pub mod ws;
Expand Down
52 changes: 52 additions & 0 deletions src/filters/mtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//! Mutual (client) TLS filters.

use std::convert::Infallible;

use rustls_pki_types::CertificateDer;

use crate::{
filter::{filter_fn_one, Filter},
route::Route,
};

/// Certificates is a iterable container of Certificates.
pub type Certificates = Vec<CertificateDer<'static>>;

/// Creates a `Filter` to get the peer certificates for the TLS connection.
///
/// If the underlying transport doesn't have peer certificates, this will yield
/// `None`.
///
/// # Example
///
/// ```
/// use warp::mtls::Certificates;
/// use warp::Filter;
///
/// let route = warp::mtls::peer_certificates()
/// .map(|certs: Option<Certificates>| {
/// println!("peer certificates = {:?}", certs.as_ref());
/// });
/// ```
pub fn peer_certificates(
) -> impl Filter<Extract = (Option<Certificates>,), Error = Infallible> + Copy {
filter_fn_one(|route| futures_util::future::ok(from_route(route)))
}

/// Testing
pub fn peer_certs_into_owned(certs: &Vec<CertificateDer<'_>>) -> Vec<CertificateDer<'static>> {
certs
.to_vec()
.iter()
.map(|cert| cert.clone().into_owned())
.collect()
}

fn from_route(route: &Route) -> Option<Certificates> {
route
.peer_certificates()
.read()
.unwrap()
.as_ref()
.map(peer_certs_into_owned)
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub use self::filters::compression;
#[cfg(feature = "multipart")]
#[doc(hidden)]
pub use self::filters::multipart;
#[cfg(feature = "tls")]
#[doc(hidden)]
pub use self::filters::mtls;
#[cfg(feature = "websocket")]
#[doc(hidden)]
pub use self::filters::ws;
Expand Down
14 changes: 10 additions & 4 deletions src/route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::net::SocketAddr;
use hyper::Body;

use crate::Request;
use crate::transport::PeerInfo;

scoped_thread_local!(static ROUTE: RefCell<Route>);

Expand All @@ -30,7 +31,7 @@ where
#[derive(Debug)]
pub(crate) struct Route {
body: BodyState,
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
req: Request,
segments_index: usize,
}
Expand All @@ -42,7 +43,7 @@ enum BodyState {
}

impl Route {
pub(crate) fn new(req: Request, remote_addr: Option<SocketAddr>) -> RefCell<Route> {
pub(crate) fn new(req: Request, peer_info: PeerInfo) -> RefCell<Route> {
let segments_index = if req.uri().path().starts_with('/') {
// Skip the beginning slash.
1
Expand All @@ -52,7 +53,7 @@ impl Route {

RefCell::new(Route {
body: BodyState::Ready,
remote_addr,
peer_info,
req,
segments_index,
})
Expand Down Expand Up @@ -124,7 +125,12 @@ impl Route {
}

pub(crate) fn remote_addr(&self) -> Option<SocketAddr> {
self.remote_addr
self.peer_info.remote_addr
}

#[cfg(feature = "tls")]
pub(crate) fn peer_certificates(&self) -> crate::transport::PeerCertificates {
self.peer_info.peer_certificates.clone()
}

pub(crate) fn take_body(&mut self) -> Option<Body> {
Expand Down
9 changes: 7 additions & 2 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,14 @@ macro_rules! into_service {
let inner = crate::service($into);
make_service_fn(move |transport| {
let inner = inner.clone();
let remote_addr = Transport::remote_addr(transport);

let peer_info = crate::transport::PeerInfo {
remote_addr: Transport::remote_addr(transport),
peer_certificates: Transport::peer_certificates(transport),
};

future::ok::<_, Infallible>(service_fn(move |req| {
inner.call_with_addr(req, remote_addr)
inner.call_with_peer_info(req, peer_info.clone())
}))
})
}};
Expand Down
33 changes: 24 additions & 9 deletions src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,19 @@ use crate::filters::ws::Message;
use crate::reject::IsReject;
use crate::reply::Reply;
use crate::route::{self, Route};
use crate::transport::PeerInfo;
use crate::Request;
#[cfg(feature = "websocket")]
use crate::{Sink, Stream};

#[cfg(feature = "tls")]
use crate::filters::mtls::Certificates;

use self::inner::OneOrTuple;

/// Starts a new test `RequestBuilder`.
pub fn request() -> RequestBuilder {
RequestBuilder {
remote_addr: None,
req: Request::default(),
}
Default::default()
}

/// Starts a new test `WsBuilder`.
Expand All @@ -137,9 +138,9 @@ pub fn ws() -> WsBuilder {
///
/// See [module documentation](crate::test) for an overview.
#[must_use = "RequestBuilder does nothing on its own"]
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct RequestBuilder {
remote_addr: Option<SocketAddr>,
peer_info: PeerInfo,
req: Request,
}

Expand Down Expand Up @@ -248,7 +249,21 @@ impl RequestBuilder {
/// .remote_addr(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080));
/// ```
pub fn remote_addr(mut self, addr: SocketAddr) -> Self {
self.remote_addr = Some(addr);
self.peer_info.remote_addr = Some(addr);
self
}

/// Set the peer certificates of this request.
/// Default is no peer certificates.
///
/// # Example
/// ```
/// let req = warp::test::request()
/// .peer_certificates([rustls_pki_types::CertificateDer::from_slice(b"FAKE CERT")]);
/// ```
#[cfg(feature = "tls")]
pub fn peer_certificates(self, certs: impl Into<Certificates>) -> Self {
*self.peer_info.peer_certificates.write().unwrap() = Some(certs.into());
self
}

Expand Down Expand Up @@ -375,7 +390,7 @@ impl RequestBuilder {
// TODO: de-duplicate this and apply_filter()
assert!(!route::is_set(), "nested test filter calls");

let route = Route::new(self.req, self.remote_addr);
let route = Route::new(self.req, self.peer_info);
let mut fut = Box::pin(
route::set(&route, move || f.filter(crate::filter::Internal)).then(|result| {
let res = match result {
Expand Down Expand Up @@ -404,7 +419,7 @@ impl RequestBuilder {
{
assert!(!route::is_set(), "nested test filter calls");

let route = Route::new(self.req, self.remote_addr);
let route = Route::new(self.req, self.peer_info);
let mut fut = Box::pin(route::set(&route, move || {
f.filter(crate::filter::Internal)
}));
Expand Down
14 changes: 13 additions & 1 deletion src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ use hyper::server::conn::{AddrIncoming, AddrStream};
use tokio_rustls::rustls::server::WebPkiClientVerifier;
use tokio_rustls::rustls::{Error as TlsError, RootCertStore, ServerConfig};

use crate::transport::Transport;
use crate::filters::mtls::peer_certs_into_owned;
use crate::transport::{PeerCertificates, Transport};

/// Represents errors that can occur building the TlsConfig
#[derive(Debug)]
Expand Down Expand Up @@ -284,6 +285,10 @@ impl Transport for TlsStream {
fn remote_addr(&self) -> Option<SocketAddr> {
Some(self.remote_addr)
}

fn peer_certificates(&self) -> PeerCertificates {
self.peer_certs.clone()
}
}

enum State {
Expand All @@ -297,6 +302,7 @@ enum State {
pub(crate) struct TlsStream {
state: State,
remote_addr: SocketAddr,
peer_certs: PeerCertificates,
}

impl TlsStream {
Expand All @@ -306,6 +312,7 @@ impl TlsStream {
TlsStream {
state: State::Handshaking(accept),
remote_addr,
peer_certs: Default::default(),
}
}
}
Expand All @@ -320,6 +327,11 @@ impl AsyncRead for TlsStream {
match pin.state {
State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) {
Ok(mut stream) => {
let (_, conn) = stream.get_ref();
*pin.peer_certs.write().unwrap() = conn
.peer_certificates()
.map(|certs| peer_certs_into_owned(&certs.to_vec()));

let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
Expand Down
19 changes: 19 additions & 0 deletions src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,27 @@ use std::task::{Context, Poll};
use hyper::server::conn::AddrStream;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[cfg(feature = "tls")]
use crate::filters::mtls::Certificates;

#[cfg(feature = "tls")]
pub(crate) type PeerCertificates = std::sync::Arc<std::sync::RwLock<Option<Certificates>>>;
#[cfg(not(feature = "tls"))]
pub(crate) type PeerCertificates = ();

pub trait Transport: AsyncRead + AsyncWrite {
fn remote_addr(&self) -> Option<SocketAddr>;

fn peer_certificates(&self) -> PeerCertificates {
Default::default()
}
}

#[derive(Clone, Debug, Default)]
pub(crate) struct PeerInfo {
pub remote_addr: Option<SocketAddr>,
#[allow(dead_code)]
pub peer_certificates: PeerCertificates,
}

impl Transport for AddrStream {
Expand Down
24 changes: 24 additions & 0 deletions tests/mtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#![deny(warnings)]
#![cfg(feature = "tls")]

use rustls_pki_types::CertificateDer;

#[tokio::test]
async fn peer_certificates_missing() {
let extract_peer_certs = warp::mtls::peer_certificates();

let req = warp::test::request();
let resp = req.filter(&extract_peer_certs).await.unwrap();
assert!(resp.is_none())
}

#[tokio::test]
async fn peer_certificates_present() {
let extract_peer_certs = warp::mtls::peer_certificates();

let cert = CertificateDer::<'_>::from_slice(b"TEST CERT");

let req = warp::test::request().peer_certificates([cert.clone()]);
let resp = req.filter(&extract_peer_certs).await.unwrap();
assert_eq!(resp.unwrap(), &[cert],)
}