Skip to content

rule reuse consistency checks#1172

Open
AstitvaAggarwal wants to merge 2 commits into
mainfrom
Astitva/issue-681-cached-rules-test
Open

rule reuse consistency checks#1172
AstitvaAggarwal wants to merge 2 commits into
mainfrom
Astitva/issue-681-cached-rules-test

Conversation

@AstitvaAggarwal
Copy link
Copy Markdown
Member

@AstitvaAggarwal AstitvaAggarwal commented May 12, 2026

closes #681 .

When a rule is compiled via build_frule/build_rrule, the resulting closure is intended to be reusable: you compile once and call many times. However, there was no test enforcing this. A rule whose closure captures mutable state that gets corrupted during a first call would pass all existing tests but silently return wrong derivatives on any subsequent use exactly the class of bug described in DI #678.

This PR adds a "Reuse" testset inside Mooncake.TestUtils.test_rule to catch this. After the existing "Correctness" testset has run a complete forward+backward pass, "Reuse" calls test_frule_correctness/test_rrule_correctness a second time with the same compiled frule/rrule closure objects. Both correctness functions deep-copy their inputs at entry, so the only shared state between the two testsets is the rule closure itself. If the rule corrupts any mutable state it captures during a first call, the second call will produce wrong derivatives and "Reuse" will fail.

RNG snapshotting - The correctness functions draw random probe tangents from a shared rng. Without snapshotting, "Reuse" would consume the rng state left over from "Correctness" and generate different probe tangents, causing spurious failures unrelated to rule mutation:

  • BFloat16 arithmetic has limited precision (machine epsilon ~3.9e-3), so only a narrow range of FD step sizes converges. A different tangent direction can cause all seven candidate step sizes to fail isapprox.

  • Functions with domain constraints such as logpdf(Dirichlet(...), x), which returns -Inf when x is not on the simplex produce NaN FD estimates if the probe tangent pushes the perturbed input outside the domain. Whether this happens depends on the specific tangent direction drawn.

To isolate failures to genuine rule-mutation bugs, rng_snapshot = deepcopy(rng) is taken before testset "Correctness" and passed to testset "Reuse", ensuring both testsets use identical probe tangents.

CI Summary — GitHub Actions

Documentation Preview

Mooncake.jl documentation for PR #1172 is available at:
https://chalk-lab.github.io/Mooncake.jl/previews/PR1172/

Performance

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌───────────────────────┬──────────┬──────────┬─────────────┬─────────┬─────────────┬────────┐
│                 Label │   Primal │ Mooncake │ MooncakeFwd │  Zygote │ ReverseDiff │ Enzyme │
│                String │   String │   String │      String │  String │      String │ String │
├───────────────────────┼──────────┼──────────┼─────────────┼─────────┼─────────────┼────────┤
│              sum_1000 │ 200.0 ns │     1.45 │        1.55 │    0.65 │         3.3 │   6.36 │
│             _sum_1000 │   1.1 μs │     6.09 │        1.02 │  3990.0 │        39.3 │   1.06 │
│          sum_sin_1000 │  7.43 μs │     2.48 │        1.12 │    1.67 │        10.9 │   1.79 │
│         _sum_sin_1000 │  4.71 μs │     3.83 │         2.6 │   382.0 │        17.3 │   3.02 │
│              kron_sum │ 201.0 μs │     13.5 │        3.24 │    7.56 │       510.0 │   19.6 │
│         kron_view_sum │ 277.0 μs │     12.6 │        5.06 │    27.2 │       476.0 │   13.8 │
│ naive_map_sin_cos_exp │  2.24 μs │     2.82 │         1.5 │ missing │        8.28 │    2.2 │
│       map_sin_cos_exp │  2.23 μs │     3.35 │        1.65 │    1.49 │        7.18 │   2.64 │
│ broadcast_sin_cos_exp │  2.28 μs │     3.02 │        1.54 │    4.29 │        1.41 │   2.09 │
│            simple_mlp │ 348.0 μs │     5.05 │        2.91 │    2.13 │        9.96 │   3.15 │
│                gp_lml │ 165.0 μs │     11.9 │        2.69 │    5.24 │     missing │   6.18 │
│    large_single_block │ 471.0 ns │     9.42 │        2.08 │  4740.0 │        31.7 │   2.08 │
└───────────────────────┴──────────┴──────────┴─────────────┴─────────┴─────────────┴────────┘

@AstitvaAggarwal AstitvaAggarwal marked this pull request as draft May 12, 2026 12:47
@AstitvaAggarwal AstitvaAggarwal marked this pull request as ready for review May 12, 2026 14:45
@AstitvaAggarwal AstitvaAggarwal requested review from sunxd3 and yebai May 12, 2026 14:58
Comment thread src/test_utils.jl
end

# Snapshot of RNG so "Reuse" testset uses identical probe tangents to "Correctness" testset.
rng_snapshot = deepcopy(rng)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we could eliminate the need for sharing random seeds.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Test suite extension: check that cached rules run correctly

2 participants