Skip to content

Commit dc2992c

Browse files
authored
feat: add support for the skip_teardown keyword (#40)
* feat: add support for the skip_teardown keyword
1 parent 8fd7841 commit dc2992c

File tree

5 files changed

+179
-29
lines changed

5 files changed

+179
-29
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,28 @@ async fn test_works(ctx: &mut MyAsyncContext) {
6969
}
7070
```
7171

72+
## Skipping the teardown execution
73+
74+
If what you need is to take full **ownership** of the context and don't care about the
75+
teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
76+
like this:
77+
78+
```rust
79+
use test_context::{test_context, TestContext};
80+
81+
struct MyContext {}
82+
83+
impl TestContext for MyContext {
84+
fn setup() -> MyContext {
85+
MyContext {}
86+
}
87+
}
88+
89+
#[test_context(MyContext, skip_teardown)]
90+
#[test]
91+
fn test_without_teardown(ctx: MyContext) {
92+
// Perform any operations that require full ownership of your context
93+
}
94+
```
95+
7296
License: MIT

test-context-macros/src/args.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
use syn::{parse::Parse, Ident, Token};
2+
3+
pub(crate) struct TestContextArgs {
4+
pub(crate) context_type: Ident,
5+
pub(crate) skip_teardown: bool,
6+
}
7+
8+
impl Parse for TestContextArgs {
9+
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
10+
let mut skip_teardown = false;
11+
let mut context_type: Option<Ident> = None;
12+
13+
while !input.is_empty() {
14+
let lookahead = input.lookahead1();
15+
if lookahead.peek(kw::skip_teardown) {
16+
if skip_teardown {
17+
return Err(input.error("expected only a single `skip_teardown` argument"));
18+
}
19+
let _ = input.parse::<kw::skip_teardown>()?;
20+
skip_teardown = true;
21+
} else if lookahead.peek(Ident) {
22+
if context_type.is_some() {
23+
return Err(input.error("expected only a single type identifier"));
24+
}
25+
context_type = Some(input.parse()?);
26+
} else if lookahead.peek(Token![,]) {
27+
let _ = input.parse::<Token![,]>()?;
28+
} else {
29+
return Err(lookahead.error());
30+
}
31+
}
32+
33+
Ok(TestContextArgs {
34+
context_type: context_type
35+
.ok_or(input.error("expected at least one type identifier"))?,
36+
skip_teardown,
37+
})
38+
}
39+
}
40+
41+
mod kw {
42+
syn::custom_keyword!(skip_teardown);
43+
}

test-context-macros/src/lib.rs

Lines changed: 68 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
mod args;
2+
3+
use args::TestContextArgs;
14
use proc_macro::TokenStream;
25
use quote::{format_ident, quote};
36
use syn::Ident;
@@ -24,7 +27,8 @@ use syn::Ident;
2427
/// ```
2528
#[proc_macro_attribute]
2629
pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
27-
let context_type = syn::parse_macro_input!(attr as syn::Ident);
30+
let args = syn::parse_macro_input!(attr as TestContextArgs);
31+
2832
let input = syn::parse_macro_input!(item as syn::ItemFn);
2933
let ret = &input.sig.output;
3034
let name = &input.sig.ident;
@@ -36,9 +40,9 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3640
let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
3741

3842
let wrapper_body = if is_async {
39-
async_wrapper_body(context_type, &wrapped_name)
43+
async_wrapper_body(args, &wrapped_name)
4044
} else {
41-
sync_wrapper_body(context_type, &wrapped_name)
45+
sync_wrapper_body(args, &wrapped_name)
4246
};
4347

4448
let async_tag = if is_async {
@@ -56,42 +60,77 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
5660
.into()
5761
}
5862

59-
fn async_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
63+
fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
64+
let context_type = args.context_type;
65+
let result_name = format_ident!("wrapped_result");
66+
67+
let body = if args.skip_teardown {
68+
quote! {
69+
let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
70+
let #result_name = std::panic::AssertUnwindSafe(
71+
#wrapped_name(ctx)
72+
).catch_unwind().await;
73+
}
74+
} else {
75+
quote! {
76+
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
77+
let ctx_reference = &mut ctx;
78+
let #result_name = std::panic::AssertUnwindSafe(
79+
#wrapped_name(ctx_reference)
80+
).catch_unwind().await;
81+
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
82+
}
83+
};
84+
85+
let handle_wrapped_result = handle_result(result_name);
86+
6087
quote! {
6188
{
6289
use test_context::futures::FutureExt;
63-
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
64-
let wrapped_ctx = &mut ctx;
65-
let result = async move {
66-
std::panic::AssertUnwindSafe(
67-
#wrapped_name(wrapped_ctx)
68-
).catch_unwind().await
69-
}.await;
70-
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
71-
match result {
72-
Ok(returned_value) => returned_value,
73-
Err(err) => {
74-
std::panic::resume_unwind(err);
75-
}
76-
}
90+
#body
91+
#handle_wrapped_result
7792
}
7893
}
7994
}
8095

81-
fn sync_wrapper_body(context_type: Ident, wrapped_name: &Ident) -> proc_macro2::TokenStream {
82-
quote! {
83-
{
96+
fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
97+
let context_type = args.context_type;
98+
let result_name = format_ident!("wrapped_result");
99+
100+
let body = if args.skip_teardown {
101+
quote! {
102+
let ctx = <#context_type as test_context::TestContext>::setup();
103+
let #result_name = std::panic::catch_unwind(move || {
104+
#wrapped_name(ctx)
105+
});
106+
}
107+
} else {
108+
quote! {
84109
let mut ctx = <#context_type as test_context::TestContext>::setup();
85-
let mut wrapper = std::panic::AssertUnwindSafe(&mut ctx);
86-
let result = std::panic::catch_unwind(move || {
87-
#wrapped_name(*wrapper)
110+
let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
111+
let #result_name = std::panic::catch_unwind(move || {
112+
#wrapped_name(*pointer)
88113
});
89114
<#context_type as test_context::TestContext>::teardown(ctx);
90-
match result {
91-
Ok(returned_value) => returned_value,
92-
Err(err) => {
93-
std::panic::resume_unwind(err);
94-
}
115+
}
116+
};
117+
118+
let handle_wrapped_result = handle_result(result_name);
119+
120+
quote! {
121+
{
122+
#body
123+
#handle_wrapped_result
124+
}
125+
}
126+
}
127+
128+
fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
129+
quote! {
130+
match #result_name {
131+
Ok(value) => value,
132+
Err(err) => {
133+
std::panic::resume_unwind(err);
95134
}
96135
}
97136
}

test-context/src/lib.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,30 @@
7777
//! assert_eq!(ctx.value, "Hello, World!");
7878
//! }
7979
//! ```
80+
//!
81+
//! # Skipping the teardown execution
82+
//!
83+
//! If what you need is to take full __ownership__ of the context and don't care about the
84+
//! teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
85+
//! like this:
86+
//!
87+
//! ```no_run
88+
//! use test_context::{test_context, TestContext};
89+
//!
90+
//! struct MyContext {}
91+
//!
92+
//! impl TestContext for MyContext {
93+
//! fn setup() -> MyContext {
94+
//! MyContext {}
95+
//! }
96+
//! }
97+
//!
98+
//! #[test_context(MyContext, skip_teardown)]
99+
//! #[test]
100+
//! fn test_without_teardown(ctx: MyContext) {
101+
//! // Perform any operations that require full ownership of your context
102+
//! }
103+
//! ```
80104
81105
// Reimported to allow for use in the macro.
82106
pub use futures;

test-context/tests/test.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,23 @@ fn use_different_name(test_data: &mut Context) {
111111
async fn use_different_name_async(test_data: &mut AsyncContext) {
112112
assert_eq!(test_data.n, 1);
113113
}
114+
115+
struct TeardownPanicContext {}
116+
117+
impl AsyncTestContext for TeardownPanicContext {
118+
async fn setup() -> Self {
119+
Self {}
120+
}
121+
122+
async fn teardown(self) {
123+
panic!("boom!");
124+
}
125+
}
126+
127+
#[test_context(TeardownPanicContext, skip_teardown)]
128+
#[tokio::test]
129+
async fn test_async_skip_teardown(mut _ctx: TeardownPanicContext) {}
130+
131+
#[test_context(TeardownPanicContext, skip_teardown)]
132+
#[test]
133+
fn test_sync_skip_teardown(mut _ctx: TeardownPanicContext) {}

0 commit comments

Comments
 (0)