Skip to content

Commit fec90e9

Browse files
feat: support AutoReactant (#1647)
* feat: support AutoReactant * Apply suggestion from @avik-pal * Update src/helpers/training.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d1ce7ad commit fec90e9

File tree

9 files changed

+117
-70
lines changed

9 files changed

+117
-70
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.30.0"
4+
version = "1.31.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -100,7 +100,7 @@ TrackerExt = "Tracker"
100100
ZygoteExt = "Zygote"
101101

102102
[compat]
103-
ADTypes = "1.15"
103+
ADTypes = "1.19"
104104
Adapt = "4.4"
105105
ArrayInterface = "7.17.1"
106106
CUDA = "5.8"

docs/src/manual/compiling_lux_models.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ boilerplate. Simply follow the following steps:
128128
device. Note that you might want to use [`DeviceIterator`](@ref) to move the data
129129
loader to the device with an iterator.
130130
3. Construct a `TrainState` using [`Training.TrainState`](@ref).
131-
4. And most importantly use `AutoEnzyme` while calling [`Training.single_train_step!`](@ref)
132-
or [`Training.single_train_step`](@ref).
131+
4. And most importantly use `AutoEnzyme`/`AutoReactant` while calling
132+
[`Training.single_train_step!`](@ref) or [`Training.single_train_step`](@ref).
133133

134134
```@example compile_lux_model
135135
model = Chain(
@@ -152,7 +152,8 @@ function train_model(model, ps, st, dataloader)
152152
for iteration in 1:1000
153153
for (i, (xᵢ, yᵢ)) in enumerate(dataloader)
154154
_, loss, _, train_state = Training.single_train_step!(
155-
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state)
155+
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state
156+
)
156157
if (iteration % 100 == 0 || iteration == 1) && i == 1
157158
@printf("Iter: [%4d/%4d]\tLoss: %.8f\n", iteration, 1000, loss)
158159
end

examples/SimpleRNN/main.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function main(model_type)
165165
else
166166
model
167167
end
168-
ad = dev isa ReactantDevice ? AutoEnzyme() : AutoZygote()
168+
ad = dev isa ReactantDevice ? AutoReactant() : AutoZygote()
169169

170170
for epoch in 1:25
171171
## Train the model

ext/ReactantExt/ReactantExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ function with_default_precision_config(f::F, ps) where {F}
5252
)
5353
end
5454

55+
function get_compile_options(backend::ReactantBackend)
56+
(; compile_options, sync) = backend
57+
@assert compile_options isa Union{Nothing,Reactant.CompileOptions}
58+
if compile_options === nothing
59+
sync === missing && return Reactant.CompileOptions()
60+
return Reactant.CompileOptions(; sync)
61+
end
62+
if sync !== missing
63+
@set! compile_options.sync = sync
64+
end
65+
return compile_options
66+
end
67+
5568
include("patches.jl")
5669
include("training.jl")
5770
include("layers.jl")

ext/ReactantExt/training.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function Lux.Training.compute_gradients_impl(
108108
else
109109
compiled_gradient_function = annotate_compile("Compute Gradients") do
110110
with_default_precision_config(ts.parameters) do
111-
@compile sync = backend.sync compute_gradients_internal(
111+
@compile compile_options = get_compile_options(backend) compute_gradients_internal(
112112
objective_function, ts.model, data, ts.parameters, ts.states
113113
)
114114
end
@@ -150,18 +150,15 @@ for inplace in ("!", "")
150150
else
151151
update_function = annotate_compile("Apply Gradients") do
152152
with_default_precision_config(ts.parameters) do
153-
@compile sync = ts.cache.backend.sync Optimisers.$(update_fn)(
153+
@compile compile_options = get_compile_options(ts.cache.backend) Optimisers.$(
154+
update_fn
155+
)(
154156
ts.optimizer_state, ts.parameters, grads
155157
)
156158
end
157159
end
158160

159-
if ts.cache isa TrainingBackendCache
160-
@set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
161-
else
162-
cache = TrainingBackendCache(backend, False(), nothing, (; update_function))
163-
@set! ts.cache = cache
164-
end
161+
@set! ts.cache.extras = merge(ts.cache.extras, (; update_function))
165162
end
166163

167164
opt_state, ps = annotate_execution("Apply Gradients", ts.step) do
@@ -206,7 +203,7 @@ for inplace in ("!", "")
206203

207204
compiled_grad_and_step_function = annotate_compile("Train Step") do
208205
with_default_precision_config(ts.parameters) do
209-
@compile sync = backend.sync compute_gradients_internal_and_step!(
206+
@compile compile_options = get_compile_options(backend) compute_gradients_internal_and_step!(
210207
objective_function,
211208
ts.model,
212209
data,

src/Lux.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ADTypes:
55
AutoEnzyme,
66
AutoForwardDiff,
77
AutoMooncake,
8+
AutoReactant,
89
AutoReverseDiff,
910
AutoTracker,
1011
AutoZygote
@@ -158,7 +159,13 @@ export Training
158159
export jacobian_vector_product, vector_jacobian_product
159160
export batched_jacobian
160161
export AutoEnzyme,
161-
AutoForwardDiff, AutoMooncake, AutoReverseDiff, AutoTracker, AutoZygote, AutoForwardDiff
162+
AutoForwardDiff,
163+
AutoMooncake,
164+
AutoReactant,
165+
AutoReverseDiff,
166+
AutoTracker,
167+
AutoZygote,
168+
AutoForwardDiff
162169

163170
export BinaryCrossEntropyLoss,
164171
BinaryFocalLoss,

src/helpers/training.jl

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,13 @@ module Training
22

33
using Adapt: Adapt
44
using ADTypes:
5-
AbstractADType, AutoEnzyme, AutoReverseDiff, AutoTracker, AutoZygote, AutoMooncake
5+
AbstractADType,
6+
AutoEnzyme,
7+
AutoReverseDiff,
8+
AutoTracker,
9+
AutoZygote,
10+
AutoMooncake,
11+
AutoReactant
612
using SciMLPublic: @public
713
using ConcreteStructs: @concrete
814
using FastClosures: @closure
@@ -161,7 +167,8 @@ end
161167

162168
@concrete struct ReactantBackend
163169
return_gradients <: StaticBool
164-
sync::Bool
170+
sync <: Union{Bool,Missing}
171+
compile_options
165172
ad <: AutoEnzyme
166173
end
167174

@@ -247,10 +254,15 @@ const SYNC_DOCSTRING = """
247254
Reactant Backend.
248255
"""
249256

257+
const COMPILE_OPTIONS_DOCSTRING = """
258+
- `compile_options`: Compile options for the reactant function. See
259+
`Reactant.CompileOptions` for more details. This is only used for Reactant Backend.
260+
"""
261+
250262
"""
251263
compute_gradients(
252264
ad::AbstractADType, objective_function::Function, data, ts::TrainState;
253-
sync::Bool=false
265+
sync::Bool=false, compile_options::Union{Missing,Reactant.CompileOptions}=missing
254266
)
255267
256268
Compute the gradients of the objective function wrt parameters stored in `ts`.
@@ -279,6 +291,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`.
279291
## Keyword Arguments
280292
281293
$(SYNC_DOCSTRING)
294+
$(COMPILE_OPTIONS_DOCSTRING)
282295
283296
## Return
284297
@@ -304,10 +317,10 @@ A 4-Tuple containing:
304317
returned in step `i + 1` might be aliased by the old gradients. If you want to prevent
305318
this, simply use `copy(grads)` or `deepcopy(grads)` to make a copy of the gradients.
306319
"""
307-
function compute_gradients(ad, obj_fn::F, data, ts::TrainState; sync::Bool=false) where {F}
320+
function compute_gradients(ad, obj_fn::F, data, ts::TrainState; kwargs...) where {F}
308321
dev_type = get_device_type((ts.parameters, ts.states))
309322
return compute_gradients_impl_with_allocator_cache(
310-
maybe_wrap_adtype(ad, dev_type; sync), ts.allocator_cache, obj_fn, data, ts
323+
maybe_wrap_adtype(ad, dev_type; kwargs...), ts.allocator_cache, obj_fn, data, ts
311324
)
312325
end
313326

@@ -346,14 +359,33 @@ end
346359
maybe_wrap_adtype(backend::ReactantBackend, ::Any; kwargs...) = backend
347360
maybe_wrap_adtype(ad::AbstractADType, ::Any; kwargs...) = ad
348361
function maybe_wrap_adtype(
349-
ad::AbstractADType,
362+
ad::AutoEnzyme,
363+
::Type{ReactantDevice};
364+
return_gradients::Utils.BoolType=True(),
365+
sync::Union{Missing,Bool}=missing,
366+
compile_options=nothing,
367+
)
368+
return ReactantBackend(static(return_gradients), sync, compile_options, ad)
369+
end
370+
function maybe_wrap_adtype(
371+
ad::AutoReactant,
350372
::Type{ReactantDevice};
351373
return_gradients::Utils.BoolType=True(),
352-
sync::Bool=false,
374+
sync::Union{Missing,Bool}=missing,
375+
compile_options=nothing,
353376
)
354-
ad isa AutoEnzyme && return ReactantBackend(static(return_gradients), sync, ad)
355-
throw(ArgumentError("Computing gradients for models on XLA is supported only with \
356-
Enzyme.jl (`AutoEnzyme`)."))
377+
return ReactantBackend(static(return_gradients), sync, compile_options, ad.mode)
378+
end
379+
function maybe_wrap_adtype(ad::AutoReactant, ::Type{T}; kwargs...) where {T}
380+
throw(ArgumentError("`AutoReactant` only supports ReactantDevice but got `$(T)`"))
381+
end
382+
function maybe_wrap_adtype(ad::AbstractADType, ::Type{ReactantDevice}; kwargs...)
383+
throw(
384+
ArgumentError(
385+
"Computing gradients for models with Reactant is supported only with \
386+
Enzyme.jl (`AutoEnzyme` or `AutoReactant`)."
387+
),
388+
)
357389
end
358390

359391
function generate_wrappers(::F, m, ps, st, data, ::False, ::StaticBool) where {F}
@@ -408,7 +440,9 @@ const RETURN_GRADIENTS_DOCSTRING = """
408440

409441
"""
410442
single_train_step!(
411-
backend, obj_fn::F, data, ts::TrainState; return_gradients=True(), sync::Bool=false
443+
backend, obj_fn::F, data, ts::TrainState;
444+
return_gradients=True(), sync::Bool=false,
445+
compile_options::Union{Nothing,Reactant.CompileOptions}=missing,
412446
)
413447
414448
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
@@ -419,6 +453,7 @@ updates the parameters using [`apply_gradients!`](@ref). All backends supported
419453
420454
$(RETURN_GRADIENTS_DOCSTRING)
421455
$(SYNC_DOCSTRING)
456+
$(COMPILE_OPTIONS_DOCSTRING)
422457
423458
## Return
424459
@@ -427,16 +462,9 @@ only the parameters in `ts` are updated inplace. Users should be using the retur
427462
object for further training steps, else there is no caching and performance will be
428463
suboptimal (and absolutely terrible for backends like `AutoReactant`).
429464
"""
430-
function single_train_step!(
431-
backend,
432-
obj_fn::F,
433-
data,
434-
ts::TrainState;
435-
return_gradients::Utils.BoolType=True(),
436-
sync::Bool=false,
437-
) where {F}
465+
function single_train_step!(backend, obj_fn::F, data, ts::TrainState; kwargs...) where {F}
438466
backend = maybe_wrap_adtype(
439-
backend, get_device_type((ts.parameters, ts.states)); return_gradients, sync
467+
backend, get_device_type((ts.parameters, ts.states)); kwargs...
440468
)
441469
return single_train_step_impl_with_allocator_cache!(
442470
backend, ts.allocator_cache, obj_fn, data, ts
@@ -445,7 +473,9 @@ end
445473

446474
"""
447475
single_train_step(
448-
backend, obj_fn::F, data, ts::TrainState; return_gradients=True(), sync::Bool=false
476+
backend, obj_fn::F, data, ts::TrainState;
477+
return_gradients=True(), sync::Bool=false,
478+
compile_options::Union{Nothing,Reactant.CompileOptions}=missing,
449479
)
450480
451481
Perform a single training step. Computes the gradients using [`compute_gradients`](@ref) and
@@ -458,21 +488,15 @@ In most cases you should use [`single_train_step!`](@ref) instead of this functi
458488
459489
$(RETURN_GRADIENTS_DOCSTRING)
460490
$(SYNC_DOCSTRING)
491+
$(COMPILE_OPTIONS_DOCSTRING)
461492
462493
## Return
463494
464495
Returned values are the same as [`single_train_step!`](@ref).
465496
"""
466-
function single_train_step(
467-
backend,
468-
obj_fn::F,
469-
data,
470-
ts::TrainState;
471-
return_gradients::Utils.BoolType=True(),
472-
sync::Bool=false,
473-
) where {F}
497+
function single_train_step(backend, obj_fn::F, data, ts::TrainState; kwargs...) where {F}
474498
backend = maybe_wrap_adtype(
475-
backend, get_device_type((ts.parameters, ts.states)); return_gradients, sync
499+
backend, get_device_type((ts.parameters, ts.states)); kwargs...
476500
)
477501
return single_train_step_impl(backend, obj_fn, data, ts)
478502
end

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ LuxTestUtils = {path = "../lib/LuxTestUtils"}
5050
MLDataDevices = {path = "../lib/MLDataDevices"}
5151

5252
[compat]
53-
ADTypes = "1.10"
53+
ADTypes = "1.19"
5454
Adapt = "4"
5555
Aqua = "0.8.4"
5656
CPUSummary = "0.2.6"

test/reactant/training_tests.jl

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,22 @@
4747
end
4848

4949
@testset for opt in (
50-
Descent(0.01f0),
51-
Momentum(0.01f0),
52-
Adam(0.01f0),
53-
AdamW(0.01f0),
54-
OptimiserChain(AccumGrad(5), Adam(0.01f0)),
55-
)
50+
Descent(0.01f0),
51+
Momentum(0.01f0),
52+
Adam(0.01f0),
53+
AdamW(0.01f0),
54+
OptimiserChain(AccumGrad(5), Adam(0.01f0)),
55+
),
56+
ad in (AutoEnzyme(), AutoReactant())
57+
5658
ps, st = xdev(Lux.setup(StableRNG(1234), model))
5759
train_state = Training.TrainState(model, ps, st, opt)
5860

5961
for epoch in 1:100, (xᵢ, yᵢ) in dataloader
6062
grads, loss, stats, train_state = if version === :iip
61-
Training.single_train_step!(
62-
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state
63-
)
63+
Training.single_train_step!(ad, MSELoss(), (xᵢ, yᵢ), train_state)
6464
elseif version === :oop
65-
Training.single_train_step(
66-
AutoEnzyme(), MSELoss(), (xᵢ, yᵢ), train_state
67-
)
65+
Training.single_train_step(ad, MSELoss(), (xᵢ, yᵢ), train_state)
6866
else
6967
error("Invalid version: $(version)")
7068
end
@@ -125,6 +123,11 @@ end
125123
AutoEnzyme(), MSELoss(), (x, x), train_state; return_gradients=Val(false)
126124
)
127125
@test loss isa Number
126+
127+
_, loss, stats, ts = Training.single_train_step(
128+
AutoReactant(), MSELoss(), (x, x), train_state; return_gradients=Val(false)
129+
)
130+
@test loss isa Number
128131
end
129132

130133
@testitem "Reactant Distributed: Training API" tags = [:reactant] setup = [SharedTestSetup] begin
@@ -152,19 +155,21 @@ end
152155
x = rand(Float32, 4, 128) |> batch_device
153156
y = rand(Float32, 4, 128) |> batch_device
154157

155-
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
158+
@testset for ad in (AutoEnzyme(), AutoReactant())
159+
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
156160

157-
_, loss, _, train_state = Training.single_train_step(
158-
AutoEnzyme(), MSELoss(), (x, y), train_state
159-
)
160-
@test loss isa Reactant.ConcreteRNumber
161-
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
161+
_, loss, _, train_state = Training.single_train_step(
162+
ad, MSELoss(), (x, y), train_state
163+
)
164+
@test loss isa Reactant.ConcreteRNumber
165+
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
162166

163-
_, loss, _, train_state = Training.single_train_step(
164-
AutoEnzyme(), MSELoss(), (x, y), train_state
165-
)
166-
@test loss isa Reactant.ConcreteRNumber
167-
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
167+
_, loss, _, train_state = Training.single_train_step(
168+
ad, MSELoss(), (x, y), train_state
169+
)
170+
@test loss isa Reactant.ConcreteRNumber
171+
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
172+
end
168173
end
169174
end
170175

0 commit comments

Comments
 (0)