Skip to content

Commit 3de611b

Browse files
committed
wip reference implementation of full elimination tree computation
1 parent 28c7478 commit 3de611b

1 file changed

Lines changed: 111 additions & 0 deletions

File tree

reference/factorization/elimination_forest_kernels.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,117 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
7575
GKO_DECLARE_ELIMINATION_FOREST_COMPUTE_SKELETON_TREE);
7676

7777

78+
template <typename IndexType>
79+
void compute_elimination_forest_parent_impl(
80+
std::shared_ptr<const Executor> host_exec, const IndexType* row_ptrs,
81+
const IndexType* cols, IndexType num_rows, IndexType* parent)
82+
{
83+
disjoint_sets<IndexType> subtrees{host_exec, num_rows};
84+
array<IndexType> subtree_root_array{host_exec,
85+
static_cast<size_type>(num_rows)};
86+
// pseudo-root one past the last row to deal with disconnected matrices
87+
const auto unattached = num_rows;
88+
auto subtree_root = subtree_root_array.get_data();
89+
for (IndexType row = 0; row < num_rows; row++) {
90+
// so far the row is an unattached singleton subtree
91+
subtree_root[row] = row;
92+
parent[row] = unattached;
93+
auto row_rep = row;
94+
for (auto nz = row_ptrs[row]; nz < row_ptrs[row + 1]; nz++) {
95+
const auto col = cols[nz];
96+
// for each lower triangular entry
97+
if (col < row) {
98+
// find the subtree it is contained in
99+
const auto col_rep = subtrees.find(col);
100+
const auto col_root = subtree_root[col_rep];
101+
// if it is not yet attached, put it below row
102+
// and make row its new root
103+
if (parent[col_root] == unattached && col_root != row) {
104+
parent[col_root] = row;
105+
row_rep = subtrees.join(row_rep, col_rep);
106+
subtree_root[row_rep] = row;
107+
}
108+
}
109+
}
110+
}
111+
}
112+
113+
114+
template <typename IndexType>
115+
void compute_elimination_forest_children_impl(const IndexType* parent,
116+
IndexType size,
117+
IndexType* child_ptr,
118+
IndexType* child)
119+
{
120+
// count how many times each parent occurs, excluding pseudo-root at
121+
// parent == size
122+
std::fill_n(child_ptr, size + 2, IndexType{});
123+
for (IndexType i = 0; i < size; i++) {
124+
const auto p = parent[i];
125+
if (p < size) {
126+
child_ptr[p + 2]++;
127+
}
128+
}
129+
// shift by 2 leads to exclusive prefix sum with 0 padding
130+
std::partial_sum(child_ptr, child_ptr + size + 2, child_ptr);
131+
// we count the same again, this time shifted by 1 => exclusive prefix sum
132+
for (IndexType i = 0; i < size; i++) {
133+
const auto p = parent[i];
134+
child[child_ptr[p + 1]] = i;
135+
child_ptr[p + 1]++;
136+
}
137+
}
138+
139+
140+
template <typename IndexType>
141+
void compute_elimination_forest_postorder_impl(
142+
std::shared_ptr<const Executor> host_exec, const IndexType* parent,
143+
const IndexType* child_ptr, const IndexType* child, IndexType size,
144+
IndexType* postorder, IndexType* inv_postorder)
145+
{
146+
array<IndexType> current_child_array{host_exec,
147+
static_cast<size_type>(size + 1)};
148+
current_child_array.fill(0);
149+
auto current_child = current_child_array.get_data();
150+
IndexType postorder_idx{};
151+
// for each tree in the elimination forest
152+
for (auto tree = child_ptr[size]; tree < child_ptr[size + 1]; tree++) {
153+
// start from the root
154+
const auto root = child[tree];
155+
auto cur_node = root;
156+
// traverse until we moved to the pseudo-root
157+
while (cur_node < size) {
158+
const auto first_child = child_ptr[cur_node];
159+
const auto num_children = child_ptr[cur_node + 1] - first_child;
160+
if (current_child[cur_node] >= num_children) {
161+
// if this node is completed, output it
162+
postorder[postorder_idx] = cur_node;
163+
inv_postorder[cur_node] = postorder_idx;
164+
cur_node = parent[cur_node];
165+
postorder_idx++;
166+
} else {
167+
// otherwise go to the next child node
168+
const auto old_node = cur_node;
169+
cur_node = child[first_child + current_child[old_node]];
170+
current_child[old_node]++;
171+
}
172+
}
173+
}
174+
}
175+
176+
177+
template <typename IndexType>
178+
void compute_elimination_forest_postorder_parent_impl(
179+
const IndexType* parent, const IndexType* inv_postorder, IndexType size,
180+
IndexType* postorder_parent)
181+
{
182+
for (IndexType row = 0; row < size; row++) {
183+
postorder_parent[inv_postorder[row]] =
184+
parent[row] == size ? size : inv_postorder[parent[row]];
185+
}
186+
}
187+
188+
78189
template <typename IndexType>
79190
void compute(std::shared_ptr<const DefaultExecutor> exec,
80191
const IndexType* row_ptrs, const IndexType* cols, size_type size,

0 commit comments

Comments
 (0)