@@ -75,6 +75,117 @@ GKO_INSTANTIATE_FOR_EACH_INDEX_TYPE(
7575 GKO_DECLARE_ELIMINATION_FOREST_COMPUTE_SKELETON_TREE);
7676
7777
78+ template <typename IndexType>
79+ void compute_elimination_forest_parent_impl (
80+ std::shared_ptr<const Executor> host_exec, const IndexType* row_ptrs,
81+ const IndexType* cols, IndexType num_rows, IndexType* parent)
82+ {
83+ disjoint_sets<IndexType> subtrees{host_exec, num_rows};
84+ array<IndexType> subtree_root_array{host_exec,
85+ static_cast <size_type>(num_rows)};
86+ // pseudo-root one past the last row to deal with disconnected matrices
87+ const auto unattached = num_rows;
88+ auto subtree_root = subtree_root_array.get_data ();
89+ for (IndexType row = 0 ; row < num_rows; row++) {
90+ // so far the row is an unattached singleton subtree
91+ subtree_root[row] = row;
92+ parent[row] = unattached;
93+ auto row_rep = row;
94+ for (auto nz = row_ptrs[row]; nz < row_ptrs[row + 1 ]; nz++) {
95+ const auto col = cols[nz];
96+ // for each lower triangular entry
97+ if (col < row) {
98+ // find the subtree it is contained in
99+ const auto col_rep = subtrees.find (col);
100+ const auto col_root = subtree_root[col_rep];
101+ // if it is not yet attached, put it below row
102+ // and make row its new root
103+ if (parent[col_root] == unattached && col_root != row) {
104+ parent[col_root] = row;
105+ row_rep = subtrees.join (row_rep, col_rep);
106+ subtree_root[row_rep] = row;
107+ }
108+ }
109+ }
110+ }
111+ }
112+
113+
114+ template <typename IndexType>
115+ void compute_elimination_forest_children_impl (const IndexType* parent,
116+ IndexType size,
117+ IndexType* child_ptr,
118+ IndexType* child)
119+ {
120+ // count how many times each parent occurs, excluding pseudo-root at
121+ // parent == size
122+ std::fill_n (child_ptr, size + 2 , IndexType{});
123+ for (IndexType i = 0 ; i < size; i++) {
124+ const auto p = parent[i];
125+ if (p < size) {
126+ child_ptr[p + 2 ]++;
127+ }
128+ }
129+ // shift by 2 leads to exclusive prefix sum with 0 padding
130+ std::partial_sum (child_ptr, child_ptr + size + 2 , child_ptr);
131+ // we count the same again, this time shifted by 1 => exclusive prefix sum
132+ for (IndexType i = 0 ; i < size; i++) {
133+ const auto p = parent[i];
134+ child[child_ptr[p + 1 ]] = i;
135+ child_ptr[p + 1 ]++;
136+ }
137+ }
138+
139+
140+ template <typename IndexType>
141+ void compute_elimination_forest_postorder_impl (
142+ std::shared_ptr<const Executor> host_exec, const IndexType* parent,
143+ const IndexType* child_ptr, const IndexType* child, IndexType size,
144+ IndexType* postorder, IndexType* inv_postorder)
145+ {
146+ array<IndexType> current_child_array{host_exec,
147+ static_cast <size_type>(size + 1 )};
148+ current_child_array.fill (0 );
149+ auto current_child = current_child_array.get_data ();
150+ IndexType postorder_idx{};
151+ // for each tree in the elimination forest
152+ for (auto tree = child_ptr[size]; tree < child_ptr[size + 1 ]; tree++) {
153+ // start from the root
154+ const auto root = child[tree];
155+ auto cur_node = root;
156+ // traverse until we moved to the pseudo-root
157+ while (cur_node < size) {
158+ const auto first_child = child_ptr[cur_node];
159+ const auto num_children = child_ptr[cur_node + 1 ] - first_child;
160+ if (current_child[cur_node] >= num_children) {
161+ // if this node is completed, output it
162+ postorder[postorder_idx] = cur_node;
163+ inv_postorder[cur_node] = postorder_idx;
164+ cur_node = parent[cur_node];
165+ postorder_idx++;
166+ } else {
167+ // otherwise go to the next child node
168+ const auto old_node = cur_node;
169+ cur_node = child[first_child + current_child[old_node]];
170+ current_child[old_node]++;
171+ }
172+ }
173+ }
174+ }
175+
176+
177+ template <typename IndexType>
178+ void compute_elimination_forest_postorder_parent_impl (
179+ const IndexType* parent, const IndexType* inv_postorder, IndexType size,
180+ IndexType* postorder_parent)
181+ {
182+ for (IndexType row = 0 ; row < size; row++) {
183+ postorder_parent[inv_postorder[row]] =
184+ parent[row] == size ? size : inv_postorder[parent[row]];
185+ }
186+ }
187+
188+
78189template <typename IndexType>
79190void compute (std::shared_ptr<const DefaultExecutor> exec,
80191 const IndexType* row_ptrs, const IndexType* cols, size_type size,
0 commit comments