Skip to content
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ EnzymeStaticArraysExt = "StaticArrays"
BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.8"
EnzymeCore = "0.8.9"
Enzyme_jll = "0.0.173"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.3"
Expand Down
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.8"
version = "0.8.9"

[compat]
Adapt = "3, 4"
Expand Down
41 changes: 41 additions & 0 deletions lib/EnzymeCore/src/EnzymeCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri
export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal
export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation
export MixedDuplicated, BatchMixedDuplicated
export Seed, BatchSeed
export DefaultABI, FFIABI, InlineABI, NonGenABI
export BatchDuplicatedFunc
export within_autodiff
Expand Down Expand Up @@ -206,6 +207,46 @@ end
@inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N
@inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N

"""
batchify_activity(::Type{A}, ::Val{B})

Turn an activity (or [`Annotation`](@ref)) type `A` into the correct activity type for a batch of size `B`.

# Examples

```jldoctest
julia> using EnzymeCore

julia> EnzymeCore.batchify_activity(Active{Float64}, Val(2))
Active{Float64}

julia> EnzymeCore.batchify_activity(Duplicated{Vector{Float64}}, Val(2))
BatchDuplicated{Vector{Float64}, 2}
```
"""
batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that it matters much here, but I feel like we do have a missing BatchActive -- in particular this is already the case for certain custom rules of batched ?returns (I don't remmeber specifics) that can't currently be expressed at the moment

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be the cause for #2514?

batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B}
batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T,B} = BatchDuplicatedNoNeed{T,B}
batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixedDuplicated{T,B}

"""
Seed(dy)

Wrapper for a single adjoint to the return value in reverse mode.
"""
struct Seed{T}
dval::T
end

"""
BatchSeed(dys::NTuple)

Wrapper for a tuple of adjoints to the return value in reverse mode.
"""
struct BatchSeed{N, T}
dvals::NTuple{N, T}
end

"""
abstract type ABI

Expand Down
10 changes: 10 additions & 0 deletions lib/EnzymeCore/test/annotation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using EnzymeCore
using EnzymeCore: batchify_activity
using Test

@testset "Batchify activity" begin
@test batchify_activity(Active{Float64}, Val(2)) == Active{Float64}
@test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64},2}
@test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64},2}
@test batchify_activity(MixedDuplicated{Tuple{Float64,Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64,Vector{Float64}},2}
end
3 changes: 3 additions & 0 deletions lib/EnzymeCore/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,7 @@ using EnzymeCore
@testset "Mode modification" begin
include("mode_modification.jl")
end
@testset "Annotation" begin
include("annotation.jl")
end
end
8 changes: 7 additions & 1 deletion src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ import EnzymeCore:
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
batchify_activity,
Seed,
BatchSeed,
ABI,
DefaultABI,
FFIABI,
Expand All @@ -52,14 +55,17 @@ import EnzymeCore:
clear_runtime_activity,
within_autodiff,
WithPrimal,
NoPrimal
NoPrimal,
Split
export Annotation,
Const,
Active,
Duplicated,
DuplicatedNoNeed,
BatchDuplicated,
BatchDuplicatedNoNeed,
Seed,
BatchSeed,
DefaultABI,
FFIABI,
InlineABI,
Expand Down
72 changes: 72 additions & 0 deletions src/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1159,3 +1159,75 @@ grad
)
return nothing
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
dresult::Seed,
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1.
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
dresult::Seed{RT},
args::Vararg{Annotation, N},
) where {ReturnPrimal, FA <: Annotation, RT, N}
rmode_split = Split(rmode)
RA = guess_activity(RT, rmode_split)
forward, reverse = autodiff_thunk(rmode_split, FA, RA, typeof.(args)...)
tape, result, shadow_result = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresult.dval, tape))
else
Compiler.recursive_accumulate(shadow_result, dresult.dval)
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end

"""
autodiff(
rmode::Union{ReverseMode,ReverseModeSplit},
f::Annotation,
dresults::BatchSeed,
annotated_args...
)

Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass.

Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1.
"""
function autodiff(
rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}},
f::FA,
dresults::BatchSeed{B,RT},
args::Vararg{Annotation, N},
) where {ReturnPrimal, B, FA <: Annotation, RT, N}
rmode_split_rightwidth = ReverseSplitWidth(Split(rmode), Val(B))
RA = batchify_activity(guess_activity(RT, rmode_split_rightwidth), Val(B))
forward, reverse = autodiff_thunk(rmode_split_rightwidth, FA, RA, typeof.(args)...)
tape, result, shadow_results = forward(f, args...)
if RA <: Active
dinputs = only(reverse(f, args..., dresults.dvals, tape))
else
foreach(shadow_results, dresults.dvals) do d0, d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vchuravy knows better than me here, but I don't think we should use foreach to ensure type stability, instead explicitly using either a generated function, or ntuple to go through?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have an example in mind where foreach would be unstable?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

foreach for tuples is fine. It boils down to ntuple

Compiler.recursive_accumulate(d0, d)
end
dinputs = only(reverse(f, args..., tape))
end
if ReturnPrimal
return (dinputs, result)
else
return (dinputs,)
end
end
69 changes: 69 additions & 0 deletions test/sugar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -650,3 +650,72 @@ end
# @show J_r_3(u, A, x)
# @show J_f_3(u, A, x)
end

@testset "Seeded reverse autodiff" begin
f(x::Vector{Float64}, y::Float64) = sum(abs2, x) * y
g(x::Vector{Float64}, y::Float64) = [f(x, y)]

x = [1.0, 2.0, 3.0]
y = 4.0
dx = similar(x)
dresult = 5.0
dxs = (similar(x), similar(x))
dresults = (5.0, 7.0)

@testset "simple" begin
for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(f), Seed(dresult), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dx)
dinputs_and_maybe_result = autodiff(mode, Const(g), Seed([dresult]), Duplicated(x, dx), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2] == dresult * sum(abs2, x)
@test dx == dresult * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

@testset "batch" begin
for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(f), BatchSeed(dresults), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == f(x, y)
end
end

for mode in (Reverse, ReverseWithPrimal, ReverseSplitNoPrimal, ReverseSplitWithPrimal)
make_zero!(dxs)
dinputs_and_maybe_result = autodiff(mode, Const(g), BatchSeed(([dresults[1]], [dresults[2]])), BatchDuplicated(x, dxs), Active(y))
dinputs = first(dinputs_and_maybe_result)
@test isnothing(dinputs[1])
@test dinputs[2][1] == dresults[1] * sum(abs2, x)
@test dinputs[2][2] == dresults[2] * sum(abs2, x)
@test dxs[1] == dresults[1] * 2x * y
@test dxs[2] == dresults[2] * 2x * y
if Enzyme.Split(mode) == ReverseSplitWithPrimal
@test last(dinputs_and_maybe_result) == g(x, y)
end
end
end

end
Loading