Skip to content

Commit a18d35c

Browse files
authored
Add caching and friendly tangents in forward mode (#980)
1 parent a3fa8e8 commit a18d35c

File tree

3 files changed

+146
-21
lines changed

3 files changed

+146
-21
lines changed

Project.toml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Mooncake"
22
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
33
authors = ["Will Tebbutt, Hong Ge, and contributors"]
4-
version = "0.5.0"
4+
version = "0.5.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -92,4 +92,15 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9292
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
9393

9494
[targets]
95-
test = ["AllocCheck", "Aqua", "BenchmarkTools", "DiffTests", "JET", "Logging", "Pkg", "Revise", "StableRNGs", "Test"]
95+
test = [
96+
"AllocCheck",
97+
"Aqua",
98+
"BenchmarkTools",
99+
"DiffTests",
100+
"JET",
101+
"Logging",
102+
"Pkg",
103+
"Revise",
104+
"StableRNGs",
105+
"Test",
106+
]

src/interface.jl

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -658,20 +658,82 @@ function value_and_gradient!!(
658658
end
659659
end
660660

661+
struct ForwardCache{R,IT<:Union{Nothing,Tuple},OP}
662+
rule::R
663+
input_tangents::IT
664+
output_primal::OP
665+
end
666+
661667
"""
662668
prepare_derivative_cache(fx...; config=Mooncake.Config())
663669
664670
Returns a cache used with [`value_and_derivative!!`](@ref). See that function for more info.
665671
"""
666-
@unstable function prepare_derivative_cache(fx...; config=Config())
667-
build_frule(fx...; config.debug_mode, config.silence_debug_messages)
672+
@unstable function prepare_derivative_cache(f, x::Vararg{Any,N}; config=Config()) where {N}
673+
fx = (f, x...)
674+
rule = build_frule(fx...; config.debug_mode, config.silence_debug_messages)
675+
676+
if config.friendly_tangents
677+
y = f(x...)
678+
input_tangents = map(zero_tangent, fx)
679+
output_primal = _copy_output(y)
680+
return ForwardCache(rule, input_tangents, output_primal)
681+
else
682+
return ForwardCache(rule, nothing, nothing)
683+
end
668684
end
669685

670686
"""
671-
value_and_derivative!!(rule::R, f::Dual, x::Vararg{Dual,N})
687+
value_and_derivative!!(cache::ForwardCache, f::Dual, x::Vararg{Dual,N})
672688
673689
Returns a `Dual` containing the result of applying forward-mode AD to compute the (Frechet)
674690
derivative of `primal(f)` at the primal values in `x` in the direction of the tangent values
675691
in `f` and `x`.
676692
"""
677-
value_and_derivative!!(rule::R, fx::Vararg{Dual,N}) where {R,N} = rule(fx...)
693+
value_and_derivative!!(cache::ForwardCache, fx::Vararg{Dual,N}) where {N} =
694+
cache.rule(fx...) # TODO: handle friendly tangents for the output here?
695+
696+
"""
697+
value_and_derivative!!(cache::ForwardCache, (f, df), (x, dx), ...)
698+
699+
Returns a tuple `(y, dy)` containing the result of applying forward-mode AD to compute the (Frechet) derivative of `primal(f)` at the primal values in `x` in the direction of the tangent values contained in `df` and `dx`.
700+
701+
Tuples are used as inputs and outputs instead of `Dual` numbers to accommodate the case where internal Mooncake tangent types do not coincide with tangents provided by the user (in which case we translate between "friendly tangents" and internal tangents using cache storage).
702+
703+
!!! info
704+
`cache` must be the output of [`prepare_derivative_cache`](@ref), and (fields of) `f` and `x` must be of the same size and shape as those used to construct the `cache`. This is to ensure that the gradient can be written to the memory allocated when the `cache` was built.
705+
706+
!!! warning
707+
`cache` owns any mutable state returned by this function, meaning that mutable components of values returned by it will be mutated if you run this function again with different arguments. Therefore, if you need to keep the values returned by this function around over multiple calls to this function with the same `cache`, you should take a copy (using `copy` or `deepcopy`) of them before calling again.
708+
"""
709+
function value_and_derivative!!(
710+
cache::ForwardCache, f::NTuple{2,Any}, x::Vararg{<:NTuple{2,Any},N}
711+
) where {N}
712+
fx = (f, x...) # to avoid method ambiguity
713+
friendly_tangents = !isnothing(cache.input_tangents)
714+
715+
input_primals = map(first, fx)
716+
input_friendly_tangents = map(last, fx)
717+
718+
# translate from friendly to native
719+
if friendly_tangents
720+
input_tangents = map(
721+
primal_to_tangent!!, cache.input_tangents, input_friendly_tangents
722+
)
723+
else
724+
input_tangents = input_friendly_tangents
725+
end
726+
727+
input_duals = map(Dual, input_primals, input_tangents)
728+
output = cache.rule(input_duals...)
729+
output_primal = primal(output)
730+
output_tangent = tangent(output)
731+
732+
# translate from native back to friendly
733+
if friendly_tangents
734+
output_friendly_tangent = tangent_to_primal!!(cache.output_primal, output_tangent)
735+
return output_primal, output_friendly_tangent
736+
else
737+
return output_primal, output_tangent
738+
end
739+
end

test/interface.jl

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,25 @@ end
8888

8989
cache = Mooncake.prepare_gradient_cache(f, x)
9090
v, dx = Mooncake.value_and_gradient!!(cache, f, x)
91-
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}}
92-
@test dx[2].fields == (; x1=2*x.x1, x2=cos(x.x2))
91+
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}}
92+
@test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2))
9393

9494
cache = Mooncake.prepare_gradient_cache(
95-
f, x; config=Mooncake.Config(friendly_tangents=true)
95+
f, x; config=Mooncake.Config(; friendly_tangents=true)
9696
)
9797
v, dx = Mooncake.value_and_gradient!!(cache, f, x)
9898
@test dx[2] isa SimplePair
99-
@test dx[2] == SimplePair(2*x.x1, cos(x.x2))
99+
@test dx[2] == SimplePair(2 * x.x1, cos(x.x2))
100100

101101
rule = build_rrule(f, x)
102102

103103
v, dx = Mooncake.value_and_gradient!!(rule, f, x)
104-
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}}
105-
@test dx[2].fields == (; x1=2*x.x1, x2=cos(x.x2))
104+
@test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}}
105+
@test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2))
106106

107107
v, dx = Mooncake.value_and_gradient!!(rule, f, x; friendly_tangents=true)
108108
@test dx[2] isa SimplePair
109-
@test dx[2] == SimplePair(2*x.x1, cos(x.x2))
109+
@test dx[2] == SimplePair(2 * x.x1, cos(x.x2))
110110
end
111111
end
112112
@testset "value_and_pullback!!" begin
@@ -162,7 +162,7 @@ end
162162
)
163163

164164
cache = Mooncake.prepare_pullback_cache(
165-
testf, x; config=Mooncake.Config(friendly_tangents=true)
165+
testf, x; config=Mooncake.Config(; friendly_tangents=true)
166166
)
167167
v, pb = Mooncake.value_and_pullback!!(cache, x̄, testf, x)
168168
@test has_equal_data(v, SimplePair(x.x1^2 + sin(x.x2), x.x1 * x.x2))
@@ -271,7 +271,6 @@ end
271271

272272
@testset "__exclude_unsupported_output , $(test_set)" for test_set in
273273
additional_test_set
274-
275274
try
276275
Mooncake.__exclude_unsupported_output(test_set[2])
277276
catch err
@@ -281,7 +280,6 @@ end
281280

282281
@testset "_copy_output & _copy_to_output!!, $(test_set)" for test_set in
283282
additional_test_set
284-
285283
original = test_set[2]
286284
try
287285
if isnothing(Mooncake.__exclude_unsupported_output(original))
@@ -304,11 +302,65 @@ end
304302
(; debug_mode=true, silence_debug_messages=true),
305303
]
306304
f = (x, y) -> x * y + cos(x)
307-
fx = (f, 5.0, 4.0)
308-
rule = Mooncake.prepare_derivative_cache(fx...; config=Mooncake.Config(; kwargs...))
309-
z = Mooncake.value_and_derivative!!(rule, map(zero_dual, fx)...)
310-
@test z isa Mooncake.Dual
311-
@test primal(z) == f(5.0, 4.0)
305+
g = (sp::SimplePair) -> SimplePair(f(sp.x1, sp.x2), 2.0)
306+
307+
x, y = 5.0, 4.0
308+
dx, dy = 3.0, 2.0
309+
fx = (f, x, y)
310+
dfx = (Mooncake.zero_tangent(f), dx, dy)
311+
z = f(x, y)
312+
dz = dx * y + x * dy + dx * (-sin(x))
313+
314+
fx_sp = (g, SimplePair(x, y))
315+
dfx_sp = (Mooncake.zero_tangent(g), SimplePair(dx, dy))
316+
z_sp = g(SimplePair(x, y))
317+
318+
@testset "Simple types" begin
319+
cache = Mooncake.prepare_derivative_cache(
320+
fx...; config=Mooncake.Config(; kwargs...)
321+
)
322+
323+
# legacy Dual interface
324+
z_and_dz_dual = Mooncake.value_and_derivative!!(
325+
cache, map(Mooncake.Dual, fx, dfx)...
326+
)
327+
@test z_and_dz_dual isa Mooncake.Dual
328+
@test Mooncake.primal(z_and_dz_dual) == z
329+
@test Mooncake.tangent(z_and_dz_dual) == dz
330+
331+
# new tuple interface
332+
z_and_dz_tup = Mooncake.value_and_derivative!!(cache, zip(fx, dfx)...)
333+
@test z_and_dz_tup isa Tuple{Float64,Float64}
334+
@test first(z_and_dz_tup) == z
335+
@test last(z_and_dz_tup) == dz
336+
end
337+
338+
@testset "Structured types" begin
339+
cache_sp_friendly = Mooncake.prepare_derivative_cache(
340+
fx_sp...; config=Mooncake.Config(; friendly_tangents=true, kwargs...)
341+
)
342+
# friendly input doesn't error
343+
z_and_dz_sp = Mooncake.value_and_derivative!!(
344+
cache_sp_friendly, zip(fx_sp, dfx_sp)...
345+
)
346+
# output is friendly
347+
@test z_and_dz_sp isa Tuple{SimplePair,SimplePair}
348+
@test first(z_and_dz_sp) == SimplePair(z, 2.0)
349+
@test last(z_and_dz_sp) == SimplePair(dz, 0.0)
350+
351+
cache_sp_unfriendly = Mooncake.prepare_derivative_cache(
352+
fx_sp...; config=Mooncake.Config(; friendly_tangents=false, kwargs...)
353+
)
354+
if get(kwargs, :debug_mode, false)
355+
@test_throws ErrorException Mooncake.value_and_derivative!!(
356+
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
357+
)
358+
else
359+
@test_throws TypeError Mooncake.value_and_derivative!!(
360+
cache_sp_unfriendly, zip(fx_sp, dfx_sp)...
361+
)
362+
end
363+
end
312364
end
313365

314366
@testset "selective zeroing of cotangents" begin

0 commit comments

Comments
 (0)