Skip to content

Support reading streaming from online .mtx or .mtx.gz files #54

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MatrixMarket.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using TranscodingStreams, CodecZlib

export mmread, mmwrite, mminfo

include("format.jl")
include("mminfo.jl")
include("mmread.jl")
include("mmwrite.jl")
Expand Down
106 changes: 106 additions & 0 deletions src/format.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
abstract type MMFormat end

Base.length(f::MMFormat) = length(f.vals)

function readout(f::MMFormat, nrow::Int, ncol::Int, nentry::Int, symm)
rep = formattext(f)
field = generate_eltype(eltype(f))
return (Tuple(f)..., nrow, ncol, nentry, rep, field, symm)
end

struct CoordinateFormat{T} <: MMFormat
rows::Vector{Int}
cols::Vector{Int}
vals::Vector{T}
end

function CoordinateFormat(field, nentry)
T = parse_eltype(field)
rows = Vector{Int}(undef, nentry)
cols = Vector{Int}(undef, nentry)
vals = Vector{T}(undef, nentry)
return CoordinateFormat{T}(rows, cols, vals)
end

function CoordinateFormat(A::SparseMatrixCSC{T}) where {T}
rows = rowvals(A)
vals = nonzeros(A)
n = size(A, 2)
cols = [repeat([j], length(nzrange(A, j))) for j in 1:n]
cols = collect(Iterators.flatten(cols))
return CoordinateFormat{T}(rows, cols, vals)
end

Base.eltype(::CoordinateFormat{T}) where T = T

formattext(::CoordinateFormat) = "coordinate"

Base.Tuple(f::CoordinateFormat) = (f.rows, f.cols, f.vals)

Base.:(==)(x::CoordinateFormat, y::CoordinateFormat) = (x.rows == y.rows) &&
(x.cols == y.cols) && (x.vals == y.vals)

function writeat!(f::CoordinateFormat{T}, i::Int, line::String) where T
f.rows[i], f.cols[i], f.vals[i] = parseline(T, line)
return f
end

function readout(f::CoordinateFormat, nrow::Int, ncol::Int, symm)
symfunc = parse_symmetric(symm)
return symfunc(sparse(f.rows, f.cols, f.vals, nrow, ncol))
end

function Base.iterate(f::CoordinateFormat, i::Integer=zero(length(f)))
i += oneunit(i)
if i <= length(f)
return (f.rows[i], f.cols[i], f.vals[i]), i
else
return nothing
end
end

struct ArrayFormat{T} <: MMFormat
vals::Vector{T}
end

function ArrayFormat(field, nentry::Int)
T = parse_eltype(field)
return ArrayFormat(T, nentry)
end

function ArrayFormat(::Type{T}, nentry::Int) where {T}
vals = Vector{T}(undef, nentry)
return ArrayFormat{T}(vals)
end

ArrayFormat(nentry::Int) = ArrayFormat(Float64, nentry)

ArrayFormat(A::AbstractMatrix{T}) where {T} = ArrayFormat{T}(reshape(A, :))

Base.eltype(::ArrayFormat{T}) where T = T

formattext(::ArrayFormat) = "array"

Base.Tuple(f::ArrayFormat) = (f.vals,)

Base.:(==)(x::ArrayFormat, y::ArrayFormat) = (x.vals == y.vals)

function writeat!(f::ArrayFormat{T}, i::Int, line::String) where T
f.vals[i] = parse(T, line)
return f
end

function readout(f::ArrayFormat, nrow::Int, ncol::Int, symm)
A = reshape(f.vals, nrow, ncol)
symfunc = parse_symmetric(symm)
return symfunc(A)
end

function Base.iterate(f::ArrayFormat, i::Integer=zero(length(f)))
i += oneunit(i)
if i <= length(f)
return f.vals[i], i
else
return nothing
end
end
74 changes: 48 additions & 26 deletions src/mmread.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,36 +27,17 @@ function mmread(filename::String, infoonly::Bool=false, retcoord::Bool=false)
end

function mmread(stream::IO, infoonly::Bool=false, retcoord::Bool=false)
rows, cols, entries, rep, field, symm = mminfo(stream)

infoonly && return rows, cols, entries, rep, field, symm

T = parse_eltype(field)
symfunc = parse_symmetric(symm)

if rep == "coordinate"
rn = Vector{Int}(undef, entries)
cn = Vector{Int}(undef, entries)
vals = Vector{T}(undef, entries)
for i in 1:entries
line = readline(stream)
splits = find_splits(line, num_splits(T))
rn[i] = parse_row(line, splits)
cn[i] = parse_col(line, splits, T)
vals[i] = parse_val(line, splits, T)
end
nrow, ncol, nentry, rep, field, symm = mminfo(stream)

result = retcoord ? (rn, cn, vals, rows, cols, entries, rep, field, symm) :
symfunc(sparse(rn, cn, vals, rows, cols))
else
vals = [parse(Float64, readline(stream)) for _ in 1:entries]
A = reshape(vals, rows, cols)
result = symfunc(A)
end
infoonly && return nrow, ncol, nentry, rep, field, symm

return result
reader = MMReader(nrow, ncol, nentry, rep, field, symm)
readlines!(reader, stream)
return readout(reader, retcoord)
end

## Parsing

function parse_eltype(field::String)
if field == "real"
return Float64
Expand Down Expand Up @@ -107,6 +88,14 @@ end
parse_val(line, splits, ::Type{Bool}) = true
parse_val(line, splits, ::Type{T}) where {T} = parse(T, line[splits[2]:length(line)])

function parseline(::Type{T}, line) where T
splits = find_splits(line, num_splits(T))
r = parse_row(line, splits)
c = parse_col(line, splits, T)
v = parse_val(line, splits, T)
return r, c, v
end

num_splits(::Type{ComplexF64}) = 3
num_splits(::Type{Bool}) = 1
num_splits(elty) = 2
Expand All @@ -130,3 +119,36 @@ function find_splits(s::String, num)

splits
end

## Reader

struct MMReader{F <: MMFormat}
nrow::Int
ncol::Int
nentry::Int
rep::String
symm::String
format::F
end

function MMReader(nrow::Integer, ncol::Integer, nentry::Integer, rep, field, symm)
@assert nentry <= nrow * ncol "given nentry ($nentry) is greater than the product of nrow and ncol ($(nrow * ncol))"
format = (rep == "coordinate") ? CoordinateFormat(field, nentry) : ArrayFormat(field, nentry)
return MMReader{typeof(format)}(nrow, ncol, nentry, rep, symm, format)
end

function readlines!(reader::MMReader, stream::IO)
for i in 1:reader.nentry
line = readline(stream)
writeat!(reader.format, i, line)
end
return reader
end

function readout(reader::MMReader, retcoord::Bool=false)
if retcoord
return readout(reader.format, reader.nrow, reader.ncol, reader.nentry, reader.symm)
else
return readout(reader.format, reader.nrow, reader.ncol, reader.symm)
end
end
84 changes: 58 additions & 26 deletions src/mmwrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,24 @@ function mmwrite(filename::String, matrix::SparseMatrixCSC)
close(stream)
end

function mmwrite(stream::IO, matrix::SparseMatrixCSC)
function mmwrite(stream::IO, matrix::SparseMatrixCSC{T}) where {T}
nl = get_newline()
elem = generate_eltype(eltype(matrix))
sym = generate_symmetric(matrix)
elem = generate_eltype(T)
writer = MMWriter(matrix)
write(stream, header(writer))
write(stream, nl)
write(stream, sizetext(writer))
write(stream, nl)

# write header
write(stream, "%%MatrixMarket matrix coordinate $elem $sym$nl")

# only use lower triangular part of symmetric and Hermitian matrices
if issymmetric(matrix) || ishermitian(matrix)
matrix = tril(matrix)
end

# write matrix size and number of nonzeros
write(stream, "$(size(matrix, 1)) $(size(matrix, 2)) $(nnz(matrix))$nl")

rows = rowvals(matrix)
vals = nonzeros(matrix)
for i in 1:size(matrix, 2)
for j in nzrange(matrix, i)
entity = generate_entity(i, j, rows, vals, elem)
write(stream, entity)
end
for (r, c, v) in writer.format
entity = generate_entity(r, c, v, elem)
write(stream, entity)
write(stream, nl)
end
end

## Generating

generate_eltype(::Type{<:Bool}) = "pattern"
generate_eltype(::Type{<:Integer}) = "integer"
generate_eltype(::Type{<:AbstractFloat}) = "real"
Expand All @@ -61,14 +53,13 @@ function generate_symmetric(m::AbstractMatrix)
end
end

function generate_entity(i, j, rows, vals, kind::String)
nl = get_newline()
function generate_entity(r, c, v, kind::String)
if kind == "pattern"
return "$(rows[j]) $i$nl"
return "$r $c"
elseif kind == "complex"
return "$(rows[j]) $i $(real(vals[j])) $(imag(vals[j]))$nl"
return "$r $c $(real(v)) $(imag(v))"
else
return "$(rows[j]) $i $(vals[j])$nl"
return "$r $c $v"
end
end

Expand All @@ -79,3 +70,44 @@ function get_newline()
return "\n"
end
end

## Writer

struct MMWriter{F <: MMFormat}
nrow::Int
ncol::Int
nentry::Int
symm::String
format::F
end

function MMWriter(A::AbstractMatrix{T}) where {T}
nrow, ncol = size(A)
nentry = nrow * ncol
vals = reshape(A, :)
symm = generate_symmetric(A)
format = ArrayFormat{T}(vals)
return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format)
end

function MMWriter(A::SparseMatrixCSC)
nrow, ncol = size(A)
symm = generate_symmetric(A)

# only use lower triangular part of symmetric and Hermitian matrices
if symm == "symmetric" || symm == "hermitian"
A = tril(A)
end

nentry = nnz(A)
format = CoordinateFormat(A)
return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format)
end

function header(writer::MMWriter)
rep = formattext(writer.format)
elem = generate_eltype(eltype(writer.format))
return "%%MatrixMarket matrix $rep $elem $(writer.symm)"
end

sizetext(writer::MMWriter) = "$(writer.nrow) $(writer.ncol) $(writer.nentry)"
38 changes: 38 additions & 0 deletions test/format.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@testset "format" begin
@testset "CoordinateFormat" begin
T = Float64
rows = [1, 2, 2, 3, 5, 7]
cols = [1, 1, 2, 3, 4, 4]
vals = T[1, 2, 3, 4, 5, 6]
A = sparse(rows, cols, vals)

f = MatrixMarket.CoordinateFormat(rows, cols, vals)
@test MatrixMarket.CoordinateFormat(A) == f
@test length(f) == length(vals)
@test eltype(f) == T
@test MatrixMarket.formattext(f) == "coordinate"
@test Tuple(f) == (rows, cols, vals)
@test MatrixMarket.readout(f, 7, 4, "general") == A

MatrixMarket.writeat!(f, 2, "3 1 7")
@test (f.rows[2], f.cols[2], f.vals[2]) == (3, 1, 7)
end

@testset "ArrayFormat" begin
T = Float64
vals = T[1, 2, 3, 4, 5, 6]
A = reshape(vals, 2, 3)

f = MatrixMarket.ArrayFormat(vals)
@test MatrixMarket.ArrayFormat(A) == f
@test eltype(MatrixMarket.ArrayFormat(length(vals))) == Float64
@test length(f) == length(vals)
@test eltype(f) == T
@test MatrixMarket.formattext(f) == "array"
@test Tuple(f) == (vals, )
@test MatrixMarket.readout(f, 2, 3, "general") == A

MatrixMarket.writeat!(f, 2, "7")
@test f.vals[2] == 7
end
end
Loading