Skip to content

Commit c420ff0

Browse files
Fix Enzyme interface (#372)
* Update utils.jl * fix Enzyme interface * change forward back * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6f0d383 commit c420ff0

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

test/ad/utils.jl

+6-4
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
5656
if !(:EnzymeForwardCrash in broken)
5757
if forward_broken
5858
@test_broken(
59-
Enzyme.gradient(Enzyme.Forward, f, x)[1] finitediff,
59+
Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] finitediff,
6060
rtol = rtol,
6161
atol = atol
6262
)
6363
else
6464
@test(
65-
Enzyme.gradient(Enzyme.Forward, f, x)[1] finitediff,
65+
Enzyme.gradient(Forward, Enzyme.Const(f), x)[1] finitediff,
6666
rtol = rtol,
6767
atol = atol
6868
)
@@ -72,13 +72,15 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
7272
if !(:EnzymeReverseCrash in broken)
7373
if reverse_broken
7474
@test_broken(
75-
Enzyme.gradient(Enzyme.Reverse, f, x)[1] finitediff,
75+
Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1]
76+
finitediff,
7677
rtol = rtol,
7778
atol = atol
7879
)
7980
else
8081
@test(
81-
Enzyme.gradient(Enzyme.Reverse, f, x)[1] finitediff,
82+
Enzyme.gradient(set_runtime_activity(Reverse), Enzyme.Const(f), x)[1]
83+
finitediff,
8284
rtol = rtol,
8385
atol = atol
8486
)

0 commit comments

Comments
 (0)