@@ -3,7 +3,7 @@ mod args;
33use args:: TestContextArgs ;
44use proc_macro:: TokenStream ;
55use quote:: { format_ident, quote} ;
6- use syn:: Ident ;
6+ use syn:: { Block , Ident } ;
77
88/// Macro to use on tests to add the setup/teardown functionality of your context.
99///
@@ -30,55 +30,44 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3030 let args = syn:: parse_macro_input!( attr as TestContextArgs ) ;
3131
3232 let input = syn:: parse_macro_input!( item as syn:: ItemFn ) ;
33- let ret = & input. sig . output ;
34- let name = & input. sig . ident ;
35- let arguments = & input. sig . inputs ;
36- let body = & input. block ;
37- let attrs = & input. attrs ;
3833 let is_async = input. sig . asyncness . is_some ( ) ;
3934
40- let wrapped_name = format_ident ! ( "__test_context_wrapped_{}" , name) ;
35+ let ( new_input, context_arg_name) =
36+ extract_and_remove_context_arg ( input. clone ( ) , args. context_type . clone ( ) ) ;
4137
4238 let wrapper_body = if is_async {
43- async_wrapper_body ( args, & wrapped_name )
39+ async_wrapper_body ( args, & context_arg_name , & input . block )
4440 } else {
45- sync_wrapper_body ( args, & wrapped_name )
41+ sync_wrapper_body ( args, & context_arg_name , & input . block )
4642 } ;
4743
48- let async_tag = if is_async {
49- quote ! { async }
50- } else {
51- quote ! { }
52- } ;
53-
54- quote ! {
55- #( #attrs) *
56- #async_tag fn #name( ) #ret #wrapper_body
44+ let mut result_input = new_input;
45+ result_input. block = Box :: new ( syn:: parse2 ( wrapper_body) . unwrap ( ) ) ;
5746
58- #async_tag fn #wrapped_name( #arguments) #ret #body
59- }
60- . into ( )
47+ quote ! { #result_input } . into ( )
6148}
6249
63- fn async_wrapper_body ( args : TestContextArgs , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
50+ fn async_wrapper_body (
51+ args : TestContextArgs ,
52+ context_arg_name : & Option < syn:: Ident > ,
53+ body : & Block ,
54+ ) -> proc_macro2:: TokenStream {
6455 let context_type = args. context_type ;
6556 let result_name = format_ident ! ( "wrapped_result" ) ;
6657
58+ let binding = format_ident ! ( "test_ctx" ) ;
59+ let context_name = context_arg_name. as_ref ( ) . unwrap_or ( & binding) ;
60+
6761 let body = if args. skip_teardown {
6862 quote ! {
69- let ctx = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
70- let #result_name = std:: panic:: AssertUnwindSafe (
71- #wrapped_name( ctx)
72- ) . catch_unwind( ) . await ;
63+ let #context_name = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
64+ let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
7365 }
7466 } else {
7567 quote ! {
76- let mut ctx = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
77- let ctx_reference = & mut ctx;
78- let #result_name = std:: panic:: AssertUnwindSafe (
79- #wrapped_name( ctx_reference)
80- ) . catch_unwind( ) . await ;
81- <#context_type as test_context:: AsyncTestContext >:: teardown( ctx) . await ;
68+ let mut #context_name = <#context_type as test_context:: AsyncTestContext >:: setup( ) . await ;
69+ let #result_name = std:: panic:: AssertUnwindSafe ( async { #body } ) . catch_unwind( ) . await ;
70+ <#context_type as test_context:: AsyncTestContext >:: teardown( #context_name) . await ;
8271 }
8372 } ;
8473
@@ -93,25 +82,32 @@ fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro
9382 }
9483}
9584
96- fn sync_wrapper_body ( args : TestContextArgs , wrapped_name : & Ident ) -> proc_macro2:: TokenStream {
85+ fn sync_wrapper_body (
86+ args : TestContextArgs ,
87+ context_arg_name : & Option < syn:: Ident > ,
88+ body : & Block ,
89+ ) -> proc_macro2:: TokenStream {
9790 let context_type = args. context_type ;
9891 let result_name = format_ident ! ( "wrapped_result" ) ;
9992
93+ let binding = format_ident ! ( "test_ctx" ) ;
94+ let context_name = context_arg_name. as_ref ( ) . unwrap_or ( & binding) ;
95+
10096 let body = if args. skip_teardown {
10197 quote ! {
102- let ctx = <#context_type as test_context:: TestContext >:: setup( ) ;
103- let #result_name = std:: panic:: catch_unwind( move || {
104- #wrapped_name( ctx)
105- } ) ;
98+ let mut #context_name = <#context_type as test_context:: TestContext >:: setup( ) ;
99+ let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
100+ let #context_name = & mut #context_name;
101+ #body
102+ } ) ) ;
106103 }
107104 } else {
108105 quote ! {
109- let mut ctx = <#context_type as test_context:: TestContext >:: setup( ) ;
110- let mut pointer = std:: panic:: AssertUnwindSafe ( & mut ctx) ;
111- let #result_name = std:: panic:: catch_unwind( move || {
112- #wrapped_name( * pointer)
113- } ) ;
114- <#context_type as test_context:: TestContext >:: teardown( ctx) ;
106+ let mut #context_name = <#context_type as test_context:: TestContext >:: setup( ) ;
107+ let #result_name = std:: panic:: catch_unwind( std:: panic:: AssertUnwindSafe ( || {
108+ #body
109+ } ) ) ;
110+ <#context_type as test_context:: TestContext >:: teardown( #context_name) ;
115111 }
116112 } ;
117113
@@ -135,3 +131,40 @@ fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
135131 }
136132 }
137133}
134+
135+ fn extract_and_remove_context_arg (
136+ mut input : syn:: ItemFn ,
137+ expected_context_type : syn:: Type ,
138+ ) -> ( syn:: ItemFn , Option < syn:: Ident > ) {
139+ let mut context_arg_name = None ;
140+ let mut new_args = syn:: punctuated:: Punctuated :: new ( ) ;
141+
142+ for arg in & input. sig . inputs {
143+ // Extract function arg:
144+ if let syn:: FnArg :: Typed ( pat_type) = arg {
145+ // Extract arg identifier:
146+ if let syn:: Pat :: Ident ( pat_ident) = & * pat_type. pat {
147+ // Check that context arg is only ref or mutable ref:
148+ if let syn:: Type :: Reference ( type_ref) = & * pat_type. ty {
149+ // Check that context has expected type:
150+ if types_equal ( & type_ref. elem , & expected_context_type) {
151+ context_arg_name = Some ( pat_ident. ident . clone ( ) ) ;
152+ continue ;
153+ }
154+ }
155+ }
156+ }
157+ new_args. push ( arg. clone ( ) ) ;
158+ }
159+
160+ input. sig . inputs = new_args;
161+ ( input, context_arg_name)
162+ }
163+
164+ fn types_equal ( a : & syn:: Type , b : & syn:: Type ) -> bool {
165+ if let ( syn:: Type :: Path ( a_path) , syn:: Type :: Path ( b_path) ) = ( a, b) {
166+ return a_path. path . segments . last ( ) . unwrap ( ) . ident
167+ == b_path. path . segments . last ( ) . unwrap ( ) . ident ;
168+ }
169+ quote ! ( #a) . to_string ( ) == quote ! ( #b) . to_string ( )
170+ }
0 commit comments