Skip to content

Commit e752f9a

Browse files
committed
Implement type piracy.
1 parent 4b0b5b1 commit e752f9a

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/ThreadedDenseSparseMul.jl

+2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@ export fastdensesparsemul_outer!, fastdensesparsemul_outer_threaded!
99

1010
# Adapted from https://github.com/BacAmorim/ThreadedSparseCSR.jl/tree/main
1111
include("set_num_threads.jl")
12+
include("override.jl")
1213
function __init__()
1314
set_num_threads(Threads.nthreads())
1415
end
1516

17+
1618
const VecOrView{T} = Union{Vector{T}, SubArray{T, 1, Matrix{T}}}
1719
const MatOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}}
1820

src/override.jl

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import LinearAlgebra: mul!, Adjoint
2+
3+
function override_mul!(threaded = true)
4+
if threaded
5+
@eval function mul!(C::MatOrView{T}, A::MatOrView{T}, B::SparseMatrixCSC{T}, α::Number, β::Number) where T
6+
fastdensesparsemul_threaded!(C, A, B, α, β)
7+
end
8+
@eval function mul!(C::MatOrView{T}, a::VecOrView{T}, b::Adjoint{<:SparseVector{T}}, α::Number, β::Number) where T
9+
fastdensesparsemul_outer_threaded!(C, a, b', α, β)
10+
end
11+
else
12+
@eval function mul!(C::MatOrView{T}, A::MatOrView{T}, B::SparseMatrixCSC{T}, α::Number, β::Number) where T
13+
fastdensesparsemul!(C, A, B, α, β)
14+
end
15+
@eval function mul!(C::MatOrView{T}, a::VecOrView{T}, b::Adjoint{<:SparseVector{T}}, α::Number, β::Number) where T
16+
fastdensesparsemul_outer!(C, a, b', α, β)
17+
end
18+
end
19+
end

0 commit comments

Comments
 (0)