Skip to content

Commit a78b271

Browse files
committed
Comply with Dojo
1 parent 4fbf8db commit a78b271

File tree

10 files changed

+101
-83
lines changed

10 files changed

+101
-83
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.2.0"
55

66
[deps]
77
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
8+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
89
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
910
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1011

src/GraphBasedSystems.jl

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,35 @@
11
module GraphBasedSystems
22

3+
using LinearAlgebra
34
using SparseArrays
45
using StaticArrays
56
using Graphs
67

78

89
export System,
10+
Entry,
11+
912
full_matrix,
1013
full_vector,
1114
randomize!,
15+
reset_inverse_diagonals!,
1216

1317
children,
1418
connections,
1519
parents,
1620

1721
ldu_solve!,
22+
ldu_factorization!,
23+
ldu_backsubstitution!,
1824
lu_solve!,
25+
lu_factorization!,
26+
lu_backsubstitution!,
1927
ldlt_solve!,
20-
llt_solve!
28+
ldlt_factorization!,
29+
ldlt_backsubstitution!,
30+
llt_solve!,
31+
llt_factorization!,
32+
llt_backsubstitution!
2133

2234

2335
include(joinpath("util", "custom_static.jl"))
@@ -26,6 +38,9 @@ include(joinpath("system", "entry.jl"))
2638
include(joinpath("system", "system.jl"))
2739
include(joinpath("system", "setup_functions.jl"))
2840

41+
include(joinpath("system", "interface.jl"))
42+
include(joinpath("system", "dense.jl"))
43+
2944
include(joinpath("solvers", "lu.jl"))
3045
include(joinpath("solvers", "llt.jl"))
3146
include(joinpath("solvers", "ldlt.jl"))

src/solvers/ldlt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ function ldlt_factorization!(system::System)
2424
acyclic_children = system.acyclic_children
2525
cyclic_children = system.cyclic_children
2626

27+
reset_inverse_diagonals!(system)
28+
2729
for v in system.dfs_list
2830
for c in acyclic_children[v]
2931
ldlt_factorization_acyclic!(matrix_entries[v,v], matrix_entries[v,c], matrix_entries[c,c], diagonal_inverses[c])
@@ -83,7 +85,6 @@ function ldlt_backsubstitution!(system::System)
8385
end
8486

8587
function ldlt_solve!(system::System)
86-
reset_inverse_diagonals!(system)
8788
ldlt_factorization!(system)
8889
ldlt_backsubstitution!(system)
8990
return

src/solvers/ldu.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ function ldu_factorization!(system::System)
2424
acyclic_children = system.acyclic_children
2525
cyclic_children = system.cyclic_children
2626

27+
reset_inverse_diagonals!(system)
28+
2729
for v in system.dfs_list
2830
for c in acyclic_children[v]
2931
ldu_factorization_acyclic!(matrix_entries[v,v], matrix_entries[v,c], matrix_entries[c,c], matrix_entries[c,v], diagonal_inverses[c])
@@ -84,7 +86,6 @@ function ldu_backsubstitution!(system::System)
8486
end
8587

8688
function ldu_solve!(system::System)
87-
reset_inverse_diagonals!(system)
8889
ldu_factorization!(system)
8990
ldu_backsubstitution!(system)
9091
return

src/solvers/llt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ function llt_factorization!(system::System)
2828
acyclic_children = system.acyclic_children
2929
cyclic_children = system.cyclic_children
3030

31+
reset_inverse_diagonals!(system)
32+
3133
for v in system.dfs_list
3234
for c in acyclic_children[v]
3335
llt_factorization_acyclic!(matrix_entries[v,v], matrix_entries[v,c], matrix_entries[c,c], diagonal_inverses[c])
@@ -92,7 +94,6 @@ function llt_backsubstitution!(system::System)
9294
end
9395

9496
function llt_solve!(system::System)
95-
reset_inverse_diagonals!(system)
9697
llt_factorization!(system)
9798
llt_backsubstitution!(system)
9899
return

src/solvers/lu.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ function lu_factorization!(system::System)
2323
acyclic_children = system.acyclic_children
2424
cyclic_children = system.cyclic_children
2525

26+
reset_inverse_diagonals!(system)
27+
2628
for v in system.dfs_list
2729
for c in acyclic_children[v]
2830
lu_factorization_acyclic!(matrix_entries[v,v], matrix_entries[v,c], matrix_entries[c,c], matrix_entries[c,v], diagonal_inverses[c])
@@ -83,7 +85,6 @@ function lu_backsubstitution!(system::System)
8385
end
8486

8587
function lu_solve!(system::System)
86-
reset_inverse_diagonals!(system)
8788
lu_factorization!(system)
8889
lu_backsubstitution!(system)
8990
return

src/system/dense.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
full_matrix(system::System) = full_matrix(system.matrix_entries, issymmetric(system), system.dims, system.dims)
2+
# There probably exists a smarter way of getting the dense matrix from the spares one
3+
function full_matrix(matrix_entries::SparseMatrixCSC, symmetric::Bool, dimensions_rows, dimensions_cols)
4+
range_dict_rows = ranges(dimensions_rows)
5+
range_dict_cols = ranges(dimensions_cols)
6+
A = zeros(sum(dimensions_rows), sum(dimensions_cols))
7+
8+
for (i,row) in enumerate(matrix_entries.rowval)
9+
col = findfirst(x -> i<x, matrix_entries.colptr)-1
10+
A[range_dict_rows[row],range_dict_cols[col]] = matrix_entries[row,col].value
11+
if symmetric && col != row
12+
A[range_dict_cols[col],range_dict_rows[row]] = matrix_entries[row,col].value'
13+
end
14+
end
15+
return A
16+
end
17+
18+
full_vector(system::System) = full_vector(system.vector_entries, system.dims)
19+
full_vector(vector_entries::AbstractVector, dimensions_rows) = vcat([getfield(vector_entries[i],:value) for i=1:size(dimensions_rows)[1]]...)

src/system/entry.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function Base.zero(::Type{Entry{ET}}) where ET
2525
return Entry{ET.parameters[2]}(dims...)
2626
end
2727
function Base.zero(::Type{Entry})
28-
return 0
28+
return nothing
2929
end
3030

3131
function randomize!(entry::Entry, rand_function = randn)

src/system/interface.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
@inline children(system::System, v) = outneighbors(system.dfs_graph, v) # all direct children of v
2+
@inline connections(system::System, v) = neighbors(system.graph, v) # all connected nodes of v
3+
@inline parents(system::System, v) = inneighbors(system.dfs_graph, v) # same elements as system.parents[v], but potentially different order
4+
5+
LinearAlgebra.issymmetric(::System{N, Symmetric}) where N = true
6+
LinearAlgebra.issymmetric(::System{N, Unsymmetric}) where N = false
7+
8+
function ranges(system::System{N}) where N
9+
dims = system.dims
10+
range_dict = Dict(1=>1:dims[1])
11+
for i=2:N
12+
range_dict[i] = last(range_dict[i-1])+1:sum(dims[1:i])
13+
end
14+
15+
return range_dict
16+
end
17+
function ranges(dims::AbstractVector)
18+
range_dict = Dict(1=>1:dims[1])
19+
for i=2:size(dims)[1]
20+
range_dict[i] = last(range_dict[i-1])+1:sum(dims[1:i])
21+
end
22+
23+
return range_dict
24+
end
25+
26+
function randomize!(system::System, rand_function = randn)
27+
for entry in system.matrix_entries.nzval
28+
randomize!(entry, rand_function)
29+
end
30+
for entry in system.vector_entries
31+
randomize!(entry, rand_function)
32+
end
33+
end
34+
35+
function randomize!(system::System{N,<:Symmetric}, rand_function = randn) where N
36+
matrix_entries = system.matrix_entries
37+
38+
for entry in matrix_entries.nzval
39+
randomize!(entry, rand_function)
40+
end
41+
for i=1:N
42+
matrix_entries[i,i].value += matrix_entries[i,i].value'
43+
end
44+
45+
for entry in system.vector_entries
46+
randomize!(entry, rand_function)
47+
end
48+
end
49+
50+
function reset_inverse_diagonals!(system::System)
51+
for entry in system.diagonal_inverses
52+
entry.isinverted = false
53+
end
54+
end

src/system/system.jl

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct System{N,T}
1111
dfs_list::SVector{N,Int64} # depth-first search list of nodes [last-found node, ..., first-found node]
1212
graph::SimpleGraph{Int64} # the graph built from the adjacency matrix
1313
dfs_graph::SimpleDiGraph{Int64} # the directed graph built from the depth-first search
14+
dims::Vector{Int64} # Dimensions of the matrix entries
1415

1516
function System{T}(A, dims; force_static = false, symmetric = false) where T
1617
N = length(dims)
@@ -79,7 +80,7 @@ struct System{N,T}
7980
cyclic_children = [unique(vcat(cycles[i]...)) for i=1:N]
8081
cyclic_children = [intersect(dfs_list, cyclic_children[i]) for i=1:N]
8182

82-
new{N,S}(matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph)
83+
new{N,S}(matrix_entries, vector_entries, diagonal_inverses, acyclic_children, cyclic_children, parents, dfs_list, full_graph, full_dfs_graph, dims)
8384
end
8485

8586
System(A, dims; force_static = false, symmetric = false) = System{Float64}(A, dims; force_static, symmetric)
@@ -94,79 +95,3 @@ function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, system::System{N,S}
9495
println(io, "System with "*string(N)*" nodes.")
9596
SparseArrays._show_with_braille_patterns(io, system.matrix_entries)
9697
end
97-
98-
99-
@inline children(system::System, v) = outneighbors(system.dfs_graph, v) # all direct children of v
100-
@inline connections(system::System, v) = neighbors(system.graph, v) # all connected nodes of v
101-
@inline parents(system::System, v) = inneighbors(system.dfs_graph, v) # same elements as system.parents[v], but potentially different order
102-
103-
dimensions(system::System{N}) where N = [size(system.vector_entries[i].value)[1] for i=1:N]
104-
function ranges(system::System{N}) where N
105-
dims = dimensions(system)
106-
range_dict = Dict(1=>1:dims[1])
107-
for i=2:N
108-
range_dict[i] = last(range_dict[i-1])+1:sum(dims[1:i])
109-
end
110-
111-
return range_dict
112-
end
113-
114-
# There probably exists a smarter way of getting the dense matrix from the spares one
115-
function full_matrix(system::System{N}) where N
116-
dims = dimensions(system)
117-
range_dict = ranges(system)
118-
A = zeros(sum(dims), sum(dims))
119-
120-
for (i,row) in enumerate(system.matrix_entries.rowval)
121-
col = findfirst(x -> i<x, system.matrix_entries.colptr)-1
122-
A[range_dict[row],range_dict[col]] = system.matrix_entries[row,col].value
123-
end
124-
return A
125-
end
126-
127-
function full_matrix(system::System{N,<:Symmetric}) where N
128-
dims = dimensions(system)
129-
range_dict = ranges(system)
130-
A = zeros(sum(dims),sum(dims))
131-
132-
for (i,row) in enumerate(system.matrix_entries.rowval)
133-
col = findfirst(x -> i<x, system.matrix_entries.colptr)-1
134-
A[range_dict[row],range_dict[col]] = system.matrix_entries[row,col].value
135-
if col != row
136-
A[range_dict[col],range_dict[row]] = system.matrix_entries[row,col].value'
137-
end
138-
end
139-
return A
140-
end
141-
142-
full_vector(system::System) = vcat(getfield.(system.vector_entries,:value)...)
143-
144-
function randomize!(system::System, rand_function = randn)
145-
for entry in system.matrix_entries.nzval
146-
randomize!(entry, rand_function)
147-
end
148-
for entry in system.vector_entries
149-
randomize!(entry, rand_function)
150-
end
151-
end
152-
153-
function randomize!(system::System{N,<:Symmetric}, rand_function = randn) where N
154-
matrix_entries = system.matrix_entries
155-
156-
for entry in matrix_entries.nzval
157-
randomize!(entry, rand_function)
158-
end
159-
for i=1:N
160-
matrix_entries[i,i].value += matrix_entries[i,i].value'
161-
end
162-
163-
for entry in system.vector_entries
164-
randomize!(entry, rand_function)
165-
end
166-
end
167-
168-
function reset_inverse_diagonals!(system::System)
169-
for entry in system.diagonal_inverses
170-
entry.isinverted = false
171-
end
172-
end

0 commit comments

Comments
 (0)