@@ -3,19 +3,27 @@ function objective_function_wrapper(objective_function::F, model, ps, st, data)
33 return loss, Reactant. ignore_derivatives(stₙ), Reactant. ignore_derivatives(stats)
44end
55
6- function compute_gradients_internal(objective_function:: F , model, data, ps, st) where {F}
7- (_, _, dps, _, _), (loss, stₙ, stats) = Enzyme. gradient(
6+ function compute_gradients_internal!(
7+ dps, objective_function:: F , model, data, ps, st
8+ ) where {F}
9+ _, (loss, stₙ, stats) = Enzyme. autodiff(
810 Enzyme. set_abi(Enzyme. ReverseWithPrimal, Reactant. ReactantABI),
911 Const(objective_function_wrapper),
1012 Const(objective_function),
1113 Const(model),
12- ps ,
14+ Duplicated(ps, dps) ,
1315 Const(st),
1416 Const(data),
1517 )
1618 return dps, loss, stats, stₙ
1719end
1820
21+ function compute_gradients_internal(objective_function:: F , model, data, ps, st) where {F}
22+ return compute_gradients_internal!(
23+ Enzyme. make_zero(ps), objective_function, model, data, ps, st
24+ )
25+ end
26+
1927Profiler. @annotate " Compile Compute Gradients" function Lux. Training. compute_gradients_impl(
2028 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
2129) where {F}
@@ -84,6 +92,19 @@ for inplace in ("!", "")
8492 @eval Profiler. @annotate " Compile Train Step" function Lux. Training.$ (fname)(
8593 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
8694 ) where {F}
95+ device = get_device((ts. parameters, ts. states, ts. optimizer_state, data))
96+ @assert device isa ReactantDevice
97+ is_sharded = device. device === nothing
98+
99+ dps = if backend. return_gradients isa True
100+ Functors. fmap(Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
101+ else
102+ nothing
103+ end
104+
105+ # TODO : make it conditional
106+ ps_cache = Functors. fmap(copy, ts. parameters; exclude= MLDataDevices. isleaf)
107+
87108 compiled_grad_and_step_function = with_default_precision_config(ts. parameters) do
88109 @compile sync = backend. sync $ (internal_fn)(
89110 objective_function,
@@ -92,7 +113,9 @@ for inplace in ("!", "")
92113 ts. parameters,
93114 ts. states,
94115 ts. optimizer_state,
95- backend. return_gradients,
116+ dps,
117+ is_sharded,
118+ ps_cache,
96119 )
97120 end
98121
@@ -103,11 +126,13 @@ for inplace in ("!", "")
103126 ts. parameters,
104127 ts. states,
105128 ts. optimizer_state,
106- backend. return_gradients,
129+ dps,
130+ is_sharded,
131+ ps_cache,
107132 )
108133
109134 cache = TrainingBackendCache(
110- backend, False(), nothing , (; compiled_grad_and_step_function)
135+ backend, False(), dps , (; compiled_grad_and_step_function, is_sharded, ps_cache )
111136 )
112137 @set! ts. cache = cache
113138 @set! ts. objective_function = objective_function
@@ -132,7 +157,9 @@ for inplace in ("!", "")
132157 ts. parameters,
133158 ts. states,
134159 ts. optimizer_state,
135- backend. return_gradients,
160+ ts. cache. dparameters,
161+ ts. cache. extras. is_sharded,
162+ ts. cache. extras. ps_cache,
136163 )
137164
138165 @set! ts. states = st
@@ -143,24 +170,67 @@ for inplace in ("!", "")
143170 return grads, loss, stats, ts
144171 end
145172
146- # XXX : Inplace version not actually inplace
147173 @eval function $ (internal_fn)(
148- objective_function:: F , model, data, ps, st, opt_state, :: False
174+ objective_function:: F ,
175+ model,
176+ data,
177+ ps,
178+ st,
179+ opt_state,
180+ :: Nothing ,
181+ is_sharded:: Bool ,
182+ ps_cache,
149183 ) where {F}
150184 dps, loss, stats, stₙ = compute_gradients_internal(
151185 objective_function, model, data, ps, st
152186 )
153- opt_state, ps = Optimisers.$ (update_fn)(opt_state, ps, dps)
154- return nothing , ps, loss, stats, stₙ, opt_state
187+
188+ opt_state, psₙ = Optimisers.$ (update_fn)(opt_state, ps, dps)
189+ Functors. fmap(copyto!, ps_cache, psₙ; exclude= MLDataDevices. isleaf)
190+ if is_sharded
191+ # Ensure sharding of input and output states are consistent
192+ mark_same_sharding_group(st, stₙ)
193+ end
194+
195+ return nothing , ps_cache, loss, stats, stₙ, opt_state
155196 end
156197
157198 @eval function $ (internal_fn)(
158- objective_function:: F , model, data, ps, st, opt_state, :: True
199+ objective_function:: F ,
200+ model,
201+ data,
202+ ps,
203+ st,
204+ opt_state,
205+ dps,
206+ is_sharded:: Bool ,
207+ ps_cache,
159208 ) where {F}
160- dps, loss, stats, stₙ = compute_gradients_internal(
161- objective_function, model, data, ps, st
209+ dps, loss, stats, stₙ = compute_gradients_internal! (
210+ dps, objective_function, model, data, ps, st
162211 )
163- opt_state, ps = Optimisers.$ (update_fn)(opt_state, ps, dps)
164- return dps, ps, loss, stats, stₙ, opt_state
212+
213+ opt_state, psₙ = Optimisers.$ (update_fn)(opt_state, ps, dps)
214+ Functors. fmap(copyto!, ps_cache, psₙ; exclude= MLDataDevices. isleaf)
215+ if is_sharded
216+ # Ensure sharding of input and output states are consistent
217+ # mark_same_sharding_group(ps, psₙ)
218+ mark_same_sharding_group(st, stₙ)
219+ end
220+
221+ return dps, ps_cache, loss, stats, stₙ, opt_state
165222 end
166223end
224+
225+ # TODO : think of a better way than sharding group. Since this will insert an optimization
226+ # barrier in the graph and we wont be able to do layout optimizations. Can we instead
227+ # use result sharding annotations here?
228+ function mark_same_sharding_group(args... )
229+ return Functors. fmap(mark_same_sharding_group_inner, args... )
230+ end
231+
232+ function mark_same_sharding_group_inner(arg1:: Union{TracedRArray,TracedRNumber} , args... )
233+ @opcall sharding_group(arg1, args... )
234+ return nothing
235+ end
236+ mark_same_sharding_group_inner(arg1, args... ) = nothing
0 commit comments