Conversation
| include("reverse_onearg.jl") | ||
| include("reverse_twoarg.jl") | ||
|
|
||
| end # module No newline at end of file |
There was a problem hiding this comment.
| end # module | |
| end # module |
| f::F, | ||
| ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
| x, | ||
| tx::NTuple, | ||
| contexts::Vararg{DI.Context,C}, | ||
| ) where {F,C} | ||
| return DI.NoPushforwardPrep() | ||
| end | ||
|
|
||
| function DI.value_and_pushforward( | ||
| f::F, | ||
| ::DI.NoPushforwardPrep, | ||
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, |
There was a problem hiding this comment.
| f::F, | |
| ::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
| x, | |
| tx::NTuple, | |
| contexts::Vararg{DI.Context,C}, | |
| ) where {F,C} | |
| return DI.NoPushforwardPrep() | |
| end | |
| function DI.value_and_pushforward( | |
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
| f::F, | |
| ::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
| x, | |
| tx::NTuple, | |
| contexts::Vararg{DI.Context, C}, | |
| ) where {F, C} | |
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
| x, | |
| tx::NTuple{1}, | |
| contexts::Vararg{DI.Context, C}, | |
| ) where {F, C} |
| f::F, | ||
| ::DI.NoPushforwardPrep, | ||
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
| x, | ||
| tx::NTuple{B}, | ||
| contexts::Vararg{DI.Context,C}, | ||
| ) where {F,B,C} |
There was a problem hiding this comment.
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
| x, | |
| tx::NTuple{B}, | |
| contexts::Vararg{DI.Context,C}, | |
| ) where {F,B,C} | |
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
| x, | |
| tx::NTuple{B}, | |
| contexts::Vararg{DI.Context, C}, | |
| ) where {F, B, C} |
| f::F, | ||
| ::DI.NoPushforwardPrep, | ||
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
| x, | ||
| tx::NTuple{1}, | ||
| contexts::Vararg{DI.Context,C}, | ||
| ) where {F,C} |
There was a problem hiding this comment.
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
| x, | |
| tx::NTuple{1}, | |
| contexts::Vararg{DI.Context,C}, | |
| ) where {F,C} | |
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
| x, | |
| tx::NTuple{1}, | |
| contexts::Vararg{DI.Context, C}, | |
| ) where {F, C} |
| f::F, | ||
| ::DI.NoPushforwardPrep, | ||
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | ||
| x, | ||
| tx::NTuple{B}, | ||
| contexts::Vararg{DI.Context,C}, | ||
| ) where {F,B,C} |
There was a problem hiding this comment.
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode,Nothing}}, | |
| x, | |
| tx::NTuple{B}, | |
| contexts::Vararg{DI.Context,C}, | |
| ) where {F,B,C} | |
| f::F, | |
| ::DI.NoPushforwardPrep, | |
| backend::AutoEnzyme{<:Union{ForwardMode, Nothing}}, | |
| x, | |
| tx::NTuple{B}, | |
| contexts::Vararg{DI.Context, C}, | |
| ) where {F, B, C} |
| f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) | ||
| ) where {F,M,B} | ||
| return f | ||
| end | ||
|
|
||
| @inline function get_f_and_df( | ||
| f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) | ||
| ) where {F,M,B} | ||
| return Const(f) | ||
| end | ||
|
|
||
| @inline function get_f_and_df( | ||
| f::F, | ||
| ::AutoEnzyme{ | ||
| M, | ||
| <:Union{ | ||
| Duplicated, | ||
| MixedDuplicated, |
There was a problem hiding this comment.
| f::F, ::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1) | |
| ) where {F,M,B} | |
| return f | |
| end | |
| @inline function get_f_and_df( | |
| f::F, ::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1) | |
| ) where {F,M,B} | |
| return Const(f) | |
| end | |
| @inline function get_f_and_df( | |
| f::F, | |
| ::AutoEnzyme{ | |
| M, | |
| <:Union{ | |
| Duplicated, | |
| MixedDuplicated, | |
| f::F, ::AutoEnzyme{M, Nothing}, mode::Mode, ::Val{B} = Val(1) | |
| ) where {F, M, B} | |
| f::F, ::AutoEnzyme{M, <:Const}, mode::Mode, ::Val{B} = Val(1) | |
| ) where {F, M, B} | |
| f::F, | |
| ::AutoEnzyme{ | |
| M, | |
| <:Union{ | |
| Duplicated, | |
| MixedDuplicated, | |
| BatchDuplicated, | |
| BatchMixedDuplicated, | |
| DuplicatedNoNeed, | |
| BatchDuplicatedNoNeed, | |
| }, | |
| mode::Mode, | |
| ::Val{B} = Val(1), | |
| ) where {F, M, B} |
| force_annotation(f::F) where {F<:Annotation} = f | ||
| force_annotation(f::F) where {F} = Const(f) | ||
|
|
||
| @inline function _translate( | ||
| ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} |
There was a problem hiding this comment.
| force_annotation(f::F) where {F<:Annotation} = f | |
| force_annotation(f::F) where {F} = Const(f) | |
| @inline function _translate( | |
| ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext} | |
| force_annotation(f::F) where {F <: Annotation} = f | |
| ::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant, DI.BackendContext} | |
| ) where {B} | |
| backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache | |
| ) where {B} |
| backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | ||
| ) where {B} | ||
| return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) | ||
| end |
There was a problem hiding this comment.
| backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | |
| ) where {B} | |
| return force_annotation(get_f_and_df(DI.unwrap(c), backend, mode, Val(B))) | |
| end | |
| backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext | |
| ) where {B} | |
| backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context, C} | |
| ) where {B, C} |
| set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) | ||
| set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode |
There was a problem hiding this comment.
| set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) | |
| set_err(mode::Mode, ::AutoEnzyme{<:Any,<:Annotation}) = mode | |
| set_err(mode::Mode, ::AutoEnzyme{<:Any, Nothing}) = EnzymeCore.set_err_if_func_written(mode) | |
| set_err(mode::Mode, ::AutoEnzyme{<:Any, <:Annotation}) = mode |
| function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} | ||
| return BatchDuplicated(x, tx) | ||
| end |
There was a problem hiding this comment.
| function annotate(::Type{BatchDuplicated{T,B}}, x, tx::NTuple{B}) where {T,B} | |
| return BatchDuplicated(x, tx) | |
| end | |
| function annotate(::Type{BatchDuplicated{T, B}}, x, tx::NTuple{B}) where {T, B} | |
| batchify_activity(::Type{Active{T}}, ::Val{B}) where {T, B} = Active{T} | |
| batchify_activity(::Type{Duplicated{T}}, ::Val{B}) where {T, B} = BatchDuplicated{T, B} |
Benchmark Results
Benchmark PlotsA plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR. |
| return set_err(ReverseSplitWithPrimal, backend) | ||
| end | ||
|
|
||
| set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode) |
There was a problem hiding this comment.
@gdalle as discussed on slack this probably should be an extension to the set_err_if_func_written function to take an ADmode, so likely we have this in an EnzymeCoreADTypes ext?
| forward_withprimal(backend::AutoEnzyme{<:ForwardMode}) = WithPrimal(backend.mode) | ||
| forward_withprimal(::AutoEnzyme{Nothing}) = ForwardWithPrimal | ||
|
|
||
| reverse_noprimal(backend::AutoEnzyme{<:ReverseMode}) = NoPrimal(backend.mode) |
There was a problem hiding this comment.
@gdalle similarly here we can make an ADTypes ext func for get_mode_or_default(AutoEnzyme, defaultMode)
| dy_sametype = convert(typeof(y), only(prep.ty_copy)) | ||
| x_and_dx = Duplicated(x, dx_sametype) | ||
| y_and_dy = Duplicated(y, dy_sametype) | ||
| annotated_contexts = translate(backend, mode, Val(1), contexts...) |
There was a problem hiding this comment.
@gdalle I presume this can be moved into Enzyme.gradient! And have DI call that?
@gdalle this will enable the DI ext to more properly touch internals if need be