Skip to content

Commit 4bba967

Browse files
committed
feat: some more progress towards good sharding
1 parent f6643d9 commit 4bba967

File tree

4 files changed

+132
-18
lines changed

4 files changed

+132
-18
lines changed

.github/workflows/CommonCI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ jobs:
4545
runs-on: ${{ inputs.os }}
4646
env:
4747
TMPDIR: ${{ github.workspace }}/tmp
48+
XLA_FLAGS: --xla_force_host_platform_device_count=8
4849
steps:
4950
- uses: actions/checkout@v5
5051
- name: Create TMPDIR

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module LuxReactantExt
22

3-
using Enzyme: Enzyme, Const
3+
using Enzyme: Enzyme, Active, Const, Duplicated
4+
using Functors: Functors
45
using Preferences: load_preference
56
using Optimisers: Optimisers
67
using Reactant:
@@ -21,7 +22,7 @@ using Static: True, False
2122
using Lux: Lux, LuxOps, Training, Utils, StatefulLuxLayer
2223
using Lux.Training: TrainingBackendCache, ReactantBackend
2324
using LuxCore: LuxCore, AbstractLuxLayer
24-
using MLDataDevices: ReactantDevice, get_device
25+
using MLDataDevices: MLDataDevices, ReactantDevice, get_device
2526

2627
Lux.is_extension_loaded(::Val{:Reactant}) = true
2728

ext/LuxReactantExt/training.jl

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,27 @@ function objective_function_wrapper(objective_function::F, model, ps, st, data)
33
return loss, Reactant.ignore_derivatives(stₙ), Reactant.ignore_derivatives(stats)
44
end
55

6-
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
7-
(_, _, dps, _, _), (loss, stₙ, stats) = Enzyme.gradient(
6+
function compute_gradients_internal!(
7+
dps, objective_function::F, model, data, ps, st
8+
) where {F}
9+
_, (loss, stₙ, stats) = Enzyme.autodiff(
810
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
911
Const(objective_function_wrapper),
1012
Const(objective_function),
1113
Const(model),
12-
ps,
14+
Duplicated(ps, dps),
1315
Const(st),
1416
Const(data),
1517
)
1618
return dps, loss, stats, stₙ
1719
end
1820

21+
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
22+
return compute_gradients_internal!(
23+
Enzyme.make_zero(ps), objective_function, model, data, ps, st
24+
)
25+
end
26+
1927
Profiler.@annotate "Compile Compute Gradients" function Lux.Training.compute_gradients_impl(
2028
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
2129
) where {F}
@@ -84,6 +92,19 @@ for inplace in ("!", "")
8492
@eval Profiler.@annotate "Compile Train Step" function Lux.Training.$(fname)(
8593
backend::ReactantBackend, objective_function::F, data, ts::Training.TrainState
8694
) where {F}
95+
device = get_device((ts.parameters, ts.states, ts.optimizer_state, data))
96+
@assert device isa ReactantDevice
97+
is_sharded = device.device === nothing
98+
99+
dps = if backend.return_gradients isa True
100+
Functors.fmap(Utils.zero, ts.parameters; exclude=MLDataDevices.isleaf)
101+
else
102+
nothing
103+
end
104+
105+
# TODO: make it conditional
106+
ps_cache = Functors.fmap(copy, ts.parameters; exclude=MLDataDevices.isleaf)
107+
87108
compiled_grad_and_step_function = with_default_precision_config(ts.parameters) do
88109
@compile sync = backend.sync $(internal_fn)(
89110
objective_function,
@@ -92,7 +113,9 @@ for inplace in ("!", "")
92113
ts.parameters,
93114
ts.states,
94115
ts.optimizer_state,
95-
backend.return_gradients,
116+
dps,
117+
is_sharded,
118+
ps_cache,
96119
)
97120
end
98121

@@ -103,11 +126,13 @@ for inplace in ("!", "")
103126
ts.parameters,
104127
ts.states,
105128
ts.optimizer_state,
106-
backend.return_gradients,
129+
dps,
130+
is_sharded,
131+
ps_cache,
107132
)
108133

109134
cache = TrainingBackendCache(
110-
backend, False(), nothing, (; compiled_grad_and_step_function)
135+
backend, False(), dps, (; compiled_grad_and_step_function, is_sharded, ps_cache)
111136
)
112137
@set! ts.cache = cache
113138
@set! ts.objective_function = objective_function
@@ -132,7 +157,9 @@ for inplace in ("!", "")
132157
ts.parameters,
133158
ts.states,
134159
ts.optimizer_state,
135-
backend.return_gradients,
160+
ts.cache.dparameters,
161+
ts.cache.extras.is_sharded,
162+
ts.cache.extras.ps_cache,
136163
)
137164

138165
@set! ts.states = st
@@ -143,24 +170,67 @@ for inplace in ("!", "")
143170
return grads, loss, stats, ts
144171
end
145172

146-
# XXX: Inplace version not actually inplace
147173
@eval function $(internal_fn)(
148-
objective_function::F, model, data, ps, st, opt_state, ::False
174+
objective_function::F,
175+
model,
176+
data,
177+
ps,
178+
st,
179+
opt_state,
180+
::Nothing,
181+
is_sharded::Bool,
182+
ps_cache,
149183
) where {F}
150184
dps, loss, stats, stₙ = compute_gradients_internal(
151185
objective_function, model, data, ps, st
152186
)
153-
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
154-
return nothing, ps, loss, stats, stₙ, opt_state
187+
188+
opt_state, psₙ = Optimisers.$(update_fn)(opt_state, ps, dps)
189+
Functors.fmap(copyto!, ps_cache, psₙ; exclude=MLDataDevices.isleaf)
190+
if is_sharded
191+
# Ensure sharding of input and output states are consistent
192+
mark_same_sharding_group(st, stₙ)
193+
end
194+
195+
return nothing, ps_cache, loss, stats, stₙ, opt_state
155196
end
156197

157198
@eval function $(internal_fn)(
158-
objective_function::F, model, data, ps, st, opt_state, ::True
199+
objective_function::F,
200+
model,
201+
data,
202+
ps,
203+
st,
204+
opt_state,
205+
dps,
206+
is_sharded::Bool,
207+
ps_cache,
159208
) where {F}
160-
dps, loss, stats, stₙ = compute_gradients_internal(
161-
objective_function, model, data, ps, st
209+
dps, loss, stats, stₙ = compute_gradients_internal!(
210+
dps, objective_function, model, data, ps, st
162211
)
163-
opt_state, ps = Optimisers.$(update_fn)(opt_state, ps, dps)
164-
return dps, ps, loss, stats, stₙ, opt_state
212+
213+
opt_state, psₙ = Optimisers.$(update_fn)(opt_state, ps, dps)
214+
Functors.fmap(copyto!, ps_cache, psₙ; exclude=MLDataDevices.isleaf)
215+
if is_sharded
216+
# Ensure sharding of input and output states are consistent
217+
# mark_same_sharding_group(ps, psₙ)
218+
mark_same_sharding_group(st, stₙ)
219+
end
220+
221+
return dps, ps_cache, loss, stats, stₙ, opt_state
165222
end
166223
end
224+
225+
# TODO: think of a better way than sharding group. Since this will insert an optimization
226+
# barrier in the graph and we wont be able to do layout optimizations. Can we instead
227+
# use result sharding annotations here?
228+
function mark_same_sharding_group(args...)
229+
return Functors.fmap(mark_same_sharding_group_inner, args...)
230+
end
231+
232+
function mark_same_sharding_group_inner(arg1::Union{TracedRArray,TracedRNumber}, args...)
233+
@opcall sharding_group(arg1, args...)
234+
return nothing
235+
end
236+
mark_same_sharding_group_inner(arg1, args...) = nothing

test/reactant/training_tests.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,45 @@ end
129129
)
130130
@test loss isa Number
131131
end
132+
133+
@testitem "Reactant Distributed: Training API" tags = [:reactant] setup = [SharedTestSetup] begin
134+
using Lux, Random, Reactant, Optimisers
135+
136+
ndevices = length(Reactant.devices())
137+
138+
# TODO: ensure lux tests are being run with IFRT
139+
if ndevices 8 && Reactant.XLA.runtime() isa Val{:IFRT}
140+
mesh = Sharding.Mesh(reshape(Reactant.devices()[1:8], (2, 4)), (:model, :batch))
141+
142+
model_device = reactant_device(;
143+
sharding=Sharding.DimsSharding(mesh, (-2,), (:model,))
144+
)
145+
batch_device = reactant_device(;
146+
sharding=Sharding.DimsSharding(mesh, (-1,), (:batch,))
147+
)
148+
149+
model = Chain(
150+
Chain(Dense(4 => 32), BatchNorm(32, relu)),
151+
Chain(Dense(32 => 32), BatchNorm(32, relu)),
152+
Dense(32 => 4),
153+
)
154+
ps, st = Lux.setup(Random.default_rng(), model) |> model_device
155+
156+
x = rand(Float32, 4, 128) |> batch_device
157+
y = rand(Float32, 4, 128) |> batch_device
158+
159+
train_state = Training.TrainState(model, ps, st, Adam(0.001f0))
160+
161+
_, loss, _, train_state2 = Training.single_train_step(
162+
AutoEnzyme(), MSELoss(), (x, y), train_state
163+
)
164+
@test loss isa Reactant.ConcreteRNumber
165+
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
166+
167+
_, loss, _, train_state = Training.single_train_step(
168+
AutoEnzyme(), MSELoss(), (x, y), train_state
169+
)
170+
@test loss isa Reactant.ConcreteRNumber
171+
@test length(Reactant.XLA.devices(Reactant.XLA.sharding(loss.data))) == 8
172+
end
173+
end

0 commit comments

Comments
 (0)