diff --git a/src/MatrixMarket.jl b/src/MatrixMarket.jl index f20b711..82a1f4e 100644 --- a/src/MatrixMarket.jl +++ b/src/MatrixMarket.jl @@ -7,6 +7,7 @@ using TranscodingStreams, CodecZlib export mmread, mmwrite, mminfo +include("format.jl") include("mminfo.jl") include("mmread.jl") include("mmwrite.jl") diff --git a/src/format.jl b/src/format.jl new file mode 100644 index 0000000..b9b83b1 --- /dev/null +++ b/src/format.jl @@ -0,0 +1,106 @@ +abstract type MMFormat end + +Base.length(f::MMFormat) = length(f.vals) + +function readout(f::MMFormat, nrow::Int, ncol::Int, nentry::Int, symm) + rep = formattext(f) + field = generate_eltype(eltype(f)) + return (Tuple(f)..., nrow, ncol, nentry, rep, field, symm) +end + +struct CoordinateFormat{T} <: MMFormat + rows::Vector{Int} + cols::Vector{Int} + vals::Vector{T} +end + +function CoordinateFormat(field, nentry) + T = parse_eltype(field) + rows = Vector{Int}(undef, nentry) + cols = Vector{Int}(undef, nentry) + vals = Vector{T}(undef, nentry) + return CoordinateFormat{T}(rows, cols, vals) +end + +function CoordinateFormat(A::SparseMatrixCSC{T}) where {T} + rows = rowvals(A) + vals = nonzeros(A) + n = size(A, 2) + cols = [repeat([j], length(nzrange(A, j))) for j in 1:n] + cols = collect(Iterators.flatten(cols)) + return CoordinateFormat{T}(rows, cols, vals) +end + +Base.eltype(::CoordinateFormat{T}) where T = T + +formattext(::CoordinateFormat) = "coordinate" + +Base.Tuple(f::CoordinateFormat) = (f.rows, f.cols, f.vals) + +Base.:(==)(x::CoordinateFormat, y::CoordinateFormat) = (x.rows == y.rows) && + (x.cols == y.cols) && (x.vals == y.vals) + +function writeat!(f::CoordinateFormat{T}, i::Int, line::String) where T + f.rows[i], f.cols[i], f.vals[i] = parseline(T, line) + return f +end + +function readout(f::CoordinateFormat, nrow::Int, ncol::Int, symm) + symfunc = parse_symmetric(symm) + return symfunc(sparse(f.rows, f.cols, f.vals, nrow, ncol)) +end + +function Base.iterate(f::CoordinateFormat, i::Integer=zero(length(f))) + i += oneunit(i) + if i <= length(f) + return (f.rows[i], f.cols[i], f.vals[i]), i + else + return nothing + end +end + +struct ArrayFormat{T} <: MMFormat + vals::Vector{T} +end + +function ArrayFormat(field, nentry::Int) + T = parse_eltype(field) + return ArrayFormat(T, nentry) +end + +function ArrayFormat(::Type{T}, nentry::Int) where {T} + vals = Vector{T}(undef, nentry) + return ArrayFormat{T}(vals) +end + +ArrayFormat(nentry::Int) = ArrayFormat(Float64, nentry) + +ArrayFormat(A::AbstractMatrix{T}) where {T} = ArrayFormat{T}(reshape(A, :)) + +Base.eltype(::ArrayFormat{T}) where T = T + +formattext(::ArrayFormat) = "array" + +Base.Tuple(f::ArrayFormat) = (f.vals,) + +Base.:(==)(x::ArrayFormat, y::ArrayFormat) = (x.vals == y.vals) + +function writeat!(f::ArrayFormat{T}, i::Int, line::String) where T + f.vals[i] = parse(T, line) + return f +end + +function readout(f::ArrayFormat, nrow::Int, ncol::Int, symm) + A = reshape(f.vals, nrow, ncol) + symfunc = parse_symmetric(symm) + return symfunc(A) +end + +function Base.iterate(f::ArrayFormat, i::Integer=zero(length(f))) + i += oneunit(i) + if i <= length(f) + return f.vals[i], i + else + return nothing + end +end diff --git a/src/mmread.jl b/src/mmread.jl index 6c3ec60..b1a5606 100644 --- a/src/mmread.jl +++ b/src/mmread.jl @@ -27,36 +27,17 @@ function mmread(filename::String, infoonly::Bool=false, retcoord::Bool=false) end function mmread(stream::IO, infoonly::Bool=false, retcoord::Bool=false) - rows, cols, entries, rep, field, symm = mminfo(stream) - - infoonly && return rows, cols, entries, rep, field, symm - - T = parse_eltype(field) - symfunc = parse_symmetric(symm) - - if rep == "coordinate" - rn = Vector{Int}(undef, entries) - cn = Vector{Int}(undef, entries) - vals = Vector{T}(undef, entries) - for i in 1:entries - line = readline(stream) - splits = find_splits(line, num_splits(T)) - rn[i] = parse_row(line, splits) - cn[i] = parse_col(line, splits, T) - vals[i] = parse_val(line, splits, T) - end + nrow, ncol, nentry, rep, field, symm = mminfo(stream) - result = retcoord ? (rn, cn, vals, rows, cols, entries, rep, field, symm) : - symfunc(sparse(rn, cn, vals, rows, cols)) - else - vals = [parse(Float64, readline(stream)) for _ in 1:entries] - A = reshape(vals, rows, cols) - result = symfunc(A) - end + infoonly && return nrow, ncol, nentry, rep, field, symm - return result + reader = MMReader(nrow, ncol, nentry, rep, field, symm) + readlines!(reader, stream) + return readout(reader, retcoord) end +## Parsing + function parse_eltype(field::String) if field == "real" return Float64 @@ -107,6 +88,14 @@ end parse_val(line, splits, ::Type{Bool}) = true parse_val(line, splits, ::Type{T}) where {T} = parse(T, line[splits[2]:length(line)]) +function parseline(::Type{T}, line) where T + splits = find_splits(line, num_splits(T)) + r = parse_row(line, splits) + c = parse_col(line, splits, T) + v = parse_val(line, splits, T) + return r, c, v +end + num_splits(::Type{ComplexF64}) = 3 num_splits(::Type{Bool}) = 1 num_splits(elty) = 2 @@ -130,3 +119,36 @@ function find_splits(s::String, num) splits end + +## Reader + +struct MMReader{F <: MMFormat} + nrow::Int + ncol::Int + nentry::Int + rep::String + symm::String + format::F +end + +function MMReader(nrow::Integer, ncol::Integer, nentry::Integer, rep, field, symm) + @assert nentry <= nrow * ncol "given nentry ($nentry) is greater than the product of nrow and ncol ($(nrow * ncol))" + format = (rep == "coordinate") ? CoordinateFormat(field, nentry) : ArrayFormat(field, nentry) + return MMReader{typeof(format)}(nrow, ncol, nentry, rep, symm, format) +end + +function readlines!(reader::MMReader, stream::IO) + for i in 1:reader.nentry + line = readline(stream) + writeat!(reader.format, i, line) + end + return reader +end + +function readout(reader::MMReader, retcoord::Bool=false) + if retcoord + return readout(reader.format, reader.nrow, reader.ncol, reader.nentry, reader.symm) + else + return readout(reader.format, reader.nrow, reader.ncol, reader.symm) + end +end diff --git a/src/mmwrite.jl b/src/mmwrite.jl index 29ec352..61ab335 100644 --- a/src/mmwrite.jl +++ b/src/mmwrite.jl @@ -19,32 +19,24 @@ function mmwrite(filename::String, matrix::SparseMatrixCSC) close(stream) end -function mmwrite(stream::IO, matrix::SparseMatrixCSC) +function mmwrite(stream::IO, matrix::SparseMatrixCSC{T}) where {T} nl = get_newline() - elem = generate_eltype(eltype(matrix)) - sym = generate_symmetric(matrix) + elem = generate_eltype(T) + writer = MMWriter(matrix) + write(stream, header(writer)) + write(stream, nl) + write(stream, sizetext(writer)) + write(stream, nl) - # write header - write(stream, "%%MatrixMarket matrix coordinate $elem $sym$nl") - - # only use lower triangular part of symmetric and Hermitian matrices - if issymmetric(matrix) || ishermitian(matrix) - matrix = tril(matrix) - end - - # write matrix size and number of nonzeros - write(stream, "$(size(matrix, 1)) $(size(matrix, 2)) $(nnz(matrix))$nl") - - rows = rowvals(matrix) - vals = nonzeros(matrix) - for i in 1:size(matrix, 2) - for j in nzrange(matrix, i) - entity = generate_entity(i, j, rows, vals, elem) - write(stream, entity) - end + for (r, c, v) in writer.format + entity = generate_entity(r, c, v, elem) + write(stream, entity) + write(stream, nl) end end +## Generating + generate_eltype(::Type{<:Bool}) = "pattern" generate_eltype(::Type{<:Integer}) = "integer" generate_eltype(::Type{<:AbstractFloat}) = "real" @@ -61,14 +53,13 @@ function generate_symmetric(m::AbstractMatrix) end end -function generate_entity(i, j, rows, vals, kind::String) - nl = get_newline() +function generate_entity(r, c, v, kind::String) if kind == "pattern" - return "$(rows[j]) $i$nl" + return "$r $c" elseif kind == "complex" - return "$(rows[j]) $i $(real(vals[j])) $(imag(vals[j]))$nl" + return "$r $c $(real(v)) $(imag(v))" else - return "$(rows[j]) $i $(vals[j])$nl" + return "$r $c $v" end end @@ -79,3 +70,44 @@ function get_newline() return "\n" end end + +## Writer + +struct MMWriter{F <: MMFormat} + nrow::Int + ncol::Int + nentry::Int + symm::String + format::F +end + +function MMWriter(A::AbstractMatrix{T}) where {T} + nrow, ncol = size(A) + nentry = nrow * ncol + vals = reshape(A, :) + symm = generate_symmetric(A) + format = ArrayFormat{T}(vals) + return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format) +end + +function MMWriter(A::SparseMatrixCSC) + nrow, ncol = size(A) + symm = generate_symmetric(A) + + # only use lower triangular part of symmetric and Hermitian matrices + if symm == "symmetric" || symm == "hermitian" + A = tril(A) + end + + nentry = nnz(A) + format = CoordinateFormat(A) + return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format) +end + +function header(writer::MMWriter) + rep = formattext(writer.format) + elem = generate_eltype(eltype(writer.format)) + return "%%MatrixMarket matrix $rep $elem $(writer.symm)" +end + +sizetext(writer::MMWriter) = "$(writer.nrow) $(writer.ncol) $(writer.nentry)" diff --git a/test/format.jl b/test/format.jl new file mode 100644 index 0000000..31b2435 --- /dev/null +++ b/test/format.jl @@ -0,0 +1,38 @@ +@testset "format" begin + @testset "CoordinateFormat" begin + T = Float64 + rows = [1, 2, 2, 3, 5, 7] + cols = [1, 1, 2, 3, 4, 4] + vals = T[1, 2, 3, 4, 5, 6] + A = sparse(rows, cols, vals) + + f = MatrixMarket.CoordinateFormat(rows, cols, vals) + @test MatrixMarket.CoordinateFormat(A) == f + @test length(f) == length(vals) + @test eltype(f) == T + @test MatrixMarket.formattext(f) == "coordinate" + @test Tuple(f) == (rows, cols, vals) + @test MatrixMarket.readout(f, 7, 4, "general") == A + + MatrixMarket.writeat!(f, 2, "3 1 7") + @test (f.rows[2], f.cols[2], f.vals[2]) == (3, 1, 7) + end + + @testset "ArrayFormat" begin + T = Float64 + vals = T[1, 2, 3, 4, 5, 6] + A = reshape(vals, 2, 3) + + f = MatrixMarket.ArrayFormat(vals) + @test MatrixMarket.ArrayFormat(A) == f + @test eltype(MatrixMarket.ArrayFormat(length(vals))) == Float64 + @test length(f) == length(vals) + @test eltype(f) == T + @test MatrixMarket.formattext(f) == "array" + @test Tuple(f) == (vals, ) + @test MatrixMarket.readout(f, 2, 3, "general") == A + + MatrixMarket.writeat!(f, 2, "7") + @test f.vals[2] == 7 + end +end diff --git a/test/mtx.jl b/test/mtx.jl index 2ce6674..c979377 100644 --- a/test/mtx.jl +++ b/test/mtx.jl @@ -10,72 +10,87 @@ testmatrices = download_unzip_nist_files() @testset "read/write mtx" begin - rows, cols, entries, rep, field, symm = mminfo(mtx_filename) - @test rows == 11 - @test cols == 12 - @test entries == 5 - @test rep == "coordinate" - @test field == "integer" - @test symm == "general" - - A = mmread(mtx_filename) - @test A isa SparseMatrixCSC - @test A == res - - newfilename = replace(mtx_filename, "test.mtx" => "test_write.mtx") - mmwrite(newfilename, res) - - f = open(mtx_filename) - sha_test = bytes2hex(sha256(read(f, String))) - close(f) - - f = open(newfilename) - sha_new = bytes2hex(sha256(read(f, String))) - close(f) - - @test sha_test == sha_new - rm(newfilename) + @testset "mminfo test.mtx" begin + rows, cols, entries, rep, field, symm = mminfo(mtx_filename) + @test rows == 11 + @test cols == 12 + @test entries == 5 + @test rep == "coordinate" + @test field == "integer" + @test symm == "general" + end + + @testset "mmread test.mtx" begin + A = mmread(mtx_filename) + @test A isa SparseMatrixCSC + @test A == res + end + + @testset "mmwrite test.mtx" begin + newfilename = replace(mtx_filename, "test.mtx" => "test_write.mtx") + mmwrite(newfilename, res) + + f = open(mtx_filename) + sha_test = bytes2hex(sha256(read(f, String))) + close(f) + + f = open(newfilename) + sha_new = bytes2hex(sha256(read(f, String))) + close(f) + + @test sha_test == sha_new + rm(newfilename) + end end @testset "read/write mtx.gz" begin gz_filename = mtx_filename * ".gz" - rows, cols, entries, rep, field, symm = mminfo(gz_filename) - @test rows == 11 - @test cols == 12 - @test entries == 5 - @test rep == "coordinate" - @test field == "integer" - @test symm == "general" - - A = mmread(gz_filename) - @test A isa SparseMatrixCSC - @test A == res - - newfilename = replace(gz_filename, "test.mtx.gz" => "test_write.mtx.gz") - mmwrite(newfilename, res) - - stream = GzipDecompressorStream(open(gz_filename)) - adjusted_content = replace(read(stream, String), "\n" => get_newline()) - sha_test = bytes2hex(sha256(adjusted_content)) - close(stream) - - stream = GzipDecompressorStream(open(newfilename)) - sha_new = bytes2hex(sha256(read(stream, String))) - close(stream) - - @test sha_test == sha_new - rm(newfilename) + @testset "mminfo test.mtx.gz" begin + rows, cols, entries, rep, field, symm = mminfo(gz_filename) + @test rows == 11 + @test cols == 12 + @test entries == 5 + @test rep == "coordinate" + @test field == "integer" + @test symm == "general" + end + + @testset "mmread test.mtx.gz" begin + A = mmread(gz_filename) + @test A isa SparseMatrixCSC + @test A == res + end + + @testset "mmwrite test.mtx.gz" begin + newfilename = replace(gz_filename, "test.mtx.gz" => "test_write.mtx.gz") + mmwrite(newfilename, res) + + stream = GzipDecompressorStream(open(gz_filename)) + adjusted_content = replace(read(stream, String), "\n" => get_newline()) + sha_test = bytes2hex(sha256(adjusted_content)) + close(stream) + + stream = GzipDecompressorStream(open(newfilename)) + sha_new = bytes2hex(sha256(read(stream, String))) + close(stream) + + @test sha_test == sha_new + rm(newfilename) + end end @testset "read/write NIST mtx files" begin # verify mmread(mmwrite(A)) == A for filename in filter(t -> endswith(t, ".mtx"), readdir()) new_filename = replace(filename, ".mtx" => "_.mtx") - A = MatrixMarket.mmread(filename) - MatrixMarket.mmwrite(new_filename, A) - new_A = MatrixMarket.mmread(new_filename) - @test new_A == A + @testset "$filename" begin + A = MatrixMarket.mmread(filename) + MatrixMarket.mmwrite(new_filename, A) + new_A = MatrixMarket.mmread(new_filename) + @test new_A == A + end + rm(new_filename) end end @@ -83,24 +98,43 @@ @testset "read/write NIST mtx.gz files" begin for gz_filename in filter(t -> endswith(t, ".mtx.gz"), readdir()) mtx_filename = replace(gz_filename, ".mtx.gz" => ".mtx") - - # reading from .mtx and .mtx.gz must be identical - A_gz = MatrixMarket.mmread(gz_filename) - + new_filename = replace(gz_filename, ".mtx.gz" => "_.mtx.gz") A = MatrixMarket.mmread(mtx_filename) - @test A_gz == A - # writing to .mtx and .mtx.gz must be identical - new_filename = replace(gz_filename, ".mtx.gz" => "_.mtx.gz") - mmwrite(new_filename, A) + @testset "mmread $gz_filename" begin + # reading from .mtx and .mtx.gz must be identical + A_gz = MatrixMarket.mmread(gz_filename) + @test A_gz == A + end - new_A = MatrixMarket.mmread(new_filename) - @test new_A == A + @testset "mmwrite $gz_filename" begin + # writing to .mtx and .mtx.gz must be identical + mmwrite(new_filename, A) + new_A = MatrixMarket.mmread(new_filename) + @test new_A == A + end rm(new_filename) end end + @testset "read from online NIST mtx.gz files" begin + for (collectionname, setname, matrixname) in testmatrices + url = "https://math.nist.gov/pub/MatrixMarket2/$collectionname/$setname/$matrixname.mtx.gz" + mtx_filename = string(collectionname, '_', setname, '_', matrixname, ".mtx") + A = MatrixMarket.mmread(mtx_filename) + + @testset "mmread $matrixname.mtx.gz" begin + # reading from .mtx and .mtx.gz must be identical + buffer = PipeBuffer() + stream = TranscodingStream(GzipDecompressor(), buffer) + Downloads.download(url, buffer) + A_gz = MatrixMarket.mmread(stream) + @test A_gz == A + end + end + end + # clean up for filename in filter(t -> endswith(t, ".mtx"), readdir()) rm(filename) diff --git a/test/reader.jl b/test/reader.jl new file mode 100644 index 0000000..4594024 --- /dev/null +++ b/test/reader.jl @@ -0,0 +1,23 @@ +@testset "reader" begin + reader = MatrixMarket.MMReader(7, 4, 6, "coordinate", "real", "general") + @test reader.nrow == 7 + @test reader.ncol == 4 + @test reader.nentry == 6 + @test eltype(reader.format) == Float64 + @test reader.format isa MatrixMarket.CoordinateFormat + @test_throws AssertionError MatrixMarket.MMReader(7, 4, 100, "coordinate", "real", "general") + @test MatrixMarket.readout(reader, true)[4:end] == (7, 4, 6, "coordinate", "real", "general") + + reader = MatrixMarket.MMReader(2, 3, 6, "array", "integer", "general") + @test reader.nrow == 2 + @test reader.ncol == 3 + @test reader.nentry == 6 + @test eltype(reader.format) == Int64 + @test reader.format isa MatrixMarket.ArrayFormat + @test MatrixMarket.readout(reader, true)[2:end] == (2, 3, 6, "array", "integer", "general") + + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_eltype("aaa") + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_symmetric("aaa") + @test MatrixMarket.parse_dimension("3 4", "array") == (3, 4, 12) + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_dimension("3 4", "coordinate") +end diff --git a/test/runtests.jl b/test/runtests.jl index 6ec7556..d71ce96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Downloads using GZip using SparseArrays using SHA +using TranscodingStreams using Test include("test_utils.jl") @@ -13,6 +14,9 @@ const NIST_FILELIST = download_nist_filelist() tests = [ "mtx", + "reader", + "writer", + "format", ] @testset "MatrixMarket.jl" begin diff --git a/test/writer.jl b/test/writer.jl new file mode 100644 index 0000000..66fb774 --- /dev/null +++ b/test/writer.jl @@ -0,0 +1,25 @@ +@testset "writer" begin + A = sparse(rand([0, 1], 3, 4)) + writer = MatrixMarket.MMWriter(A) + @test writer.nrow == size(A, 1) + @test writer.ncol == size(A, 2) + @test writer.nentry == nnz(A) + @test writer.symm == "general" + @test eltype(writer.format) == Int64 + @test writer.format isa MatrixMarket.CoordinateFormat + @test MatrixMarket.header(writer) == "%%MatrixMarket matrix coordinate integer general" + @test MatrixMarket.sizetext(writer) == "$(size(A, 1)) $(size(A, 2)) $(nnz(A))" + + A = rand(ComplexF64, 3, 4) + writer = MatrixMarket.MMWriter(A) + @test writer.nrow == size(A, 1) + @test writer.ncol == size(A, 2) + @test writer.nentry == length(A) + @test writer.symm == "general" + @test eltype(writer.format) == ComplexF64 + @test writer.format isa MatrixMarket.ArrayFormat + @test MatrixMarket.header(writer) == "%%MatrixMarket matrix array complex general" + @test MatrixMarket.sizetext(writer) == "$(size(A, 1)) $(size(A, 2)) $(length(A))" + + @test_throws ErrorException MatrixMarket.generate_eltype(String) +end