8787function DI. value_and_gradient! (
8888 f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff{compile} , x
8989) where {compile}
90- y = f (x) # TODO : ReverseDiff#251
91- result = DiffResult (y, (grad,))
90+ result = MutableDiffResult (zero (eltype (x)), (grad,)) # ReverseDiff#251
9291 if compile
9392 result = gradient! (result, prep. tape, x)
9493 else
9594 result = gradient! (result, f, x, prep. config)
9695 end
97- y = DR. value (result)
98- grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
99- return y, grad
96+ return DR. value (result), grad # ReverseDiff#269
10097end
10198
10299function DI. value_and_gradient (
103- f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff , x
104- )
105- grad = similar (x)
106- return DI. value_and_gradient! (f, grad, prep, backend, x)
100+ f, prep:: ReverseDiffGradientPrep , backend:: AutoReverseDiff{compile} , x
101+ ) where {compile}
102+ # GradientResult tries to mutate an SArray
103+ result = MutableDiffResult (zero (eltype (x)), (similar (x),))
104+ if compile
105+ result = gradient! (result, prep. tape, x)
106+ else
107+ result = gradient! (result, f, x, prep. config)
108+ end
109+ return DR. value (result), DR. gradient (result)
107110end
108111
109112function DI. gradient! (
@@ -144,23 +147,19 @@ function DI.value_and_gradient!(
144147 contexts:: Vararg{DI.Context,C} ,
145148) where {C}
146149 fc = DI. with_contexts (f, contexts... )
147- y = fc (x) # TODO : ReverseDiff#251
148- result = DiffResult (y, (grad,))
150+ result = MutableDiffResult (zero (eltype (x)), (grad,)) # ReverseDiff#251
149151 result = gradient! (result, fc, x, prep. config)
150- y = DR. value (result)
151- grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
152- return y, grad
152+ return DR. value (result), grad # ReverseDiff#269
153153end
154154
155155function DI. value_and_gradient (
156- f,
157- prep:: ReverseDiffGradientPrep ,
158- backend:: AutoReverseDiff ,
159- x,
160- contexts:: Vararg{DI.Context,C} ,
156+ f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
161157) where {C}
162- grad = similar (x)
163- return DI. value_and_gradient! (f, grad, prep, backend, x, contexts... )
158+ fc = DI. with_contexts (f, contexts... )
159+ # GradientResult tries to mutate an SArray
160+ result = MutableDiffResult (zero (eltype (x)), (similar (x),))
161+ result = gradient! (result, fc, x, prep. config)
162+ return DR. value (result), DR. gradient (result)
164163end
165164
166165function DI. gradient! (
@@ -310,31 +309,23 @@ end
310309
311310# ## Without contexts
312311
313- @kwdef struct ReverseDiffHessianPrep{GC ,HC,GT ,HT} <: DI.HessianPrep
314- gradient_config :: GC
312+ @kwdef struct ReverseDiffHessianPrep{G <: ReverseDiffGradientPrep ,HC,HT} <: DI.HessianPrep
313+ gradient_prep :: G
315314 hessian_config:: HC
316- gradient_tape:: GT
317315 hessian_tape:: HT
318316end
319317
320- function DI. prepare_hessian (f, :: AutoReverseDiff{compile} , x) where {compile}
318+ function DI. prepare_hessian (f, backend:: AutoReverseDiff{compile} , x) where {compile}
319+ gradient_prep = DI. prepare_gradient (f, backend, x)
321320 if compile
322- gradient_tape = ReverseDiff. compile (GradientTape (f, x))
323321 hessian_tape = ReverseDiff. compile (HessianTape (f, x))
324322 return ReverseDiffHessianPrep (;
325- gradient_config= nothing ,
326- hessian_config= nothing ,
327- gradient_tape= gradient_tape,
328- hessian_tape= hessian_tape,
323+ gradient_prep, hessian_config= nothing , hessian_tape= hessian_tape
329324 )
330325 else
331- gradient_config = GradientConfig (x)
332326 hessian_config = HessianConfig (x)
333327 return ReverseDiffHessianPrep (;
334- gradient_config= gradient_config,
335- hessian_config= hessian_config,
336- gradient_tape= nothing ,
337- hessian_tape= nothing ,
328+ gradient_prep, hessian_config= hessian_config, hessian_tape= nothing
338329 )
339330 end
340331end
@@ -360,47 +351,32 @@ function DI.hessian(
360351end
361352
362353function DI. value_gradient_and_hessian! (
363- f, grad, hess, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff{compile} , x
354+ f, grad, hess, prep:: ReverseDiffHessianPrep , backend :: AutoReverseDiff{compile} , x
364355) where {compile}
365- y = f (x) # TODO : ReverseDiff#251
366- result = DiffResult (y, (grad, hess))
367- if compile
368- result = hessian! (result, prep. hessian_tape, x)
369- grad = gradient! (grad, prep. gradient_tape, x) # TODO : ReverseDiff#251
370- else
371- result = hessian! (result, f, x) # TODO : add prep.hessian_config
372- grad = gradient! (grad, f, x, prep. gradient_config) # TODO : ReverseDiff#251
373- end
374- # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
375- hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
356+ y = f (x)
357+ DI. gradient! (f, grad, prep. gradient_prep, backend, x)
358+ DI. hessian! (f, hess, prep, backend, x)
376359 return y, grad, hess
377360end
378361
379362function DI. value_gradient_and_hessian (
380- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff{compile} , x
363+ f, prep:: ReverseDiffHessianPrep , backend :: AutoReverseDiff{compile} , x
381364) where {compile}
382- y = f (x) # TODO : remove once ReverseDiff#251 is fixed
383- result = DiffResult (y, (similar (x), similar (x, length (x), length (x))))
384- if compile
385- result = hessian! (result, prep. hessian_tape, x)
386- else
387- result = hessian! (result, f, x) # todo: add prep.hessian_config
388- end
389- return (y, DR. gradient (result), DR. hessian (result))
365+ y = f (x)
366+ grad = DI. gradient (f, prep. gradient_prep, backend, x)
367+ hess = DI. hessian (f, prep, backend, x)
368+ return y, grad, hess
390369end
391370
392371# ## With contexts
393372
394373function DI. prepare_hessian (
395- f, :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
374+ f, backend :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
396375) where {C}
397- gradient_config = GradientConfig (x )
376+ gradient_prep = DI . prepare_gradient (f, backend, x, contexts ... )
398377 hessian_config = HessianConfig (x)
399378 return ReverseDiffHessianPrep (;
400- gradient_config= gradient_config,
401- hessian_config= hessian_config,
402- gradient_tape= nothing ,
403- hessian_tape= nothing ,
379+ gradient_prep, hessian_config= hessian_config, hessian_tape= nothing
404380 )
405381end
406382
@@ -428,27 +404,25 @@ function DI.value_gradient_and_hessian!(
428404 grad,
429405 hess,
430406 prep:: ReverseDiffHessianPrep ,
431- :: AutoReverseDiff ,
407+ backend :: AutoReverseDiff ,
432408 x,
433409 contexts:: Vararg{DI.Context,C} ,
434410) where {C}
435- fc = DI. with_contexts (f, contexts... )
436- y = fc (x) # TODO : ReverseDiff#251
437- result = DiffResult (y, (grad, hess))
438- result = hessian! (result, fc, x) # TODO : add prep.hessian_config
439- y = DR. value (result)
440- # grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
441- grad = gradient! (grad, fc, x, prep. gradient_config) # TODO : ReverseDiff#251
442- hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
411+ y = f (x, map (DI. unwrap, contexts)... )
412+ DI. gradient! (f, grad, prep. gradient_prep, backend, x, contexts... )
413+ DI. hessian! (f, hess, prep, backend, x, contexts... )
443414 return y, grad, hess
444415end
445416
446417function DI. value_gradient_and_hessian (
447- f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{DI.Context,C}
418+ f,
419+ prep:: ReverseDiffHessianPrep ,
420+ backend:: AutoReverseDiff ,
421+ x,
422+ contexts:: Vararg{DI.Context,C} ,
448423) where {C}
449- fc = DI. with_contexts (f, contexts... )
450- y = fc (x) # TODO : ReverseDiff#251
451- result = HessianResult (x)
452- result = hessian! (result, fc, x) # TODO : add prep.hessian_config
453- return (DR. value (result), DR. gradient (result), DR. hessian (result))
424+ y = f (x, map (DI. unwrap, contexts)... )
425+ grad = DI. gradient (f, prep. gradient_prep, backend, x, contexts... )
426+ hess = DI. hessian (f, prep, backend, x, contexts... )
427+ return y, grad, hess
454428end
0 commit comments