Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ fn test_addition_commutative(tc: TestCase) {
This test will fail! Integer addition panics on overflow. Hegel will produce a minimal failing test case for us:

```
Draw 1: 1
Draw 2: 2147483647
let x = 1;
let y = 2147483647;
thread 'test_addition_commutative' (2) panicked at examples/readme.rs:8:16:
attempt to add with overflow
```
Expand Down
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
RELEASE_TYPE: patch

This release should significantly improve the format and quality of output
printing for failing test cases.
2 changes: 1 addition & 1 deletion hegel-macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ proc-macro = true
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2.0", features = ["full"] }
syn = { version = "2.0", features = ["full", "visit-mut"] }
154 changes: 152 additions & 2 deletions hegel-macros/src/hegel_test.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::collections::HashMap;

use proc_macro2::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::{Expr, FnArg, Ident, ItemFn, Token};
use syn::visit_mut::VisitMut;
use syn::{Expr, FnArg, Ident, ItemFn, Pat, Token};

/// A single named argument in a `#[hegel::test(...)]` expression.
struct SettingArg {
Expand Down Expand Up @@ -59,6 +62,120 @@ impl Parse for TestArgs {
}
}

/// Extract a simple identifier from a pattern, handling type annotations.
fn extract_ident_from_pat(pat: &Pat) -> Option<String> {
match pat {
Pat::Ident(pat_ident) => Some(pat_ident.ident.to_string()),
Pat::Type(pat_type) => extract_ident_from_pat(&pat_type.pat),
_ => None,
}
}

/// Check if a `let` binding is of the form `let <ident> = <tc_ident>.draw(<one_arg>)`.
fn is_tc_draw_binding(node: &syn::Local, tc_ident: &str) -> Option<String> {
let var_name = extract_ident_from_pat(&node.pat)?;

let init = node.init.as_ref()?;
let method_call = match &*init.expr {
Expr::MethodCall(mc) => mc,
_ => return None,
};

if method_call.method != "draw" || method_call.args.len() != 1 {
return None;
}

let is_tc = match &*method_call.receiver {
Expr::Path(path) => path.path.is_ident(tc_ident),
_ => false,
};
if !is_tc {
return None;
}

Some(var_name)
}

/// Pass 1: Collect all draw variable names and determine per-name repeatable flags.
///
/// If any use of a name appears in a repeatable context (nested block, closure),
/// ALL uses of that name become repeatable. This ensures the runtime never sees
/// inconsistent repeatable flags for the same name.
struct DrawNameCollector {
tc_ident: String,
repeatable_depth: usize,
name_flags: HashMap<String, bool>,
}

impl VisitMut for DrawNameCollector {
fn visit_block_mut(&mut self, node: &mut syn::Block) {
self.repeatable_depth += 1;
syn::visit_mut::visit_block_mut(self, node);
self.repeatable_depth -= 1;
}

fn visit_expr_closure_mut(&mut self, node: &mut syn::ExprClosure) {
self.repeatable_depth += 1;
syn::visit_mut::visit_expr_closure_mut(self, node);
self.repeatable_depth -= 1;
}

fn visit_item_fn_mut(&mut self, _node: &mut syn::ItemFn) {}

fn visit_local_mut(&mut self, node: &mut syn::Local) {
syn::visit_mut::visit_local_mut(self, node);

if let Some(var_name) = is_tc_draw_binding(node, &self.tc_ident) {
let repeatable = self.repeatable_depth > 0;
let entry = self.name_flags.entry(var_name).or_insert(false);
if repeatable {
*entry = true;
}
}
}
}

/// Pass 2: Rewrite `let x = tc.draw(gen)` to `let x = tc.draw_named(gen, "x", repeatable)`.
///
/// Uses the pre-computed name_flags from DrawNameCollector so that every use of
/// a given name gets the same repeatable flag.
struct DrawRewriter {
tc_ident: String,
name_flags: HashMap<String, bool>,
}

impl VisitMut for DrawRewriter {
fn visit_item_fn_mut(&mut self, _node: &mut syn::ItemFn) {}

fn visit_local_mut(&mut self, node: &mut syn::Local) {
syn::visit_mut::visit_local_mut(self, node);

let var_name = match is_tc_draw_binding(node, &self.tc_ident) {
Some(name) => name,
None => return,
};

let repeatable = self.name_flags.get(&var_name).copied().unwrap_or(false);

let init = node.init.as_mut().unwrap();
let method_call = match &mut *init.expr {
Expr::MethodCall(mc) => mc,
_ => unreachable!(),
};

let span = method_call.method.span();
method_call.method = Ident::new("draw_named", span);
method_call.args.push(Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Str(syn::LitStr::new(&var_name, span)),
}));
method_call.args.push(Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Bool(syn::LitBool::new(repeatable, span)),
}));
}
}

pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStream) -> TokenStream {
let test_args: TestArgs = if attr.is_empty() {
TestArgs {
Expand Down Expand Up @@ -110,7 +227,40 @@ pub fn expand_test(attr: proc_macro2::TokenStream, item: proc_macro2::TokenStrea
}
}

let body = &func.block;
// Rewrite `let x = tc.draw(gen)` -> `let x = tc.draw_named(gen, "x", repeatable)`
//
// Two-pass approach:
// 1. Collect all draw variable names and determine per-name repeatable flags.
// If any use of a name is in a nested block/closure, all uses are repeatable.
// 2. Rewrite draws using the computed flags.
//
// We visit the function body's statements directly (not the block itself) so that
// the outermost block doesn't count as a nesting level.
let body = {
let mut body = (*func.block).clone();
if let Some(tc_name) = extract_ident_from_pat(param_pat) {
// Pass 1: collect names
let mut collector = DrawNameCollector {
tc_ident: tc_name.clone(),
repeatable_depth: 0,
name_flags: HashMap::new(),
};
for stmt in &mut body.stmts {
collector.visit_stmt_mut(stmt);
}

// Pass 2: rewrite
let mut rewriter = DrawRewriter {
tc_ident: tc_name,
name_flags: collector.name_flags,
};
for stmt in &mut body.stmts {
rewriter.visit_stmt_mut(stmt);
}
}
body
};

let test_name = func.sig.ident.to_string();

let settings_args_chain: Vec<TokenStream> = test_args
Expand Down
85 changes: 75 additions & 10 deletions src/test_case.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use crate::protocol::{Channel, Connection, SERVER_CRASHED_MESSAGE};
use crate::runner::Verbosity;
use ciborium::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{Arc, LazyLock};

Expand Down Expand Up @@ -59,12 +60,13 @@ pub(crate) struct TestCaseGlobalData {
verbosity: Verbosity,
is_last_run: bool,
test_aborted: bool,
named_draw_counts: HashMap<String, usize>,
named_draw_repeatable: HashMap<String, bool>,
}

#[derive(Clone)]
pub(crate) struct TestCaseLocalData {
span_depth: usize,
draw_count: usize,
indent: usize,
on_draw: Rc<dyn Fn(&str)>,
}
Expand Down Expand Up @@ -125,10 +127,11 @@ impl TestCase {
verbosity,
is_last_run,
test_aborted: false,
named_draw_counts: HashMap::new(),
named_draw_repeatable: HashMap::new(),
})),
local: RefCell::new(TestCaseLocalData {
span_depth: 0,
draw_count: 0,
indent: 0,
on_draw,
}),
Expand All @@ -148,10 +151,37 @@ impl TestCase {
/// let s: String = tc.draw(generators::text());
/// }
/// ```
///
/// Note: when run inside a `#[hegel::test]`, `draw()` will typically be
/// rewritten to `draw_named()` with an appropriate variable name
/// in order to give better test output.
pub fn draw<T: std::fmt::Debug>(&self, generator: impl Generator<T>) -> T {
self.draw_named(generator, "unnamed", true)
}

/// Draw a value from a generator with a specific name for output.
///
/// When `repeatable` is true, a counter suffix is appended (e.g. `x_1`, `x_2`).
/// When `repeatable` is false, reusing the same name panics.
///
/// Using the same name with different values of `repeatable` is an error.
///
/// On the final replay of a failing test case, this prints:
/// - `let name = value;` (when not repeatable)
/// - `let name_N = value;` (when repeatable)
///
/// Note: although this is public API and you are welcome to use it,
/// it's not really intended for direct use. It is the target that
/// `#[hegel::test]` rewrites `draw()` calls to where appropriate.
pub fn draw_named<T: std::fmt::Debug>(
&self,
generator: impl Generator<T>,
name: &str,
repeatable: bool,
) -> T {
let value = generator.do_draw(self);
if self.local.borrow().span_depth == 0 {
self.record_draw(&value);
self.record_named_draw(&value, name, repeatable);
}
value
}
Expand Down Expand Up @@ -211,22 +241,57 @@ impl TestCase {
global: self.global.clone(),
local: RefCell::new(TestCaseLocalData {
span_depth: 0,
draw_count: 0,
indent: local.indent + extra_indent,
on_draw: local.on_draw.clone(),
}),
}
}

fn record_draw<T: std::fmt::Debug>(&self, value: &T) {
let mut local = self.local.borrow_mut();
local.draw_count += 1;
let count = local.draw_count;
fn record_named_draw<T: std::fmt::Debug>(&self, value: &T, name: &str, repeatable: bool) {
let mut global = self.global.borrow_mut();

match global.named_draw_repeatable.get(name) {
Some(&prev) if prev != repeatable => {
panic!(
"draw_named: name {:?} used with inconsistent repeatable flag (was {}, now {})",
name, prev, repeatable
);
}
_ => {
global
.named_draw_repeatable
.insert(name.to_string(), repeatable);
}
}

let count = global
.named_draw_counts
.entry(name.to_string())
.or_insert(0);
*count += 1;
let current_count = *count;
drop(global);

if !repeatable && current_count > 1 {
panic!(
"draw_named: name {:?} used more than once but repeatable is false",
name
);
}

let local = self.local.borrow();
let indent = local.indent;

let display_name = if repeatable {
format!("{}_{}", name, current_count)
} else {
name.to_string()
};

(local.on_draw)(&format!(
"{:indent$}Draw {}: {:?}",
"{:indent$}let {} = {:?};",
"",
count,
display_name,
value,
indent = indent
));
Expand Down
Loading