1+ mod args;
2+
3+ use args:: TestContextArgs ;
14use proc_macro:: TokenStream ;
25use quote:: { format_ident, quote} ;
36use syn:: Ident ;
@@ -24,7 +27,8 @@ use syn::Ident;
2427/// ```
2528#[ proc_macro_attribute]
2629pub fn test_context ( attr : TokenStream , item : TokenStream ) -> TokenStream {
27- let context_type = syn:: parse_macro_input!( attr as syn:: Ident ) ;
30+ let args = syn:: parse_macro_input!( attr as TestContextArgs ) ;
31+
2832 let input = syn:: parse_macro_input!( item as syn:: ItemFn ) ;
2933 let ret = & input. sig . output ;
3034 let name = & input. sig . ident ;
@@ -36,9 +40,9 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3640 let wrapped_name = format_ident ! ( "__test_context_wrapped_{}" , name) ;
3741
3842 let wrapper_body = if is_async {
39- async_wrapper_body ( context_type , & wrapped_name)
43+ async_wrapper_body ( args , & wrapped_name)
4044 } else {
41- sync_wrapper_body ( context_type , & wrapped_name)
45+ sync_wrapper_body ( args , & wrapped_name)
4246 } ;
4347
4448 let async_tag = if is_async {
@@ -56,42 +60,77 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
5660 . into ( )
5761}
5862
59- fn async_wrapper_body ( context_type : Ident , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
63+ fn async_wrapper_body ( args : TestContextArgs , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
64+ let context_type = args. context_type ;
65+ let result_name = format_ident ! ( "wrapped_result" ) ;
66+
67+ let body = if args. skip_teardown {
68+ quote ! {
69+ let ctx = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
70+ let #result_name = std:: panic:: AssertUnwindSafe (
71+ #wrapped_name( ctx)
72+ ) . catch_unwind( ) . await ;
73+ }
74+ } else {
75+ quote ! {
76+ let mut ctx = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
77+ let ctx_reference = & mut ctx;
78+ let #result_name = std:: panic:: AssertUnwindSafe (
79+ #wrapped_name( ctx_reference)
80+ ) . catch_unwind( ) . await ;
81+ <#context_type as test_context:: AsyncTestContext >:: teardown( ctx) . await ;
82+ }
83+ } ;
84+
85+ let handle_wrapped_result = handle_result ( result_name) ;
86+
6087 quote ! {
6188 {
6289 use test_context:: futures:: FutureExt ;
63- let mut ctx = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
64- let wrapped_ctx = & mut ctx;
65- let result = async move {
66- std:: panic:: AssertUnwindSafe (
67- #wrapped_name( wrapped_ctx)
68- ) . catch_unwind( ) . await
69- } . await ;
70- <#context_type as test_context:: AsyncTestContext >:: teardown( ctx) . await ;
71- match result {
72- Ok ( returned_value) => returned_value,
73- Err ( err) => {
74- std:: panic:: resume_unwind( err) ;
75- }
76- }
90+ #body
91+ #handle_wrapped_result
7792 }
7893 }
7994}
8095
81- fn sync_wrapper_body ( context_type : Ident , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
82- quote ! {
83- {
96+ fn sync_wrapper_body ( args : TestContextArgs , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
97+ let context_type = args. context_type ;
98+ let result_name = format_ident ! ( "wrapped_result" ) ;
99+
100+ let body = if args. skip_teardown {
101+ quote ! {
102+ let ctx = <#context_type as test_context:: TestContext >:: setup( ) ;
103+ let #result_name = std:: panic:: catch_unwind( move || {
104+ #wrapped_name( ctx)
105+ } ) ;
106+ }
107+ } else {
108+ quote ! {
84109 let mut ctx = <#context_type as test_context:: TestContext >:: setup( ) ;
85- let mut wrapper = std:: panic:: AssertUnwindSafe ( & mut ctx) ;
86- let result = std:: panic:: catch_unwind( move || {
87- #wrapped_name( * wrapper )
110+ let mut pointer = std:: panic:: AssertUnwindSafe ( & mut ctx) ;
111+ let #result_name = std:: panic:: catch_unwind( move || {
112+ #wrapped_name( * pointer )
88113 } ) ;
89114 <#context_type as test_context:: TestContext >:: teardown( ctx) ;
90- match result {
91- Ok ( returned_value) => returned_value,
92- Err ( err) => {
93- std:: panic:: resume_unwind( err) ;
94- }
115+ }
116+ } ;
117+
118+ let handle_wrapped_result = handle_result ( result_name) ;
119+
120+ quote ! {
121+ {
122+ #body
123+ #handle_wrapped_result
124+ }
125+ }
126+ }
127+
128+ fn handle_result ( result_name : Ident ) -> proc_macro2:: TokenStream {
129+ quote ! {
130+ match #result_name {
131+ Ok ( value) => value,
132+ Err ( err) => {
133+ std:: panic:: resume_unwind( err) ;
95134 }
96135 }
97136 }
0 commit comments