Skip to content

Commit 4de5284

Browse files
feat: support distributed training via TrainState API (#1529)
* feat: allow tracking numbers in ReactantDevice * fix: iddict with no sharding * feat: distributed training now works * chore: run fmt Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * chore: bump lux version * chore: bump version * fix: accumgrad implementation * feat: some more progress towards good sharding * test: use IFRT for testing * fix: zero grads * chore: bump reactant version --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent bf9a4e0 commit 4de5284

File tree

11 files changed

+176
-106
lines changed

11 files changed

+176
-106
lines changed

.github/workflows/CommonCI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jobs:
4545
runs-on: ${{ inputs.os }}
4646
env:
4747
TMPDIR: ${{ github.workspace }}/tmp
48+
XLA_FLAGS: --xla_force_host_platform_device_count=8
4849
steps:
4950
- uses: actions/checkout@v5
5051
- name: Create TMPDIR

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.24.0"
4+
version = "1.25.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -98,7 +98,7 @@ LinearAlgebra = "1.10"
9898
LossFunctions = "0.11.1, 1"
9999
LuxCore = "1.4.2"
100100
LuxLib = "1.12.1"
101-
MLDataDevices = "1.12.1"
101+
MLDataDevices = "1.15"
102102
MLUtils = "0.4.4"
103103
MPI = "0.20.19"
104104
MacroTools = "0.5.13"
@@ -110,7 +110,7 @@ Optimisers = "0.4.6"
110110
PrecompileTools = "1.2.1"
111111
Preferences = "1.4.3"
112112
Random = "1.10"
113-
Reactant = "0.2.170"
113+
Reactant = "0.2.174"
114114
ReactantCore = "0.1.16"
115115
Reexport = "1.2.2"
116116
ReverseDiff = "1.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ OpenSSL_jll = "=3.0.16"
6666
Optimisers = "0.4.6"
6767
Printf = "1.10"
6868
Random = "1.10"
69-
Reactant = "0.2.170"
69+
Reactant = "0.2.173"
7070
StableRNGs = "1"
7171
StaticArrays = "1"
7272
WeightInitializers = "1"

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LuxReactantExt
22

3-
using Enzyme: Enzyme, Const
3+
using Enzyme: Enzyme, Active, Const, Duplicated
4+
using Functors: Functors
45
using Preferences: load_preference
56
using Optimisers: Optimisers
67
using Reactant:
@@ -21,38 +22,22 @@ using Static: True, False
2122
using Lux: Lux, LuxOps, Training, Utils, StatefulLuxLayer
2223
using Lux.Training: TrainingBackendCache, ReactantBackend
2324
using LuxCore: LuxCore, AbstractLuxLayer
24-
using MLDataDevices: ReactantDevice, get_device
25+
using MLDataDevices: MLDataDevices, ReactantDevice, get_device
2526

2627
Lux.is_extension_loaded(::Val{:Reactant}) = true
2728

28-
Utils.to_rarray(x; kwargs...) = Reactant.to_rarray(x; kwargs...)
29-
3029
Utils.contiguous(x::AnyTracedRArray) = ReactantCore.materialize_traced_array(x)
3130

3231
Utils.eltype(::Type{<:TracedRArray{T,N}}) where {T,N} = T
3332
Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T
3433
Utils.eltype(x::Reactant.AnyTracedRArray) = Reactant.unwrapped_eltype(x)
3534

36-
function Utils.promote_to(::Type{T}, x::Number) where {T<:Number}
37-
x isa Reactant.TracedType && return x
38-
return Reactant.ConcreteRNumber{T}(x)
39-
end
40-
41-
# For CUDA use `PrecisionConfig.HIGH`. For other backends use `PrecisionConfig.DEFAULT`.
4235
function default_precision_config(ps)
4336
precision_config_preference = lowercase(
4437
load_preference(Lux, "precision_config", "auto")
4538
)
4639

47-
if precision_config_preference == "auto"
48-
rdev = get_device(ps)
49-
rdev isa ReactantDevice || return PrecisionConfig.DEFAULT
50-
device = rdev.device === missing ? Reactant.XLA.default_device() : rdev.device
51-
device_kind = string(device)
52-
contains(device_kind, "CUDA") && return PrecisionConfig.HIGH
53-
return PrecisionConfig.DEFAULT
54-
end
55-
40+
precision_config_preference == "auto" && return PrecisionConfig.DEFAULT
5641
precision_config_preference == "default" && return PrecisionConfig.DEFAULT
5742
precision_config_preference == "high" && return PrecisionConfig.HIGH
5843
precision_config_preference == "highest" && return PrecisionConfig.HIGHEST

ext/LuxReactantExt/patches.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,17 @@ Utils.vec(x::AnyTracedRArray) = ReactantCore.materialize_traced_array(vec(x))
33
# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
44
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g
55

6-
# Optimisers setup
7-
Profiler.@annotate "Optimisers Setup" function Lux.ReactantCompatibleOptimisers.optimisers_setup_with_jit(
8-
opt, ps
9-
)
10-
return @jit Optimisers.setup(opt, ps)
11-
end
12-
136
# rsqrt
147
LuxOps.rsqrt(x::TracedRNumber) = @opcall rsqrt(x)
8+
9+
# convert eltype
10+
function Utils.convert_eltype(
11+
::Type{T}, x::Reactant.ConcretePJRTNumber{S}
12+
) where {T<:Number,S}
13+
return Reactant.ConcretePJRTNumber{T}(x)
14+
end
15+
function Utils.convert_eltype(
16+
::Type{T}, x::Reactant.ConcreteIFRTNumber{S}
17+
) where {T<:Number,S}
18+
return Reactant.ConcreteIFRTNumber{T}(x)
19+
end

ext/LuxReactantExt/training.jl

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,29 @@ function objective_function_wrapper(objective_function::F, model, ps, st, data)
33
return loss, Reactant.ignore_derivatives(stₙ), Reactant.ignore_derivatives(stats)
44
end
55

6-
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
7-
(_, _, dps, _, _), (loss, stₙ, stats) = Enzyme.gradient(
6+
function compute_gradients_internal!(
7+
dps, objective_function::F, model, data, ps, st, zeroed_grads::Bool=false
8+
) where {F}
9+
zeroed_grads || Enzyme.make_zero!(dps)
10+
11+
_, (loss, stₙ, stats) = Enzyme.autodiff(
812
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
913
Const(objective_function_wrapper),
1014
Const(objective_function),
1115
Const(model),
12-
ps,
16+
Duplicated(ps, dps),
1317
Const(st),
1418
Const(data),
1519
)
1620
return dps, loss, stats, stₙ
1721
end
1822

23+
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
24+
return compute_gradients_internal!(
25+
Enzyme.make_zero(ps), objective_function, model, data, ps, st, true
26+
)
27+
end
28+
1929
Profiler.@annotate "Compile Compute Gradients" function Lux.Training.compute_gradients_impl(
2030
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
2131
) where {F}
@@ -80,34 +90,54 @@ for inplace in ("!", "")
8090
return ts
8191
end
8292

93+
ps_expr = if inplace == "!"
94+
:(ps = ts.parameters)
95+
else
96+
:(ps = Functors.fmap(copy, ts.parameters; exclude=MLDataDevices.isleaf))
97+
end
98+
8399
# XXX: recompile with a warning if new input types are used
84100
@eval Profiler.@annotate "Compile Train Step" function Lux.Training.$(fname)(
85101
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
86102
) where {F}
103+
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
104+
@assert device isa ReactantDevice
105+
is_sharded = device.device === nothing
106+
107+
dps = if backend.return_gradients isa True
108+
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
109+
else
110+
nothing
111+
end
112+
113+
$(ps_expr)
114+
87115
compiled_grad_and_step_function = with_default_precision_config(ts.parameters) do
88116
@compile sync = backend.sync $(internal_fn)(
89117
objective_function,
90118
ts.model,
91119
data,
92-
ts.parameters,
120+
ps,
93121
ts.states,
94122
ts.optimizer_state,
95-
backend.return_gradients,
123+
dps,
124+
is_sharded,
96125
)
97126
end
98127

99128
grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
100129
objective_function,
101130
ts.model,
102131
data,
103-
ts.parameters,
132+
ps,
104133
ts.states,
105134
ts.optimizer_state,
106-
backend.return_gradients,
135+
dps,
136+
is_sharded,
107137
)
108138

109139
cache = TrainingBackendCache(
110-
backend, False(), nothing, (; compiled_grad_and_step_function)
140+
backend, False(), dps, (; compiled_grad_and_step_function, is_sharded)
111141
)
112142
@set! ts.cache = cache
113143
@set! ts.objective_function = objective_function
@@ -120,7 +150,7 @@ for inplace in ("!", "")
120150
end
121151

122152
@eval Profiler.@annotate "Train Step" function Lux.Training.$(fname)(
123-
backend::ReactantBackend,
153+
::ReactantBackend,
124154
obj_fn::F,
125155
data,
126156
ts::Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F},
@@ -132,7 +162,8 @@ for inplace in ("!", "")
132162
ts.parameters,
133163
ts.states,
134164
ts.optimizer_state,
135-
backend.return_gradients,
165+
ts.cache.dparameters,
166+
ts.cache.extras.is_sharded,
136167
)
137168

138169
@set! ts.states = st
@@ -143,24 +174,38 @@ for inplace in ("!", "")
143174
return grads, loss, stats, ts
144175
end
145176

146-
# XXX: Inplace version not actually inplace
147177
@eval function $(internal_fn)(
148-
objective_function::F, model, data, ps, st, opt_state, ::False
178+
objective_function::F, model, data, ps, st, opt_state, ::Nothing, is_sharded::Bool
149179
) where {F}
150180
dps, loss, stats, stₙ = compute_gradients_internal(
151181
objective_function, model, data, ps, st
152182
)
153-
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
154-
return nothing, ps, loss, stats, stₙ, opt_state
183+
184+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
185+
# Ensure sharding of input and output states are consistent
186+
is_sharded && mark_same_sharding_group(st, stₙ)
187+
188+
return nothing, psₙ, loss, stats, stₙ, opt_state
155189
end
156190

157191
@eval function $(internal_fn)(
158-
objective_function::F, model, data, ps, st, opt_state, ::True
192+
objective_function::F, model, data, ps, st, opt_state, dps, is_sharded::Bool
159193
) where {F}
160-
dps, loss, stats, stₙ = compute_gradients_internal(
161-
objective_function, model, data, ps, st
194+
dps, loss, stats, stₙ = compute_gradients_internal!(
195+
dps, objective_function, model, data, ps, st
162196
)
163-
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
164-
return dps, ps, loss, stats, stₙ, opt_state
197+
198+
opt_state, psₙ = Optimisers.update!(opt_state, ps, dps)
199+
# Ensure sharding of input and output states are consistent
200+
is_sharded && mark_same_sharding_group(st, stₙ)
201+
202+
return dps, psₙ, loss, stats, stₙ, opt_state
165203
end
166204
end
205+
206+
mark_same_sharding_group(args...) = Functors.fmap(mark_same_sharding_group_inner, args...)
207+
208+
function mark_same_sharding_group_inner(arg1::Union{TracedRArray,TracedRNumber}, args...)
209+
return @opcall sharding_group(arg1, args...)
210+
end
211+
mark_same_sharding_group_inner(arg1, args...) = nothing

0 commit comments

Comments
 (0)