@@ -8,8 +8,7 @@ function seeded_autodiff_thunk(
88 forward, reverse = autodiff_thunk (rmode, FA, RA, typeof .(args)... )
99 tape, result, shadow_result = forward (f, args... )
1010 if RA <: Active
11- dresult_righttype = convert (typeof (result), dresult)
12- dinputs = only (reverse (f, args... , dresult_righttype, tape))
11+ dinputs = only (reverse (f, args... , dresult, tape))
1312 else
1413 shadow_result .+ = dresult # TODO : generalize beyond arrays
1514 dinputs = only (reverse (f, args... , tape))
@@ -32,8 +31,7 @@ function batch_seeded_autodiff_thunk(
3231 forward, reverse = autodiff_thunk (rmode_rightwidth, FA, RA, typeof .(args)... )
3332 tape, result, shadow_results = forward (f, args... )
3433 if RA <: Active
35- dresults_righttype = map (Fix1 (convert, typeof (result)), dresults)
36- dinputs = only (reverse (f, args... , dresults_righttype, tape))
34+ dinputs = only (reverse (f, args... , dresults, tape))
3735 else
3836 foreach (shadow_results, dresults) do d0, d
3937 d0 .+ = d # use recursive_add here?
@@ -141,13 +139,12 @@ function DI.value_and_pullback!(
141139 mode = reverse_split_withprimal (backend)
142140 f_and_df = force_annotation (get_f_and_df (f, backend, mode))
143141 RA = guess_activity (typeof (prep. y_example), mode)
144- dx_righttype = convert ( typeof (x), only (tx) )
145- make_zero! (dx_righttype )
142+ dx = only (tx)
143+ make_zero! (dx )
146144 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
147145 _, result = seeded_autodiff_thunk (
148- mode, only (ty), f_and_df, RA, Duplicated (x, dx_righttype ), annotated_contexts...
146+ mode, only (ty), f_and_df, RA, Duplicated (x, dx ), annotated_contexts...
149147 )
150- copyto_if_different_addresses! (only (tx), dx_righttype)
151148 return result, tx
152149end
153150
@@ -163,13 +160,11 @@ function DI.value_and_pullback!(
163160 mode = reverse_split_withprimal (backend)
164161 f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
165162 RA = batchify_activity (guess_activity (typeof (prep. y_example), mode), Val (B))
166- tx_righttype = map (Fix1 (convert, typeof (x)), tx)
167- make_zero! (tx_righttype)
163+ make_zero! (tx)
168164 annotated_contexts = translate (backend, mode, Val (B), contexts... )
169165 _, result = batch_seeded_autodiff_thunk (
170- mode, ty, f_and_df, RA, BatchDuplicated (x, tx_righttype ), annotated_contexts...
166+ mode, ty, f_and_df, RA, BatchDuplicated (x, tx ), annotated_contexts...
171167 )
172- foreach (copyto_if_different_addresses!, tx, tx_righttype)
173168 return result, tx
174169end
175170
@@ -187,10 +182,73 @@ end
187182
188183# # Gradient
189184
190- # ## Without preparation
185+ function DI. prepare_gradient (
186+ f:: F , :: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{DI.Context,C}
187+ ) where {F,C}
188+ return DI. NoGradientPrep ()
189+ end
190+
191+ # ## Enzyme gradient API (only constants)
192+
193+ function DI. gradient (
194+ f:: F ,
195+ :: DI.NoGradientPrep ,
196+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
197+ x,
198+ contexts:: Vararg{DI.Constant,C} ,
199+ ) where {F,C}
200+ mode = reverse_noprimal (backend)
201+ f_and_df = get_f_and_df (f, backend, mode)
202+ annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
203+ grads = gradient (mode, f_and_df, x, annotated_contexts... )
204+ return first (grads)
205+ end
206+
207+ function DI. value_and_gradient (
208+ f:: F ,
209+ :: DI.NoGradientPrep ,
210+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
211+ x,
212+ contexts:: Vararg{DI.Constant,C} ,
213+ ) where {F,C}
214+ mode = reverse_withprimal (backend)
215+ f_and_df = get_f_and_df (f, backend, mode)
216+ annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
217+ grads, result = gradient (mode, f_and_df, x, annotated_contexts... )
218+ return result, first (grads)
219+ end
220+
221+ function DI. gradient! (
222+ f:: F ,
223+ grad,
224+ :: DI.NoGradientPrep ,
225+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
226+ x,
227+ ) where {F}
228+ mode = reverse_noprimal (backend)
229+ f_and_df = get_f_and_df (f, backend, mode)
230+ gradient! (mode, grad, f_and_df, x)
231+ return grad
232+ end
233+
234+ function DI. value_and_gradient! (
235+ f:: F ,
236+ grad,
237+ :: DI.NoGradientPrep ,
238+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}} ,
239+ x,
240+ ) where {F}
241+ mode = reverse_withprimal (backend)
242+ f_and_df = get_f_and_df (f, backend, mode)
243+ _, result = gradient! (mode, grad, f_and_df, x)
244+ return result, grad
245+ end
246+
247+ # ## Generic
191248
192249function DI. gradient (
193250 f:: F ,
251+ :: DI.NoGradientPrep ,
194252 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
195253 x,
196254 contexts:: Vararg{DI.Context,C} ,
213271
214272function DI. value_and_gradient (
215273 f:: F ,
274+ :: DI.NoGradientPrep ,
216275 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
217276 x,
218277 contexts:: Vararg{DI.Context,C} ,
@@ -233,73 +292,34 @@ function DI.value_and_gradient(
233292 end
234293end
235294
236- # ## With preparation
237-
238- struct EnzymeGradientPrep{G} <: DI.GradientPrep
239- grad_righttype:: G
240- end
241-
242- function DI. prepare_gradient (
243- f:: F , :: AutoEnzyme{<:Union{ReverseMode,Nothing}} , x, contexts:: Vararg{DI.Context,C}
244- ) where {F,C}
245- grad_righttype = make_zero (x)
246- return EnzymeGradientPrep (grad_righttype)
247- end
248-
249- function DI. gradient (
250- f:: F ,
251- :: EnzymeGradientPrep ,
252- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
253- x,
254- contexts:: Vararg{DI.Context,C} ,
255- ) where {F,C}
256- return DI. gradient (f, backend, x, contexts... )
257- end
258-
259295function DI. gradient! (
260296 f:: F ,
261297 grad,
262- prep :: EnzymeGradientPrep ,
298+ :: DI.NoGradientPrep ,
263299 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
264300 x,
265301 contexts:: Vararg{DI.Context,C} ,
266302) where {F,C}
267303 mode = reverse_noprimal (backend)
268304 f_and_df = get_f_and_df (f, backend, mode)
269- grad_righttype = grad isa typeof (x) ? grad : prep. grad_righttype
270- make_zero! (grad_righttype)
271305 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
272- autodiff (mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts ... )
273- copyto_if_different_addresses! ( grad, grad_righttype )
306+ make_zero! (grad )
307+ autodiff (mode, f_and_df, Active, Duplicated (x, grad), annotated_contexts ... )
274308 return grad
275309end
276310
277- function DI. value_and_gradient (
278- f:: F ,
279- :: EnzymeGradientPrep ,
280- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
281- x,
282- contexts:: Vararg{DI.Context,C} ,
283- ) where {F,C}
284- return DI. value_and_gradient (f, backend, x, contexts... )
285- end
286-
287311function DI. value_and_gradient! (
288312 f:: F ,
289313 grad,
290- prep :: EnzymeGradientPrep ,
314+ :: DI.NoGradientPrep ,
291315 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
292316 x,
293317 contexts:: Vararg{DI.Context,C} ,
294318) where {F,C}
295319 mode = reverse_withprimal (backend)
296320 f_and_df = get_f_and_df (f, backend, mode)
297- grad_righttype = grad isa typeof (x) ? grad : prep. grad_righttype
298- make_zero! (grad_righttype)
299321 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
300- _, y = autodiff (
301- mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts...
302- )
303- copyto_if_different_addresses! (grad, grad_righttype)
322+ make_zero! (grad)
323+ _, y = autodiff (mode, f_and_df, Active, Duplicated (x, grad), annotated_contexts... )
304324 return y, grad
305325end
0 commit comments