Skip to content

Commit

Permalink
move enzyme to extension (#2467)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jul 10, 2024
1 parent 36abc73 commit 8c15898
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 41 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.14.17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand All @@ -26,13 +25,15 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[extensions]
FluxAMDGPUExt = "AMDGPU"
FluxCUDAExt = "CUDA"
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
FluxMetalExt = "Metal"

[compat]
Expand Down
47 changes: 47 additions & 0 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module FluxEnzymeExt

using Flux
import Flux.Train: train!, _rule_to_state
import Flux.Optimise
import Optimisers
import Enzyme
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
using ProgressLogging: @withprogress, @logprogress

_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)

_applyloss(loss, model, d...) = loss(model, d...)

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true

using Flux: _old_to_new # from src/deprecations.jl
train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)

function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

function train!(loss, model::Duplicated, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)

_make_zero!(model.dval)
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
Active, Const(loss), model, map(Const, d_splat)...)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
model = Duplicated(model2, model.dval)

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end

end # FluxEnzymeExt
3 changes: 1 addition & 2 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error

train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)
train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, _old_to_new(opt); cb)


# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
import .Train: setup
Expand Down
1 change: 0 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ using LinearAlgebra: Cholesky
using Zygote: IdSet
import Functors: Functors, @functor, functor, fmap, isleaf
using SparseArrays: AbstractSparseArray
using Enzyme

"""
testmode!(model, [mode]) -> model
Expand Down
2 changes: 0 additions & 2 deletions src/losses/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import Enzyme

"""
xlogx(x)
Expand Down Expand Up @@ -38,4 +37,3 @@ end
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1

ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true
29 changes: 1 addition & 28 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Optimisers: Optimisers
using Functors: fmap, fmapstructure
using ..Flux: Flux # used only in docstring
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
import Enzyme

export setup, train!

Expand Down Expand Up @@ -53,12 +52,6 @@ function setup(rule::Optimisers.AbstractRule, model)
state
end

_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
_make_zero_internal!(x) = x
_make_zero!(model) = fmap(_make_zero_internal!, model)

_applyloss(loss, model, d...) = loss(model, d...)

"""
train!(loss, model, data, opt_state)
Expand All @@ -67,7 +60,7 @@ according to a particular optimisation rule encoded in `opt_state`.
Iterates through `data` once, evaluating for each `d in data` either
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme,
If `model` is an Enzyme.Duplicated and `Enzyme.jl` is loaded, gradients will be computed with Enzyme,
otherwise they will be computed with Zygote.
For example, with these definitions...
Expand Down Expand Up @@ -122,32 +115,12 @@ function train!(loss, model, data, opt; cb = nothing)
@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end
function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing)
isnothing(cb) || error("""train! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)

_make_zero!(model.dval)
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
model = Enzyme.Duplicated(model2, model.dval)

@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end
function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
state = setup(rule, model)
Expand Down
7 changes: 2 additions & 5 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
using Test
using Flux

using Enzyme
using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal

using Functors
using FiniteDifferences
using CUDA

_make_zero(x::Union{Number,AbstractArray}) = zero(x)
_make_zero(x) = x
make_zero(model) = fmap(_make_zero, model)
## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329

function gradient_fd(f, x...)
x = [cpu(x) for x in x]
Expand Down
4 changes: 2 additions & 2 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import Optimisers

using Test
using Random
using Enzyme
import Enzyme

function train_enzyme!(fn, model, args...; kwargs...)
Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
end

for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))
Expand Down

0 comments on commit 8c15898

Please sign in to comment.