diff --git a/src/test_utils.jl b/src/test_utils.jl index 9882f1c382..41c327224c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -691,6 +691,84 @@ _deepcopy(x::Module) = x rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any} +function test_frule_reuse(x_ẋ...; frule) + @nospecialize x_ẋ + x_ẋ_a = map(_deepcopy, x_ẋ) + x_ẋ_b = map(_deepcopy, x_ẋ) + + # Snapshot every observable at the same point in each cycle. Without snapshots, + # an aliased mutable buffer would let call B overwrite call A's data; snapshotting + # only one side would compare different temporal points if a rule mutates inputs. + # Skip the deepcopy when tangent is NoTangent: such primals (e.g. Module-containing + # types like Core.TypeName) can't safely be deepcopied and can't be mutated either. + y_ẏ_a = frule(x_ẋ_a...) + y_primal_a = tangent(y_ẏ_a) isa NoTangent ? primal(y_ẏ_a) : _deepcopy(primal(y_ẏ_a)) + ẏ_a = _deepcopy(tangent(y_ẏ_a)) + ẋ_a = map(_deepcopy ∘ tangent, x_ẋ_a) + + y_ẏ_b = frule(x_ẋ_b...) + y_primal_b = tangent(y_ẏ_b) isa NoTangent ? primal(y_ẏ_b) : _deepcopy(primal(y_ẏ_b)) + ẏ_b = _deepcopy(tangent(y_ẏ_b)) + ẋ_b = map(_deepcopy ∘ tangent, x_ẋ_b) + + @test has_equal_data(y_primal_a, y_primal_b) + @test has_equal_data(ẏ_a, ẏ_b) + @test all(map(has_equal_data, ẋ_a, ẋ_b)) +end + +function test_rrule_reuse(rng::AbstractRNG, x_x̄...; rrule, output_tangent=nothing) + @nospecialize rng x_x̄ + x = map(primal, x_x̄) + x̄_zero_a = map(zero_tangent, x) + x_x̄_a = map( + (x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), + x, + map(Mooncake.fdata, x̄_zero_a), + ) + x̄_zero_b = map(zero_tangent, x) + x_x̄_b = map( + (x, x̄_f) -> fcodual_type(_typeof(x))(_deepcopy(x), x̄_f), + x, + map(Mooncake.fdata, x̄_zero_b), + ) + + # Snapshot every observable at the same point in each cycle: post-forward for + # output primal/fdata, post-pullback for inputs and pullback returns. Without + # snapshots, aliased mutable buffers can alias-pass; snapshotting only one side + # would compare different temporal points since the pullback restores in-place + # mutations on the way back. + # Skip the deepcopy when tangent (fdata) is NoFData: such primals (e.g. Module- + # containing types like Core.TypeName) can't safely be deepcopied and the pullback + # has no fdata path through which it could mutate them. + y_ȳ_a, pb_a!! = rrule(x_x̄_a...) + y_primal_a = tangent(y_ȳ_a) isa NoFData ? primal(y_ȳ_a) : _deepcopy(primal(y_ȳ_a)) + y_fdata_a = _deepcopy(tangent(y_ȳ_a)) + ȳ_delta = if isnothing(output_tangent) + randn_tangent(rng, primal(y_ȳ_a)) + else + output_tangent + end + ȳ_a = increment!!( + set_to_zero!!(zero_tangent(primal(y_ȳ_a), tangent(y_ȳ_a))), _deepcopy(ȳ_delta) + ) + x̄_rvs_a = _deepcopy(pb_a!!(Mooncake.rdata(ȳ_a))) + x̄_fwds_a = map(_deepcopy ∘ Mooncake.fdata, x̄_zero_a) + + y_ȳ_b, pb_b!! = rrule(x_x̄_b...) + y_primal_b = tangent(y_ȳ_b) isa NoFData ? primal(y_ȳ_b) : _deepcopy(primal(y_ȳ_b)) + y_fdata_b = _deepcopy(tangent(y_ȳ_b)) + ȳ_b = increment!!( + set_to_zero!!(zero_tangent(primal(y_ȳ_b), tangent(y_ȳ_b))), _deepcopy(ȳ_delta) + ) + x̄_rvs_b = _deepcopy(pb_b!!(Mooncake.rdata(ȳ_b))) + x̄_fwds_b = map(_deepcopy ∘ Mooncake.fdata, x̄_zero_b) + + @test has_equal_data(y_primal_a, y_primal_b) + @test has_equal_data(y_fdata_a, y_fdata_b) + @test has_equal_data(x̄_rvs_a, x̄_rvs_b) + @test all(map(has_equal_data, x̄_fwds_a, x̄_fwds_b)) +end + function test_frule_interface(x_ẋ...; frule) @nospecialize x_ẋ @@ -1066,6 +1144,18 @@ function test_rule( redirector = print_results ? ((f, x) -> f()) : redirect_stdout ts = redirector(devnull) do @testset "$(typeof(x))" begin + # Verify rules give identical results on a second call, + # i.e. the rule does not corrupt its internal state across calls. + @testset "Reuse" begin + if test_fwd && !interface_only + test_frule_reuse(x_ẋ...; frule) + end + if test_rvs && !interface_only + # Isolated rng so Reuse doesn't perturb Correctness's rng state. + test_rrule_reuse(Xoshiro(123), x_x̄...; rrule, output_tangent) + end + end + # Test that the interface is basically satisfied (checks types / memory addresses). @testset "Interface (1)" begin test_fwd && test_frule_interface(x_ẋ...; frule)