diff --git a/lib/EnzymeCore/src/EnzymeCore.jl b/lib/EnzymeCore/src/EnzymeCore.jl index cc73251167..5b1156c047 100644 --- a/lib/EnzymeCore/src/EnzymeCore.jl +++ b/lib/EnzymeCore/src/EnzymeCore.jl @@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation export MixedDuplicated, BatchMixedDuplicated +export Seed, BatchSeed export DefaultABI, FFIABI, InlineABI, NonGenABI export BatchDuplicatedFunc export within_autodiff diff --git a/src/Enzyme.jl b/src/Enzyme.jl index 6fdd84a778..925a5681b6 100644 --- a/src/Enzyme.jl +++ b/src/Enzyme.jl @@ -57,7 +57,8 @@ import EnzymeCore: NoPrimal, needs_primal, runtime_activity, - strong_zero + strong_zero, + Split export Annotation, Const, Active, @@ -112,6 +113,7 @@ export autodiff, export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient! export batch_size, onehot, chunkedonehot +export Seed, BatchSeed using LinearAlgebra import SparseArrays diff --git a/src/sugar.jl b/src/sugar.jl index 59e2915536..f15e6e28de 100644 --- a/src/sugar.jl +++ b/src/sugar.jl @@ -1349,3 +1349,133 @@ grad ) return nothing end + +""" + batchify_activity(::Type{A}, ::Val{B}) + +Turn an activity (or [`Annotation`](@ref)) type `A` into the correct activity type for a batch of size `B`. + +# Examples + +```jldoctest +julia> using Enzyme + +julia> Enzyme.batchify_activity(Active{Float64}, Val(2)) +Active{Float64} + +julia> Enzyme.batchify_activity(Duplicated{Vector{Float64}}, Val(2)) +BatchDuplicated{Vector{Float64}, 2} +``` +""" +batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} +batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} +batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T,B} = BatchDuplicatedNoNeed{T,B} +batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixedDuplicated{T,B} + + +""" + Seed(dy) + +Wrapper for a single adjoint to the return value in reverse mode. +""" +struct Seed{T,A} + dval::T + + function Seed(dval::T) where T + A = guess_activity(T, Reverse) + return new{T,A}(dval) + end +end + +""" + BatchSeed(dys::NTuple) + +Wrapper for a tuple of adjoints to the return value in reverse mode. +""" +struct BatchSeed{N, T, AB} + dvals::NTuple{N, T} + + function BatchSeed(dvals::NTuple{N,T}) where {N,T} + A = guess_activity(T, Reverse) + AB = batchify_activity(A, Val(N)) + return new{N,T,AB}(dvals) + end +end + +""" + autodiff( + rmode::Union{ReverseMode,ReverseModeSplit}, + f::Annotation, + dresult::Seed, + annotated_args... + ) + +Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass. + +Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1. +""" +function autodiff( + rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, + f::FA, + dresult::Seed{RT,RA}, + args::Vararg{Annotation, N}, + ) where {ReturnPrimal, FA <: Annotation, RT, RA, N} + rmode_split = Split(rmode) + forward, reverse = autodiff_thunk(rmode_split, FA, RA, typeof.(args)...) + tape, result, shadow_result = forward(f, args...) + if RA <: Active + dinputs = only(reverse(f, args..., dresult.dval, tape)) + elseif RA <: Duplicated + Compiler.recursive_accumulate(shadow_result, dresult.dval) + dinputs = only(reverse(f, args..., tape)) + else # RA <: MixedDuplicated + Compiler.recursive_accumulate(shadow_result, Ref(dresult.dval)) + dinputs = only(reverse(f, args..., dresult.dval, tape)) + end + if ReturnPrimal + return (dinputs, result) + else + return (dinputs,) + end +end + +""" + autodiff( + rmode::Union{ReverseMode,ReverseModeSplit}, + f::Annotation, + dresults::BatchSeed, + annotated_args... + ) + +Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass. + +Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1. +""" +function autodiff( + rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, + f::FA, + dresults::BatchSeed{B,RT,RA}, + args::Vararg{Annotation, N}, + ) where {ReturnPrimal, B, FA <: Annotation, RT, RA, N} + rmode_split_rightwidth = ReverseSplitWidth(Split(rmode), Val(B)) + forward, reverse = autodiff_thunk(rmode_split_rightwidth, FA, RA, typeof.(args)...) + tape, result, shadow_results = forward(f, args...) + if RA <: Active + dinputs = only(reverse(f, args..., dresults.dvals, tape)) + elseif RA <: BatchDuplicated + foreach(shadow_results, dresults.dvals) do d0, d + Compiler.recursive_accumulate(d0, d) + end + dinputs = only(reverse(f, args..., tape)) + else # RA <: BatchMixedDuplicated + foreach(shadow_results, dresults.dvals) do d0, d + Compiler.recursive_accumulate(d0, Ref(d)) + end + dinputs = only(reverse(f, args..., dresults.dvals, tape)) + end + if ReturnPrimal + return (dinputs, result) + else + return (dinputs,) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index e22e00929d..50c81d5650 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3031,6 +3031,10 @@ end include("sugar.jl") include("errors.jl") +@testset "Seeded autodiff" begin + include("seeded.jl") +end + @testset "Forward on Reverse" begin function speelpenning(y, x) diff --git a/test/seeded.jl b/test/seeded.jl new file mode 100644 index 0000000000..68c3962af2 --- /dev/null +++ b/test/seeded.jl @@ -0,0 +1,165 @@ +using Enzyme +using Enzyme: batchify_activity +using Test + +@testset "Batchify activity" begin + @test batchify_activity(Active{Float64}, Val(2)) == Active{Float64} + @test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64},2} + @test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64},2} + @test batchify_activity(MixedDuplicated{Tuple{Float64,Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64,Vector{Float64}},2} +end + +# the base case is a function returning (a(x, y), b(x, y)) + +a(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y +b(x::Vector{Float64}, y::Float64) = sum(x) * abs2(y) + +struct MyStruct + bar::Float64 + foo::Float64 +end + +mutable struct MyMutableStruct + bar::Float64 + foo::Float64 +end + +Base.:(==)(s1::MyMutableStruct, s2::MyMutableStruct) = s1.bar == s2.bar && s1.foo == s2.foo + +struct MyMixedStruct + bar::Float64 + foo::Vector{Float64} +end + +Base.:(==)(s1::MyMixedStruct, s2::MyMixedStruct) = s1.bar == s2.bar && s1.foo == s2.foo + +f1(x, y) = a(x, y) + b(x, y) +f2(x, y) = [a(x, y), b(x, y)] +f3(x, y) = (a(x, y), b(x, y)) +f4(x, y) = MyStruct(a(x, y), b(x, y)) +f5(x, y) = MyMutableStruct(a(x, y), b(x, y)) +f6(x, y) = MyMixedStruct(a(x, y), [b(x, y)]) + +x = [1.0, 2.0, 3.0] +y = 4.0 + +# output seeds, (a,b) case + +da = 5.0 +db = 7.0 +das = (5.0, 11.0) +dbs = (7.0, 13.0) + +# input derivatives, (a,b) case + +dx_ref = da * 2x * y .+ db * abs2(y) +dy_ref = da * sum(abs2, x) + db * sum(x) * 2y +dxs_ref = ( + das[1] * 2x * y .+ dbs[1] * abs2(y), + das[2] * 2x * y .+ dbs[2] * abs2(y) +) +dys_ref = ( + das[1] * sum(abs2, x) + dbs[1] * sum(x) * 2y, + das[2] * sum(abs2, x) + dbs[2] * sum(x) * 2y +) + +# input derivatives, (a+b) case + +dx1_ref = (da + db) * (2x * y .+ abs2(y)) +dy1_ref = (da + db) * (sum(abs2, x) + sum(x) * 2y) +dxs1_ref = ( + (das[1] + dbs[1]) * (2x * y .+ abs2(y)), + (das[2] + dbs[2]) * (2x * y .+ abs2(y)) +) +dys1_ref = ( + (das[1] + dbs[1]) * (sum(abs2, x) + sum(x) * 2y), + (das[2] + dbs[2]) * (sum(abs2, x) + sum(x) * 2y) +) + +# output seeds, weird cases + +dz1 = da + db +dzs1 = das .+ dbs + +dz2 = [da, db] +dzs2 = ([das[1], dbs[1]], [das[2], dbs[2]]) + +dz3 = (da, db) +dzs3 = ((das[1], dbs[1]), (das[2], dbs[2])) + +dz4 = MyStruct(da, db) +dzs4 = (MyStruct(das[1], dbs[1]), MyStruct(das[2], dbs[2])) + +dz5 = MyMutableStruct(da, db) +dzs5 = (MyMutableStruct(das[1], dbs[1]), MyMutableStruct(das[2], dbs[2])) + +dz6 = MyMixedStruct(da, [db]) +dzs6 = (MyMixedStruct(das[1], [dbs[1]]), MyMixedStruct(das[2], [dbs[2]])) + +# validation + +function validate_seeded_autodiff(f, dz, dzs) + @testset for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal) + @testset "Simple" begin + dx = make_zero(x) + dinputs_and_maybe_result = autodiff(mode, Const(f), Seed(dz), Duplicated(x, dx), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + if f === f1 + @test dinputs[2] == dy1_ref + @test dx == dx1_ref + else + @test dinputs[2] == dy_ref + @test dx == dx_ref + end + if Enzyme.Split(mode) == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == f(x, y) + end + end + + @testset "Batch" begin + dxs = (make_zero(x), make_zero(x)) + dinputs_and_maybe_result = autodiff(mode, Const(f), BatchSeed(dzs), BatchDuplicated(x, dxs), Active(y)) + dinputs = first(dinputs_and_maybe_result) + @test isnothing(dinputs[1]) + if f === f1 + @test dinputs[2][1] == dys1_ref[1] + @test dinputs[2][2] == dys1_ref[2] + @test dxs[1] == dxs1_ref[1] + @test dxs[2] == dxs1_ref[2] + else + @test dinputs[2][1] == dys_ref[1] + @test dinputs[2][2] == dys_ref[2] + @test dxs[1] == dxs_ref[1] + @test dxs[2] == dxs_ref[2] + end + if Enzyme.Split(mode) == ReverseSplitWithPrimal + @test last(dinputs_and_maybe_result) == f(x, y) + end + end + end +end + +@testset "Scalar output" begin + validate_seeded_autodiff(f1, dz1, dzs1) +end; + +@testset "Vector output" begin + validate_seeded_autodiff(f2, dz2, dzs2) +end; + +@testset "Tuple output" begin + validate_seeded_autodiff(f3, dz3, dzs3) +end; + +@testset "Struct output" begin + validate_seeded_autodiff(f4, dz4, dzs4) +end; + +@testset "Mutable struct output" begin + validate_seeded_autodiff(f5, dz5, dzs5) +end; + +@testset "Mixed struct output" begin + validate_seeded_autodiff(f6, dz6, dzs6) # TODO: debug this +end;