@@ -3,19 +3,29 @@ function objective_function_wrapper(objective_function::F, model, ps, st, data)
33 return loss, Reactant. ignore_derivatives(stₙ), Reactant. ignore_derivatives(stats)
44end
55
6- function compute_gradients_internal(objective_function:: F , model, data, ps, st) where {F}
7- (_, _, dps, _, _), (loss, stₙ, stats) = Enzyme. gradient(
6+ function compute_gradients_internal!(
7+ dps, objective_function:: F , model, data, ps, st, zeroed_grads:: Bool = false
8+ ) where {F}
9+ zeroed_grads || Enzyme. make_zero!(dps)
10+
11+ _, (loss, stₙ, stats) = Enzyme. autodiff(
812 Enzyme. set_abi(Enzyme. ReverseWithPrimal, Reactant. ReactantABI),
913 Const(objective_function_wrapper),
1014 Const(objective_function),
1115 Const(model),
12- ps ,
16+ Duplicated(ps, dps) ,
1317 Const(st),
1418 Const(data),
1519 )
1620 return dps, loss, stats, stₙ
1721end
1822
23+ function compute_gradients_internal(objective_function:: F , model, data, ps, st) where {F}
24+ return compute_gradients_internal!(
25+ Enzyme. make_zero(ps), objective_function, model, data, ps, st, true
26+ )
27+ end
28+
1929Profiler. @annotate " Compile Compute Gradients" function Lux. Training. compute_gradients_impl(
2030 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
2131) where {F}
@@ -80,34 +90,54 @@ for inplace in ("!", "")
8090 return ts
8191 end
8292
93+ ps_expr = if inplace == " !"
94+ :(ps = ts. parameters)
95+ else
96+ :(ps = Functors. fmap(copy, ts. parameters; exclude= MLDataDevices. isleaf))
97+ end
98+
8399 # XXX : recompile with a warning if new input types are used
84100 @eval Profiler. @annotate " Compile Train Step" function Lux. Training.$ (fname)(
85101 backend:: ReactantBackend , objective_function:: F , data, ts:: Training.TrainState
86102 ) where {F}
103+ device = get_device((ts. parameters, ts. states, ts. optimizer_state, data))
104+ @assert device isa ReactantDevice
105+ is_sharded = device. device === nothing
106+
107+ dps = if backend. return_gradients isa True
108+ Functors. fmap(Utils. zero, ts. parameters; exclude= MLDataDevices. isleaf)
109+ else
110+ nothing
111+ end
112+
113+ $ (ps_expr)
114+
87115 compiled_grad_and_step_function = with_default_precision_config(ts. parameters) do
88116 @compile sync = backend. sync $ (internal_fn)(
89117 objective_function,
90118 ts. model,
91119 data,
92- ts . parameters ,
120+ ps ,
93121 ts. states,
94122 ts. optimizer_state,
95- backend. return_gradients,
123+ dps,
124+ is_sharded,
96125 )
97126 end
98127
99128 grads, ps, loss, stats, st, opt_state = compiled_grad_and_step_function(
100129 objective_function,
101130 ts. model,
102131 data,
103- ts . parameters ,
132+ ps ,
104133 ts. states,
105134 ts. optimizer_state,
106- backend. return_gradients,
135+ dps,
136+ is_sharded,
107137 )
108138
109139 cache = TrainingBackendCache(
110- backend, False(), nothing , (; compiled_grad_and_step_function)
140+ backend, False(), dps , (; compiled_grad_and_step_function, is_sharded )
111141 )
112142 @set! ts. cache = cache
113143 @set! ts. objective_function = objective_function
@@ -120,7 +150,7 @@ for inplace in ("!", "")
120150 end
121151
122152 @eval Profiler. @annotate " Train Step" function Lux. Training.$ (fname)(
123- backend :: ReactantBackend ,
153+ :: ReactantBackend ,
124154 obj_fn:: F ,
125155 data,
126156 ts:: Training.TrainState{<:TrainingBackendCache{<:ReactantBackend},F} ,
@@ -132,7 +162,8 @@ for inplace in ("!", "")
132162 ts. parameters,
133163 ts. states,
134164 ts. optimizer_state,
135- backend. return_gradients,
165+ ts. cache. dparameters,
166+ ts. cache. extras. is_sharded,
136167 )
137168
138169 @set! ts. states = st
@@ -143,24 +174,38 @@ for inplace in ("!", "")
143174 return grads, loss, stats, ts
144175 end
145176
146- # XXX : Inplace version not actually inplace
147177 @eval function $ (internal_fn)(
148- objective_function:: F , model, data, ps, st, opt_state, :: False
178+ objective_function:: F , model, data, ps, st, opt_state, :: Nothing , is_sharded :: Bool
149179 ) where {F}
150180 dps, loss, stats, stₙ = compute_gradients_internal(
151181 objective_function, model, data, ps, st
152182 )
153- opt_state, ps = Optimisers.$ (update_fn)(opt_state, ps, dps)
154- return nothing , ps, loss, stats, stₙ, opt_state
183+
184+ opt_state, psₙ = Optimisers. update!(opt_state, ps, dps)
185+ # Ensure sharding of input and output states are consistent
186+ is_sharded && mark_same_sharding_group(st, stₙ)
187+
188+ return nothing , psₙ, loss, stats, stₙ, opt_state
155189 end
156190
157191 @eval function $ (internal_fn)(
158- objective_function:: F , model, data, ps, st, opt_state, :: True
192+ objective_function:: F , model, data, ps, st, opt_state, dps, is_sharded :: Bool
159193 ) where {F}
160- dps, loss, stats, stₙ = compute_gradients_internal(
161- objective_function, model, data, ps, st
194+ dps, loss, stats, stₙ = compute_gradients_internal! (
195+ dps, objective_function, model, data, ps, st
162196 )
163- opt_state, ps = Optimisers.$ (update_fn)(opt_state, ps, dps)
164- return dps, ps, loss, stats, stₙ, opt_state
197+
198+ opt_state, psₙ = Optimisers. update!(opt_state, ps, dps)
199+ # Ensure sharding of input and output states are consistent
200+ is_sharded && mark_same_sharding_group(st, stₙ)
201+
202+ return dps, psₙ, loss, stats, stₙ, opt_state
165203 end
166204end
205+
206+ mark_same_sharding_group(args... ) = Functors. fmap(mark_same_sharding_group_inner, args... )
207+
208+ function mark_same_sharding_group_inner(arg1:: Union{TracedRArray,TracedRNumber} , args... )
209+ return @opcall sharding_group(arg1, args... )
210+ end
211+ mark_same_sharding_group_inner(arg1, args... ) = nothing
0 commit comments