Skip to content

Commit 6bc880d

Browse files
committed
Manage SparseMatrix component datatypes during conversion.
Improve tests.
1 parent 7b29bc1 commit 6bc880d

File tree

4 files changed

+147
-70
lines changed

4 files changed

+147
-70
lines changed

src/SparseMatrixCSR.jl

+29-5
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,8 @@ Convert x to a value of type SymSparseMatrixCSR.
294294
"""
295295
convert(::Type{SparseMatrixCSR}, x::AbstractSparseMatrix) = convert(SparseMatrixCSR{1}, x)
296296

297-
function convert(::Type{SparseMatrixCSR{Bi}}, x::SparseMatrixCSR{Bj}) where {Bi,Bj}
298-
if Bi == Bj
297+
function convert(::Type{SparseMatrixCSR{Bi}}, x::SparseMatrixCSR{xBi}) where {Bi,xBi}
298+
if Bi == xBi
299299
return x
300300
else
301301
return SparseMatrixCSR{Bi}( x.m,
@@ -306,20 +306,44 @@ function convert(::Type{SparseMatrixCSR{Bi}}, x::SparseMatrixCSR{Bj}) where {Bi,
306306
end
307307
end
308308

309+
function convert(::Type{SparseMatrixCSR{Bi,Tv,Ti}}, x::SparseMatrixCSR{xBi,xTv,xTi}) where {Bi,Tv,Ti,xBi,xTv,xTi}
310+
if (Bi,Tv,Ti) == (xBi,xTv,xTi)
311+
return x
312+
else
313+
return SparseMatrixCSR{Bi}( x.m,
314+
x.n,
315+
convert(Vector{Ti}, copy(getptr(x)).-x.offset),
316+
convert(Vector{Ti}, copy(getindices(x)).-x.offset),
317+
convert(Vector{Tv}, copy(nonzeros(x))))
318+
end
319+
end
320+
309321
function convert(::Type{SparseMatrixCSR{Bi}}, x::SparseMatrixCSC) where {Bi}
310322
A = sparse(transpose(x))
311323
(m, n) = size(A)
312324
return SparseMatrixCSR{Bi}(m, n, getptr(A), getindices(A), nonzeros(A))
313325
end
314326

327+
function convert(::Type{SparseMatrixCSR{Bi,Tv,Ti}}, x::SparseMatrixCSC) where {Bi,Tv,Ti}
328+
A = sparse(transpose(x))
329+
return SparseMatrixCSR{Bi}( A.m,
330+
A.n,
331+
convert(Vector{Ti}, getptr(A)),
332+
convert(Vector{Ti}, getindices(A)),
333+
convert(Vector{Tv}, nonzeros(A)))
334+
end
335+
315336
"""
316337
function convert(::Type{SparseMatrixCSC}, x::SparseMatrixCSR)
317338
318339
Convert x to a value of type SparseMatrixCSC.
319340
"""
320-
function convert(::Type{SparseMatrixCSC}, x::SparseMatrixCSR{Bi}) where {Bi}
341+
function convert(::Type{SparseMatrixCSC{Tv,Ti}}, x::SparseMatrixCSR{xBi,xTv,xTi}) where {Tv,Ti,xBi,xTv,xTi}
321342
A = sparse(transpose(x))
322-
(m, n) = size(A)
323-
return SparseMatrixCSR{Bi}(m, n, getptr(A), getindices(A), nonzeros(A))
343+
return SparseMatrixCSR{Bi}( A.m,
344+
A.n,
345+
convert(Vector{Ti}, getptr(A)),
346+
convert(Vector{Ti}, getindices(A)),
347+
convert(Vector{Tv}, nonzeros(A)))
324348
end
325349

src/SymSparseMatrixCSR.jl

+11-2
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,19 @@ Convert x to a value of type SymSparseMatrixCSR.
221221
"""
222222
convert(::Type{SymSparseMatrixCSR}, x::SymSparseMatrixCSR) = convert(SymSparseMatrixCSR{1}, x)
223223

224-
function convert(::Type{SymSparseMatrixCSR{Bi}}, x::SymSparseMatrixCSR{Bj}) where {Bi,Bj}
225-
if Bi == Bj
224+
function convert(::Type{SymSparseMatrixCSR{Bi}}, x::SymSparseMatrixCSR{xBi}) where {Bi,xBi}
225+
if Bi == xBi
226226
return x
227227
else
228228
return SymSparseMatrixCSR(convert(SparseMatrixCSR{Bi}, x.uppertrian))
229229
end
230230
end
231+
232+
function convert(::Type{SymSparseMatrixCSR{Bi,Tv,Ti}}, x::SymSparseMatrixCSR{xBi,xTv,xTi}) where {Bi,Tv,Ti,xBi,xTv,xTi}
233+
if (Bi,Tv,Ti) == (xBi,xTv,xTi)
234+
return x
235+
else
236+
return SymSparseMatrixCSR(convert(SparseMatrixCSR{Bi,Tv,Ti}, x.uppertrian))
237+
end
238+
end
239+

test/SparseMatrixCSR.jl

+55-35
Original file line numberDiff line numberDiff line change
@@ -2,55 +2,75 @@
22
maxnz=5
33
maxrows=5
44
maxcols=5
5+
int_types=(Int32,Int64)
6+
float_types=(Float32,Float64)
7+
Bi_types=(0,1)
58

6-
for T in (Int32,Int64,Float32,Float64)
7-
for Bi in (0,1)
8-
I = Vector{Int}()
9-
J = Vector{Int}()
10-
V = Vector{T}()
11-
for (ik, jk, vk) in zip(rand(1:maxrows, maxnz), rand(1:maxcols, maxnz), rand(1:T(maxnz), maxnz))
12-
push_coo!(SparseMatrixCSR,I,J,V,ik,jk,vk)
13-
end
14-
finalize_coo!(SparseMatrixCSR,I,J,V,maxrows,maxcols)
15-
CSC = sparse(I, J, V, maxrows,maxcols)
16-
CSR = sparsecsr(SparseMatrixCSR{Bi},I, J, V,maxrows,maxcols)
9+
for Ti in int_types
10+
for Tv in float_types
11+
for Bi in Bi_types
12+
I = Vector{Ti}()
13+
J = Vector{Ti}()
14+
V = Vector{Tv}()
15+
for (ik, jk, vk) in zip(rand(1:maxrows, maxnz), rand(1:maxcols, maxnz), rand(1:Tv(maxnz), maxnz))
16+
push_coo!(SparseMatrixCSR,I,J,V,ik,jk,vk)
17+
end
18+
finalize_coo!(SparseMatrixCSR,I,J,V,maxrows,maxcols)
19+
CSC = sparse(I, J, V, maxrows,maxcols)
20+
CSR = sparsecsr(SparseMatrixCSR{Bi},I, J, V,maxrows,maxcols)
1721

18-
display(CSR)
22+
show(devnull, CSR);
1923

20-
@test CSC == CSR
24+
@test CSC == CSR
25+
@test nnz(CSC) == count(!iszero, CSC) == nnz(CSR) == count(!iszero, CSR)
2126

22-
@test convert(SparseMatrixCSR{Bi}, CSR) === CSR
27+
@test hasrowmajororder(CSR) == true
28+
@test hascolmajororder(CSR) == false
29+
@test getptr(CSR) == CSR.rowptr
30+
@test getindices(CSR) == colvals(CSR)
2331

24-
CSRC = convert(SparseMatrixCSR{CSR.offset}, CSR)
25-
@test CSRC == CSR
26-
@test CSRC !== CSR
32+
TCSC = sparse(J, I, V, maxrows, maxcols)
33+
TCSR = sparsecsr(SparseMatrixCSR{Bi}, J, I, V, maxrows, maxcols)
2734

28-
@test nnz(CSC) == count(!iszero, CSC) == nnz(CSR) == count(!iszero, CSR)
35+
@test size(CSC)==size(CSR)==reverse(size(TCSC))==reverse(size(TCSC))
2936

30-
@test hasrowmajororder(CSR) == true
31-
@test hascolmajororder(CSR) == false
32-
@test getptr(CSR) == CSR.rowptr
33-
@test getindices(CSR) == colvals(CSR)
37+
@test [nzrange(CSC,col) for col in 1:size(CSC,2)] == [nzrange(TCSR,row) for row in 1:size(TCSR,1)]
38+
@test [nzrange(CSR,row) for row in 1:size(CSR,1)] == [nzrange(TCSC,col) for col in 1:size(TCSC,2)]
3439

35-
TCSC = sparse(J, I, V, maxrows, maxcols)
36-
TCSR = sparsecsr(SparseMatrixCSR{Bi}, J, I, V, maxrows, maxcols)
40+
@test nonzeros(CSC) == nonzeros(TCSR) && nonzeros(CSR) == nonzeros(TCSC)
3741

38-
@test size(CSC)==size(CSR)==reverse(size(TCSC))==reverse(size(TCSC))
42+
ICSC,JCSC,VCSC= findnz(CSC)
43+
ICSR,JCSR,VCSR= findnz(CSR)
3944

40-
@test [nzrange(CSC,col) for col in 1:size(CSC,2)] == [nzrange(TCSR,row) for row in 1:size(TCSR,1)]
41-
@test [nzrange(CSR,row) for row in 1:size(CSR,1)] == [nzrange(TCSC,col) for col in 1:size(TCSC,2)]
45+
@test sort(ICSC)==sort(JCSR) && sort(JCSC)==sort(ICSR) && sort(VCSC)==sort(VCSR)
4246

43-
@test nonzeros(CSC) == nonzeros(TCSR) && nonzeros(CSR) == nonzeros(TCSC)
47+
v = rand(size(CSC)[2])
48+
@test CSC*v == CSR*v
4449

45-
ICSC,JCSC,VCSC= findnz(CSC)
46-
ICSR,JCSR,VCSR= findnz(CSR)
50+
for cBi in Bi_types
51+
if Bi == cBi
52+
@test convert(SparseMatrixCSR{Bi}, CSR) === CSR
53+
else
54+
CSRC = convert(SparseMatrixCSR{cBi}, CSR)
55+
@test CSRC == CSR
56+
@test CSRC !== CSR
57+
end
58+
end
4759

48-
@test sort(ICSC)==sort(JCSR) && sort(JCSC)==sort(ICSR) && sort(VCSC)==sort(VCSR)
60+
for cTi in int_types
61+
for cTv in float_types
62+
if (Ti,Tv) == (cTi,cTv)
63+
@test convert(SparseMatrixCSR{Bi,Tv,Ti}, CSR) === CSR
64+
else
65+
CSRC = convert(SparseMatrixCSR{Bi,cTv,cTi}, CSR)
66+
@test CSRC == CSR
67+
@test CSRC !== CSR
68+
end
69+
end
70+
end
4971

50-
v = rand(size(CSC)[2])
51-
@test CSC*v == CSR*v
72+
end
5273
end
53-
end
54-
74+
end
5575
end
5676

test/SymSparseMatrixCSR.jl

+52-28
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,67 @@
22
maxnz=5
33
maxrows=5
44
maxcols=5
5+
int_types=(Int32,Int64)
6+
float_types=(Float32,Float64)
7+
Bi_types=(0,1)
58

6-
for T in (Int32,Int64,Float32,Float64)
7-
for Bi in (0,1)
8-
I = Vector{Int}()
9-
J = Vector{Int}()
10-
V = Vector{T}()
11-
for (ik, jk, vk) in zip(rand(1:maxrows, maxnz), rand(1:maxcols, maxnz), rand(1:T(maxnz), maxnz) )
12-
push_coo!(SymSparseMatrixCSR,I,J,V,ik,jk,vk)
13-
end
14-
finalize_coo!(SymSparseMatrixCSR,I,J,V,maxrows, maxcols)
15-
SYMCSC = Symmetric(sparse(I, J, V, maxrows, maxcols),:U)
16-
SYMCSR = symsparsecsr(SymSparseMatrixCSR{Bi},I, J, V, maxrows, maxcols)
179

18-
@test size(SYMCSC)==size(SYMCSR)
19-
@test SYMCSC == SYMCSR
10+
for Ti in int_types
11+
for Tv in float_types
12+
for Bi in Bi_types
13+
I = Vector{Ti}()
14+
J = Vector{Ti}()
15+
V = Vector{Tv}()
16+
for (ik, jk, vk) in zip(rand(1:maxrows, maxnz), rand(1:maxcols, maxnz), rand(1:Tv(maxnz), maxnz) )
17+
push_coo!(SymSparseMatrixCSR,I,J,V,ik,jk,vk)
18+
end
19+
finalize_coo!(SymSparseMatrixCSR,I,J,V,maxrows, maxcols)
20+
SYMCSC = Symmetric(sparse(I, J, V, maxrows, maxcols),:U)
21+
SYMCSR = symsparsecsr(SymSparseMatrixCSR{Bi},I, J, V, maxrows, maxcols)
22+
23+
@test size(SYMCSC)==size(SYMCSR)
24+
@test SYMCSC == SYMCSR
2025

21-
@test convert(SymSparseMatrixCSR{Bi}, SYMCSR) === SYMCSR
26+
@test convert(SymSparseMatrixCSR{Bi}, SYMCSR) === SYMCSR
2227

23-
SYMCSRC = convert(SymSparseMatrixCSR{SYMCSR.uppertrian.offset}, SYMCSR)
24-
@test SYMCSRC == SYMCSR
25-
@test SYMCSRC !== SYMCSR
28+
@test hasrowmajororder(SYMCSR) == true
29+
@test hascolmajororder(SYMCSR) == false
30+
@test getptr(SYMCSR) == SYMCSR.uppertrian.rowptr
31+
@test getindices(SYMCSR) == colvals(SYMCSR)
2632

27-
@test hasrowmajororder(SYMCSR) == true
28-
@test hascolmajororder(SYMCSR) == false
29-
@test getptr(SYMCSR) == SYMCSR.uppertrian.rowptr
30-
@test getindices(SYMCSR) == colvals(SYMCSR)
33+
@test nnz(SYMCSC.data) == nnz(SYMCSR.uppertrian) <= nnz(SYMCSR)
34+
@test count(!iszero, SYMCSC.data) == count(!iszero, SYMCSR.uppertrian)
3135

32-
@test nnz(SYMCSC.data) == nnz(SYMCSR.uppertrian) <= nnz(SYMCSR)
33-
@test count(!iszero, SYMCSC.data) == count(!iszero, SYMCSR.uppertrian)
36+
ICSC,JCSC,VCSC= findnz(SYMCSC.data)
37+
ICSR,JCSR,VCSR= findnz(SYMCSR)
3438

35-
ICSC,JCSC,VCSC= findnz(SYMCSC.data)
36-
ICSR,JCSR,VCSR= findnz(SYMCSR)
39+
@test sort(ICSC)==sort(JCSR) && sort(JCSC)==sort(ICSR) && sort(VCSC)==sort(VCSR)
3740

38-
@test sort(ICSC)==sort(JCSR) && sort(JCSC)==sort(ICSR) && sort(VCSC)==sort(VCSR)
41+
v = rand(size(SYMCSC)[2])
42+
@test SYMCSC*v == SYMCSR*v
3943

40-
v = rand(size(SYMCSC)[2])
41-
@test SYMCSC*v == SYMCSR*v
44+
for cBi in Bi_types
45+
if Bi == cBi
46+
@test convert(SymSparseMatrixCSR{Bi}, SYMCSR) === SYMCSR
47+
else
48+
SYMCSRC = convert(SymSparseMatrixCSR{cBi}, SYMCSR)
49+
@test SYMCSRC == SYMCSR
50+
@test SYMCSRC !== SYMCSR
51+
end
52+
end
53+
54+
for cTi in int_types
55+
for cTv in float_types
56+
if (Ti,Tv) == (cTi,cTv)
57+
@test convert(SymSparseMatrixCSR{Bi,Tv,Ti}, SYMCSR) === SYMCSR
58+
else
59+
SYMCSRC = convert(SymSparseMatrixCSR{Bi,cTv,cTi}, SYMCSR)
60+
@test SYMCSRC == SYMCSR
61+
@test SYMCSRC !== SYMCSR
62+
end
63+
end
64+
end
65+
end
4266
end
4367
end
4468

0 commit comments

Comments
 (0)