-
Notifications
You must be signed in to change notification settings - Fork 90
feat: VJP utility based on autodiff_thunk
#2309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 14 commits
6bf6299
e71a70a
3ffcd12
dafac0a
13f7a26
aa0825a
02b93d4
9fc2920
ba6c807
6b98941
1fa601e
cec5ef0
e0e3db2
d0f459e
d2bfa50
b02f776
c9cad07
b0eab62
fb78c02
0a55f3b
27546a9
c009e83
c5a8052
7eb7fe2
46885a5
533fd4e
21873f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ export Forward, ForwardWithPrimal, Reverse, ReverseWithPrimal, ReverseSplitNoPri | |
| export ReverseSplitModified, ReverseSplitWidth, ReverseHolomorphic, ReverseHolomorphicWithPrimal | ||
| export Const, Active, Duplicated, DuplicatedNoNeed, BatchDuplicated, BatchDuplicatedNoNeed, Annotation | ||
| export MixedDuplicated, BatchMixedDuplicated | ||
| export Seed, BatchSeed | ||
| export DefaultABI, FFIABI, InlineABI, NonGenABI | ||
| export BatchDuplicatedFunc | ||
| export within_autodiff | ||
|
|
@@ -206,6 +207,46 @@ end | |
| @inline batch_size(::BatchMixedDuplicated{T,N}) where {T,N} = N | ||
| @inline batch_size(::Type{BatchMixedDuplicated{T,N}}) where {T,N} = N | ||
|
|
||
| """ | ||
| batchify_activity(::Type{A}, ::Val{B}) | ||
|
|
||
| Turn an activity (or [`Annotation`](@ref)) type `A` into the correct activity type for a batch of size `B`. | ||
|
|
||
| # Examples | ||
|
|
||
| ```jldoctest | ||
| julia> using EnzymeCore | ||
|
|
||
| julia> EnzymeCore.batchify_activity(Active{Float64}, Val(2)) | ||
| Active{Float64} | ||
|
|
||
| julia> EnzymeCore.batchify_activity(Duplicated{Vector{Float64}}, Val(2)) | ||
| BatchDuplicated{Vector{Float64}, 2} | ||
| ``` | ||
| """ | ||
| batchify_activity(::Type{Active{T}}, ::Val{B}) where {T,B} = Active{T} | ||
|
||
| batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T,B} = BatchDuplicated{T,B} | ||
gdalle marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| batchify_activity(::Type{DuplicatedNoNeed{T}}, ::Val{B}) where {T,B} = BatchDuplicatedNoNeed{T,B} | ||
| batchify_activity(::Type{MixedDuplicated{T}}, ::Val{B}) where {T,B} = BatchMixedDuplicated{T,B} | ||
|
|
||
| """ | ||
| Seed(dy) | ||
|
|
||
| Wrapper for a single adjoint to the return value in reverse mode. | ||
| """ | ||
vchuravy marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| struct Seed{T} | ||
| dval::T | ||
| end | ||
|
|
||
| """ | ||
| BatchSeed(dys::NTuple) | ||
|
|
||
| Wrapper for a tuple of adjoints to the return value in reverse mode. | ||
| """ | ||
| struct BatchSeed{N, T} | ||
| dvals::NTuple{N, T} | ||
| end | ||
|
|
||
| """ | ||
| abstract type ABI | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| using EnzymeCore | ||
| using EnzymeCore: batchify_activity | ||
| using Test | ||
|
|
||
| @testset "Batchify activity" begin | ||
| @test batchify_activity(Active{Float64}, Val(2)) == Active{Float64} | ||
| @test batchify_activity(Duplicated{Vector{Float64}}, Val(2)) == BatchDuplicated{Vector{Float64},2} | ||
| @test batchify_activity(DuplicatedNoNeed{Vector{Float64}}, Val(2)) == BatchDuplicatedNoNeed{Vector{Float64},2} | ||
| @test batchify_activity(MixedDuplicated{Tuple{Float64,Vector{Float64}}}, Val(2)) == BatchMixedDuplicated{Tuple{Float64,Vector{Float64}},2} | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1159,3 +1159,75 @@ grad | |
| ) | ||
| return nothing | ||
| end | ||
|
|
||
| """ | ||
| autodiff( | ||
| rmode::Union{ReverseMode,ReverseModeSplit}, | ||
| f::Annotation, | ||
| dresult::Seed, | ||
| annotated_args... | ||
| ) | ||
|
|
||
| Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment output adjoint with `dresult`, then execute the reverse pass. | ||
|
|
||
| Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1. | ||
| """ | ||
| function autodiff( | ||
| rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, | ||
| f::FA, | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dresult::Seed{RT}, | ||
| args::Vararg{Annotation, N}, | ||
| ) where {ReturnPrimal, FA <: Annotation, RT, N} | ||
| rmode_split = Split(rmode) | ||
| RA = guess_activity(RT, rmode_split) | ||
| forward, reverse = autodiff_thunk(rmode_split, FA, RA, typeof.(args)...) | ||
| tape, result, shadow_result = forward(f, args...) | ||
| if RA <: Active | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| dinputs = only(reverse(f, args..., dresult.dval, tape)) | ||
| else | ||
| Compiler.recursive_accumulate(shadow_result, dresult.dval) | ||
| dinputs = only(reverse(f, args..., tape)) | ||
| end | ||
| if ReturnPrimal | ||
| return (dinputs, result) | ||
| else | ||
| return (dinputs,) | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| autodiff( | ||
| rmode::Union{ReverseMode,ReverseModeSplit}, | ||
| f::Annotation, | ||
| dresults::BatchSeed, | ||
| annotated_args... | ||
| ) | ||
|
|
||
| Call [`autodiff_thunk`](@ref) in split mode, execute the forward pass, increment each output adjoint with the corresponding element from `dresults`, then execute the reverse pass. | ||
|
|
||
| Useful for computing pullbacks / VJPs for functions whose output is not a scalar, or when the scalar seed is not 1. | ||
| """ | ||
| function autodiff( | ||
| rmode::Union{ReverseMode{ReturnPrimal}, ReverseModeSplit{ReturnPrimal}}, | ||
| f::FA, | ||
| dresults::BatchSeed{B,RT}, | ||
| args::Vararg{Annotation, N}, | ||
| ) where {ReturnPrimal, B, FA <: Annotation, RT, N} | ||
| rmode_split_rightwidth = ReverseSplitWidth(Split(rmode), Val(B)) | ||
| RA = batchify_activity(guess_activity(RT, rmode_split_rightwidth), Val(B)) | ||
| forward, reverse = autodiff_thunk(rmode_split_rightwidth, FA, RA, typeof.(args)...) | ||
gdalle marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| tape, result, shadow_results = forward(f, args...) | ||
| if RA <: Active | ||
| dinputs = only(reverse(f, args..., dresults.dvals, tape)) | ||
| else | ||
| foreach(shadow_results, dresults.dvals) do d0, d | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vchuravy knows better than me here, but I don't think we should use foreach to ensure type stability, instead explicitly using either a generated function, or ntuple to go through?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you have an example in mind where
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| Compiler.recursive_accumulate(d0, d) | ||
| end | ||
| dinputs = only(reverse(f, args..., tape)) | ||
| end | ||
| if ReturnPrimal | ||
| return (dinputs, result) | ||
| else | ||
| return (dinputs,) | ||
| end | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.