Skip to content

Commit 9227f3e

Browse files
killpop3770Vyacheslav Volkov
andauthored
Added the ability to work with rstest (#51)
* [FIX-CONFLICT-WITH-RSTEST-V2]: semi-stable version, fix test_context function, remove unnecessary field, remove wrapper function, add ability to work with rstest and same * [FIX-CONFLICT-WITH-RSTEST-V2]: stable version, fix bugs with AssertUnwindSafe in sync_wrapper_body, remove redundant code, fix two tests, add notes for this tests * [FIX-CONFLICT-WITH-RSTEST-V2]: stable version, fix context name extraction function, fix readme, fix docs, increase version of crate --------- Co-authored-by: Vyacheslav Volkov <[email protected]>
1 parent 5ce361b commit 9227f3e

File tree

6 files changed

+208
-56
lines changed

6 files changed

+208
-56
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ members = ["test-context", "test-context-macros"]
44

55
[workspace.package]
66
edition = "2021"
7-
version = "0.4.0"
7+
version = "0.5.0"
88
rust-version = "1.75.0"
99
homepage = "https://github.com/JasterV/test-context"
1010
repository = "https://github.com/JasterV/test-context"

README.md

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ enable the optional `tokio-runtime` feature so those steps run inside a Tokio ru
118118

119119
```toml
120120
[dependencies]
121-
test-context = { version = "0.4", features = ["tokio-runtime"] }
121+
test-context = { version = "0.5", features = ["tokio-runtime"] }
122122
```
123123

124124
With this feature, the crate tries to reuse an existing runtime; if none is present, it creates
@@ -127,7 +127,7 @@ tests annotated with `#[tokio::test]` continue to work as usual without the feat
127127

128128
## Skipping the teardown execution
129129

130-
If what you need is to take full **ownership** of the context and don't care about the
130+
Also, if you don't care about the
131131
teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
132132
like this:
133133

@@ -144,9 +144,73 @@ like this:
144144

145145
#[test_context(MyContext, skip_teardown)]
146146
#[test]
147-
fn test_without_teardown(ctx: MyContext) {
147+
fn test_without_teardown(ctx: &mut MyContext) {
148148
// Perform any operations that require full ownership of your context
149149
}
150150
```
151151

152+
## ⚠️ Ensure that the context type specified in the macro matches the test function argument type exactly
153+
154+
The error occurs when a context type with an absolute path is mixed with an it's alias.
155+
156+
For example:
157+
158+
```
159+
mod database {
160+
use test_context::TestContext;
161+
162+
pub struct Connection;
163+
164+
impl TestContext for :Connection {
165+
fn setup() -> Self {Connection}
166+
fn teardown(self) {...}
167+
}
168+
}
169+
```
170+
171+
✅The following code will work:
172+
```
173+
use database::Connection as DbConn;
174+
175+
#[test_context(DbConn)]
176+
#[test]
177+
fn test1(ctx: &mut DbConn) {
178+
//some test logic
179+
}
180+
181+
// or
182+
183+
use database::Connection
184+
185+
#[test_context(database::Connection)]
186+
#[test]
187+
fn test1(ctx: &mut database::Connection) {
188+
//some test logic
189+
}
190+
```
191+
192+
❌The following code will not work:
193+
```
194+
use database::Connection as DbConn;
195+
196+
#[test_context(database::Connection)]
197+
#[test]
198+
fn test1(ctx: &mut DbConn) {
199+
//some test logic
200+
}
201+
202+
// or
203+
204+
use database::Connection as DbConn;
205+
206+
#[test_context(DbConn)]
207+
#[test]
208+
fn test1(ctx: &mut database::Connection) {
209+
//some test logic
210+
}
211+
```
212+
213+
Type mismatches will cause context parsing to fail during either static analysis or compilation.
214+
215+
152216
License: MIT

test-context-macros/src/lib.rs

Lines changed: 76 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ mod args;
33
use args::TestContextArgs;
44
use proc_macro::TokenStream;
55
use quote::{format_ident, quote};
6-
use syn::Ident;
6+
use syn::{Block, Ident};
77

88
/// Macro to use on tests to add the setup/teardown functionality of your context.
99
///
@@ -30,55 +30,44 @@ pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
3030
let args = syn::parse_macro_input!(attr as TestContextArgs);
3131

3232
let input = syn::parse_macro_input!(item as syn::ItemFn);
33-
let ret = &input.sig.output;
34-
let name = &input.sig.ident;
35-
let arguments = &input.sig.inputs;
36-
let body = &input.block;
37-
let attrs = &input.attrs;
3833
let is_async = input.sig.asyncness.is_some();
3934

40-
let wrapped_name = format_ident!("__test_context_wrapped_{}", name);
35+
let (new_input, context_arg_name) =
36+
extract_and_remove_context_arg(input.clone(), args.context_type.clone());
4137

4238
let wrapper_body = if is_async {
43-
async_wrapper_body(args, &wrapped_name)
39+
async_wrapper_body(args, &context_arg_name, &input.block)
4440
} else {
45-
sync_wrapper_body(args, &wrapped_name)
41+
sync_wrapper_body(args, &context_arg_name, &input.block)
4642
};
4743

48-
let async_tag = if is_async {
49-
quote! { async }
50-
} else {
51-
quote! {}
52-
};
53-
54-
quote! {
55-
#(#attrs)*
56-
#async_tag fn #name() #ret #wrapper_body
44+
let mut result_input = new_input;
45+
result_input.block = Box::new(syn::parse2(wrapper_body).unwrap());
5746

58-
#async_tag fn #wrapped_name(#arguments) #ret #body
59-
}
60-
.into()
47+
quote! { #result_input }.into()
6148
}
6249

63-
fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
50+
fn async_wrapper_body(
51+
args: TestContextArgs,
52+
context_arg_name: &Option<syn::Ident>,
53+
body: &Block,
54+
) -> proc_macro2::TokenStream {
6455
let context_type = args.context_type;
6556
let result_name = format_ident!("wrapped_result");
6657

58+
let binding = format_ident!("test_ctx");
59+
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
60+
6761
let body = if args.skip_teardown {
6862
quote! {
69-
let ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
70-
let #result_name = std::panic::AssertUnwindSafe(
71-
#wrapped_name(ctx)
72-
).catch_unwind().await;
63+
let #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
64+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
7365
}
7466
} else {
7567
quote! {
76-
let mut ctx = <#context_type as test_context::AsyncTestContext>::setup().await;
77-
let ctx_reference = &mut ctx;
78-
let #result_name = std::panic::AssertUnwindSafe(
79-
#wrapped_name(ctx_reference)
80-
).catch_unwind().await;
81-
<#context_type as test_context::AsyncTestContext>::teardown(ctx).await;
68+
let mut #context_name = <#context_type as test_context::AsyncTestContext>::setup().await;
69+
let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
70+
<#context_type as test_context::AsyncTestContext>::teardown(#context_name).await;
8271
}
8372
};
8473

@@ -93,25 +82,32 @@ fn async_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro
9382
}
9483
}
9584

96-
fn sync_wrapper_body(args: TestContextArgs, wrapped_name: &Ident) -> proc_macro2::TokenStream {
85+
fn sync_wrapper_body(
86+
args: TestContextArgs,
87+
context_arg_name: &Option<syn::Ident>,
88+
body: &Block,
89+
) -> proc_macro2::TokenStream {
9790
let context_type = args.context_type;
9891
let result_name = format_ident!("wrapped_result");
9992

93+
let binding = format_ident!("test_ctx");
94+
let context_name = context_arg_name.as_ref().unwrap_or(&binding);
95+
10096
let body = if args.skip_teardown {
10197
quote! {
102-
let ctx = <#context_type as test_context::TestContext>::setup();
103-
let #result_name = std::panic::catch_unwind(move || {
104-
#wrapped_name(ctx)
105-
});
98+
let mut #context_name = <#context_type as test_context::TestContext>::setup();
99+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
100+
let #context_name = &mut #context_name;
101+
#body
102+
}));
106103
}
107104
} else {
108105
quote! {
109-
let mut ctx = <#context_type as test_context::TestContext>::setup();
110-
let mut pointer = std::panic::AssertUnwindSafe(&mut ctx);
111-
let #result_name = std::panic::catch_unwind(move || {
112-
#wrapped_name(*pointer)
113-
});
114-
<#context_type as test_context::TestContext>::teardown(ctx);
106+
let mut #context_name = <#context_type as test_context::TestContext>::setup();
107+
let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
108+
#body
109+
}));
110+
<#context_type as test_context::TestContext>::teardown(#context_name);
115111
}
116112
};
117113

@@ -135,3 +131,40 @@ fn handle_result(result_name: Ident) -> proc_macro2::TokenStream {
135131
}
136132
}
137133
}
134+
135+
fn extract_and_remove_context_arg(
136+
mut input: syn::ItemFn,
137+
expected_context_type: syn::Type,
138+
) -> (syn::ItemFn, Option<syn::Ident>) {
139+
let mut context_arg_name = None;
140+
let mut new_args = syn::punctuated::Punctuated::new();
141+
142+
for arg in &input.sig.inputs {
143+
// Extract function arg:
144+
if let syn::FnArg::Typed(pat_type) = arg {
145+
// Extract arg identifier:
146+
if let syn::Pat::Ident(pat_ident) = &*pat_type.pat {
147+
// Check that context arg is only ref or mutable ref:
148+
if let syn::Type::Reference(type_ref) = &*pat_type.ty {
149+
// Check that context has expected type:
150+
if types_equal(&type_ref.elem, &expected_context_type) {
151+
context_arg_name = Some(pat_ident.ident.clone());
152+
continue;
153+
}
154+
}
155+
}
156+
}
157+
new_args.push(arg.clone());
158+
}
159+
160+
input.sig.inputs = new_args;
161+
(input, context_arg_name)
162+
}
163+
164+
fn types_equal(a: &syn::Type, b: &syn::Type) -> bool {
165+
if let (syn::Type::Path(a_path), syn::Type::Path(b_path)) = (a, b) {
166+
return a_path.path.segments.last().unwrap().ident
167+
== b_path.path.segments.last().unwrap().ident;
168+
}
169+
quote!(#a).to_string() == quote!(#b).to_string()
170+
}

test-context/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ authors.workspace = true
1313
license.workspace = true
1414

1515
[dependencies]
16-
test-context-macros = { version = "0.4.0", path = "../test-context-macros/" }
16+
test-context-macros = { version = "0.5.0", path = "../test-context-macros/" }
1717
futures = "0.3"
1818

1919
[dev-dependencies]
20+
rstest = "0.26.1"
2021
tokio = { version = "1.0", features = ["macros", "rt"] }

test-context/src/lib.rs

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@
103103
//!
104104
//! # Skipping the teardown execution
105105
//!
106-
//! If what you need is to take full __ownership__ of the context and don't care about the
107-
//! teardown execution for a specific test, you can use the `skip_teardown` keyword on the macro
108-
//! like this:
106+
//! Also, if you don't care about the teardown execution for a specific test,
107+
//! you can use the `skip_teardown` keyword on the macro like this:
109108
//!
110109
//! ```no_run
111110
//! use test_context::{test_context, TestContext};
@@ -120,7 +119,7 @@
120119
//!
121120
//! #[test_context(MyContext, skip_teardown)]
122121
//! #[test]
123-
//! fn test_without_teardown(ctx: MyContext) {
122+
//! fn test_without_teardown(ctx: &mut MyContext) {
124123
//! // Perform any operations that require full ownership of your context
125124
//! }
126125
//! ```

test-context/tests/test.rs

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use std::marker::PhantomData;
22

3+
use rstest::rstest;
34
use test_context::{test_context, AsyncTestContext, TestContext};
45

56
struct Context {
@@ -165,11 +166,11 @@ impl AsyncTestContext for TeardownPanicContext {
165166

166167
#[test_context(TeardownPanicContext, skip_teardown)]
167168
#[tokio::test]
168-
async fn test_async_skip_teardown(mut _ctx: TeardownPanicContext) {}
169+
async fn test_async_skip_teardown(_ctx: &mut TeardownPanicContext) {}
169170

170171
#[test_context(TeardownPanicContext, skip_teardown)]
171172
#[test]
172-
fn test_sync_skip_teardown(mut _ctx: TeardownPanicContext) {}
173+
fn test_sync_skip_teardown(_ctx: &mut TeardownPanicContext) {}
173174

174175
struct GenericContext<T> {
175176
contents: T,
@@ -209,6 +210,60 @@ fn test_generic_with_string(ctx: &mut GenericContext<String>) {
209210

210211
#[test_context(GenericContext<u64>)]
211212
#[tokio::test]
212-
async fn test_async_generic(ctx: &mut GenericContext<u64>) {
213-
assert_eq!(ctx.contents, 1);
213+
async fn test_async_generic(test_ctx: &mut GenericContext<u64>) {
214+
assert_eq!(test_ctx.contents, 1);
215+
}
216+
217+
struct MyAsyncContext {
218+
what_the_of_life: u32,
219+
}
220+
221+
impl AsyncTestContext for MyAsyncContext {
222+
async fn setup() -> Self {
223+
println!("I guess...");
224+
MyAsyncContext {
225+
what_the_of_life: 42,
226+
}
227+
}
228+
229+
async fn teardown(self) {
230+
println!("Answer is {}", self.what_the_of_life);
231+
drop(self);
232+
}
233+
}
234+
235+
#[test_context(MyAsyncContext)]
236+
#[rstest]
237+
#[case("Hello, World!")]
238+
#[tokio::test]
239+
async fn test_async_generic_with_sync(#[case] value: String, test_ctx: &mut MyAsyncContext) {
240+
println!("Something happens sync... {}", value);
241+
assert_eq!(test_ctx.what_the_of_life, 42);
242+
}
243+
244+
struct MyContext {
245+
what_the_of_life: u32,
246+
}
247+
248+
impl TestContext for MyContext {
249+
fn setup() -> Self {
250+
println!("I guess...");
251+
MyContext {
252+
what_the_of_life: 42,
253+
}
254+
}
255+
256+
fn teardown(self) {
257+
println!("Answer is {}", self.what_the_of_life);
258+
drop(self);
259+
}
260+
}
261+
262+
#[test_context(MyContext)]
263+
#[rstest]
264+
#[case("Hello, World!")]
265+
#[test]
266+
fn test_async_generic_with_async(test_ctx: &mut MyContext, #[case] value: String) {
267+
println!("Something happens async... {}", value);
268+
assert_eq!(test_ctx.what_the_of_life, 42);
214269
}

0 commit comments

Comments
 (0)