Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ fn test_addition_commutative(tc: TestCase) {
This test will fail! Integer addition panics on overflow. Hegel will produce a minimal failing test case for us:

```
Draw 1: 1
Draw 2: 2147483647
let x = 1;
let y = 2147483647;
thread 'test_addition_commutative' (2) panicked at examples/readme.rs:8:16:
attempt to add with overflow
```
Expand Down
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
RELEASE_TYPE: patch

This release should significantly improve the format and quality of output
printing for failing test cases.
2 changes: 1 addition & 1 deletion hegel-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0", features = ["full"] }
syn = { version = "2.0", features = ["full", "visit-mut"] }
48 changes: 48 additions & 0 deletions hegel-macros/src/explicit_test_case.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use proc_macro2::TokenStream;
use syn::ItemFn;

/// Check if an attribute's path is `hegel::test`.
fn has_hegel_test_attr(func: &ItemFn) -> bool {
func.attrs.iter().any(|attr| {
let segments: Vec<_> = attr.path().segments.iter().collect();
segments.len() == 2 && segments[0].ident == "hegel" && segments[1].ident == "test"
})
}

/// This macro always produces a compile error when it actually runs.
///
/// In correct usage (`#[hegel::test]` above, `#[hegel::explicit_test_case]` below),
/// `#[hegel::test]` processes first and consumes the explicit_test_case attributes
/// directly from `func.attrs`, so this macro never executes.
///
/// If this macro DOES execute, it means either:
/// - Wrong order: `#[hegel::explicit_test_case]` is above `#[hegel::test]`
/// - Bare function: no `#[hegel::test]` at all
pub fn expand_explicit_test_case(_attr: TokenStream, item: TokenStream) -> TokenStream {
let func: ItemFn = match syn::parse2(item) {
Ok(f) => f,
Err(e) => return e.to_compile_error(),
};

if has_hegel_test_attr(&func) {
// #[hegel::test] is below us, meaning we're in the wrong order.
// (If it were above us, it would have consumed our attribute before we ran.)
syn::Error::new_spanned(
&func.sig,
"#[hegel::explicit_test_case] must appear below #[hegel::test], not above it.\n\
Write:\n \
#[hegel::test]\n \
#[hegel::explicit_test_case(...)]\n \
fn my_test(tc: hegel::TestCase) { ... }",
)
.to_compile_error()
} else {
// No #[hegel::test] at all.
syn::Error::new_spanned(
&func.sig,
"#[hegel::explicit_test_case] can only be used together with #[hegel::test].\n\
Add #[hegel::test] above #[hegel::explicit_test_case].",
)
.to_compile_error()
}
}
292 changes: 289 additions & 3 deletions hegel-macros/src/hegel_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashMap;

use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, FnArg, Ident, ItemFn, Token};
use syn::visit_mut::VisitMut;
use syn::{Expr, FnArg, Ident, ItemFn, Pat, Token};

/// A single named argument in a `#[hegel::test(...)]` expression.
struct SettingArg {
Expand Down Expand Up @@ -59,6 +62,219 @@ impl Parse for TestArgs {
}
}

/// Extract a simple identifier from a pattern, handling type annotations.
fn extract_ident_from_pat(pat: &Pat) -> Option<String> {
match pat {
Pat::Ident(pat_ident) => Some(pat_ident.ident.to_string()),
Pat::Type(pat_type) => extract_ident_from_pat(&pat_type.pat),
_ => None,
}
}

/// Check if a `let` binding is of the form `let <ident> = <tc_ident>.draw(<one_arg>)`.
fn is_tc_draw_binding(node: &syn::Local, tc_ident: &str) -> Option<String> {
let var_name = extract_ident_from_pat(&node.pat)?;

let init = node.init.as_ref()?;
let method_call = match &*init.expr {
Expr::MethodCall(mc) => mc,
_ => return None,
};

if method_call.method != "draw" || method_call.args.len() != 1 {
return None;
}

let is_tc = match &*method_call.receiver {
Expr::Path(path) => path.path.is_ident(tc_ident),
_ => false,
};
if !is_tc {
return None;
}

Some(var_name)
}

/// Pass 1: Collect all draw variable names and determine per-name repeatable flags.
///
/// If any use of a name appears in a repeatable context (nested block, closure),
/// ALL uses of that name become repeatable. This ensures the runtime never sees
/// inconsistent repeatable flags for the same name.
struct DrawNameCollector {
tc_ident: String,
repeatable_depth: usize,
name_flags: HashMap<String, bool>,
}

impl VisitMut for DrawNameCollector {
fn visit_block_mut(&mut self, node: &mut syn::Block) {
self.repeatable_depth += 1;
syn::visit_mut::visit_block_mut(self, node);
self.repeatable_depth -= 1;
}

fn visit_expr_closure_mut(&mut self, node: &mut syn::ExprClosure) {
self.repeatable_depth += 1;
syn::visit_mut::visit_expr_closure_mut(self, node);
self.repeatable_depth -= 1;
}

fn visit_item_fn_mut(&mut self, _node: &mut syn::ItemFn) {}

fn visit_local_mut(&mut self, node: &mut syn::Local) {
syn::visit_mut::visit_local_mut(self, node);

if let Some(var_name) = is_tc_draw_binding(node, &self.tc_ident) {
let repeatable = self.repeatable_depth > 0;
let entry = self.name_flags.entry(var_name).or_insert(false);
if repeatable {
*entry = true;
}
}
}
}

/// Pass 2: Rewrite `let x = tc.draw(gen)` to `let x = tc.draw_named(gen, "x", repeatable)`.
///
/// Uses the pre-computed name_flags from DrawNameCollector so that every use of
/// a given name gets the same repeatable flag.
struct DrawRewriter {
tc_ident: String,
name_flags: HashMap<String, bool>,
}

impl VisitMut for DrawRewriter {
fn visit_item_fn_mut(&mut self, _node: &mut syn::ItemFn) {}

fn visit_local_mut(&mut self, node: &mut syn::Local) {
syn::visit_mut::visit_local_mut(self, node);

let var_name = match is_tc_draw_binding(node, &self.tc_ident) {
Some(name) => name,
None => return,
};

let repeatable = self.name_flags.get(&var_name).copied().unwrap_or(false);

let init = node.init.as_mut().unwrap();
let method_call = match &mut *init.expr {
Expr::MethodCall(mc) => mc,
_ => unreachable!(),
};

let span = method_call.method.span();
method_call.method = Ident::new("draw_named", span);
method_call.args.push(Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Str(syn::LitStr::new(&var_name, span)),
}));
method_call.args.push(Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Bool(syn::LitBool::new(repeatable, span)),
}));
}
}

/// A parsed explicit test case: a list of (name, expression_source) pairs.
struct ParsedExplicitTestCase {
entries: Vec<(String, String)>, // (name, expr_source)
}

/// Check if an attribute path matches `hegel::explicit_test_case`.
fn is_explicit_test_case_attr(attr: &syn::Attribute) -> bool {
let segments: Vec<_> = attr.path().segments.iter().collect();
segments.len() == 2 && segments[0].ident == "hegel" && segments[1].ident == "explicit_test_case"
}

/// Extract `#[hegel::explicit_test_case(...)]` attributes directly from `func.attrs`.
/// Returns the parsed test cases and removes the attributes from the list.
/// Returns `Err` with a compile error if any attribute is malformed.
fn extract_explicit_test_cases(
attrs: &mut Vec<syn::Attribute>,
) -> Result<Vec<ParsedExplicitTestCase>, TokenStream> {
let mut cases = Vec::new();
let mut error = None;
attrs.retain(|attr| {
if !is_explicit_test_case_attr(attr) {
return true;
}

let syn::Meta::List(list) = &attr.meta else {
error = Some(
syn::Error::new_spanned(
attr,
"#[hegel::explicit_test_case] requires arguments.\n\
Usage: #[hegel::explicit_test_case(name = value, ...)]",
)
.to_compile_error(),
);
return false;
};

let parsed: syn::Result<ExplicitTestCaseAttrArgs> = syn::parse2(list.tokens.clone());
match parsed {
Ok(args) if args.entries.is_empty() => {
error = Some(
syn::Error::new_spanned(
attr,
"#[hegel::explicit_test_case] requires at least one name = value pair.\n\
Usage: #[hegel::explicit_test_case(name = value, ...)]",
)
.to_compile_error(),
);
}
Ok(args) => {
let entries = args
.entries
.iter()
.map(|arg| {
let name = arg.name.to_string();
let expr = &arg.value;
let expr_source = quote::quote!(#expr).to_string();
(name, expr_source)
})
.collect();
cases.push(ParsedExplicitTestCase { entries });
}
Err(e) => {
error = Some(e.to_compile_error());
}
}
false // remove this attr
});
if let Some(err) = error {
return Err(err);
}
Ok(cases)
}

/// Parsed arguments for a single `#[hegel::explicit_test_case(name = expr, ...)]`.
struct ExplicitTestCaseAttrArgs {
entries: Vec<ExplicitTestCaseEntry>,
}

struct ExplicitTestCaseEntry {
name: Ident,
value: Expr,
}

impl syn::parse::Parse for ExplicitTestCaseAttrArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut entries = Vec::new();
while !input.is_empty() {
let name: Ident = input.parse()?;
let _eq: Token![=] = input.parse()?;
let value: Expr = input.parse()?;
entries.push(ExplicitTestCaseEntry { name, value });
if !input.is_empty() {
let _comma: Token![,] = input.parse()?;
}
}
Ok(ExplicitTestCaseAttrArgs { entries })
}
}

pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> TokenStream {
let test_args: TestArgs = if attr.is_empty() {
TestArgs {
Expand All @@ -72,7 +288,7 @@ pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStrea
}
};

let func: ItemFn = match syn::parse2(item) {
let mut func: ItemFn = match syn::parse2(item) {
Ok(f) => f,
Err(e) => return e.to_compile_error(),
};
Expand Down Expand Up @@ -110,7 +326,47 @@ pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStrea
}
}

let body = &func.block;
// Extract #[hegel::explicit_test_case(...)] attributes (they haven't been
// processed yet because #[hegel::test] runs first as the outermost attribute).
let explicit_cases = match extract_explicit_test_cases(&mut func.attrs) {
Ok(cases) => cases,
Err(err) => return err,
};

// Rewrite `let x = tc.draw(gen)` -> `let x = tc.draw_named(gen, "x", repeatable)`
//
// Two-pass approach:
// 1. Collect all draw variable names and determine per-name repeatable flags.
// If any use of a name is in a nested block/closure, all uses are repeatable.
// 2. Rewrite draws using the computed flags.
//
// We visit the function body's statements directly (not the block itself) so that
// the outermost block doesn't count as a nesting level.
let body = {
let mut body = (*func.block).clone();
if let Some(tc_name) = extract_ident_from_pat(param_pat) {
// Pass 1: collect names
let mut collector = DrawNameCollector {
tc_ident: tc_name.clone(),
repeatable_depth: 0,
name_flags: HashMap::new(),
};
for stmt in &mut body.stmts {
collector.visit_stmt_mut(stmt);
}

// Pass 2: rewrite
let mut rewriter = DrawRewriter {
tc_ident: tc_name,
name_flags: collector.name_flags,
};
for stmt in &mut body.stmts {
rewriter.visit_stmt_mut(stmt);
}
}
body
};

let test_name = func.sig.ident.to_string();

let settings_args_chain: Vec<TokenStream> = test_args
Expand All @@ -129,8 +385,38 @@ pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStrea
None => quote! { hegel::Settings::new() #(#settings_args_chain)* },
};

// Generate explicit test case blocks (run before the property test).
let explicit_blocks: Vec<TokenStream> = explicit_cases
.iter()
.map(|case| {
let with_value_calls: Vec<TokenStream> = case
.entries
.iter()
.map(|(name, expr_source)| {
let expr: syn::Expr = syn::parse_str(expr_source).unwrap_or_else(|e| {
panic!("Failed to parse explicit_test_case expression: {}", e)
});
let source_lit = syn::LitStr::new(expr_source, proc_macro2::Span::call_site());
quote! {
.with_value(#name, #source_lit, #expr)
}
})
.collect();

quote! {
{
let __hegel_etc = hegel::ExplicitTestCase::new()
#(#with_value_calls)*;
__hegel_etc.run(|#param_pat: &hegel::ExplicitTestCase| #body);
}
}
})
.collect();

let new_body: TokenStream = quote! {
{
#(#explicit_blocks)*

hegel::Hegel::new(|#param_pat: #param_ty| #body)
.settings(#settings_expr)
.__database_key(format!("{}::{}", module_path!(), #test_name))
Expand Down
Loading