Skip to content

Commit 504b778

Browse files
authored
Fix segfault when passing friendly tangents to "non-friendly" forward cache in 1.10 and formatting (#983)
1 parent a18d35c commit 504b778

File tree

3 files changed

+34
-14
lines changed

3 files changed

+34
-14
lines changed

src/interface.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -690,8 +690,10 @@ Returns a `Dual` containing the result of applying forward-mode AD to compute th
690690
derivative of `primal(f)` at the primal values in `x` in the direction of the tangent values
691691
in `f` and `x`.
692692
"""
693-
value_and_derivative!!(cache::ForwardCache, fx::Vararg{Dual,N}) where {N} =
694-
cache.rule(fx...) # TODO: handle friendly tangents for the output here?
693+
function value_and_derivative!!(cache::ForwardCache, fx::Vararg{Dual,N}) where {N}
694+
# TODO: check Dual coherence here like we do below?
695+
return cache.rule(fx...)
696+
end
695697

696698
"""
697699
value_and_derivative!!(cache::ForwardCache, (f, df), (x, dx), ...)
@@ -707,7 +709,7 @@ Tuples are used as inputs and outputs instead of `Dual` numbers to accommodate t
707709
`cache` owns any mutable state returned by this function, meaning that mutable components of values returned by it will be mutated if you run this function again with different arguments. Therefore, if you need to keep the values returned by this function around over multiple calls to this function with the same `cache`, you should take a copy (using `copy` or `deepcopy`) of them before calling again.
708710
"""
709711
function value_and_derivative!!(
710-
cache::ForwardCache, f::NTuple{2,Any}, x::Vararg{<:NTuple{2,Any},N}
712+
cache::ForwardCache, f::NTuple{2,Any}, x::Vararg{NTuple{2,Any},N}
711713
) where {N}
712714
fx = (f, x...) # to avoid method ambiguity
713715
friendly_tangents = !isnothing(cache.input_tangents)
@@ -725,6 +727,11 @@ function value_and_derivative!!(
725727
end
726728

727729
input_duals = map(Dual, input_primals, input_tangents)
730+
731+
if !friendly_tangents # in friendly mode, conversion should ensure tangent coherence
732+
error_if_incorrect_dual_types(input_duals...)
733+
end
734+
728735
output = cache.rule(input_duals...)
729736
output_primal = primal(output)
730737
output_tangent = tangent(output)

src/tangents/dual.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ Check that the type of `tangent(x)` is the tangent type of the type of `primal(x
4949
"""
5050
verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x))
5151

52+
function error_if_incorrect_dual_types(duals::Vararg{Dual,N}) where {N}
53+
correct_types = map(verify_dual_type, duals)
54+
if !all(correct_types)
55+
primals = map(primal, duals)
56+
tangents = map(tangent, duals)
57+
throw(ArgumentError("""
58+
Tangent types do not match primal types:
59+
- primal types: $(map(typeof, primals))
60+
- provided tangent types: $(map(typeof, tangents))
61+
- required tangent types: $(map(tangent_type, map(typeof, primals)))
62+
"""))
63+
end
64+
end
65+
5266
@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x))
5367

5468
# Always sharpen the first thing if it's a type so static dispatch remains possible.

test/interface.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888

8989
cache = Mooncake.prepare_gradient_cache(f, x)
9090
v, dx = Mooncake.value_and_gradient!!(cache, f, x)
91-
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}}
91+
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}}
9292
@test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2))
9393

9494
cache = Mooncake.prepare_gradient_cache(
@@ -101,7 +101,7 @@ end
101101
rule = build_rrule(f, x)
102102

103103
v, dx = Mooncake.value_and_gradient!!(rule, f, x)
104-
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}}
104+
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}}
105105
@test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2))
106106

107107
v, dx = Mooncake.value_and_gradient!!(rule, f, x; friendly_tangents=true)
@@ -271,6 +271,7 @@ end
271271

272272
@testset "__exclude_unsupported_output , $(test_set)" for test_set in
273273
additional_test_set
274+
274275
try
275276
Mooncake.__exclude_unsupported_output(test_set[2])
276277
catch err
@@ -280,6 +281,7 @@ end
280281

281282
@testset "_copy_output & _copy_to_output!!, $(test_set)" for test_set in
282283
additional_test_set
284+
283285
original = test_set[2]
284286
try
285287
if isnothing(Mooncake.__exclude_unsupported_output(original))
@@ -351,15 +353,12 @@ end
351353
cache_sp_unfriendly = Mooncake.prepare_derivative_cache(
352354
fx_sp...; config=Mooncake.Config(; friendly_tangents=false, kwargs...)
353355
)
354-
if get(kwargs, :debug_mode, false)
355-
@test_throws ErrorException Mooncake.value_and_derivative!!(
356-
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
357-
)
358-
else
359-
@test_throws TypeError Mooncake.value_and_derivative!!(
360-
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
361-
)
362-
end
356+
@test_throws ArgumentError Mooncake.value_and_derivative!!(
357+
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
358+
)
359+
@test_throws "Tangent types do not match primal types:" Mooncake.value_and_derivative!!(
360+
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
361+
)
363362
end
364363
end
365364

0 commit comments

Comments
 (0)