Skip to content

Commit 46d46ca

Browse files
authored
Merge pull request #7 from janbruedigam/add_system_algebra
Add system algebra
2 parents cdbf6a2 + 4d65d35 commit 46d46ca

20 files changed

+393
-134
lines changed

benchmark/adjacency_matrix.jl

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
A = [
2+
0 1 0 1 1 0 1 0 1 0
3+
1 0 1 0 0 0 0 0 0 0
4+
0 1 0 0 0 0 0 0 0 0
5+
1 0 0 0 1 0 0 0 0 0
6+
1 0 0 1 0 1 0 0 0 0
7+
0 0 0 0 1 0 0 0 0 0
8+
1 0 0 0 0 0 0 1 0 0
9+
0 0 0 0 0 0 1 0 1 1
10+
1 0 0 0 0 0 0 1 0 0
11+
0 0 0 0 0 0 0 1 0 0
12+
]
13+
14+
B = [
15+
0 1 0 0 0 1 0 0 0
16+
1 0 0 0 0 1 0 0 0
17+
0 0 0 0 0 1 1 0 1
18+
0 0 0 0 0 0 1 1 0
19+
0 0 0 0 0 0 0 1 1
20+
1 1 1 0 0 0 1 0 0
21+
0 0 1 1 0 1 0 1 1
22+
0 0 0 1 1 0 1 0 1
23+
0 0 1 0 1 0 1 1 0
24+
]
25+
26+
C = [
27+
0 1 1 0 0 0
28+
1 0 0 1 0 0
29+
1 0 0 1 1 0
30+
0 1 1 0 0 1
31+
0 0 1 0 0 1
32+
0 0 0 1 1 0
33+
]
34+
35+
D = [
36+
0 1 1 0 1 0
37+
1 0 0 1 0 1
38+
1 0 0 1 0 0
39+
0 1 1 0 0 0
40+
1 0 0 0 0 1
41+
0 1 0 0 1 0]
42+
43+
ZAA = zeros(Int64, 10, 10)
44+
ZAB = zeros(Int64, 10, 9)
45+
ZAC = zeros(Int64, 10, 6)
46+
ZAD = zeros(Int64, 10, 6)
47+
ZBB = zeros(Int64, 9, 9)
48+
ZBC = zeros(Int64, 9, 6)
49+
ZBD = zeros(Int64, 9, 6)
50+
ZCC = zeros(Int64, 6, 6)
51+
ZCD = zeros(Int64, 6, 6)
52+
53+
ZBA = ZAB'
54+
ZCA = ZAC'
55+
ZDA = ZAD'
56+
ZCB = ZBC'
57+
ZDB = ZBD'
58+
ZDC = ZCD'
59+
60+
61+
# Graph 1 is disconnected, 2-3-4-5 are connected, 6 is disconnected, 7 is disconnected
62+
63+
A = [
64+
A ZAA ZAB ZAC ZAA ZAA ZAD
65+
ZAA A ZAB ZAC ZAA ZAA ZAD
66+
ZBA ZBA B ZBC ZBA ZBA ZBD
67+
ZCA ZCA ZCB C ZCA ZCA ZCD
68+
ZAA ZAA ZAB ZAC A ZAA ZAD
69+
ZAA ZAA ZAB ZAC ZAA A ZAD
70+
ZDA ZDA ZDB ZDC ZDA ZDA D
71+
]
72+
73+
A[13,21] = A[21,13] = 1
74+
A[16,30] = A[30,16] = 1
75+
A[20,36] = A[36,20] = 1

benchmark/example_benchmark.jl

+17-93
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,12 @@
11
using GraphBasedSystems
22
using LinearAlgebra
33

4-
A = [
5-
0 1 0 1 1 0 1 0 1 0
6-
1 0 1 0 0 0 0 0 0 0
7-
0 1 0 0 0 0 0 0 0 0
8-
1 0 0 0 1 0 0 0 0 0
9-
1 0 0 1 0 1 0 0 0 0
10-
0 0 0 0 1 0 0 0 0 0
11-
1 0 0 0 0 0 0 1 0 0
12-
0 0 0 0 0 0 1 0 1 1
13-
1 0 0 0 0 0 0 1 0 0
14-
0 0 0 0 0 0 0 1 0 0
15-
]
4+
include("adjacency_matrix.jl")
165

17-
B = [
18-
0 1 0 0 0 1 0 0 0
19-
1 0 0 0 0 1 0 0 0
20-
0 0 0 0 0 1 1 0 1
21-
0 0 0 0 0 0 1 1 0
22-
0 0 0 0 0 0 0 1 1
23-
1 1 1 0 0 0 1 0 0
24-
0 0 1 1 0 1 0 1 1
25-
0 0 0 1 1 0 1 0 1
26-
0 0 1 0 1 0 1 1 0
27-
]
286

29-
C = [
30-
0 1 1 0 0 0
31-
1 0 0 1 0 0
32-
1 0 0 1 1 0
33-
0 1 1 0 0 1
34-
0 0 1 0 0 1
35-
0 0 0 1 1 0
36-
]
37-
38-
D = [
39-
0 1 1 0 1 0
40-
1 0 0 1 0 1
41-
1 0 0 1 0 0
42-
0 1 1 0 0 0
43-
1 0 0 0 0 1
44-
0 1 0 0 1 0]
45-
46-
ZAA = zeros(Int64,10,10)
47-
ZAB = zeros(Int64,10,9)
48-
ZAC = zeros(Int64,10,6)
49-
ZAD = zeros(Int64,10,6)
50-
ZBB = zeros(Int64,9,9)
51-
ZBC = zeros(Int64,9,6)
52-
ZBD = zeros(Int64,9,6)
53-
ZCC = zeros(Int64,6,6)
54-
ZCD = zeros(Int64,6,6)
55-
56-
ZBA = ZAB'
57-
ZCA = ZAC'
58-
ZDA = ZAD'
59-
ZCB = ZBC'
60-
ZDB = ZBD'
61-
ZDC = ZCD'
62-
63-
64-
# Graph 1 is disconnected, 2-3-4-5 are connected, 6 is disconnected, 7 is disconnected
65-
66-
A = [
67-
A ZAA ZAB ZAC ZAA ZAA ZAD
68-
ZAA A ZAB ZAC ZAA ZAA ZAD
69-
ZBA ZBA B ZBC ZBA ZBA ZBD
70-
ZCA ZCA ZCB C ZCA ZCA ZCD
71-
ZAA ZAA ZAB ZAC A ZAA ZAD
72-
ZAA ZAA ZAB ZAC ZAA A ZAD
73-
ZDA ZDA ZDB ZDC ZDA ZDA D
74-
]
75-
76-
A[13,21] = A[21,13] = 1
77-
A[16,30] = A[30,16] = 1
78-
A[20,36] = A[36,20] = 1
79-
80-
function initialize!_posdef!(system)
7+
function initialize!_posdef!(system::System{N}) where N
818
initialize!(system,rand)
82-
for i=1:size(A)[1]
9+
for i=1:N
8310
system.matrix_entries[i,i].value += 1000*I
8411
end
8512
end
@@ -89,20 +16,17 @@ system = System{Float64}(A, ones(Int,size(A)[1])*3)
8916
systemldlt = System{Float64}(A, ones(Int,size(A)[1])*3, symmetric=true)
9017
systemllt = System{Float64}(A, ones(Int,size(A)[1])*3, symmetric=true)
9118

92-
SUITE["ldu"] = @benchmarkable ldu_solve!($system) samples=2 setup=(initialize!($system))
93-
SUITE["lu"] = @benchmarkable lu_solve!($system) samples=2 setup=(initialize!($system))
94-
SUITE["ldlt"] = @benchmarkable ldlt_solve!($systemldlt) samples=2 setup=(initialize!($systemldlt))
95-
SUITE["llt"] = @benchmarkable llt_solve!($systemllt) samples=2 setup=(initialize!_posdef!($systemllt))
96-
97-
# A = [
98-
# 0 1 1 1 1
99-
# 1 0 1 1 1
100-
# 1 1 0 1 1
101-
# 1 1 1 0 1
102-
# 1 1 1 1 0
103-
# ]
104-
105-
106-
107-
108-
19+
SUITE["sparse_ldu"] = @benchmarkable ldu_solve!($system) setup=(initialize!($system))
20+
SUITE["sparse_lu"] = @benchmarkable lu_solve!($system) setup=(initialize!($system))
21+
SUITE["dense_lu"] = @benchmarkable lu(F)\f setup=(initialize!($system);F=full_matrix($system);f=full_vector($system))
22+
SUITE["sparse_ldlt"] = @benchmarkable ldlt_solve!($systemldlt) setup=(initialize!($systemldlt))
23+
SUITE["dense_ldlt"] = @benchmarkable bunchkaufman(F)\f setup=(initialize!($systemldlt);F=full_matrix($systemldlt);f=full_vector($systemldlt))
24+
SUITE["sparse_llt"] = @benchmarkable llt_solve!($systemllt) setup=(initialize!_posdef!($systemllt))
25+
SUITE["dense_llt"] = @benchmarkable cholesky(F)\f setup=(initialize!_posdef!($systemllt);F=full_matrix($systemllt);f=full_vector($systemllt))
26+
27+
SUITE["sparse_add"] = @benchmarkable +($system,$system) setup=(initialize!($system))
28+
SUITE["dense_add"] = @benchmarkable (+(F,F);+(f+f)) setup=(initialize!($system);F=full_matrix($system);f=full_vector($system))
29+
SUITE["sparse_mul"] = @benchmarkable *($system,$system) setup=(initialize!($system))
30+
SUITE["dense_mul"] = @benchmarkable *(F,F) setup=(initialize!($system);F=full_matrix($system))
31+
SUITE["sparse_solve"] = @benchmarkable \($system,Bmat) setup=(initialize!($system);Bmat=deepcopy(system.matrix_entries);initialize!($system))
32+
SUITE["dense_solve"] = @benchmarkable \(F1,F2) setup=(initialize!($system);F2=full_matrix($system);initialize!($system);F1=full_matrix($system))

src/GraphBasedSystems.jl

+8-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module GraphBasedSystems
22

33
using LinearAlgebra
44
using SparseArrays
5+
using SparseArrays: widelength
6+
using SparseArrays.HigherOrderFns: _sumnnzs, _allocres, _map_zeropres!, _densestructure!
57
using StaticArrays
68
using Graphs
79

@@ -19,15 +21,19 @@ export System,
1921
parents,
2022

2123
ldu_solve!,
24+
ldu_matrix_solve!,
2225
ldu_factorization!,
2326
ldu_backsubstitution!,
2427
lu_solve!,
28+
lu_matrix_solve!,
2529
lu_factorization!,
2630
lu_backsubstitution!,
2731
ldlt_solve!,
32+
ldlt_matrix_solve!,
2833
ldlt_factorization!,
2934
ldlt_backsubstitution!,
3035
llt_solve!,
36+
llt_matrix_solve!,
3137
llt_factorization!,
3238
llt_backsubstitution!
3339

@@ -36,11 +42,12 @@ include(joinpath("util", "custom_static.jl"))
3642

3743
include(joinpath("system", "entry.jl"))
3844
include(joinpath("system", "system.jl"))
39-
include(joinpath("system", "setup_functions.jl"))
45+
include(joinpath("system", "graph_functions.jl"))
4046

4147
include(joinpath("system", "interface.jl"))
4248
include(joinpath("system", "dense.jl"))
4349

50+
include(joinpath("solvers", "matrix.jl"))
4451
include(joinpath("solvers", "lu.jl"))
4552
include(joinpath("solvers", "llt.jl"))
4653
include(joinpath("solvers", "ldlt.jl"))

src/solvers/ldlt.jl

+9
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ function ldlt_solve!(system::System)
8989
ldlt_backsubstitution!(system)
9090
return
9191
end
92+
93+
function ldlt_matrix_solve!(system::System, matrix::SparseMatrixCSC{Entry, Int64}; keep_vector = true)
94+
keep_vector && (vector_entries = deepcopy(system.vector_entries))
95+
ldlt_factorization!(system)
96+
C = matrix_backsubsitution!(system, matrix, ldlt_backsubstitution!)
97+
keep_vector && (system.vector_entries .= vector_entries)
98+
99+
return C
100+
end

src/solvers/ldu.jl

+9
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,12 @@ function ldu_solve!(system::System)
9090
ldu_backsubstitution!(system)
9191
return
9292
end
93+
94+
function ldu_matrix_solve!(system::System, matrix::SparseMatrixCSC{Entry, Int64}; keep_vector = true)
95+
keep_vector && (vector_entries = deepcopy(system.vector_entries))
96+
ldu_factorization!(system)
97+
C = matrix_backsubsitution!(system, matrix, ldu_backsubstitution!)
98+
keep_vector && (system.vector_entries .= vector_entries)
99+
100+
return C
101+
end

src/solvers/llt.jl

+10-4
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@ function llt_backsubstitution_d!(vector, diagonal, diagonal_inverse)
5959
if diagonal_inverse.isinverted
6060
vector.value = diagonal_inverse.value * vector.value
6161
else
62-
invdiagonal = inv(diagonal.value)
63-
diagonal_inverse.value = invdiagonal
64-
diagonal_inverse.isinverted = true
65-
vector.value = diagonal_inverse.value * vector.value
62+
vector.value = diagonal.value \ vector.value
6663
end
6764
return
6865
end
@@ -98,3 +95,12 @@ function llt_solve!(system::System)
9895
llt_backsubstitution!(system)
9996
return
10097
end
98+
99+
function llt_matrix_solve!(system::System, matrix::SparseMatrixCSC{Entry, Int64}; keep_vector = true)
100+
keep_vector && (vector_entries = deepcopy(system.vector_entries))
101+
llt_factorization!(system)
102+
C = matrix_backsubsitution!(system, matrix, llt_backsubstitution!)
103+
keep_vector && (system.vector_entries .= vector_entries)
104+
105+
return C
106+
end

src/solvers/lu.jl

+9
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,12 @@ function lu_solve!(system::System)
8989
lu_backsubstitution!(system)
9090
return
9191
end
92+
93+
function lu_matrix_solve!(system::System, matrix::SparseMatrixCSC{Entry, Int64}; keep_vector = true)
94+
keep_vector && (vector_entries = deepcopy(system.vector_entries))
95+
lu_factorization!(system)
96+
C = matrix_backsubsitution!(system, matrix, lu_backsubstitution!)
97+
keep_vector && (system.vector_entries .= vector_entries)
98+
99+
return C
100+
end

src/solvers/matrix.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Assumes quadratic matrices of same size
2+
function matrix_backsubsitution!(system::System{N}, matrix::SparseMatrixCSC{Entry, Int64}, backsubsitution!) where N
3+
vector_entries = system.vector_entries
4+
dims = system.dims
5+
6+
maxnnzC = Int(widelength(system.matrix_entries))
7+
C = _allocres((N,N), Int64, Entry, maxnnzC)
8+
# _densestructure!(C)
9+
10+
for i = 1:N
11+
for j = 1:N
12+
Bji = matrix[j,i]
13+
Bji isa Entry{Nothing} ? vector_entries[j] = Entry(dims[j], dims[i]) : vector_entries[j] = Bji
14+
end
15+
backsubsitution!(system)
16+
for j = 1:N
17+
C[j,i] = vector_entries[j]
18+
end
19+
end
20+
21+
return C
22+
end

src/system/dense.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
full_matrix(system::System) = full_matrix(system.matrix_entries, issymmetric(system), system.dims, system.dims)
22
# There probably exists a smarter way of getting the dense matrix from the spares one
3-
function full_matrix(matrix_entries::SparseMatrixCSC, symmetric::Bool, dimensions_rows, dimensions_cols)
3+
function full_matrix(matrix::SparseMatrixCSC, symmetric::Bool, dimensions_rows, dimensions_cols)
44
range_dict_rows = ranges(dimensions_rows)
55
range_dict_cols = ranges(dimensions_cols)
66
A = zeros(sum(dimensions_rows), sum(dimensions_cols))
77

8-
for (i,row) in enumerate(matrix_entries.rowval)
9-
col = findfirst(x -> i<x, matrix_entries.colptr)-1
10-
A[range_dict_rows[row],range_dict_cols[col]] = matrix_entries[row,col].value
8+
for (i,row) in enumerate(matrix.rowval)
9+
col = findfirst(x -> i<x, matrix.colptr)-1
10+
A[range_dict_rows[row],range_dict_cols[col]] = matrix[row,col].value
1111
if symmetric && col != row
12-
A[range_dict_cols[col],range_dict_rows[row]] = matrix_entries[row,col].value'
12+
A[range_dict_cols[col],range_dict_rows[row]] = matrix[row,col].value'
1313
end
1414
end
1515
return A
1616
end
1717

1818
full_vector(system::System) = full_vector(system.vector_entries, system.dims)
19-
full_vector(vector_entries::AbstractVector, dimensions_rows) = vcat([getfield(vector_entries[i],:value) for i=1:size(dimensions_rows)[1]]...)
19+
full_vector(vector_entries::AbstractVector, dimensions_rows::SVector{N}) where N = vcat([getfield(vector_entries[i], :value) for i=1:N]...)

0 commit comments

Comments
 (0)