diff --git a/bindings/rust/s2n-tls/src/error.rs b/bindings/rust/s2n-tls/src/error.rs index 9cdc731aa45..97e5c493e1e 100644 --- a/bindings/rust/s2n-tls/src/error.rs +++ b/bindings/rust/s2n-tls/src/error.rs @@ -8,7 +8,7 @@ use s2n_tls_sys::*; use std::{convert::TryFrom, ffi::CStr}; #[non_exhaustive] -#[derive(Debug, PartialEq)] +#[derive(Copy, Clone, Debug, PartialEq)] pub enum ErrorType { UnknownErrorType, NoError, @@ -47,8 +47,7 @@ impl From for ErrorType { } enum Context { - InvalidInput, - MissingWaker, + Bindings(ErrorType, &'static str, &'static str), Code(s2n_status_code::Type, Errno), Application(Box), } @@ -149,8 +148,16 @@ impl Pollable for T { } impl Error { - pub(crate) const INVALID_INPUT: Error = Self(Context::InvalidInput); - pub(crate) const MISSING_WAKER: Error = Self(Context::MissingWaker); + pub(crate) const INVALID_INPUT: Error = Self::bindings( + ErrorType::UsageError, + "InvalidInput", + "An input parameter was incorrect", + ); + pub(crate) const MISSING_WAKER: Error = Self::bindings( + ErrorType::UsageError, + "MissingWaker", + "Tried to perform an asynchronous operation without a configured waker", + ); /// Converts an io::Error into an s2n-tls Error pub fn io_error(err: std::io::Error) -> Error { @@ -167,6 +174,15 @@ impl Error { Self(Context::Application(error)) } + /// An error occured while running bindings code. + pub(crate) const fn bindings( + kind: ErrorType, + name: &'static str, + message: &'static str, + ) -> Self { + Self(Context::Bindings(kind, name, message)) + } + fn capture() -> Self { unsafe { let s2n_errno = s2n_errno_location(); @@ -184,8 +200,7 @@ impl Error { pub fn name(&self) -> &'static str { match self.0 { - Context::InvalidInput => "InvalidInput", - Context::MissingWaker => "MissingWaker", + Context::Bindings(_, name, _) => name, Context::Application(_) => "ApplicationError", Context::Code(code, _) => unsafe { // Safety: we assume the string has a valid encoding coming from s2n @@ -196,10 +211,7 @@ impl Error { pub fn message(&self) -> &'static str { match self.0 { - Context::InvalidInput => "A parameter was incorrect", - Context::MissingWaker => { - "Tried to perform an asynchronous operation without a configured waker" - } + Context::Bindings(_, _, msg) => msg, Context::Application(_) => "An error occurred while executing application code", Context::Code(code, _) => unsafe { // Safety: we assume the string has a valid encoding coming from s2n @@ -210,7 +222,7 @@ impl Error { pub fn debug(&self) -> Option<&'static str> { match self.0 { - Context::InvalidInput | Context::MissingWaker | Context::Application(_) => None, + Context::Bindings(_, _, _) | Context::Application(_) => None, Context::Code(code, _) => unsafe { let debug_info = s2n_strerror_debug(code, core::ptr::null()); @@ -230,7 +242,7 @@ impl Error { pub fn kind(&self) -> ErrorType { match self.0 { - Context::InvalidInput | Context::MissingWaker => ErrorType::UsageError, + Context::Bindings(error_type, _, _) => error_type, Context::Application(_) => ErrorType::Application, Context::Code(code, _) => unsafe { ErrorType::from(s2n_error_get_type(code)) }, } @@ -238,7 +250,7 @@ impl Error { pub fn source(&self) -> ErrorSource { match self.0 { - Context::InvalidInput | Context::MissingWaker => ErrorSource::Bindings, + Context::Bindings(_, _, _) => ErrorSource::Bindings, Context::Application(_) => ErrorSource::Application, Context::Code(_, _) => ErrorSource::Library, } @@ -270,7 +282,7 @@ impl Error { /// This API is currently incomplete and should not be relied upon. pub fn alert(&self) -> Option { match self.0 { - Context::InvalidInput | Context::MissingWaker | Context::Application(_) => None, + Context::Bindings(_, _, _) | Context::Application(_) => None, Context::Code(code, _) => { let mut alert = 0; let r = unsafe { s2n_error_get_alert(code, &mut alert) }; @@ -465,4 +477,17 @@ mod tests { .unwrap(); } } + + #[test] + fn bindings_error() { + let name = "TestError"; + let message = "Custom error for test"; + let kind = ErrorType::InternalError; + let error = Error::bindings(kind, name, message); + assert_eq!(error.kind(), kind); + assert_eq!(error.name(), name); + assert_eq!(error.message(), message); + assert_eq!(error.debug(), None); + assert_eq!(error.source(), ErrorSource::Bindings); + } }