Skip to content

Commit 1613735

Browse files
authored
Fix multiarg fwd and fwd mode copy on 1.11 (#2469)
* Fix multiarg fwd and fwd mode copy on 1.11 * cleanup
1 parent 23c268c commit 1613735

File tree

3 files changed

+14
-13
lines changed

3 files changed

+14
-13
lines changed

src/rules/llvmrules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ end
805805

806806
LLVM.memset!(
807807
B,
808-
inttoptr!(B, ev2, LLVM.PointerType(LLVM.IntType(8))),
808+
get_memory_data(B, callv),
809809
LLVM.ConstantInt(i8, 0, false),
810810
length,
811811
algn,

src/sugar.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -487,15 +487,6 @@ end
487487
res
488488
end
489489

490-
@inline function tupstack(x, outshape::Tuple{Vararg{Int}}, inshape::Tuple{Vararg{Int}})
491-
st = Base.stack(x)
492-
if length(outshape) == 1
493-
st
494-
else
495-
reshape(st, (outshape..., inshape...))
496-
end
497-
end
498-
499490
@inline specialize_output(output, input) = output
500491

501492
"""
@@ -744,12 +735,14 @@ gradient(Forward, mul, [2.0, 3.0], Const([2.7, 3.1]))
744735
if argnum > 0
745736
quote
746737
if $tmp[1] isa AbstractArray
747-
inshape = size($(vals[1]))
738+
inshape = size($(vals[i]))
748739
outshape = size($tmp[1])
740+
num = prod(outshape)
741+
749742
# st : outshape x total inputs
750743
tupstack($tmp, outshape, inshape)
751744
else
752-
specialize_output(TupleArray($tmp, size($arg)), $(vals[1]))
745+
specialize_output(TupleArray($tmp, size($arg)), $(vals[i]))
753746
end
754747
end
755748
else

test/sugar.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ mul_vector(x, y) = [x[1]*y[2], x[2]*y[1]]
1010
@test res[1] === inp
1111
end
1212

13+
14+
function diffsize(θ0, X)
15+
return copy(X)
16+
end
17+
18+
1319
@testset "Forward Multi-Arg Gradient" begin
1420
res = gradient(Forward, mul_scalar, [2.0, 3.0], [2.7, 3.1])
1521
@test res[1] [3.1, 2.7]
@@ -38,7 +44,9 @@ end
3844
@test res.derivs[1] [3.1, 2.7]
3945
@test res.derivs[2] [3.0, 2.0]
4046

41-
47+
res = gradient(Forward, diffsize, [2.0, 3.0], [2.7, 3.1, 4.5])
48+
@test res[1] zeros(3, 2)
49+
@test res[2] [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0]
4250

4351
res = gradient(Forward, mul_scalar, Const([2.0, 3.0]), [2.7, 3.1])
4452
@test res[1] == nothing

0 commit comments

Comments
 (0)