Skip to content

Commit 32bf475

Browse files
committed
Split up tests. Test Float16 and Complex{Float32}. Test overriding mul!
1 parent e752f9a commit 32bf475

File tree

3 files changed

+52
-45
lines changed

3 files changed

+52
-45
lines changed

test/runtests.jl

+9-45
Original file line numberDiff line numberDiff line change
@@ -4,62 +4,26 @@ import SparseArrays: mul!, sprand
44
import Profile
55

66
# Helper function to run common test logic
7-
function run_common_tests(method!, buf::AbstractMatrix{T}, lhs, rhs, α, β, baseline) where {T <: Real}
7+
function run_common_tests(method!, buf::AbstractMatrix{T}, lhs, rhs, α, β, baseline) where {T <: Number}
88
method!(buf, lhs, rhs, α, β)
9-
@test buf baseline rtol=sqrt(eps(T))
9+
@test buf baseline rtol=sqrt(eps(real(T)))
1010
@test !any(isnan, buf)
1111

1212
# Test with negative α
1313
method!(buf, lhs, rhs, -α, β)
1414
method!(buf, lhs, rhs, α, β)
15-
@test buf baseline rtol=sqrt(eps(T))
15+
@test buf baseline rtol=sqrt(eps(real(T)))
1616
@test !any(isnan, buf)
1717
end
1818

1919
@testset "ThreadedDenseSparseMul Tests" begin
2020
@test ThreadedDenseSparseMul.get_num_threads() == Threads.nthreads()
21-
@testset "Dense-Sparse Multiplication" begin
22-
@testset "$T type" for T in [Float64, Float32]
23-
@testset "$method! implementation" for method! in [fastdensesparsemul!, fastdensesparsemul_threaded!]
24-
@testset "Trial $trial" for trial in 1:10
25-
lhs = rand(T, 50, 100)
26-
rhs = sprand(T, 100, 1_000, 0.1)
27-
baseline = lhs * rhs
28-
29-
buf = similar(baseline) .* false # fill buffer with zeros. Carefull with NaNs, see https://discourse.julialang.org/t/occasionally-nans-when-using-similar/48224/12
30-
31-
# Test basic multiplication
32-
run_common_tests(method!, buf, lhs, rhs, 1, 0, baseline)
33-
34-
# Test @view interface and β ≠ 0
35-
inds = rand(axes(lhs, 1), size(lhs, 1) ÷ 3)
36-
baseline[inds, :] .+= 2.5 * @view(lhs[inds, :]) * rhs
37-
38-
run_common_tests(method!, @view(buf[inds, :]), @view(lhs[inds, :]), rhs, 2.5, 1, @view(baseline[inds, :]))
39-
end
40-
end
41-
end
42-
end
43-
44-
@testset "Outer Product Multiplication" begin
45-
@testset "$T type" for T in [Float64, Float32]
46-
@testset "$method! implementation" for method! in [fastdensesparsemul_outer!, fastdensesparsemul_outer_threaded!]
47-
@testset "Trial $trial" for trial in 1:10
48-
lhs = rand(T, 50, 100)
49-
rhs = sprand(T, 100, 1_000, 0.1)
50-
k = rand(axes(rhs, 1))
51-
52-
baseline = lhs[:, k:k] * rhs[k:k, :]
53-
buf = similar(baseline) .* false
54-
55-
# Test basic outer product multiplication
56-
run_common_tests(method!, buf, @view(lhs[:, k]), rhs[k, :], 1, 0, baseline)
57-
58-
# Test with β ≠ 0
59-
baseline .+= 2.5 * lhs[:, (k+1):(k+1)] * rhs[(k+1):(k+1), :]
60-
run_common_tests(method!, buf, lhs[:, k+1], rhs[k+1, :], 2.5, 1, baseline)
61-
end
62-
end
21+
@testset "Override mul! ?" for override_mul! in [false, true]
22+
override_mul! && ThreadedDenseSparseMul.override_mul!()
23+
@testset "nthreads" for nthreads in [1, Threads.nthreads()]
24+
ThreadedDenseSparseMul.set_num_threads(nthreads)
25+
include("test_densesparsemul.jl")
26+
include("test_densesparseouter.jl")
6327
end
6428
end
6529
end

test/test_densesparsemul.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
@testset "Dense-Sparse Multiplication" begin
2+
@testset "$T type" for T in [Float64, Float32, Float16, Complex{Float32}]
3+
@testset "$method! implementation" for method! in [fastdensesparsemul!, fastdensesparsemul_threaded!]
4+
@testset "Trial $trial" for trial in 1:10
5+
lhs = rand(T, 50, 100)
6+
rhs = sprand(T, 100, 1_000, 0.1)
7+
baseline = lhs * rhs
8+
9+
buf = similar(baseline) .* false # fill buffer with zeros. Carefull with NaNs, see https://discourse.julialang.org/t/occasionally-nans-when-using-similar/48224/12
10+
11+
# Test basic multiplication
12+
run_common_tests(method!, buf, lhs, rhs, 1, 0, baseline)
13+
14+
# Test @view interface and β ≠ 0
15+
inds = rand(axes(lhs, 1), size(lhs, 1) ÷ 3)
16+
baseline[inds, :] .+= 2.5 * @view(lhs[inds, :]) * rhs
17+
18+
run_common_tests(method!, @view(buf[inds, :]), @view(lhs[inds, :]), rhs, 2.5, 1, @view(baseline[inds, :]))
19+
end
20+
end
21+
end
22+
end

test/test_densesparseouter.jl

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
@testset "Outer Product Multiplication" begin
2+
@testset "$T type" for T in [Float64, Float32]
3+
@testset "$method! implementation" for method! in [fastdensesparsemul_outer!, fastdensesparsemul_outer_threaded!]
4+
@testset "Trial $trial" for trial in 1:10
5+
lhs = rand(T, 50, 100)
6+
rhs = sprand(T, 100, 1_000, 0.1)
7+
k = rand(axes(rhs, 1))
8+
9+
baseline = lhs[:, k:k] * rhs[k:k, :]
10+
buf = similar(baseline) .* false
11+
12+
# Test basic outer product multiplication
13+
run_common_tests(method!, buf, @view(lhs[:, k]), rhs[k, :], 1, 0, baseline)
14+
15+
# Test with β ≠ 0
16+
baseline .+= 2.5 * lhs[:, (k+1):(k+1)] * rhs[(k+1):(k+1), :]
17+
run_common_tests(method!, buf, lhs[:, k+1], rhs[k+1, :], 2.5, 1, baseline)
18+
end
19+
end
20+
end
21+
end

0 commit comments

Comments
 (0)