@@ -57,6 +57,37 @@ GKO_REGISTER_OPERATION(initialize, ir::initialize);
5757} // namespace chebyshev
5858
5959
60+ template <typename ValueType>
61+ Chebyshev<ValueType>::Chebyshev(const Factory* factory,
62+ std::shared_ptr<const LinOp> system_matrix)
63+ : EnableLinOp<Chebyshev>(factory->get_executor (),
64+ gko::transpose(system_matrix->get_size ())),
65+ EnableSolverBase<Chebyshev>{std::move (system_matrix)},
66+ EnableIterativeBase<Chebyshev>{
67+ stop::combine (factory->get_parameters ().criteria )},
68+ parameters_{factory->get_parameters ()}
69+ {
70+ if (parameters_.generated_solver ) {
71+ this ->set_solver (parameters_.generated_solver );
72+ } else if (parameters_.solver ) {
73+ this ->set_solver (
74+ parameters_.solver ->generate (this ->get_system_matrix ()));
75+ } else {
76+ this ->set_solver (matrix::Identity<ValueType>::create (
77+ this ->get_executor (), this ->get_size ()));
78+ }
79+ this ->set_default_initial_guess (parameters_.default_initial_guess );
80+ center_ = (std::get<0 >(parameters_.foci ) + std::get<1 >(parameters_.foci )) /
81+ ValueType{2 };
82+ foci_direction_ =
83+ (std::get<1 >(parameters_.foci ) - std::get<0 >(parameters_.foci )) /
84+ ValueType{2 };
85+ // if changing the lower/upper eig, need to reset it to zero
86+ num_generated_scalar_ = 0 ;
87+ num_max_generation_ = 3 ;
88+ }
89+
90+
6091template <typename ValueType>
6192void Chebyshev<ValueType>::set_solver(std::shared_ptr<const LinOp> new_solver)
6293{
@@ -185,12 +216,29 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
185216 GKO_SOLVER_VECTOR (residual, dense_b);
186217 GKO_SOLVER_VECTOR (inner_solution, dense_b);
187218 GKO_SOLVER_VECTOR (update_solution, dense_b);
219+
188220 // Use the scalar first
189- auto num_keep = this ->get_parameters ().num_keep ;
221+ // get the iteration information from stopping criterion.
222+ if (auto combined =
223+ std::dynamic_pointer_cast<const gko::stop::Combined::Factory>(
224+ this ->get_stop_criterion_factory ())) {
225+ for (const auto & factory : combined->get_parameters ().criteria ) {
226+ if (auto iter_stop = std::dynamic_pointer_cast<
227+ const gko::stop::Iteration::Factory>(factory)) {
228+ num_max_generation_ = std::max (
229+ num_max_generation_, iter_stop->get_parameters ().max_iters );
230+ }
231+ }
232+ } else if (auto iter_stop = std::dynamic_pointer_cast<
233+ const gko::stop::Iteration::Factory>(
234+ this ->get_stop_criterion_factory ())) {
235+ num_max_generation_ = std::max (num_max_generation_,
236+ iter_stop->get_parameters ().max_iters );
237+ }
190238 auto alpha = this ->template create_workspace_scalar <ValueType>(
191- GKO_SOLVER_TRAITS ::alpha, num_keep + 1 );
239+ GKO_SOLVER_TRAITS ::alpha, num_max_generation_ + 1 );
192240 auto beta = this ->template create_workspace_scalar <ValueType>(
193- GKO_SOLVER_TRAITS ::beta, num_keep + 1 );
241+ GKO_SOLVER_TRAITS ::beta, num_max_generation_ + 1 );
194242
195243 GKO_SOLVER_ONE_MINUS_ONE ();
196244
@@ -218,39 +266,50 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
218266 int iter = -1 ;
219267 while (true ) {
220268 ++iter;
221- this ->template log <log::Logger::iteration_complete>(
222- this , iter, residual_ptr, dense_x);
223-
224269 if (iter == 0 ) {
225270 // In iter 0, the iteration and residual are updated.
226- if (stop_criterion->update ()
227- .num_iterations (iter)
228- .residual (residual_ptr)
229- .solution (dense_x)
230- .check (relative_stopping_id, true , &stop_status,
231- &one_changed)) {
271+ bool all_stopped = stop_criterion->update ()
272+ .num_iterations (iter)
273+ .residual (residual_ptr)
274+ .solution (dense_x)
275+ .check (relative_stopping_id, true ,
276+ &stop_status, &one_changed);
277+ this ->template log <log::Logger::iteration_complete>(
278+ this , dense_b, dense_x, iter, residual_ptr, nullptr , nullptr ,
279+ &stop_status, all_stopped);
280+ if (all_stopped) {
232281 break ;
233282 }
234283 } else {
235284 // In the other iterations, the residual can be updated separately.
236- if (stop_criterion->update ()
237- .num_iterations (iter)
238- .solution (dense_x)
239- .check (relative_stopping_id, false , &stop_status,
240- &one_changed)) {
285+ bool all_stopped = stop_criterion->update ()
286+ .num_iterations (iter)
287+ .solution (dense_x)
288+ // we have the residual check later
289+ .ignore_residual_check (true )
290+ .check (relative_stopping_id, false ,
291+ &stop_status, &one_changed);
292+ if (all_stopped) {
293+ this ->template log <log::Logger::iteration_complete>(
294+ this , dense_b, dense_x, iter, nullptr , nullptr , nullptr ,
295+ &stop_status, all_stopped);
241296 break ;
242297 }
243298 residual_ptr = residual;
244299 // residual = b - A * x
245300 residual->copy_from (dense_b);
246301 this ->get_system_matrix ()->apply (neg_one_op, dense_x, one_op,
247302 residual);
248- if (stop_criterion->update ()
249- .num_iterations (iter)
250- .residual (residual_ptr)
251- .solution (dense_x)
252- .check (relative_stopping_id, true , &stop_status,
253- &one_changed)) {
303+ all_stopped = stop_criterion->update ()
304+ .num_iterations (iter)
305+ .residual (residual_ptr)
306+ .solution (dense_x)
307+ .check (relative_stopping_id, true , &stop_status,
308+ &one_changed);
309+ this ->template log <log::Logger::iteration_complete>(
310+ this , dense_b, dense_x, iter, residual_ptr, nullptr , nullptr ,
311+ &stop_status, all_stopped);
312+ if (all_stopped) {
254313 break ;
255314 }
256315 }
@@ -262,17 +321,18 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
262321 inner_solution->copy_from (residual_ptr);
263322 }
264323 solver_->apply (residual_ptr, inner_solution);
265- size_type index = (iter >= num_keep) ? num_keep : iter;
324+ size_type index =
325+ (iter >= num_max_generation_) ? num_max_generation_ : iter;
266326 auto alpha_scalar =
267327 alpha->create_submatrix (span{0 , 1 }, span{index, index + 1 });
268328 auto beta_scalar =
269329 beta->create_submatrix (span{0 , 1 }, span{index, index + 1 });
270330 if (iter == 0 ) {
271- if (num_generated_ < num_keep ) {
331+ if (num_generated_scalar_ < num_max_generation_ ) {
272332 alpha_scalar->fill (alpha_ref);
273333 // unused beta for first iteration, but fill zero
274334 beta_scalar->fill (zero<ValueType>());
275- num_generated_ ++;
335+ num_generated_scalar_ ++;
276336 }
277337 // x = x + alpha * inner_solution
278338 dense_x->add_scaled (alpha_scalar.get (), inner_solution);
@@ -286,12 +346,13 @@ void Chebyshev<ValueType>::apply_dense_impl(const VectorType* dense_b,
286346 }
287347 alpha_ref = ValueType{1.0 } / (center_ - beta_ref / alpha_ref);
288348 // The last one is always the updated one
289- if (num_generated_ < num_keep || iter >= num_keep) {
349+ if (num_generated_scalar_ < num_max_generation_ ||
350+ iter >= num_max_generation_) {
290351 alpha_scalar->fill (alpha_ref);
291352 beta_scalar->fill (beta_ref);
292353 }
293- if (num_generated_ < num_keep ) {
294- num_generated_ ++;
354+ if (num_generated_scalar_ < num_max_generation_ ) {
355+ num_generated_scalar_ ++;
295356 }
296357 // z = z + beta * p
297358 // p = z
0 commit comments