Skip to content

Commit b745765

Browse files
committed
use iteration from stop criterion and update doc
1 parent 7f35a9a commit b745765

3 files changed

Lines changed: 246 additions & 70 deletions

File tree

core/solver/chebyshev.cpp

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,37 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize);
5757
} // namespace chebyshev
5858

5959

60+
template <typename ValueType>
61+
Chebyshev<ValueType>::Chebyshev(const Factory* factory,
62+
std::shared_ptr<const LinOp> system_matrix)
63+
: EnableLinOp<Chebyshev>(factory->get_executor(),
64+
gko::transpose(system_matrix->get_size())),
65+
EnableSolverBase<Chebyshev>{std::move(system_matrix)},
66+
EnableIterativeBase<Chebyshev>{
67+
stop::combine(factory->get_parameters().criteria)},
68+
parameters_{factory->get_parameters()}
69+
{
70+
if (parameters_.generated_solver) {
71+
this->set_solver(parameters_.generated_solver);
72+
} else if (parameters_.solver) {
73+
this->set_solver(
74+
parameters_.solver->generate(this->get_system_matrix()));
75+
} else {
76+
this->set_solver(matrix::Identity<ValueType>::create(
77+
this->get_executor(), this->get_size()));
78+
}
79+
this->set_default_initial_guess(parameters_.default_initial_guess);
80+
center_ = (std::get<0>(parameters_.foci) + std::get<1>(parameters_.foci)) /
81+
ValueType{2};
82+
foci_direction_ =
83+
(std::get<1>(parameters_.foci) - std::get<0>(parameters_.foci)) /
84+
ValueType{2};
85+
// if changing the lower/upper eig, need to reset it to zero
86+
num_generated_scalar_ = 0;
87+
num_max_generation_ = 3;
88+
}
89+
90+
6091
template <typename ValueType>
6192
void Chebyshev<ValueType>::set_solver(std::shared_ptr<const LinOp> new_solver)
6293
{
@@ -185,12 +216,29 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
185216
GKO_SOLVER_VECTOR(residual, dense_b);
186217
GKO_SOLVER_VECTOR(inner_solution, dense_b);
187218
GKO_SOLVER_VECTOR(update_solution, dense_b);
219+
188220
// Use the scalar first
189-
auto num_keep = this->get_parameters().num_keep;
221+
// get the iteration information from stopping criterion.
222+
if (auto combined =
223+
std::dynamic_pointer_cast<const gko::stop::Combined::Factory>(
224+
this->get_stop_criterion_factory())) {
225+
for (const auto& factory : combined->get_parameters().criteria) {
226+
if (auto iter_stop = std::dynamic_pointer_cast<
227+
const gko::stop::Iteration::Factory>(factory)) {
228+
num_max_generation_ = std::max(
229+
num_max_generation_, iter_stop->get_parameters().max_iters);
230+
}
231+
}
232+
} else if (auto iter_stop = std::dynamic_pointer_cast<
233+
const gko::stop::Iteration::Factory>(
234+
this->get_stop_criterion_factory())) {
235+
num_max_generation_ = std::max(num_max_generation_,
236+
iter_stop->get_parameters().max_iters);
237+
}
190238
auto alpha = this->template create_workspace_scalar<ValueType>(
191-
GKO_SOLVER_TRAITS::alpha, num_keep + 1);
239+
GKO_SOLVER_TRAITS::alpha, num_max_generation_ + 1);
192240
auto beta = this->template create_workspace_scalar<ValueType>(
193-
GKO_SOLVER_TRAITS::beta, num_keep + 1);
241+
GKO_SOLVER_TRAITS::beta, num_max_generation_ + 1);
194242

195243
GKO_SOLVER_ONE_MINUS_ONE();
196244

@@ -218,39 +266,50 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
218266
int iter = -1;
219267
while (true) {
220268
++iter;
221-
this->template log<log::Logger::iteration_complete>(
222-
this, iter, residual_ptr, dense_x);
223-
224269
if (iter == 0) {
225270
// In iter 0, the iteration and residual are updated.
226-
if (stop_criterion->update()
227-
.num_iterations(iter)
228-
.residual(residual_ptr)
229-
.solution(dense_x)
230-
.check(relative_stopping_id, true, &stop_status,
231-
&one_changed)) {
271+
bool all_stopped = stop_criterion->update()
272+
.num_iterations(iter)
273+
.residual(residual_ptr)
274+
.solution(dense_x)
275+
.check(relative_stopping_id, true,
276+
&stop_status, &one_changed);
277+
this->template log<log::Logger::iteration_complete>(
278+
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
279+
&stop_status, all_stopped);
280+
if (all_stopped) {
232281
break;
233282
}
234283
} else {
235284
// In the other iterations, the residual can be updated separately.
236-
if (stop_criterion->update()
237-
.num_iterations(iter)
238-
.solution(dense_x)
239-
.check(relative_stopping_id, false, &stop_status,
240-
&one_changed)) {
285+
bool all_stopped = stop_criterion->update()
286+
.num_iterations(iter)
287+
.solution(dense_x)
288+
// we have the residual check later
289+
.ignore_residual_check(true)
290+
.check(relative_stopping_id, false,
291+
&stop_status, &one_changed);
292+
if (all_stopped) {
293+
this->template log<log::Logger::iteration_complete>(
294+
this, dense_b, dense_x, iter, nullptr, nullptr, nullptr,
295+
&stop_status, all_stopped);
241296
break;
242297
}
243298
residual_ptr = residual;
244299
// residual = b - A * x
245300
residual->copy_from(dense_b);
246301
this->get_system_matrix()->apply(neg_one_op, dense_x, one_op,
247302
residual);
248-
if (stop_criterion->update()
249-
.num_iterations(iter)
250-
.residual(residual_ptr)
251-
.solution(dense_x)
252-
.check(relative_stopping_id, true, &stop_status,
253-
&one_changed)) {
303+
all_stopped = stop_criterion->update()
304+
.num_iterations(iter)
305+
.residual(residual_ptr)
306+
.solution(dense_x)
307+
.check(relative_stopping_id, true, &stop_status,
308+
&one_changed);
309+
this->template log<log::Logger::iteration_complete>(
310+
this, dense_b, dense_x, iter, residual_ptr, nullptr, nullptr,
311+
&stop_status, all_stopped);
312+
if (all_stopped) {
254313
break;
255314
}
256315
}
@@ -262,17 +321,18 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
262321
inner_solution->copy_from(residual_ptr);
263322
}
264323
solver_->apply(residual_ptr, inner_solution);
265-
size_type index = (iter >= num_keep) ? num_keep : iter;
324+
size_type index =
325+
(iter >= num_max_generation_) ? num_max_generation_ : iter;
266326
auto alpha_scalar =
267327
alpha->create_submatrix(span{0, 1}, span{index, index + 1});
268328
auto beta_scalar =
269329
beta->create_submatrix(span{0, 1}, span{index, index + 1});
270330
if (iter == 0) {
271-
if (num_generated_ < num_keep) {
331+
if (num_generated_scalar_ < num_max_generation_) {
272332
alpha_scalar->fill(alpha_ref);
273333
// unused beta for first iteration, but fill zero
274334
beta_scalar->fill(zero<ValueType>());
275-
num_generated_++;
335+
num_generated_scalar_++;
276336
}
277337
// x = x + alpha * inner_solution
278338
dense_x->add_scaled(alpha_scalar.get(), inner_solution);
@@ -286,12 +346,13 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
286346
}
287347
alpha_ref = ValueType{1.0} / (center_ - beta_ref / alpha_ref);
288348
// The last one is always the updated one
289-
if (num_generated_ < num_keep || iter >= num_keep) {
349+
if (num_generated_scalar_ < num_max_generation_ ||
350+
iter >= num_max_generation_) {
290351
alpha_scalar->fill(alpha_ref);
291352
beta_scalar->fill(beta_ref);
292353
}
293-
if (num_generated_ < num_keep) {
294-
num_generated_++;
354+
if (num_generated_scalar_ < num_max_generation_) {
355+
num_generated_scalar_++;
295356
}
296357
// z = z + beta * p
297358
// p = z

include/ginkgo/core/solver/chebyshev.hpp

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,12 @@ namespace solver {
5353

5454

5555
/**
56-
* Chebyshev iteration is an iterative method that uses another coarse
57-
* method to approximate the error of the current solution via the current
56+
* Chebyshev iteration is an iterative method that uses another inner
57+
* solver to approximate the error of the current solution via the current
5858
* residual. It has another term for the difference of solution. Moreover, this
59-
* method requires knowledge about the spectrum of the matrix.
59+
* method requires knowledge about the spectrum of the matrix. This
60+
* implementation follows the algorithm in "Templates for the Solution of Linear
61+
* Systems: Building Blocks for Iterative Methods, 2nd Edition".
6062
*
6163
* ```
6264
* solution = initial_guess
@@ -156,7 +158,8 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
156158
GKO_FACTORY_PARAMETER_VECTOR(criteria, nullptr);
157159

158160
/**
159-
* Inner solver factory.
161+
* Inner solver factory. If not provided this will result in a
162+
* non-preconditioned Chebyshev iteration.
160163
*/
161164
std::shared_ptr<const LinOpFactory> GKO_FACTORY_PARAMETER_SCALAR(
162165
solver, nullptr);
@@ -181,11 +184,6 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
181184
*/
182185
initial_guess_mode GKO_FACTORY_PARAMETER_SCALAR(
183186
default_initial_guess, initial_guess_mode::provided);
184-
185-
/**
186-
* The number of scalar to keep
187-
*/
188-
int GKO_FACTORY_PARAMETER_SCALAR(num_keep, 2);
189187
};
190188
GKO_ENABLE_LIN_OP_FACTORY(Chebyshev, parameters, Factory);
191189
GKO_ENABLE_BUILD_METHOD(Factory);
@@ -215,38 +213,16 @@ class Chebyshev : public EnableLinOp<Chebyshev<ValueType>>,
215213
{}
216214

217215
explicit Chebyshev(const Factory* factory,
218-
std::shared_ptr<const LinOp> system_matrix)
219-
: EnableLinOp<Chebyshev>(factory->get_executor(),
220-
gko::transpose(system_matrix->get_size())),
221-
EnableSolverBase<Chebyshev>{std::move(system_matrix)},
222-
EnableIterativeBase<Chebyshev>{
223-
stop::combine(factory->get_parameters().criteria)},
224-
parameters_{factory->get_parameters()}
225-
{
226-
if (parameters_.generated_solver) {
227-
this->set_solver(parameters_.generated_solver);
228-
} else if (parameters_.solver) {
229-
this->set_solver(
230-
parameters_.solver->generate(this->get_system_matrix()));
231-
} else {
232-
this->set_solver(matrix::Identity<ValueType>::create(
233-
this->get_executor(), this->get_size()));
234-
}
235-
this->set_default_initial_guess(parameters_.default_initial_guess);
236-
center_ =
237-
(std::get<0>(parameters_.foci) + std::get<1>(parameters_.foci)) /
238-
ValueType{2};
239-
// the absolute value of foci_direction is the focal direction
240-
foci_direction_ =
241-
(std::get<1>(parameters_.foci) - std::get<0>(parameters_.foci)) /
242-
ValueType{2};
243-
// if changing the lower/upper eig, need to reset it to zero
244-
num_generated_ = 0;
245-
}
216+
std::shared_ptr<const LinOp> system_matrix);
246217

247218
private:
248219
std::shared_ptr<const LinOp> solver_{};
249-
mutable int num_generated_;
220+
// num_generated_scalar_ is to track the number of generated scalar alpha
221+
// and beta.
222+
mutable size_type num_generated_scalar_;
223+
// num_max_generation_ is the number of keeping the generated scalar in
224+
// workspace.
225+
mutable size_type num_max_generation_;
250226
ValueType center_;
251227
ValueType foci_direction_;
252228
};

0 commit comments

Comments
 (0)