Skip to content

Commit 19b8c6c

Browse files
authored
Support real and complex inputs for sparse matmul (#2353)
* Support real and complex inputs for sparse matmul * Tweak tests * Remove extra test since already covered * Apply runic format suggestions * Re-add forward mode sort * clean up * clean up formatting * Shrink tests to make them faster * try to tweak test
1 parent 2f37c0d commit 19b8c6c

File tree

2 files changed

+127
-100
lines changed

2 files changed

+127
-100
lines changed

src/internal_rules.jl

Lines changed: 83 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -729,15 +729,16 @@ function EnzymeRules.reverse(
729729
end
730730

731731

732-
function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
733-
func::Const{typeof(LinearAlgebra.mul!)},
734-
::Type{RT},
735-
C::Annotation{<:StridedVecOrMat},
736-
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
737-
B::Annotation{<:StridedVecOrMat},
738-
α::Annotation{<:Number},
739-
β::Annotation{<:Number}
740-
) where {RT}
732+
function EnzymeRules.augmented_primal(
733+
config::EnzymeRules.RevConfig,
734+
func::Const{typeof(LinearAlgebra.mul!)},
735+
::Type{RT},
736+
C::Annotation{<:StridedVecOrMat},
737+
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
738+
B::Annotation{<:StridedVecOrMat},
739+
α::Annotation{<:Number},
740+
β::Annotation{<:Number}
741+
) where {RT}
741742

742743
cache_C = !(isa(β, Const)) ? copy(C.val) : nothing
743744
# Always need to do forward pass otherwise primal may not be correct
@@ -755,37 +756,56 @@ function EnzymeRules.augmented_primal(config::EnzymeRules.RevConfig,
755756
nothing
756757
end
757758

759+
758760
# Check if A is overwritten and B is active (and thus required)
759-
cache_A = ( EnzymeRules.overwritten(config)[5]
760-
&& !(typeof(B) <: Const)
761-
&& !(typeof(C) <: Const)
762-
) ? copy(A.val) : nothing
763-
764-
cache_B = ( EnzymeRules.overwritten(config)[6]
765-
&& !(typeof(A) <: Const)
766-
&& !(typeof(C) <: Const)
767-
) ? copy(B.val) : nothing
761+
cache_A = (
762+
EnzymeRules.overwritten(config)[5]
763+
&& !(typeof(B) <: Const)
764+
&& !(typeof(C) <: Const)
765+
) ? copy(A.val) : nothing
766+
767+
cache_B = (
768+
EnzymeRules.overwritten(config)[6]
769+
&& !(typeof(A) <: Const)
770+
&& !(typeof(C) <: Const)
771+
) ? copy(B.val) : nothing
768772

769773
if !isa(α, Const)
770-
cache_α = A.val*B.val
774+
cache_α = A.val * B.val
771775
else
772776
cache_α = nothing
773777
end
774-
778+
775779
cache = (cache_C, cache_A, cache_B, cache_α)
776780

777781
return EnzymeRules.AugmentedReturn(primal, shadow, cache)
778782
end
779783

780-
function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
781-
func::Const{typeof(LinearAlgebra.mul!)},
782-
::Type{RT}, cache,
783-
C::Annotation{<:StridedVecOrMat},
784-
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
785-
B::Annotation{<:StridedVecOrMat},
786-
α::Annotation{<:Number},
787-
β::Annotation{<:Number}
788-
) where {RT}
784+
# This is required to handle arugments that mix real and complex numbers
785+
_project(::Type{<:Real}, x) = x
786+
_project(::Type{<:Real}, x::Complex) = real(x)
787+
_project(::Type{<:Complex}, x) = x
788+
789+
function _muladdproject!(::Type{<:Number}, dB::AbstractArray, A::AbstractArray, C::AbstractArray, α)
790+
return LinearAlgebra.mul!(dB, A, C, α, true)
791+
end
792+
793+
function _muladdproject!(::Type{<:Complex}, dB::AbstractArray{<:Real}, A::AbstractArray, C::AbstractArray, α::Number)
794+
tmp = A * C
795+
return dB .+= real.(α .* tmp)
796+
end
797+
798+
799+
function EnzymeRules.reverse(
800+
config::EnzymeRules.RevConfig,
801+
func::Const{typeof(LinearAlgebra.mul!)},
802+
::Type{RT}, cache,
803+
C::Annotation{<:StridedVecOrMat},
804+
A::Annotation{<:SparseArrays.SparseMatrixCSCUnion},
805+
B::Annotation{<:StridedVecOrMat},
806+
α::Annotation{<:Number},
807+
β::Annotation{<:Number}
808+
) where {RT}
789809

790810
cache_C, cache_A, cache_B, cache_α = cache
791811
Cval = !isnothing(cache_C) ? cache_C : C.val
@@ -795,67 +815,71 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
795815
N = EnzymeRules.width(config)
796816
if !isa(C, Const)
797817
dCs = C.dval
798-
dBs = isa(B, Const) ? dCs : B.dval
799-
818+
dBs = isa(B, Const) ? dCs : B.dval
800819
= if !isa(α, Const)
801-
if N == 1
802-
LinearAlgebra.dot(C.dval, cache_α)
803-
else
804-
ntuple(Val(N)) do i
805-
Base.@_inline_meta
806-
LinearAlgebra.dot(C.dval[i], cache_α)
807-
end
820+
if N == 1
821+
_project(typeof.val), conj(LinearAlgebra.dot(C.dval, cache_α)))
822+
else
823+
ntuple(Val(N)) do i
824+
Base.@_inline_meta
825+
_project(typeof.val), conj(LinearAlgebra.dot(C.dval[i], cache_α)))
808826
end
827+
end
809828
else
810829
nothing
811830
end
812831

813832
= if !isa(β, Const)
814-
if N == 1
815-
LinearAlgebra.dot(C.dval, Cval)
816-
else
817-
ntuple(Val(N)) do i
818-
Base.@_inline_meta
819-
LinearAlgebra.dot(C.dval[i], Cval)
820-
end
833+
if N == 1
834+
_project(typeof.val), conj(LinearAlgebra.dot(C.dval, Cval)))
835+
else
836+
ntuple(Val(N)) do i
837+
Base.@_inline_meta
838+
_project(typeof.val), conj(LinearAlgebra.dot(C.dval[i], Cval)))
821839
end
840+
end
822841
else
823842
nothing
824843
end
825844

826845
for i in 1:N
827846
if !isa(A, Const)
828-
# dA .+= αdC*B'
829-
# You need to be careful so that dA sparsity pattern does not change. Otherwise
847+
# dA .+= α'dC*B'
848+
# You need to be careful so that dA sparsity pattern does not change. Otherwise
830849
# you will get incorrect gradients. So for now we do the slow and bad way of accumulating
831850
dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[i]
832851
dC = EnzymeRules.width(config) == 1 ? C.dval : C.dval[i]
833852
# Now accumulate to preserve the correct sparsity pattern
834853
I, J, _ = SparseArrays.findnz(dA)
835854
for k in eachindex(I, J)
836855
Ik, Jk = I[k], J[k]
837-
tmp = zero(eltype(dA))
838-
for ti in axes(dC,2)
839-
tmp += dC[Ik, ti]*Bval[Jk, ti]
856+
# May need to widen if the eltype differ
857+
tmp = zero(promote_type(eltype(dA), eltype(dC)))
858+
for ti in axes(dC, 2)
859+
tmp += dC[Ik, ti] * conj(Bval[Jk, ti])
840860
end
841-
dA[Ik, Jk] += α.val*tmp
861+
dA[Ik, Jk] += _project(eltype(dA), conj(α.val) * tmp)
842862
end
843863
# mul!(dA, dCs, Bval', α.val, true)
844864
end
845865

846866
if !isa(B, Const)
847867
#dB .+= α*A'*dC
848-
if N ==1
849-
func.val(dBs, Aval', dCs, α.val, true)
868+
# Get the type of all arguments since we may need to
869+
# project down to a smaller type during accumulation
870+
if N == 1
871+
Targs = promote_type(eltype(Aval), eltype(dCs), typeof.val))
872+
_muladdproject!(Targs, dBs, Aval', dCs, conj.val))
850873
else
851-
func.val(dBs[i], Aval', dCs[i], α.val, true)
874+
Targs = promote_type(eltype(Aval[i]), eltype(dCs[i]), typeof.val))
875+
_muladdproject!(Targs, dBs[i], Aval', dCs[i], conj.val))
852876
end
853877
end
854-
855-
if N==1
856-
dCs .*= β.val
878+
#dC = dC*conj(β.val)
879+
if N == 1
880+
dCs .*= _project(eltype(dCs), conj(β.val))
857881
else
858-
dCs[i] .*= β.val
882+
dCs[i] .*= _project(eltype(dCs[i]), conj(β.val))
859883
end
860884
end
861885
else
@@ -888,7 +912,7 @@ function EnzymeRules.reverse(config::EnzymeRules.RevConfig,
888912
nothing
889913
end
890914
end
891-
915+
892916
return (nothing, nothing, nothing, dα, dβ)
893917
end
894918

test/internal_rules.jl

Lines changed: 44 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -714,62 +714,65 @@ end
714714
# @test Enzyme.autodiff(Reverse, (x, y) -> begin y[] = f4(x); nothing end, Active(0.1), BatchDuplicated(Ref(0.0), (Ref(1.0), Ref(2.0)))) == (((0.0,0.0)),)
715715
end
716716

717+
717718
@testset "SparseArrays spmatvec reverse rule" begin
718-
C = zeros(18)
719-
M = sprand(18, 9, 0.1)
720-
v = randn(9)
721-
α = 2.0
722-
β = 1.0
719+
Ts = (Float64, ComplexF64)
723720

724-
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
725-
in (Const, Active), Tβ in (Const, Active)
721+
Ms = sprandn.(Ts, 5, 3, 0.3)
722+
vs = rand.(Ts, 3)
723+
αs = rand.(Ts)
724+
βs = rand.(Ts)
726725

727-
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
728-
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))
726+
for M in Ms, v in vs, α in αs, β in βs
727+
tout = promote_type(eltype(M), eltype(v), typeof(α), typeof(β))
728+
C = zeros(tout, 5)
729729

730-
end
730+
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
731+
in (Const, Active), Tβ in (Const, Active)
731732

733+
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
734+
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))
735+
end
732736

733-
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
734-
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
735-
are_activities_compatible(Tret, Tret, TM, Tv) || continue
736-
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
737-
end
737+
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
738+
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
739+
are_activities_compatible(Tret, Tret, TM, Tv) || continue
740+
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
741+
end
738742

739-
# Test with a const output and active α and β
740-
(_,_,_,dα, dβ), = autodiff(Reverse, LinearAlgebra.mul!, Const, Const(C), Const(M), Const(v), Active(α), Active(β))
741-
@test 0
742-
@test 0
743+
test_reverse(LinearAlgebra.mul!, Const, (C, Const), (M, Const), (v, Const), (α, Active), (β, Active))
743744

745+
end
744746
end
745747

746-
@testset "SparseArrays spmatmat reverse rule" begin
747-
C = zeros(18, 11)
748-
M = sprand(18, 9, 0.1)
749-
v = randn(9, 11)
750-
α = 2.0
751-
β = 1.0
752-
753-
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
754-
in (Const, Active), Tβ in (Const, Active)
755748

756-
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
757-
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))
749+
@testset "SparseArrays spmatmat reverse rule" begin
750+
Ts = (Float64, ComplexF64)
751+
752+
for tm in Ts, tv in Ts, tα in Ts, tβ in Ts
753+
tout = promote_type(tm, tv, tα, tβ)
754+
C = zeros(tout, 5, 3)
755+
M = sprand(tm, 5, 3, 0.3)
756+
v = randn(tv, 3, 3)
757+
α = rand(tα)
758+
β = rand(tβ)
759+
760+
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated), Tv in (Const, Duplicated, BatchDuplicated),
761+
in (Const, Active), Tβ in (Const, Active)
762+
763+
are_activities_compatible(Tret, Tret, TM, Tv, Tα, Tβ) || continue
764+
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, TM), (v, Tv), (α, Tα), (β, Tβ))
765+
end
758766

759-
end
767+
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
768+
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
769+
are_activities_compatible(Tret, Tret, TM, Tv) || continue
770+
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
771+
end
760772

773+
test_reverse(LinearAlgebra.mul!, Const, (C, Const), (M, Const), (v, Const), (α, Active), (β, Active))
761774

762-
for Tret in (Duplicated, BatchDuplicated), TM in (Const, Duplicated, BatchDuplicated),
763-
Tv in (Const, Duplicated, BatchDuplicated), bα in (true, false), bβ in (true, false)
764-
are_activities_compatible(Tret, Tret, TM, Tv) || continue
765-
test_reverse(LinearAlgebra.mul!, Tret, (C, Tret), (M, Const), (v, Tv), (bα, Const), (bβ, Const))
766775
end
767-
768-
# Test with a const output and active α and β
769-
(_,_,_,dα, dβ), = autodiff(Reverse, LinearAlgebra.mul!, Const, Const(C), Const(M), Const(v), Active(α), Active(β))
770-
@test 0
771-
@test 0
772-
773776
end
774777

775778
end # InternalRules

0 commit comments

Comments
 (0)