Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export Seed, BatchSeed
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff
Expand Down
4 changes: 3 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ import EnzymeCore:
NoPrimal,
needs_primal,
runtime_activity,
strong_zero
strong_zero,
Split
export Annotation,
Const,
Active,
Expand Down Expand Up @@ -112,6 +113,7 @@ export autodiff,

export jacobian, gradient, gradient!, hvp, hvp!, hvp_and_gradient!
export batch_size, onehot, chunkedonehot
export Seed, BatchSeed

using LinearAlgebra
import SparseArrays
Expand Down
130 changes: 130 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1349,3 +1349,133 @@ grad
)
return nothing
end

"""
batchify_activity(::Type{A}, ::Val{B})

Turn an activity (or [`Annotation`](@ref)) type `A` into the correct activity type for a batch of size `B`.

# Examples

```jldoctest
julia> using Enzyme

julia> Enzyme.batchify_activity(Active{Float64}, Val(2))
Active{Float64}

julia> Enzyme.batchify_activity(Duplicated{Vector{Float64}}, Val(2))
BatchDuplicated{Vector{Float64}, 2}
```
"""
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}
batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T,B} = BatchDuplicatedNoNeed{T,B}
batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixedDuplicated{T,B}


"""
Seed(dy)

Wrapper for a single adjoint to the return value in reverse mode.
"""
struct Seed{T,A}
dval::T

function Seed(dval::T) where T
A = guess_activity(T, Reverse)
return new{T,A}(dval)
end
end

"""
BatchSeed(dys::NTuple)

Wrapper for a tuple of adjoints to the return value in reverse mode.
"""
struct BatchSeed{N, T, AB}
dvals::NTuple{N, T}

function BatchSeed(dvals::NTuple{N,T}) where {N,T}
A = guess_activity(T, Reverse)
AB = batchify_activity(A, Val(N))
return new{N,T,AB}(dvals)
end
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
dresult::Seed,
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1.
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
dresult::Seed{RT,RA},
args::Vararg{Annotation, N},
) where {ReturnPrimal, FA <: Annotation, RT, RA, N}
rmode_split = Split(rmode)
forward, reverse = autodiff_thunk(rmode_split, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult.dval, tape))
elseif RA <: Duplicated
Compiler.recursive_accumulate(shadow_result, dresult.dval)
dinputs = only(reverse(f, args..., tape))
else # RA <: MixedDuplicated
Compiler.recursive_accumulate(shadow_result, Ref(dresult.dval))
dinputs = only(reverse(f, args..., dresult.dval, tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
dresults::BatchSeed,
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1.
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
dresults::BatchSeed{B,RT,RA},
args::Vararg{Annotation, N},
) where {ReturnPrimal, B, FA <: Annotation, RT, RA, N}
rmode_split_rightwidth = ReverseSplitWidth(Split(rmode), Val(B))
forward, reverse = autodiff_thunk(rmode_split_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults.dvals, tape))
elseif RA <: BatchDuplicated
foreach(shadow_results, dresults.dvals) do d0, d
Compiler.recursive_accumulate(d0, d)
end
dinputs = only(reverse(f, args..., tape))
else # RA <: BatchMixedDuplicated
foreach(shadow_results, dresults.dvals) do d0, d
Compiler.recursive_accumulate(d0, Ref(d))
end
dinputs = only(reverse(f, args..., dresults.dvals, tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,10 @@ end
include("sugar.jl")
include("errors.jl")

@testset "Seeded autodiff" begin
include("seeded.jl")
end

@testset "Forward on Reverse" begin

function speelpenning(y, x)
Expand Down
165 changes: 165 additions & 0 deletions test/seeded.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
using Enzyme
using Enzyme: batchify_activity
using Test

@testset "Batchify activity" begin
@test batchify_activity(Active{Float64}, Val(2)) == Active{Float64}
@test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64},2}
@test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64},2}
@test batchify_activity(MixedDuplicated{Tuple{Float64,Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64,Vector{Float64}},2}
end

# the base case is a function returning (a(x, y), b(x, y))

a(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
b(x::Vector{Float64}, y::Float64) = sum(x) * abs2(y)

struct MyStruct
bar::Float64
foo::Float64
end

mutable struct MyMutableStruct
bar::Float64
foo::Float64
end

Base.:(==)(s1::MyMutableStruct, s2::MyMutableStruct) = s1.bar == s2.bar && s1.foo == s2.foo

struct MyMixedStruct
bar::Float64
foo::Vector{Float64}
end

Base.:(==)(s1::MyMixedStruct, s2::MyMixedStruct) = s1.bar == s2.bar && s1.foo == s2.foo

f1(x, y) = a(x, y) + b(x, y)
f2(x, y) = [a(x, y), b(x, y)]
f3(x, y) = (a(x, y), b(x, y))
f4(x, y) = MyStruct(a(x, y), b(x, y))
f5(x, y) = MyMutableStruct(a(x, y), b(x, y))
f6(x, y) = MyMixedStruct(a(x, y), [b(x, y)])

x = [1.0, 2.0, 3.0]
y = 4.0

# output seeds, (a,b) case

da = 5.0
db = 7.0
das = (5.0, 11.0)
dbs = (7.0, 13.0)

# input derivatives, (a,b) case

dx_ref = da * 2x * y .+ db * abs2(y)
dy_ref = da * sum(abs2, x) + db * sum(x) * 2y
dxs_ref = (
das[1] * 2x * y .+ dbs[1] * abs2(y),
das[2] * 2x * y .+ dbs[2] * abs2(y)
)
dys_ref = (
das[1] * sum(abs2, x) + dbs[1] * sum(x) * 2y,
das[2] * sum(abs2, x) + dbs[2] * sum(x) * 2y
)

# input derivatives, (a+b) case

dx1_ref = (da + db) * (2x * y .+ abs2(y))
dy1_ref = (da + db) * (sum(abs2, x) + sum(x) * 2y)
dxs1_ref = (
(das[1] + dbs[1]) * (2x * y .+ abs2(y)),
(das[2] + dbs[2]) * (2x * y .+ abs2(y))
)
dys1_ref = (
(das[1] + dbs[1]) * (sum(abs2, x) + sum(x) * 2y),
(das[2] + dbs[2]) * (sum(abs2, x) + sum(x) * 2y)
)

# output seeds, weird cases

dz1 = da + db
dzs1 = das .+ dbs

dz2 = [da, db]
dzs2 = ([das[1], dbs[1]], [das[2], dbs[2]])

dz3 = (da, db)
dzs3 = ((das[1], dbs[1]), (das[2], dbs[2]))

dz4 = MyStruct(da, db)
dzs4 = (MyStruct(das[1], dbs[1]), MyStruct(das[2], dbs[2]))

dz5 = MyMutableStruct(da, db)
dzs5 = (MyMutableStruct(das[1], dbs[1]), MyMutableStruct(das[2], dbs[2]))

dz6 = MyMixedStruct(da, [db])
dzs6 = (MyMixedStruct(das[1], [dbs[1]]), MyMixedStruct(das[2], [dbs[2]]))

# validation

function validate_seeded_autodiff(f, dz, dzs)
@testset for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
@testset "Simple" begin
dx = make_zero(x)
dinputs_and_maybe_result = autodiff(mode, Const(f), Seed(dz), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
if f === f1
@test dinputs[2] == dy1_ref
@test dx == dx1_ref
else
@test dinputs[2] == dy_ref
@test dx == dx_ref
end
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

@testset "Batch" begin
dxs = (make_zero(x), make_zero(x))
dinputs_and_maybe_result = autodiff(mode, Const(f), BatchSeed(dzs), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
if f === f1
@test dinputs[2][1] == dys1_ref[1]
@test dinputs[2][2] == dys1_ref[2]
@test dxs[1] == dxs1_ref[1]
@test dxs[2] == dxs1_ref[2]
else
@test dinputs[2][1] == dys_ref[1]
@test dinputs[2][2] == dys_ref[2]
@test dxs[1] == dxs_ref[1]
@test dxs[2] == dxs_ref[2]
end
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end
end
end

@testset "Scalar output" begin
validate_seeded_autodiff(f1, dz1, dzs1)
end;

@testset "Vector output" begin
validate_seeded_autodiff(f2, dz2, dzs2)
end;

@testset "Tuple output" begin
validate_seeded_autodiff(f3, dz3, dzs3)
end;

@testset "Struct output" begin
validate_seeded_autodiff(f4, dz4, dzs4)
end;

@testset "Mutable struct output" begin
validate_seeded_autodiff(f5, dz5, dzs5)
end;

@testset "Mixed struct output" begin
validate_seeded_autodiff(f6, dz6, dzs6) # TODO: debug this
end;