Skip to content
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
c0a44a0
checks for rule's state consistency
AstitvaAggarwal May 12, 2026
7494b30
use same rng for perturbations repetition
AstitvaAggarwal May 12, 2026
19d2c8a
seperate test function
AstitvaAggarwal May 12, 2026
bce1d4f
fix test_rrule_reuse: snapshot primal and use fdata for output cotangent
AstitvaAggarwal May 12, 2026
2a6503f
form
AstitvaAggarwal May 12, 2026
f00d7bf
fix test_rrule_reuse: skip deepcopy for NoFData output primals
AstitvaAggarwal May 12, 2026
c66a0da
tests for rule reuse, reorder reuse testset in test_rule
AstitvaAggarwal May 13, 2026
5b94de0
typo
AstitvaAggarwal May 13, 2026
2dccfa3
run Mooncake's skills for reviews
AstitvaAggarwal May 13, 2026
f31e723
seperate rng states for testsets
AstitvaAggarwal May 13, 2026
7ea8819
changes from reviews
AstitvaAggarwal May 15, 2026
953cb94
remove function
AstitvaAggarwal May 15, 2026
0d29772
Merge branch 'main' into Astitva/issue-681-cached-rules-test
AstitvaAggarwal May 15, 2026
402d0d0
Apply suggestions from code review
AstitvaAggarwal May 18, 2026
f0386fb
isolate rng in test_rrule_reuse so Correctness sees the same rng stat…
sunxd3 May 18, 2026
dad9342
Snapshot first-cycle observables in reuse helpers and drop negative t…
sunxd3 May 18, 2026
4eae271
Fix temporal divergence and Module-primal handling in reuse helpers
sunxd3 May 18, 2026
c42848a
Align reuse-helper variable names with codebase conventions
sunxd3 May 18, 2026
5724009
Merge branch 'main' into Astitva/issue-681-cached-rules-test
sunxd3 May 21, 2026
6f27e0a
Merge branch 'main' into Astitva/issue-681-cached-rules-test
sunxd3 May 21, 2026
10b5bfc
Merge branch 'main' into Astitva/issue-681-cached-rules-test
sunxd3 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,84 @@ _deepcopy(x::Module) = x

rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any}

function test_frule_reuse(x_ẋ...; frule)
@nospecialize x_ẋ
x_ẋ_a = map(_deepcopy, x_ẋ)
x_ẋ_b = map(_deepcopy, x_ẋ)

# Snapshot every observable at the same point in each cycle. Without snapshots,
# an aliased mutable buffer would let call B overwrite call A's data; snapshotting
# only one side would compare different temporal points if a rule mutates inputs.
# Skip the deepcopy when tangent is NoTangent: such primals (e.g. Module-containing
# types like Core.TypeName) can't safely be deepcopied and can't be mutated either.
y_ẏ_a = frule(x_ẋ_a...)
y_primal_a = tangent(y_ẏ_a) isa NoTangent ? primal(y_ẏ_a) : _deepcopy(primal(y_ẏ_a))
ẏ_a = _deepcopy(tangent(y_ẏ_a))
ẋ_a = map(_deepcopy ∘ tangent, x_ẋ_a)

y_ẏ_b = frule(x_ẋ_b...)
y_primal_b = tangent(y_ẏ_b) isa NoTangent ? primal(y_ẏ_b) : _deepcopy(primal(y_ẏ_b))
ẏ_b = _deepcopy(tangent(y_ẏ_b))
ẋ_b = map(_deepcopy ∘ tangent, x_ẋ_b)

@test has_equal_data(y_primal_a, y_primal_b)
@test has_equal_data(ẏ_a, ẏ_b)
@test all(map(has_equal_data, ẋ_a, ẋ_b))
end

function test_rrule_reuse(rng::AbstractRNG, x_x̄...; rrule, output_tangent=nothing)
@nospecialize rng x_x̄
x = map(primal, x_x̄)
x̄_zero_a = map(zero_tangent, x)
x_x̄_a = map(
(x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f),
x,
map(Mooncake.fdata, x̄_zero_a),
)
x̄_zero_b = map(zero_tangent, x)
x_x̄_b = map(
(x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f),
x,
map(Mooncake.fdata, x̄_zero_b),
)

# Snapshot every observable at the same point in each cycle: post-forward for
# output primal/fdata, post-pullback for inputs and pullback returns. Without
# snapshots, aliased mutable buffers can alias-pass; snapshotting only one side
# would compare different temporal points since the pullback restores in-place
# mutations on the way back.
# Skip the deepcopy when tangent (fdata) is NoFData: such primals (e.g. Module-
# containing types like Core.TypeName) can't safely be deepcopied and the pullback
# has no fdata path through which it could mutate them.
y_ȳ_a, pb_a!! = rrule(x_x̄_a...)
y_primal_a = tangent(y_ȳ_a) isa NoFData ? primal(y_ȳ_a) : _deepcopy(primal(y_ȳ_a))
y_fdata_a = _deepcopy(tangent(y_ȳ_a))
ȳ_delta = if isnothing(output_tangent)
randn_tangent(rng, primal(y_ȳ_a))
else
output_tangent
end
ȳ_a = increment!!(
set_to_zero!!(zero_tangent(primal(y_ȳ_a), tangent(y_ȳ_a))), _deepcopy(ȳ_delta)
)
x̄_rvs_a = _deepcopy(pb_a!!(Mooncake.rdata(ȳ_a)))
x̄_fwds_a = map(_deepcopy ∘ Mooncake.fdata, x̄_zero_a)

y_ȳ_b, pb_b!! = rrule(x_x̄_b...)
y_primal_b = tangent(y_ȳ_b) isa NoFData ? primal(y_ȳ_b) : _deepcopy(primal(y_ȳ_b))
y_fdata_b = _deepcopy(tangent(y_ȳ_b))
ȳ_b = increment!!(
set_to_zero!!(zero_tangent(primal(y_ȳ_b), tangent(y_ȳ_b))), _deepcopy(ȳ_delta)
)
x̄_rvs_b = _deepcopy(pb_b!!(Mooncake.rdata(ȳ_b)))
x̄_fwds_b = map(_deepcopy ∘ Mooncake.fdata, x̄_zero_b)

@test has_equal_data(y_primal_a, y_primal_b)
@test has_equal_data(y_fdata_a, y_fdata_b)
@test has_equal_data(x̄_rvs_a, x̄_rvs_b)
@test all(map(has_equal_data, x̄_fwds_a, x̄_fwds_b))
end

function test_frule_interface(x_ẋ...; frule)
@nospecialize x_ẋ

Expand Down Expand Up @@ -1066,6 +1144,18 @@ function test_rule(
redirector = print_results ? ((f, x) -> f()) : redirect_stdout
ts = redirector(devnull) do
@testset "$(typeof(x))" begin
# Verify rules give identical results on a second call,
# i.e. the rule does not corrupt its internal state across calls.
@testset "Reuse" begin
if test_fwd && !interface_only
test_frule_reuse(x_ẋ...; frule)
end
if test_rvs && !interface_only
# Isolated rng so Reuse doesn't perturb Correctness's rng state.
test_rrule_reuse(Xoshiro(123), x_x̄...; rrule, output_tangent)
end
end

# Test that the interface is basically satisfied (checks types / memory addresses).
@testset "Interface (1)" begin
test_fwd && test_frule_interface(x_ẋ...; frule)
Expand Down
Loading