Skip to content

Commit 2498a1a

Browse files
committed
Improve the M-matrix check in RS
1 parent 206cd51 commit 2498a1a

3 files changed

Lines changed: 17 additions & 17 deletions

File tree

core/multigrid/rs.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,14 @@ void Rs<ValueType, IndexType>::generate()
5959
this->set_fine_op(rs_op_shared_ptr);
6060
}
6161
array<bool> is_m_matrix_array(exec, 1);
62-
exec->run(rs::make_check_m_matrix(rs_op, is_m_matrix_array));
63-
if (!exec->copy_val_to_host(is_m_matrix_array.get_const_data())) {
64-
GKO_NOT_SUPPORTED(
65-
"RS coarsening requires an M-matrix (strictly positive diagonal, "
66-
"non-positive off-diagonals).");
62+
if (!parameters_.skip_m_matrix_check) {
63+
exec->run(rs::make_check_m_matrix(rs_op, is_m_matrix_array));
64+
if (!exec->copy_val_to_host(is_m_matrix_array.get_const_data())) {
65+
GKO_NOT_SUPPORTED(
66+
"RS coarsening requires an M-matrix (strictly positive "
67+
"diagonal, "
68+
"non-positive off-diagonals).");
69+
}
6770
}
6871

6972
// define arrays

include/ginkgo/core/multigrid/rs.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ class Rs : public EnableLinOp<Rs<ValueType, IndexType>>,
7878
strength_threshold, 0.25);
7979

8080
bool GKO_FACTORY_PARAMETER_SCALAR(skip_sorting, false);
81+
82+
bool GKO_FACTORY_PARAMETER_SCALAR(skip_m_matrix_check, false);
8183
};
8284
GKO_ENABLE_LIN_OP_FACTORY(Rs, parameters, Factory);
8385
GKO_ENABLE_BUILD_METHOD(Factory);

reference/multigrid/rs_kernels.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void check_m_matrix(std::shared_ptr<const ReferenceExecutor> exec,
3434
const auto col_idxs = matrix->get_const_col_idxs();
3535
const auto values = matrix->get_const_values();
3636

37-
bool is_m_matrix = true;
37+
auto is_m_matrix = is_m_matrix_array.get_data();
3838

3939
for (size_type row = 0; row < num_rows; ++row) {
4040
bool has_diag = false;
@@ -46,27 +46,22 @@ void check_m_matrix(std::shared_ptr<const ReferenceExecutor> exec,
4646
if (row == col) {
4747
has_diag = true;
4848
if (val <= 0.0) {
49-
is_m_matrix = false;
50-
break;
49+
*is_m_matrix = false;
50+
return;
5151
}
5252
} else {
5353
if (val > 0.0) {
54-
is_m_matrix = false;
55-
break;
54+
*is_m_matrix = false;
55+
return;
5656
}
5757
}
5858
}
5959

6060
if (!has_diag) {
61-
is_m_matrix = false;
62-
}
63-
64-
if (!is_m_matrix) {
65-
break;
61+
*is_m_matrix = false;
62+
return;
6663
}
6764
}
68-
69-
is_m_matrix_array.get_data()[0] = is_m_matrix;
7065
}
7166

7267
GKO_INSTANTIATE_FOR_EACH_VALUE_AND_INDEX_TYPE(

0 commit comments

Comments
 (0)