Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -31,6 +32,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
[extensions]
EnzymeBFloat16sExt = "BFloat16s"
EnzymeChainRulesCoreExt = "ChainRulesCore"
EnzymeFunctionWrappersExt = "FunctionWrappers"
EnzymeGPUArraysCoreExt = "GPUArraysCore"
EnzymeLogExpFunctionsExt = "LogExpFunctions"
EnzymeSpecialFunctionsExt = "SpecialFunctions"
Expand All @@ -42,6 +44,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5, 0.6"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.16"
FunctionWrappers = "1.1"
Enzyme_jll = "0.0.249"
GPUArraysCore = "0.1.6, 0.2"
GPUCompiler = "1.6.2"
Expand All @@ -59,6 +62,7 @@ julia = "1.10"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand Down
209 changes: 209 additions & 0 deletions ext/EnzymeFunctionWrappersExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
module EnzymeFunctionWrappersExt

using FunctionWrappers: FunctionWrapper
using EnzymeCore
using EnzymeCore.EnzymeRules
using Enzyme

# Helper to extract the raw function from a FunctionWrapper
@inline unwrap_fw(fw::FunctionWrapper) = fw.obj[]

# Helper to reconstruct an annotation with a cached primal value
@inline _reconstruct_arg(arg::Const, cached, overwritten::Bool) = arg
@inline function _reconstruct_arg(arg::Duplicated, cached, overwritten::Bool)
overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg
end
@inline function _reconstruct_arg(arg::BatchDuplicated, cached, overwritten::Bool)
overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg
end
@inline _reconstruct_arg(arg::Active, cached, overwritten::Bool) = arg

# Helper for type-stable reverse return values
@inline _reverse_val(::Active{T}, grad, dret_val) where {T} = (grad * dret_val)::T
@inline _reverse_val(::Const, grad, dret_val) = nothing
@inline _reverse_val(::Duplicated, grad, dret_val) = nothing
@inline _reverse_val(::BatchDuplicated, grad, dret_val) = nothing

# ---------------------------------------------------------------------------
# Forward mode rule
# ---------------------------------------------------------------------------
# Single rule for both IIP (Nothing return) and OOP FunctionWrappers.
# Extracts the wrapped function and delegates to autodiff_deferred.
function EnzymeRules.forward(
config::EnzymeRules.FwdConfig,
func::Const{<:FunctionWrapper},
RT::Type{<:Annotation},
args::Annotation...,
)
raw_f = unwrap_fw(func.val)

# For IIP functions (Const{Nothing} return), needs_shadow is false but we
# still must propagate tangents into argument shadow arrays via AD.
if RT <: Const
# IIP or inactive return — run AD for tangent propagation into arg shadows
Enzyme.autodiff_deferred(Forward, Const(raw_f), Const{eltype(RT)}, args...)
if EnzymeRules.needs_primal(config)
return raw_f(map(x -> x.val, args)...)
else
return nothing
end
end

# OOP: shadow is needed. Always use Duplicated for autodiff_deferred
# (it rejects DuplicatedNoNeed).
RealRt = eltype(RT)
if EnzymeRules.needs_primal(config)
res = Enzyme.autodiff_deferred(ForwardWithPrimal, Const(raw_f), Duplicated, args...)
# autodiff ForwardWithPrimal returns (derivs, primal)
if EnzymeRules.width(config) == 1
return Duplicated(res[2]::RealRt, res[1]::RealRt)
else
return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config),RealRt})
end
else
res = Enzyme.autodiff_deferred(Forward, Const(raw_f), Duplicated, args...)
# autodiff Forward returns (derivs,)
if EnzymeRules.width(config) == 1
return res[1]::RealRt
else
return res[1]::NTuple{EnzymeRules.width(config),RealRt}
end
end
end

# ---------------------------------------------------------------------------
# Reverse mode rules
# ---------------------------------------------------------------------------

# augmented_primal: execute the forward pass, cache data for reverse
function EnzymeRules.augmented_primal(
config::EnzymeRules.RevConfig,
func::Const{<:FunctionWrapper{Ret}},
RT::Type{<:Annotation},
args::Annotation...,
) where {Ret}
raw_f = unwrap_fw(func.val)
ow = EnzymeRules.overwritten(config)
nargs = length(args)

# Cache copies of overwritten mutable args (needed for reverse pass)
cached_args = ntuple(Val(nargs)) do i
Base.@_inline_meta
# ow[1] is the function itself, ow[i+1] is the i-th argument
if ow[i + 1] && !(args[i] isa Const)
deepcopy(args[i].val)
else
nothing
end
end

# Execute the primal
primal_result = raw_f(map(x -> x.val, args)...)

primal = if EnzymeRules.needs_primal(config)
primal_result
else
nothing
end

shadow = if EnzymeRules.needs_shadow(config)
if Ret === Nothing
nothing
else
if EnzymeRules.width(config) == 1
Enzyme.make_zero(primal_result)
else
ntuple(Val(EnzymeRules.width(config))) do j
Base.@_inline_meta
Enzyme.make_zero(primal_result)
end
end
end
else
nothing
end

tape = (raw_f, cached_args)
return EnzymeRules.AugmentedReturn(primal, shadow, tape)
end

# reverse for IIP (Nothing return): accumulate gradients into dval arrays
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{<:FunctionWrapper{Nothing}},
::Type{<:Const{Nothing}},
tape,
args::Annotation...,
)
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
nargs = length(args)

new_args = ntuple(Val(nargs)) do i
Base.@_inline_meta
_reconstruct_arg(args[i], cached_args[i], ow[i + 1])
end

Enzyme.autodiff_deferred(Reverse, Const(raw_f), Const{Nothing}, new_args...)

return ntuple(Val(nargs)) do i
Base.@_inline_meta
nothing
end
end

# reverse for OOP with Active return: return scaled per-arg gradients
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{<:FunctionWrapper{Ret}},
dret::Active,
tape,
args::Annotation...,
) where {Ret}
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
nargs = length(args)

new_args = ntuple(Val(nargs)) do i
Base.@_inline_meta
_reconstruct_arg(args[i], cached_args[i], ow[i + 1])
end

# autodiff_deferred(Reverse, ..., Active, args...) returns ((grad1, grad2, ...),)
res = Enzyme.autodiff_deferred(Reverse, Const(raw_f), Active, new_args...)
grads = res[1]

return ntuple(Val(nargs)) do i
Base.@_inline_meta
_reverse_val(args[i], grads[i], dret.val)
end
end

# reverse for OOP with Duplicated/Const return type (non-Active)
function EnzymeRules.reverse(
config::EnzymeRules.RevConfig,
func::Const{<:FunctionWrapper{Ret}},
dret::Type{<:Annotation},
tape,
args::Annotation...,
) where {Ret}
if !(dret <: Const)
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
nargs = length(args)

new_args = ntuple(Val(nargs)) do i
Base.@_inline_meta
_reconstruct_arg(args[i], cached_args[i], ow[i + 1])
end

Enzyme.autodiff_deferred(Reverse, Const(raw_f), dret, new_args...)
end

return ntuple(Val(length(args))) do i
Base.@_inline_meta
nothing
end
end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeTestUtils = "12d8515a-0907-448a-8884-5fe00fdf1c5a"
Enzyme_jll = "7cc45869-7501-5eee-bdea-0790c847d4ef"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand Down
87 changes: 87 additions & 0 deletions test/ext/functionwrappers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using Enzyme, Test
using FunctionWrappers: FunctionWrapper

@testset "FunctionWrappers Extension" begin

# In-place (IIP) test function: du[1] = p[1] * u[1]^2
f!(du, u, p) = (du[1] = p[1] * u[1]^2; nothing)

# Out-of-place (OOP) test function: returns p[1] * x^2
f_oop(x, p) = p[1] * x^2

@testset "IIP Forward Mode" begin
fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)

u = [2.0]; du = zeros(1); p = [3.0]
ddu = zeros(1); du_u = [1.0]

# Differentiate through FunctionWrapper
Enzyme.autodiff(Forward, fw, Const{Nothing},
Duplicated(du, ddu), Duplicated(u, du_u), Const(p))

# Compare with raw function
u2 = [2.0]; du2 = zeros(1); ddu2 = zeros(1); du_u2 = [1.0]
Enzyme.autodiff(Forward, f!, Const{Nothing},
Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))

@test ddu ≈ ddu2
# ddu[1] should be d/du(p*u^2) * du_u = 3.0 * 2 * 2.0 * 1.0 = 12.0
@test ddu[1] ≈ 12.0
end

@testset "IIP Reverse Mode" begin
fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)

u = [2.0]; du = zeros(1); p = [3.0]
ddu = [1.0]; du_u = zeros(1)

Enzyme.autodiff(Reverse, fw, Const{Nothing},
Duplicated(du, ddu), Duplicated(u, du_u), Const(p))

# Compare with raw function
u2 = [2.0]; du2 = zeros(1); ddu2 = [1.0]; du_u2 = zeros(1)
Enzyme.autodiff(Reverse, f!, Const{Nothing},
Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))

@test du_u ≈ du_u2
# du/du[1] of (du[1] = p[1]*u[1]^2) with seed ddu[1]=1.0:
# = p[1] * 2 * u[1] = 3.0 * 2 * 2.0 = 12.0
@test du_u[1] ≈ 12.0
end

@testset "OOP Forward Mode" begin
fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)

x = 3.0; p = [2.0]
dx = 1.0

res = Enzyme.autodiff(Forward, fw_oop, Duplicated,
Duplicated(x, dx), Const(p))

# Compare with raw function
res2 = Enzyme.autodiff(Forward, f_oop, Duplicated,
Duplicated(x, dx), Const(p))

@test res[1] ≈ res2[1]
# d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0
@test res[1] ≈ 12.0
end

@testset "OOP Reverse Mode" begin
fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)

x = 3.0; p = [2.0]

res = Enzyme.autodiff(Reverse, fw_oop, Active,
Active(x), Const(p))

# Compare with raw function
res2 = Enzyme.autodiff(Reverse, f_oop, Active,
Active(x), Const(p))

@test res[1][1] ≈ res2[1][1]
# d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0
@test res[1][1] ≈ 12.0
end

end