Skip to content

Commit 8c15898

Browse files
move enzyme to extension (#2467)
1 parent 36abc73 commit 8c15898

File tree

8 files changed

+55
-41
lines changed

8 files changed

+55
-41
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ version = "0.14.17"
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
9-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
109
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1211
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
@@ -26,13 +25,15 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2625
[weakdeps]
2726
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
2827
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
28+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2929
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
3030
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
3131

3232
[extensions]
3333
FluxAMDGPUExt = "AMDGPU"
3434
FluxCUDAExt = "CUDA"
3535
FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
36+
FluxEnzymeExt = "Enzyme"
3637
FluxMetalExt = "Metal"
3738

3839
[compat]

ext/FluxEnzymeExt/FluxEnzymeExt.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
module FluxEnzymeExt
2+
3+
using Flux
4+
import Flux.Train: train!, _rule_to_state
5+
import Flux.Optimise
6+
import Optimisers
7+
import Enzyme
8+
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
9+
using ProgressLogging: @withprogress, @logprogress
10+
11+
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
12+
_make_zero_internal!(x) = x
13+
_make_zero!(model) = fmap(_make_zero_internal!, model)
14+
15+
_applyloss(loss, model, d...) = loss(model, d...)
16+
17+
EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
18+
19+
using Flux: _old_to_new # from src/deprecations.jl
20+
train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
21+
train!(loss, model, data, _old_to_new(opt); cb)
22+
23+
function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
24+
train!(loss, model, data, _rule_to_state(model, rule); cb)
25+
end
26+
27+
function train!(loss, model::Duplicated, data, opt; cb = nothing)
28+
isnothing(cb) || error("""train! does not support callback functions.
29+
For more control use a loop with `gradient` and `update!`.""")
30+
@withprogress for (i,d) in enumerate(data)
31+
d_splat = d isa Tuple ? d : (d,)
32+
33+
_make_zero!(model.dval)
34+
_, l = Enzyme.autodiff(ReverseWithPrimal, _applyloss,
35+
Active, Const(loss), model, map(Const, d_splat)...)
36+
37+
if !isfinite(l)
38+
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
39+
end
40+
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
41+
model = Duplicated(model2, model.dval)
42+
43+
@logprogress Base.haslength(data) ? i/length(data) : nothing
44+
end
45+
end
46+
47+
end # FluxEnzymeExt

src/deprecations.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error
109109

110110
train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
111111
train!(loss, model, data, _old_to_new(opt); cb)
112-
train!(loss, model::Enzyme.Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
113-
train!(loss, model, data, _old_to_new(opt); cb)
112+
114113

115114
# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
116115
import .Train: setup

src/functor.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ using LinearAlgebra: Cholesky
33
using Zygote: IdSet
44
import Functors: Functors, @functor, functor, fmap, isleaf
55
using SparseArrays: AbstractSparseArray
6-
using Enzyme
76

87
"""
98
testmode!(model, [mode]) -> model

src/losses/utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import Enzyme
21

32
"""
43
xlogx(x)
@@ -38,4 +37,3 @@ end
3837
_check_sizes(ŷ, y) = nothing # pass-through, for constant label e.g. y = 1
3938

4039
ChainRulesCore.@non_differentiable _check_sizes(ŷ::Any, y::Any)
41-
Enzyme.EnzymeRules.inactive(::typeof(_check_sizes), args...) = true

src/train.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Optimisers: Optimisers
55
using Functors: fmap, fmapstructure
66
using ..Flux: Flux # used only in docstring
77
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
8-
import Enzyme
98

109
export setup, train!
1110

@@ -53,12 +52,6 @@ function setup(rule::Optimisers.AbstractRule, model)
5352
state
5453
end
5554

56-
_make_zero_internal!(x::AbstractArray) = fill!(x, 0)
57-
_make_zero_internal!(x) = x
58-
_make_zero!(model) = fmap(_make_zero_internal!, model)
59-
60-
_applyloss(loss, model, d...) = loss(model, d...)
61-
6255
"""
6356
train!(loss, model, data, opt_state)
6457
@@ -67,7 +60,7 @@ according to a particular optimisation rule encoded in `opt_state`.
6760
Iterates through `data` once, evaluating for each `d in data` either
6861
`loss(model, d...)` if `d isa Tuple`, or else `loss(model, d)` for other `d`.
6962
70-
If `model` is an Enzyme.Duplicated, gradients will be computed with Enzyme,
63+
If `model` is an Enzyme.Duplicated and `Enzyme.jl` is loaded, gradients will be computed with Enzyme,
7164
otherwise they will be computed with Zygote.
7265
7366
For example, with these definitions...
@@ -122,32 +115,12 @@ function train!(loss, model, data, opt; cb = nothing)
122115
@logprogress Base.haslength(data) ? i/length(data) : nothing
123116
end
124117
end
125-
function train!(loss, model::Enzyme.Duplicated, data, opt; cb = nothing)
126-
isnothing(cb) || error("""train! does not support callback functions.
127-
For more control use a loop with `gradient` and `update!`.""")
128-
@withprogress for (i,d) in enumerate(data)
129-
d_splat = d isa Tuple ? d : (d,)
130-
131-
_make_zero!(model.dval)
132-
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, _applyloss, Enzyme.Active, Enzyme.Const(loss), model, map(Enzyme.Const, d_splat)...)
133-
134-
if !isfinite(l)
135-
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
136-
end
137-
opt, model2 = Optimisers.update!(opt, model.val, model.dval)
138-
model = Enzyme.Duplicated(model2, model.dval)
139118

140-
@logprogress Base.haslength(data) ? i/length(data) : nothing
141-
end
142-
end
143119

144120
# This method let you use Optimisers.Descent() without setup, when there is no state
145121
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
146122
train!(loss, model, data, _rule_to_state(model, rule); cb)
147123
end
148-
function train!(loss, model::Enzyme.Duplicated, data, rule::Optimisers.AbstractRule; cb = nothing)
149-
train!(loss, model, data, _rule_to_state(model, rule); cb)
150-
end
151124

152125
function _rule_to_state(model, rule::Optimisers.AbstractRule)
153126
state = setup(rule, model)

test/ext_enzyme/enzyme.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,12 @@
11
using Test
22
using Flux
33

4-
using Enzyme
4+
using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal
5+
56
using Functors
67
using FiniteDifferences
78
using CUDA
89

9-
_make_zero(x::Union{Number,AbstractArray}) = zero(x)
10-
_make_zero(x) = x
11-
make_zero(model) = fmap(_make_zero, model)
12-
## make_differential(model) = fmapstructure(make_zero, model) # NOT SUPPORTED, See https://github.com/EnzymeAD/Enzyme.jl/issues/1329
1310

1411
function gradient_fd(f, x...)
1512
x = [cpu(x) for x in x]

test/train.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ import Optimisers
44

55
using Test
66
using Random
7-
using Enzyme
7+
import Enzyme
88

99
function train_enzyme!(fn, model, args...; kwargs...)
10-
Flux.train!(fn, Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
10+
Flux.train!(fn, Enzyme.Duplicated(model, Enzyme.make_zero(model)), args...; kwargs...)
1111
end
1212

1313
for (trainfn!, name) in ((Flux.train!, "Zygote"), (train_enzyme!, "Enzyme"))

0 commit comments

Comments
 (0)