|
1 | | -// SPDX-FileCopyrightText: 2017 - 2024 The Ginkgo authors |
| 1 | +// SPDX-FileCopyrightText: 2017 - 2025 The Ginkgo authors |
2 | 2 | // |
3 | 3 | // SPDX-License-Identifier: BSD-3-Clause |
4 | 4 |
|
@@ -45,11 +45,12 @@ struct SolveStruct : gko::solver::SolveStruct { |
45 | 45 | csrsv2Info_t solve_info; |
46 | 46 | hipsparseSolvePolicy_t policy; |
47 | 47 | hipsparseMatDescr_t factor_descr; |
48 | | - int factor_work_size; |
| 48 | + array<char> factor_work_array; |
49 | 49 | void* factor_work_vec; |
50 | | - SolveStruct(bool is_upper, bool unit_diag) |
| 50 | + SolveStruct(std::shared_ptr<const Executor> exec, bool is_upper, |
| 51 | + bool unit_diag) |
| 52 | + : factor_work_array{exec} |
51 | 53 | { |
52 | | - factor_work_vec = nullptr; |
53 | 54 | GKO_ASSERT_NO_HIPSPARSE_ERRORS(hipsparseCreateMatDescr(&factor_descr)); |
54 | 55 | GKO_ASSERT_NO_HIPSPARSE_ERRORS( |
55 | 56 | hipsparseSetMatIndexBase(factor_descr, HIPSPARSE_INDEX_BASE_ZERO)); |
@@ -79,10 +80,6 @@ struct SolveStruct : gko::solver::SolveStruct { |
79 | 80 | if (solve_info) { |
80 | 81 | hipsparseDestroyCsrsv2Info(solve_info); |
81 | 82 | } |
82 | | - if (factor_work_vec != nullptr) { |
83 | | - hipFree(factor_work_vec); |
84 | | - factor_work_vec = nullptr; |
85 | | - } |
86 | 83 | } |
87 | 84 | }; |
88 | 85 |
|
@@ -114,29 +111,31 @@ void generate_kernel(std::shared_ptr<const HipExecutor> exec, |
114 | 111 | return; |
115 | 112 | } |
116 | 113 | if (sparselib::is_supported<ValueType, IndexType>::value) { |
117 | | - solve_struct = |
118 | | - std::make_shared<solver::hip::SolveStruct>(is_upper, unit_diag); |
| 114 | + solve_struct = std::make_shared<solver::hip::SolveStruct>( |
| 115 | + exec, is_upper, unit_diag); |
119 | 116 | if (auto hip_solve_struct = |
120 | 117 | std::dynamic_pointer_cast<solver::hip::SolveStruct>( |
121 | 118 | solve_struct)) { |
122 | 119 | auto handle = exec->get_sparselib_handle(); |
123 | 120 |
|
124 | 121 | { |
125 | 122 | sparselib::pointer_mode_guard pm_guard(handle); |
| 123 | + int factor_work_size{}; |
126 | 124 | sparselib::csrsv2_buffer_size( |
127 | 125 | handle, SPARSELIB_OPERATION_NON_TRANSPOSE, |
128 | 126 | matrix->get_size()[0], matrix->get_num_stored_elements(), |
129 | 127 | hip_solve_struct->factor_descr, matrix->get_const_values(), |
130 | 128 | matrix->get_const_row_ptrs(), matrix->get_const_col_idxs(), |
131 | | - hip_solve_struct->solve_info, |
132 | | - &hip_solve_struct->factor_work_size); |
| 129 | + hip_solve_struct->solve_info, &factor_work_size); |
133 | 130 |
|
134 | 131 | // allocate workspace |
135 | | - if (hip_solve_struct->factor_work_vec != nullptr) { |
136 | | - exec->free(hip_solve_struct->factor_work_vec); |
| 132 | + if (hip_solve_struct->factor_work_array.get_size() < |
| 133 | + factor_work_size) { |
| 134 | + hip_solve_struct->factor_work_array.resize_and_reset( |
| 135 | + factor_work_size); |
| 136 | + hip_solve_struct->factor_work_vec = |
| 137 | + hip_solve_struct->factor_work_array.get_data(); |
137 | 138 | } |
138 | | - hip_solve_struct->factor_work_vec = |
139 | | - exec->alloc<void*>(hip_solve_struct->factor_work_size); |
140 | 139 |
|
141 | 140 | sparselib::csrsv2_analysis( |
142 | 141 | handle, SPARSELIB_OPERATION_NON_TRANSPOSE, |
|
0 commit comments