Skip to content

Commit 0c56bec

Browse files
Safety make_zero! on repeated Enzyme calls with caches
This should ensure that the caches are always zero'd for the derivatives.
1 parent 456ea4b commit 0c56bec

File tree

4 files changed

+5
-1
lines changed

4 files changed

+5
-1
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ DiffEqNoiseProcess = "5.19"
6161
Distributed = "1"
6262
Distributions = "0.25"
6363
EllipsisNotation = "1"
64-
Enzyme = "0.12"
64+
Enzyme = "0.12.12"
6565
FiniteDiff = "2"
6666
ForwardDiff = "0.10"
6767
FunctionProperties = "0.1"

Diff for: src/derivative_wrappers.jl

+2
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
705705

706706
isautojacvec = get_jacvec(sensealg)
707707

708+
Enzyme.make_zero!(_tmp6)
709+
708710
if inplace_sensitivity(S)
709711
if W === nothing
710712
Enzyme.autodiff(Enzyme.Reverse, Enzyme.Duplicated(S.diffcache.pf, _tmp6),

Diff for: src/gauss_adjoint.jl

+1
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
498498
vtmp4 = vec(tmp4)
499499
vtmp4 .= λ
500500
out .= 0
501+
Enzyme.make_zero!(tmp6)
501502
Enzyme.autodiff(
502503
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
503504
Enzyme.Duplicated(tmp3, tmp4),

Diff for: src/quadrature_adjoint.jl

+1
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
294294
tmp3, tmp4, tmp6 = paramjac_config
295295
tmp4 .= λ
296296
out .= 0
297+
Enzyme.make_zero!(tmp6)
297298
Enzyme.autodiff(
298299
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
299300
Enzyme.Duplicated(tmp3, tmp4),

0 commit comments

Comments
 (0)