-
Notifications
You must be signed in to change notification settings - Fork 90
Closed
Description
Enzyme is giving the wrong derivative for sparse matrix and dense vector multiplication. The example below uses Zygote as a reference, although I have computed the derivatives by-hand and they match Zygote. The example below multiplies the same vectors together, the only difference is that one is converted to dense before the multiplication
MWE:
using Enzyme
import Zygote
using SparseArrays
function myfunc(a, b)
""" used for sparse matrix times vector (dense or sparse) multiplication """
return sum(abs2.(a * b))
end
n = 5
a = sprand(n, n, 0.5) + 1im*sprand(n, n, 0.5)
b = sprand(n, 1, 0.5) + 1im*sprand(n, 1, 0.5)
da_sparse = Enzyme.make_zero(a)
da_dense = Enzyme.make_zero(a)
# compute gradient with Zygote
println("Zygote")
grad_sparse = Zygote.gradient(A -> myfunc(A, b), a) # GOOD
grad_dense = Zygote.gradient(A -> myfunc(A, Array(b)), a) # GOOD
@show Array(grad_sparse[1]) ≈ Array(grad_dense[1])
# compute gradient with Enzyme
println("")
println("Enzyme")
autodiff(Reverse, myfunc, Duplicated(a, da_sparse), Const(b)) # GOOD
autodiff(Reverse, myfunc, Duplicated(a, da_dense), Const(Array(b))) # BAD
@show Array(da_sparse) ≈ Array(grad_sparse[1])
@show Array(da_sparse) ≈ Array(da_dense)
# show different results
println("")
println("Compare Results")
println("da_sparse")
display(da_sparse)
println("da_dense")
display(da_dense)
output:
Zygote
Array(grad_sparse[1]) ≈ Array(grad_dense[1]) = true
Enzyme
Array(da_sparse) ≈ Array(grad_sparse[1]) = true
Array(da_sparse) ≈ Array(da_dense) = false
Compare Results
da_sparse
5×5 SparseMatrixCSC{ComplexF64, Int64} with 19 stored entries:
⋅ ⋅ 1.91776+1.53486im ⋅ 0.875634+3.84073im
1.01186+1.10147im -0.958577+1.08917im 0.77511+0.843753im 0.0+0.0im 0.130956+1.83277im
0.806518+2.30802im -2.14352+1.01501im 0.617814+1.768im ⋅ -0.988918+2.83604im
⋅ -0.622128+2.52033im 1.92702+0.699208im 0.0+0.0im 1.72126+2.80094im
1.2+0.482232im -0.341932+1.20706im 0.919231+0.369403im 0.0+0.0im ⋅
da_dense
5×5 SparseMatrixCSC{ComplexF64, Int64} with 19 stored entries:
⋅ ⋅ -1.91776-1.53486im ⋅ -3.93926-0.0128396im
-1.01186-1.10147im -1.16641+0.862936im -0.77511-0.843753im 0.0+0.0im -1.81511-0.285631im
-0.806518-2.30802im -2.30918+0.540953im -0.617814-1.768im ⋅ -2.54006-1.60287im
⋅ -1.13882+2.33285im -1.92702-0.699208im 0.0+0.0im -3.11689+1.04544im
-1.2-0.482232im -0.588405+1.10801im -0.919231-0.369403im 0.0+0.0im ⋅
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels