Skip to content

Commit 6b30dda

Browse files
authored
Batched push (#2763)
* Batched push * fix
1 parent 986bbf4 commit 6b30dda

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

src/rules/llvmrules.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1587,10 +1587,10 @@ end
15871587
offset
15881588
]
15891589
if width > 1
1590-
extract_value!(B, shadowin, idx - 1)
1590+
args[1] = extract_value!(B, args[1], idx - 1)
15911591
end
15921592

1593-
if get_runtime_activity(gutils) && endB == nothing
1593+
if get_runtime_activity(gutils) && endB === nothing
15941594
cond = icmp!(B, LLVM.API.LLVMIntNE, fval, args[1])
15951595

15961596
nextB = add_reverse_block!(gutils, currentBlock, ogname*"_active")
@@ -1604,7 +1604,7 @@ end
16041604

16051605
LLVM.call!(B, fty, delF, args)
16061606
end
1607-
1607+
16081608
if endB !== nothing
16091609
br!(B, endB)
16101610
set_reverse_block!(gutils, endB)

test/array.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,52 @@ for RTA in (false, true)
202202
end
203203
@test x [2.3, 2.0]
204204

205+
206+
function pusher_ref(x, y, z)
207+
push!(x, y[])
208+
z[] = x[1] + x[2]
209+
nothing
210+
end
211+
212+
yr = Ref(2.0)
213+
dyr = Ref(0.0)
214+
dyr2 = Ref(0.0)
215+
216+
zr = Ref(2.0)
217+
dzr = Ref(1.0)
218+
dzr2 = Ref(2.7)
219+
220+
x = [2.3]
221+
dx = [0.0]
222+
dx2 = [0.0]
223+
rf = @static if VERSION < v"1.11-"
224+
nothing
225+
else
226+
dx.ref.mem
227+
end
228+
229+
rf2 = @static if VERSION < v"1.11-"
230+
nothing
231+
else
232+
dx2.ref.mem
233+
end
234+
235+
Enzyme.autodiff(set_runtime_activity(Reverse, RTA), pusher_ref, BatchDuplicated(x, (dx, dx2)), BatchDuplicated(yr, (dyr, dyr2)), BatchDuplicated(zr, (dzr, dzr2)))
236+
237+
@test 1.0 dyr[]
238+
@test 2.7 dyr2[]
239+
240+
@static if VERSION < v"1.11-"
241+
@test dx [1.0]
242+
@test dx2 [2.7]
243+
else
244+
@test dx [0.0, 0.0]
245+
@test dx2 [0.0, 0.0]
246+
@test rf [1.0]
247+
@test rf2 [2.7]
248+
end
249+
@test x [2.3, 2.0]
250+
205251
function double_push(x)
206252
a = [0.5]
207253
push!(a, 1.0)

0 commit comments

Comments
 (0)