diff --git a/serial_test/src/lib.rs b/serial_test/src/lib.rs index b12d518..a4930ab 100644 --- a/serial_test/src/lib.rs +++ b/serial_test/src/lib.rs @@ -28,8 +28,7 @@ lazy_static! { Arc::new(RwLock::new(HashMap::new())); } -#[doc(hidden)] -pub fn serial_core(name: &str, function: fn()) { +fn check_new_key(name: &str) { // Check if a new key is needed. Just need a read lock, which can be done in sync with everyone else let new_key = { let unlock = LOCKS.read().unwrap(); @@ -43,6 +42,22 @@ pub fn serial_core(name: &str, function: fn()) { .deref_mut() .insert(name.to_string(), ReentrantMutex::new(())); } +} + +#[doc(hidden)] +pub fn serial_core_with_return(name: &str, function: fn() -> Result<(), E>) -> Result<(), E> { + check_new_key(name); + + let unlock = LOCKS.read().unwrap(); + // _guard needs to be named to avoid being instant dropped + let _guard = unlock.deref()[name].lock(); + function() +} + +#[doc(hidden)] +pub fn serial_core(name: &str, function: fn()) { + check_new_key(name); + let unlock = LOCKS.read().unwrap(); // _guard needs to be named to avoid being instant dropped let _guard = unlock.deref()[name].lock(); diff --git a/serial_test_derive/src/lib.rs b/serial_test_derive/src/lib.rs index e938e2d..932a601 100644 --- a/serial_test_derive/src/lib.rs +++ b/serial_test_derive/src/lib.rs @@ -7,6 +7,7 @@ use proc_macro::TokenStream; use proc_macro2::TokenTree; use quote::quote; use syn; +use std::ops::Deref; /// Allows for the creation of serialised Rust tests /// ```` @@ -78,6 +79,12 @@ fn serial_core( }; let ast: syn::ItemFn = syn::parse2(input).unwrap(); let name = ast.sig.ident; + let return_type = match ast.sig.output { + syn::ReturnType::Default => None, + syn::ReturnType::Type(_rarrow, ref box_type) => { + Some(box_type.deref()) + } + }; let block = ast.block; let attrs: Vec = ast .attrs @@ -96,13 +103,25 @@ fn serial_core( } }) .collect(); - let gen = quote! { - #(#attrs) - * - fn #name () { - serial_test::serial_core(#key, || { - #block - }); + let gen = if let Some(ret) = return_type { + quote! { + #(#attrs) + * + fn #name () -> #ret { + serial_test::serial_core_with_return(#key, || { + #block + }) + } + } + } else { + quote! { + #(#attrs) + * + fn #name () { + serial_test::serial_core(#key, || { + #block + }); + } } }; return gen.into(); diff --git a/serial_test_test/src/lib.rs b/serial_test_test/src/lib.rs index 1d51941..074c925 100644 --- a/serial_test_test/src/lib.rs +++ b/serial_test_test/src/lib.rs @@ -76,4 +76,10 @@ mod tests { init(); panic!("Testing panic"); } + + #[test] + #[serial] + fn test_can_return() -> Result<(), ()> { + Ok(()) + } }