Skip to content

Commit c88ea63

Browse files
yhmtsaiMarcelKoch
andcommitted
extract the split residual update and update test
Co-authored-by: Marcel Koch <marcel.koch@kit.edu>
1 parent b745765 commit c88ea63

5 files changed

Lines changed: 136 additions & 96 deletions

File tree

core/solver/chebyshev.cpp

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4040

4141
#include "core/distributed/helpers.hpp"
4242
#include "core/solver/ir_kernels.hpp"
43+
#include "core/solver/residual_update.hpp"
4344
#include "core/solver/solver_base.hpp"
4445
#include "core/solver/solver_boilerplate.hpp"
4546

@@ -266,52 +267,20 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
266267
int iter = -1;
267268
while (true) {
268269
++iter;
269-
if (iter == 0) {
270-
// In iter 0, the iteration and residual are updated.
271-
bool all_stopped = stop_criterion->update()
272-
.num_iterations(iter)
273-
.residual(residual_ptr)
274-
.solution(dense_x)
275-
.check(relative_stopping_id, true,
276-
&stop_status, &one_changed);
277-
this->template log<log::Logger::iteration_complete>(
278-
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
279-
&stop_status, all_stopped);
280-
if (all_stopped) {
281-
break;
282-
}
283-
} else {
284-
// In the other iterations, the residual can be updated separately.
285-
bool all_stopped = stop_criterion->update()
286-
.num_iterations(iter)
287-
.solution(dense_x)
288-
// we have the residual check later
289-
.ignore_residual_check(true)
290-
.check(relative_stopping_id, false,
291-
&stop_status, &one_changed);
292-
if (all_stopped) {
293-
this->template log<log::Logger::iteration_complete>(
294-
this, dense_b, dense_x, iter, nullptr, nullptr, nullptr,
295-
&stop_status, all_stopped);
296-
break;
297-
}
298-
residual_ptr = residual;
299-
// residual = b - A * x
300-
residual->copy_from(dense_b);
301-
this->get_system_matrix()->apply(neg_one_op, dense_x, one_op,
302-
residual);
303-
all_stopped = stop_criterion->update()
304-
.num_iterations(iter)
305-
.residual(residual_ptr)
306-
.solution(dense_x)
307-
.check(relative_stopping_id, true, &stop_status,
308-
&one_changed);
270+
auto log_func = [this](auto solver, auto dense_b, auto dense_x,
271+
auto iter, auto residual_ptr,
272+
array<stopping_status>& stop_status,
273+
bool all_stopped) {
309274
this->template log<log::Logger::iteration_complete>(
310-
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
275+
solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
311276
&stop_status, all_stopped);
312-
if (all_stopped) {
313-
break;
314-
}
277+
};
278+
bool all_stopped = residual_update(
279+
this, iter, one_op, neg_one_op, dense_b, dense_x, residual,
280+
residual_ptr, stop_criterion, relative_stopping_id, stop_status,
281+
one_changed, log_func);
282+
if (all_stopped) {
283+
break;
315284
}
316285

317286
if (solver_->apply_uses_initial_guess()) {

core/solver/ir.cpp

Lines changed: 14 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
4040

4141
#include "core/distributed/helpers.hpp"
4242
#include "core/solver/ir_kernels.hpp"
43+
#include "core/solver/residual_update.hpp"
4344
#include "core/solver/solver_base.hpp"
4445
#include "core/solver/solver_boilerplate.hpp"
4546

@@ -223,54 +224,23 @@ void Ir<ValueType>::apply_dense_impl(const VectorType* dense_b,
223224
while (true) {
224225
++iter;
225226

226-
if (iter == 0) {
227-
// In iter 0, the iteration and residual are updated.
228-
bool all_stopped = stop_criterion->update()
229-
.num_iterations(iter)
230-
.residual(residual_ptr)
231-
.solution(dense_x)
232-
.check(relative_stopping_id, true,
233-
&stop_status, &one_changed);
227+
auto log_func = [this](auto solver, auto dense_b, auto dense_x,
228+
auto iter, auto residual_ptr,
229+
array<stopping_status>& stop_status,
230+
bool all_stopped) {
234231
this->template log<log::Logger::iteration_complete>(
235-
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
232+
solver, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
236233
&stop_status, all_stopped);
237-
if (all_stopped) {
238-
break;
239-
}
240-
} else {
241-
// In the other iterations, the residual can be updated separately.
242-
bool all_stopped = stop_criterion->update()
243-
.num_iterations(iter)
244-
.solution(dense_x)
245-
// we have the residual check later
246-
.ignore_residual_check(true)
247-
.check(relative_stopping_id, false,
248-
&stop_status, &one_changed);
249-
if (all_stopped) {
250-
this->template log<log::Logger::iteration_complete>(
251-
this, dense_b, dense_x, iter, nullptr, nullptr, nullptr,
252-
&stop_status, all_stopped);
253-
break;
254-
}
255-
residual_ptr = residual;
256-
// residual = b - A * x
257-
residual->copy_from(dense_b);
258-
this->get_system_matrix()->apply(neg_one_op, dense_x, one_op,
259-
residual);
260-
all_stopped = stop_criterion->update()
261-
.num_iterations(iter)
262-
.residual(residual_ptr)
263-
.solution(dense_x)
264-
.check(relative_stopping_id, true, &stop_status,
265-
&one_changed);
266-
this->template log<log::Logger::iteration_complete>(
267-
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
268-
&stop_status, all_stopped);
269-
if (all_stopped) {
270-
break;
271-
}
234+
};
235+
bool all_stopped = residual_update(
236+
this, iter, one_op, neg_one_op, dense_b, dense_x, residual,
237+
residual_ptr, stop_criterion, relative_stopping_id, stop_status,
238+
one_changed, log_func);
239+
if (all_stopped) {
240+
break;
272241
}
273242

243+
274244
if (solver_->apply_uses_initial_guess()) {
275245
// Use the inner solver to solve
276246
// A * inner_solution = residual

core/solver/residual_update.hpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*******************************<GINKGO LICENSE>******************************
2+
Copyright (c) 2017-2023, the Ginkgo authors
3+
All rights reserved.
4+
5+
Redistribution and use in source and binary forms, with or without
6+
modification, are permitted provided that the following conditions
7+
are met:
8+
9+
1. Redistributions of source code must retain the above copyright
10+
notice, this list of conditions and the following disclaimer.
11+
12+
2. Redistributions in binary form must reproduce the above copyright
13+
notice, this list of conditions and the following disclaimer in the
14+
documentation and/or other materials provided with the distribution.
15+
16+
3. Neither the name of the copyright holder nor the names of its
17+
contributors may be used to endorse or promote products derived from
18+
this software without specific prior written permission.
19+
20+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
21+
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
22+
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
23+
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24+
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25+
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26+
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27+
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28+
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
******************************<GINKGO LICENSE>*******************************/
32+
33+
#ifndef GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_
34+
#define GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_
35+
36+
37+
#include <ginkgo/core/base/array.hpp>
38+
#include <ginkgo/core/matrix/dense.hpp>
39+
#include <ginkgo/core/stop/criterion.hpp>
40+
41+
42+
namespace gko {
43+
namespace solver {
44+
45+
46+
template <typename SolverType, typename VectorType, typename ScalarType,
47+
typename LogFunc>
48+
bool residual_update(SolverType* solver, int iter, const ScalarType* one_op,
49+
const ScalarType* neg_one_op, const VectorType* dense_b,
50+
VectorType* dense_x, VectorType* residual,
51+
const VectorType*& residual_ptr,
52+
std::unique_ptr<gko::stop::Criterion>& stop_criterion,
53+
uint8 relative_stopping_id,
54+
array<stopping_status>& stop_status, bool& one_changed,
55+
LogFunc log)
56+
{
57+
if (iter == 0) {
58+
// In iter 0, the iteration and residual are updated.
59+
bool all_stopped =
60+
stop_criterion->update()
61+
.num_iterations(iter)
62+
.residual(residual_ptr)
63+
.solution(dense_x)
64+
.check(relative_stopping_id, true, &stop_status, &one_changed);
65+
log(solver, dense_b, dense_x, iter, residual_ptr, stop_status,
66+
all_stopped);
67+
return all_stopped;
68+
} else {
69+
// In the other iterations, the residual can be updated separately.
70+
bool all_stopped =
71+
stop_criterion->update()
72+
.num_iterations(iter)
73+
.solution(dense_x)
74+
// we have the residual check later
75+
.ignore_residual_check(true)
76+
.check(relative_stopping_id, false, &stop_status, &one_changed);
77+
if (all_stopped) {
78+
log(solver, dense_b, dense_x, iter, nullptr, stop_status,
79+
all_stopped);
80+
return all_stopped;
81+
}
82+
residual_ptr = residual;
83+
// residual = b - A * x
84+
residual->copy_from(dense_b);
85+
solver->get_system_matrix()->apply(neg_one_op, dense_x, one_op,
86+
residual);
87+
all_stopped =
88+
stop_criterion->update()
89+
.num_iterations(iter)
90+
.residual(residual_ptr)
91+
.solution(dense_x)
92+
.check(relative_stopping_id, true, &stop_status, &one_changed);
93+
log(solver, dense_b, dense_x, iter, residual_ptr, stop_status,
94+
all_stopped);
95+
return all_stopped;
96+
}
97+
}
98+
99+
100+
} // namespace solver
101+
} // namespace gko
102+
103+
#endif // GKO_CORE_SOLVER_RESIDUAL_UPDATE_HPP_

core/test/solver/chebyshev.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,7 @@ TYPED_TEST(Chebyshev, CanSetInnerSolverInFactory)
195195
auto chebyshev_factory =
196196
Solver::build()
197197
.with_criteria(
198-
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec),
199-
gko::stop::ResidualNorm<value_type>::build()
200-
.with_reduction_factor(r<value_type>::value)
201-
.on(this->exec))
198+
gko::stop::Iteration::build().with_max_iters(3u).on(this->exec)
202199
.with_solver(
203200
Solver::build()
204201
.with_criteria(

include/ginkgo/core/solver/chebyshev.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
158158
GKO_FACTORY_PARAMETER_VECTOR(criteria, nullptr);
159159

160160
/**
161-
* Inner solver factory. If not provided this will result in a
162-
* non-preconditioned Chebyshev iteration.
161+
* Inner solver (preconditioner) factory. If not provided this will
162+
* result in a non-preconditioned Chebyshev iteration.
163163
*/
164164
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
165165
solver, nullptr);
@@ -172,8 +172,9 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
172172
generated_solver, nullptr);
173173

174174
/**
175-
* The pair of foci of ellipse. It is usually be {lower bound of eigval,
176-
* upper bound of eigval} for real matrices.
175+
* The pair of foci of ellipse, which covers the eigenvalues of
176+
* preconditioned system. It is usually be {lower bound of eigval, upper
177+
* bound of eigval} of preconditioned real matrices.
177178
*/
178179
std::pair<value_type, value_type> GKO_FACTORY_PARAMETER_VECTOR(
179180
foci, value_type{0}, value_type{1});

0 commit comments

Comments
 (0)