-
Notifications
You must be signed in to change notification settings - Fork 113
Expand file tree
/
Copy pathinvert_quda.h
More file actions
1669 lines (1352 loc) · 56.8 KB
/
invert_quda.h
File metadata and controls
1669 lines (1352 loc) · 56.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#pragma once
#include <quda.h>
#include <quda_internal.h>
#include <timer.h>
#include <dirac_quda.h>
#include <color_spinor_field.h>
#include <qio_field.h>
#include <eigensolve_quda.h>
#include <vector>
#include <memory>
#include <madwf_param.h>
namespace quda {
/**
SolverParam is the meta data used to define linear solvers.
*/
struct SolverParam {
/**
Which linear solver to use
*/
QudaInverterType inv_type;
/**
* The inner Krylov solver used in the preconditioner. Set to
* QUDA_INVALID_INVERTER to disable the preconditioner entirely.
*/
QudaInverterType inv_type_precondition;
/**
* Preconditioner instance, e.g., multigrid
*/
void *preconditioner;
/**
* Deflation operator
*/
void *deflation_op;
/**
* Whether to use the L2 relative residual, L2 absolute residual
* or Fermilab heavy-quark residual, or combinations therein to
* determine convergence. To require that multiple stopping
* conditions are satisfied, use a bitwise OR as follows:
*
* p.residual_type = (QudaResidualType) (QUDA_L2_RELATIVE_RESIDUAL
* | QUDA_HEAVY_QUARK_RESIDUAL);
*/
QudaResidualType residual_type;
/**< Whether deflate the initial guess */
bool deflate;
/**< Used to define deflation */
QudaEigParam eig_param;
/**< Whether to use an initial guess in the solver or not */
QudaUseInitGuess use_init_guess;
/**< Whether or not to allow a zero RHS solve for near-null vector generation */
bool compute_null_vector;
/**< Reliable update tolerance */
double delta;
/**< Whether to user alternative reliable updates (CG only at the moment) */
bool use_alternative_reliable;
/**< Whether to keep the partial solution accumulator in sloppy precision */
bool use_sloppy_partial_accumulator;
/**< This parameter determines how often we accumulate into the
solution vector from the direction vectors in the solver.
E.g., running with solution_accumulator_pipeline = 4, means we
will update the solution vector every four iterations using the
direction vectors from the prior four iterations. This
increases performance of mixed-precision solvers since it means
less high-precision vector round-trip memory travel, but
requires more low-precision memory allocation. */
int solution_accumulator_pipeline;
/**< This parameter determines how many consecutive reliable update
residual increases we tolerate before terminating the solver,
i.e., how long do we want to keep trying to converge */
int max_res_increase;
/**< This parameter determines how many total reliable update
residual increases we tolerate before terminating the solver,
i.e., how long do we want to keep trying to converge */
int max_res_increase_total;
/**< This parameter determines how many consecutive heavy-quark
residual increases we tolerate before terminating the solver,
i.e., how long do we want to keep trying to converge */
int max_hq_res_increase;
/**< This parameter determines how many total heavy-quark residual
restarts we tolerate before terminating the solver, i.e., how long
do we want to keep trying to converge */
int max_hq_res_restart_total;
/**< After how many iterations shall the heavy quark residual be updated */
int heavy_quark_check;
/**< Enable pipeline solver */
int pipeline;
/**< Solver tolerance in the L2 residual norm */
double tol;
/**< Solver tolerance in the L2 residual norm */
double tol_restart;
/**< Solver tolerance in the heavy quark residual norm */
double tol_hq;
/**< Whether to compute the true residual post solve */
bool compute_true_res;
/** Whether to declare convergence without checking the true residual */
bool sloppy_converge;
/**< Actual L2 residual norm achieved in solver */
double true_res;
/**< Actual heavy quark residual norm achieved in solver */
double true_res_hq;
/**< Maximum number of iterations in the linear solver */
int maxiter;
/**< The number of iterations performed by the solver */
int iter;
/**< The precision used by the QUDA solver */
QudaPrecision precision;
/**< The precision used by the QUDA sloppy operator */
QudaPrecision precision_sloppy;
/**< The precision used by the QUDA sloppy operator for multishift refinement */
QudaPrecision precision_refinement_sloppy;
/**< The precision used by the QUDA preconditioner */
QudaPrecision precision_precondition;
/**< The precision used by the QUDA eigensolver */
QudaPrecision precision_eigensolver;
/**< Whether the source vector should contain the residual vector
when the solver returns */
bool return_residual;
/**< Domain overlap to use in the preconditioning */
int overlap_precondition;
/**< Number of sources in the multi-src solver */
int num_src;
// Multi-shift solver parameters
/**< Number of offsets in the multi-shift solver */
int num_offset;
/** Offsets for multi-shift solver */
double offset[QUDA_MAX_MULTI_SHIFT];
/** Solver tolerance for each offset */
double tol_offset[QUDA_MAX_MULTI_SHIFT];
/** Solver tolerance for each shift when refinement is applied using the heavy-quark residual */
double tol_hq_offset[QUDA_MAX_MULTI_SHIFT];
/** Actual L2 residual norm achieved in solver for each offset */
double true_res_offset[QUDA_MAX_MULTI_SHIFT];
/** Iterated L2 residual norm achieved in multi shift solver for each offset */
double iter_res_offset[QUDA_MAX_MULTI_SHIFT];
/** Actual heavy quark residual norm achieved in solver for each offset */
double true_res_hq_offset[QUDA_MAX_MULTI_SHIFT];
/** Number of steps in s-step algorithms */
int Nsteps;
/** Maximum size of Krylov space used by solver */
int Nkrylov;
/** Number of preconditioner cycles to perform per iteration */
int precondition_cycle;
/** Tolerance in the inner solver */
double tol_precondition;
/** Maximum number of iterations allowed in the inner solver */
int maxiter_precondition;
/** Relaxation parameter used in GCR-DD (default = 1.0) */
double omega;
/** Basis for CA algorithms */
QudaPolynomialBasis ca_basis;
/** Minimum eigenvalue for Chebyshev CA basis */
double ca_lambda_min;
/** Maximum eigenvalue for Chebyshev CA basis */
double ca_lambda_max; // -1 -> power iter generate
/** Basis for CA algorithms in a preconditioner */
QudaPolynomialBasis ca_basis_precondition;
/** Minimum eigenvalue for Chebyshev CA basis in a preconditioner */
double ca_lambda_min_precondition;
/** Maximum eigenvalue for Chebyshev CA basis in a preconditioner */
double ca_lambda_max_precondition; // -1 -> power iter generate
/** Whether to use additive or multiplicative Schwarz preconditioning */
QudaSchwarzType schwarz_type;
/** The type of accelerator type to use for preconditioner */
QudaAcceleratorType accelerator_type_precondition;
/**< The time taken by the solver */
double secs;
/**< The Gflops rate of the solver */
double gflops;
// Incremental EigCG solver parameters
/**< The precision of the Ritz vectors */
QudaPrecision precision_ritz;//also search space precision
int n_ev; // number of eigenvectors produced by EigCG
int m;//Dimension of the search space
int deflation_grid;
int rhs_idx;
int eigcg_max_restarts;
int max_restart_num;
double inc_tol;
double eigenval_tol;
QudaVerbosity verbosity_precondition; //! verbosity to use for preconditioner
bool is_preconditioner; //! whether the solver acting as a preconditioner for another solver
bool global_reduction; //! whether to use a global or local (node) reduction for this solver
/** Whether the MG preconditioner (if any) is an instance of MG
(used internally in MG) or of multigrid_solver (used in the
interface)*/
bool mg_instance;
MadwfParam madwf_param;
/** Whether to perform advanced features in a preconditioning inversion,
including reliable updates, pipelining, and mixed precision. */
bool precondition_no_advanced_feature;
/** Which external lib to use in the solver */
QudaExtLibType extlib_type;
/**
Default constructor
*/
SolverParam() :
compute_null_vector(false),
compute_true_res(true),
sloppy_converge(false),
verbosity_precondition(QUDA_SILENT),
mg_instance(false)
{
;
}
/**
Constructor that matches the initial values to that of the
QudaInvertParam instance
@param param The QudaInvertParam instance from which the values are copied
*/
SolverParam(const QudaInvertParam ¶m) :
inv_type(param.inv_type),
inv_type_precondition(param.inv_type_precondition),
preconditioner(param.preconditioner),
deflation_op(param.deflation_op),
residual_type(param.residual_type),
deflate(param.eig_param != 0),
use_init_guess(param.use_init_guess),
compute_null_vector(false),
delta(param.reliable_delta),
use_alternative_reliable(param.use_alternative_reliable),
use_sloppy_partial_accumulator(param.use_sloppy_partial_accumulator),
solution_accumulator_pipeline(param.solution_accumulator_pipeline),
max_res_increase(param.max_res_increase),
max_res_increase_total(param.max_res_increase_total),
max_hq_res_increase(param.max_hq_res_increase),
max_hq_res_restart_total(param.max_hq_res_restart_total),
heavy_quark_check(param.heavy_quark_check),
pipeline(param.pipeline),
tol(param.tol),
tol_restart(param.tol_restart),
tol_hq(param.tol_hq),
compute_true_res(param.compute_true_res),
sloppy_converge(false),
true_res(param.true_res),
true_res_hq(param.true_res_hq),
maxiter(param.maxiter),
iter(param.iter),
precision(param.cuda_prec),
precision_sloppy(param.cuda_prec_sloppy),
precision_refinement_sloppy(param.cuda_prec_refinement_sloppy),
precision_precondition(param.cuda_prec_precondition),
precision_eigensolver(param.cuda_prec_eigensolver),
return_residual(false),
num_src(param.num_src),
num_offset(param.num_offset),
Nsteps(param.Nsteps),
Nkrylov(param.gcrNkrylov),
precondition_cycle(param.precondition_cycle),
tol_precondition(param.tol_precondition),
maxiter_precondition(param.maxiter_precondition),
omega(param.omega),
ca_basis(param.ca_basis),
ca_lambda_min(param.ca_lambda_min),
ca_lambda_max(param.ca_lambda_max),
ca_basis_precondition(param.ca_basis_precondition),
ca_lambda_min_precondition(param.ca_lambda_min_precondition),
ca_lambda_max_precondition(param.ca_lambda_max_precondition),
schwarz_type(param.schwarz_type),
accelerator_type_precondition(param.accelerator_type_precondition),
secs(param.secs),
gflops(param.gflops),
precision_ritz(param.cuda_prec_ritz),
n_ev(param.n_ev),
m(param.max_search_dim),
deflation_grid(param.deflation_grid),
rhs_idx(0),
eigcg_max_restarts(param.eigcg_max_restarts),
max_restart_num(param.max_restart_num),
inc_tol(param.inc_tol),
eigenval_tol(param.eigenval_tol),
verbosity_precondition(param.verbosity_precondition),
is_preconditioner(false),
global_reduction(true),
mg_instance(false),
precondition_no_advanced_feature(param.schwarz_type == QUDA_ADDITIVE_SCHWARZ),
extlib_type(param.extlib_type)
{
if (deflate) { eig_param = *(static_cast<QudaEigParam *>(param.eig_param)); }
for (int i=0; i<num_offset; i++) {
offset[i] = param.offset[i];
tol_offset[i] = param.tol_offset[i];
tol_hq_offset[i] = param.tol_hq_offset[i];
}
if (param.rhs_idx != 0
&& (param.inv_type == QUDA_INC_EIGCG_INVERTER || param.inv_type == QUDA_GMRESDR_PROJ_INVERTER)) {
rhs_idx = param.rhs_idx;
}
madwf_param.madwf_diagonal_suppressor = param.madwf_diagonal_suppressor;
madwf_param.madwf_ls = param.madwf_ls;
madwf_param.madwf_null_miniter = param.madwf_null_miniter;
madwf_param.madwf_null_tol = param.madwf_null_tol;
madwf_param.madwf_train_maxiter = param.madwf_train_maxiter;
madwf_param.madwf_param_load = param.madwf_param_load == QUDA_BOOLEAN_TRUE;
madwf_param.madwf_param_save = param.madwf_param_save == QUDA_BOOLEAN_TRUE;
if (madwf_param.madwf_param_load) madwf_param.madwf_param_infile = std::string(param.madwf_param_infile);
if (madwf_param.madwf_param_save) madwf_param.madwf_param_outfile = std::string(param.madwf_param_outfile);
}
SolverParam(const SolverParam ¶m) :
inv_type(param.inv_type),
inv_type_precondition(param.inv_type_precondition),
preconditioner(param.preconditioner),
deflation_op(param.deflation_op),
residual_type(param.residual_type),
deflate(param.deflate),
eig_param(param.eig_param),
use_init_guess(param.use_init_guess),
compute_null_vector(param.compute_null_vector),
delta(param.delta),
use_alternative_reliable(param.use_alternative_reliable),
use_sloppy_partial_accumulator(param.use_sloppy_partial_accumulator),
solution_accumulator_pipeline(param.solution_accumulator_pipeline),
max_res_increase(param.max_res_increase),
max_res_increase_total(param.max_res_increase_total),
heavy_quark_check(param.heavy_quark_check),
pipeline(param.pipeline),
tol(param.tol),
tol_restart(param.tol_restart),
tol_hq(param.tol_hq),
compute_true_res(param.compute_true_res),
sloppy_converge(param.sloppy_converge),
true_res(param.true_res),
true_res_hq(param.true_res_hq),
maxiter(param.maxiter),
iter(param.iter),
precision(param.precision),
precision_sloppy(param.precision_sloppy),
precision_refinement_sloppy(param.precision_refinement_sloppy),
precision_precondition(param.precision_precondition),
precision_eigensolver(param.precision_eigensolver),
return_residual(param.return_residual),
num_offset(param.num_offset),
Nsteps(param.Nsteps),
Nkrylov(param.Nkrylov),
precondition_cycle(param.precondition_cycle),
tol_precondition(param.tol_precondition),
maxiter_precondition(param.maxiter_precondition),
omega(param.omega),
ca_basis(param.ca_basis),
ca_lambda_min(param.ca_lambda_min),
ca_lambda_max(param.ca_lambda_max),
ca_basis_precondition(param.ca_basis_precondition),
ca_lambda_min_precondition(param.ca_lambda_min_precondition),
ca_lambda_max_precondition(param.ca_lambda_max_precondition),
schwarz_type(param.schwarz_type),
accelerator_type_precondition(param.accelerator_type_precondition),
secs(param.secs),
gflops(param.gflops),
precision_ritz(param.precision_ritz),
n_ev(param.n_ev),
m(param.m),
deflation_grid(param.deflation_grid),
rhs_idx(0),
eigcg_max_restarts(param.eigcg_max_restarts),
max_restart_num(param.max_restart_num),
inc_tol(param.inc_tol),
eigenval_tol(param.eigenval_tol),
verbosity_precondition(param.verbosity_precondition),
is_preconditioner(param.is_preconditioner),
global_reduction(param.global_reduction),
mg_instance(param.mg_instance),
madwf_param(param.madwf_param),
precondition_no_advanced_feature(param.precondition_no_advanced_feature),
extlib_type(param.extlib_type)
{
for (int i=0; i<num_offset; i++) {
offset[i] = param.offset[i];
tol_offset[i] = param.tol_offset[i];
tol_hq_offset[i] = param.tol_hq_offset[i];
}
if((param.inv_type == QUDA_INC_EIGCG_INVERTER || param.inv_type == QUDA_EIGCG_INVERTER) && m % 16){//current hack for the magma library
m = (m / 16) * 16 + 16;
warningQuda("\nSwitched eigenvector search dimension to %d\n", m);
}
if(param.rhs_idx != 0 && (param.inv_type==QUDA_INC_EIGCG_INVERTER || param.inv_type==QUDA_GMRESDR_PROJ_INVERTER)){
rhs_idx = param.rhs_idx;
}
}
~SolverParam() { }
/**
Update the QudaInvertParam with the data from this
@param param the QudaInvertParam to be updated
*/
void updateInvertParam(QudaInvertParam ¶m, int offset=-1) {
param.true_res = true_res;
param.true_res_hq = true_res_hq;
param.iter += iter;
comm_allreduce_sum(gflops);
param.gflops += gflops;
param.secs += secs;
if (offset >= 0) {
param.true_res_offset[offset] = true_res_offset[offset];
param.iter_res_offset[offset] = iter_res_offset[offset];
param.true_res_hq_offset[offset] = true_res_hq_offset[offset];
} else {
for (int i=0; i<num_offset; i++) {
param.true_res_offset[i] = true_res_offset[i];
param.iter_res_offset[i] = iter_res_offset[i];
param.true_res_hq_offset[i] = true_res_hq_offset[i];
}
}
//for incremental eigCG:
param.rhs_idx = rhs_idx;
param.ca_lambda_min = ca_lambda_min;
param.ca_lambda_max = ca_lambda_max;
param.ca_lambda_min_precondition = ca_lambda_min_precondition;
param.ca_lambda_max_precondition = ca_lambda_max_precondition;
if (deflate) *static_cast<QudaEigParam *>(param.eig_param) = eig_param;
}
void updateRhsIndex(QudaInvertParam ¶m) {
//for incremental eigCG:
rhs_idx = param.rhs_idx;
}
};
class Solver {
protected:
const DiracMatrix &mat;
const DiracMatrix &matSloppy;
const DiracMatrix &matPrecon;
const DiracMatrix &matEig;
SolverParam ¶m;
TimeProfile &profile;
int node_parity;
EigenSolver *eig_solve; /** Eigensolver object. */
bool deflate_init; /** If true, the deflation space has been computed. */
bool deflate_compute; /** If true, instruct the solver to create a deflation space. */
bool recompute_evals; /** If true, instruct the solver to recompute evals from an existing deflation space. */
std::vector<ColorSpinorField *> evecs; /** Holds the eigenvectors. */
std::vector<Complex> evals; /** Holds the eigenvalues. */
bool mixed() { return param.precision != param.precision_sloppy; }
public:
Solver(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile);
virtual ~Solver();
virtual void operator()(ColorSpinorField &out, ColorSpinorField &in) = 0;
virtual void blocksolve(ColorSpinorField &out, ColorSpinorField &in);
/**
@return Return the residual vector from the prior solve
*/
virtual ColorSpinorField &get_residual()
{
errorQuda("Not implemented");
static ColorSpinorField dummy;
return dummy;
}
/**
@brief a virtual method that performs the necessary training/preparation at the beginning of a solve.
The default here is a no-op.
@param Solver the solver to be used to collect the null space vectors.
@param ColorSpinorField the vector used to perform the training.
*/
virtual void train_param(Solver &, ColorSpinorField &)
{
// Do nothing
}
/**
@brief a virtual method that performs the inversion and collect some vectors.
The default here is a no-op and should not be called.
*/
virtual void solve_and_collect(ColorSpinorField &, ColorSpinorField &, std::vector<ColorSpinorField *> &, int, double)
{
errorQuda("NOT implemented.");
}
void set_tol(double tol) { param.tol = tol; }
void set_maxiter(int maxiter) { param.maxiter = maxiter; }
const DiracMatrix &M() { return mat; }
const DiracMatrix &Msloppy() { return matSloppy; }
const DiracMatrix &Mprecon() { return matPrecon; }
const DiracMatrix &Meig() { return matEig; }
/**
@return Whether the solver is only for Hermitian systems
*/
virtual bool hermitian() = 0;
/**
@brief Generic solver setup and parameter checking
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);
/**
@brief Solver factory
*/
static Solver *create(SolverParam ¶m, const DiracMatrix &mat, const DiracMatrix &matSloppy,
const DiracMatrix &matPrecon, const DiracMatrix &matEig, TimeProfile &profile);
/**
@brief Set the solver L2 stopping condition
@param[in] Desired solver tolerance
@param[in] b2 L2 norm squared of the source vector
@param[in] residual_type The type of residual we want to solve for
@return L2 stopping condition
*/
static double stopping(double tol, double b2, QudaResidualType residual_type);
/**
@briefTest for solver convergence
@param[in] r2 L2 norm squared of the residual
@param[in] hq2 Heavy quark residual
@param[in] r2_tol Solver L2 tolerance
@param[in] hq_tol Solver heavy-quark tolerance
@return Whether converged
*/
bool convergence(double r2, double hq2, double r2_tol, double hq_tol);
/**
@brief Test for HQ solver convergence -- ignore L2 residual
@param[in] r2 L2 norm squared of the residual
@param[in] hq2 Heavy quark residual
@param[in] r2_tol Solver L2 tolerance
@param[in[ hq_tol Solver heavy-quark tolerance
@return Whether converged
*/
bool convergenceHQ(double r2, double hq2, double r2_tol, double hq_tol);
/**
@brief Test for L2 solver convergence -- ignore HQ residual
@param[in] r2 L2 norm squared of the residual
@param[in] hq2 Heavy quark residual
@param[in] r2_tol Solver L2 tolerance
@param[in] hq_tol Solver heavy-quark tolerance
*/
bool convergenceL2(double r2, double hq2, double r2_tol, double hq_tol);
/**
@brief Prints out the running statistics of the solver
(requires a verbosity of QUDA_VERBOSE)
@param[in] name Name of solver that called this
@param[in] k iteration count
@param[in] r2 L2 norm squared of the residual
@param[in] hq2 Heavy quark residual
*/
void PrintStats(const char *name, int k, double r2, double b2, double hq2);
/**
@brief Prints out the summary of the solver convergence
(requires a verbosity of QUDA_SUMMARIZE). Assumes
SolverParam.true_res and SolverParam.true_res_hq has been set
@param[in] name Name of solver that called this
@param[in] k iteration count
@param[in] r2 L2 norm squared of the residual
@param[in] hq2 Heavy quark residual
@param[in] r2_tol Solver L2 tolerance
@param[in] hq_tol Solver heavy-quark tolerance
*/
void PrintSummary(const char *name, int k, double r2, double b2, double r2_tol, double hq_tol);
/**
@brief Returns the epsilon tolerance for a given precision, by default returns
the solver precision.
@param[in] prec Input precision, default value is solver precision
*/
double precisionEpsilon(QudaPrecision prec = QUDA_INVALID_PRECISION) const;
/**
@brief Constructs the deflation space and eigensolver
@param[in] meta A sample ColorSpinorField with which to instantiate
the eigensolver
@param[in] mat The operator to eigensolve
@param[in] Whether to compute the SVD
*/
void constructDeflationSpace(const ColorSpinorField &meta, const DiracMatrix &mat);
/**
@brief Destroy the allocated deflation space
*/
void destroyDeflationSpace();
/**
@brief Extends the deflation space to twice its size for SVD deflation
*/
void extendSVDDeflationSpace();
/**
@brief Injects a deflation space into the solver from the
vector argument. Note the input space is reduced to zero size as a
result of calling this function, with responsibility for the
space transferred to the solver.
@param[in,out] defl_space the deflation space we wish to
transfer to the solver.
*/
void injectDeflationSpace(std::vector<ColorSpinorField *> &defl_space);
/**
@brief Extracts the deflation space from the solver to the
vector argument. Note the solver deflation space is reduced to
zero size as a result of calling this function, with
responsibility for the space transferred to the argument.
@param[in,out] defl_space the extracted deflation space. On
input, this vector should have zero size.
*/
void extractDeflationSpace(std::vector<ColorSpinorField *> &defl_space);
/**
@brief Returns the size of deflation space
*/
int deflationSpaceSize() const { return (int)evecs.size(); };
/**
@brief Sets the deflation compute boolean
@param[in] flag Set to this boolean value
*/
void setDeflateCompute(bool flag) { deflate_compute = flag; };
/**
@brief Sets the recompute evals boolean
@param[in] flag Set to this boolean value
*/
void setRecomputeEvals(bool flag) { recompute_evals = flag; };
/**
* @brief Return flops
* @return flops expended by this operator
*/
virtual double flops() const { return 0; }
};
/**
@brief Conjugate-Gradient Solver.
*/
class CG : public Solver {
private:
// pointers to fields to avoid multiple creation overhead
ColorSpinorField *yp, *rp, *rnewp, *pp, *App, *tmpp, *tmp2p, *tmp3p, *rSloppyp, *xSloppyp;
bool init;
public:
CG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig,
SolverParam ¶m, TimeProfile &profile);
virtual ~CG();
/**
* @brief Run CG.
* @param out Solution vector.
* @param in Right-hand side.
*/
void operator()(ColorSpinorField &out, ColorSpinorField &in){
(*this)(out, in, nullptr, 0.0);
};
/**
* @brief Solve re-using an initial Krylov space defined by an initial r2_old_init and search direction p_init.
* @details This can be used when continuing a CG, e.g. as refinement step after a multi-shift solve.
* @param out Solution-vector.
* @param in Right-hand side.
* @param p_init Initial-search direction.
* @param r2_old_init [description]
*/
void operator()(ColorSpinorField &out, ColorSpinorField &in, ColorSpinorField *p_init, double r2_old_init);
void blocksolve(ColorSpinorField& out, ColorSpinorField& in);
virtual bool hermitian() { return true; } /** CG is only for Hermitian systems */
};
class CGNE : public CG
{
private:
DiracMMdag mmdag;
DiracMMdag mmdagSloppy;
DiracMMdag mmdagPrecon;
DiracMMdag mmdagEig;
ColorSpinorField xp;
ColorSpinorField yp;
bool init;
/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);
public:
CGNE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig,
SolverParam ¶m, TimeProfile &profile);
void operator()(ColorSpinorField &out, ColorSpinorField &in);
/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual();
virtual bool hermitian() { return false; } /** CGNE is for any system */
};
class CGNR : public CG
{
private:
DiracMdagM mdagm;
DiracMdagM mdagmSloppy;
DiracMdagM mdagmPrecon;
DiracMdagM mdagmEig;
ColorSpinorField br;
bool init;
/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);
public:
CGNR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, const DiracMatrix &matEig,
SolverParam ¶m, TimeProfile &profile);
void operator()(ColorSpinorField &out, ColorSpinorField &in);
/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual();
virtual bool hermitian() { return false; } /** CGNR is for any system */
};
class CG3 : public Solver
{
private:
// pointers to fields to avoid multiple creation overhead
ColorSpinorField *yp, *rp, *tmpp, *ArSp, *rSp, *xSp, *xS_oldp, *tmpSp, *rS_oldp, *tmp2Sp;
bool init;
public:
CG3(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m,
TimeProfile &profile);
virtual ~CG3();
void operator()(ColorSpinorField &out, ColorSpinorField &in);
virtual bool hermitian() { return true; } /** CG is only for Hermitian systems */
};
class CG3NE : public CG3
{
private:
DiracMMdag mmdag;
DiracMMdag mmdagSloppy;
DiracMMdag mmdagPrecon;
ColorSpinorField xp;
ColorSpinorField yp;
bool init;
/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);
public:
CG3NE(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m,
TimeProfile &profile);
void operator()(ColorSpinorField &out, ColorSpinorField &in);
/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual();
virtual bool hermitian() { return false; } /** CG3NE is for any system */
};
class CG3NR : public CG3
{
private:
DiracMdagM mdagm;
DiracMdagM mdagmSloppy;
DiracMdagM mdagmPrecon;
ColorSpinorField br;
bool init;
/**
@brief Initiate the fields needed by the solver
@param[in] x Solution vector
@param[in] b Source vector
*/
void create(ColorSpinorField &x, const ColorSpinorField &b);
public:
CG3NR(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon, SolverParam ¶m,
TimeProfile &profile);
void operator()(ColorSpinorField &out, ColorSpinorField &in);
/**
@return Return the residual vector from the prior solve
*/
ColorSpinorField &get_residual();
virtual bool hermitian() { return false; } /** CG3NR is for any system */
};
class PreconCG : public Solver {
private:
std::shared_ptr<Solver> K;
SolverParam Kparam; // parameters for preconditioner solve
public:
PreconCG(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile);
virtual ~PreconCG();
void operator()(ColorSpinorField &out, ColorSpinorField &in)
{
std::vector<ColorSpinorField *> v_r(0);
this->solve_and_collect(out, in, v_r, 0, 0);
}
/**
@brief a virtual method that performs the inversion and collect the r vectors in PCG.
@param out the output vector
@param in the input vector
@param v_r the series of vectors that is to be collected
@param collect_miniter minimal iteration start from which the r vectors are to be collected
@param collect_tol maxiter tolerance start from which the r vectors are to be collected
*/
virtual void solve_and_collect(ColorSpinorField &out, ColorSpinorField &in, std::vector<ColorSpinorField *> &v_r,
int collect_miniter, double collect_tol);
virtual bool hermitian() { return true; } /** PCG is only Hermitian system */
};
class BiCGstab : public Solver {
private:
// pointers to fields to avoid multiple creation overhead
ColorSpinorField *yp, *rp, *pp, *vp, *tmpp, *tp;
bool init;
public:
BiCGstab(const DiracMatrix &mat, const DiracMatrix &matSloppy, const DiracMatrix &matPrecon,
const DiracMatrix &matEig, SolverParam ¶m, TimeProfile &profile);
virtual ~BiCGstab();
void operator()(ColorSpinorField &out, ColorSpinorField &in);
virtual bool hermitian() { return false; } /** BiCGStab is for any linear system */
};
/**
* @brief Optimized version of the BiCGstabL solver described in
* https://etna.math.kent.edu/vol.1.1993/pp11-32.dir/pp11-32.pdf
*/
class BiCGstabL : public Solver {
private:
const DiracMdagM matMdagM; // used by the eigensolver
/**
The size of the Krylov space that BiCGstabL uses.
*/
int n_krylov; // in the language of BiCGstabL, this is L.
int pipeline; // pipelining factor for legacyGramSchmidt
// Various coefficients and params needed on each iteration.
Complex rho0, rho1, alpha, omega, beta; // Various coefficients for the BiCG part of BiCGstab-L.
std::vector<Complex> gamma, gamma_prime, gamma_prime_prime; // Parameters for MR part of BiCGstab-L. (L+1) length.
std::vector<Complex> tau; // Parameters for MR part of BiCGstab-L. Tech. modified Gram-Schmidt coeffs. (L+1)x(L+1) length.
std::vector<double> sigma; // Parameters for MR part of BiCGstab-L. Tech. the normalization part of Gram-Scmidt. (L+1) length.
// pointers to fields to avoid multiple creation overhead
// full precision fields
std::unique_ptr<ColorSpinorField> r_fullp; //! Full precision residual.
std::unique_ptr<ColorSpinorField> yp; //! Full precision temporary.
// sloppy precision fields
std::unique_ptr<ColorSpinorField> tempp; //! Sloppy temporary vector.
std::vector<ColorSpinorField*> r; // Current residual + intermediate residual values, along the MR.
std::vector<ColorSpinorField*> u; // Search directions.
// Saved, preallocated vectors. (may or may not get used depending on precision.)
ColorSpinorField *x_sloppy_saved_p; //! Sloppy solution vector.
ColorSpinorField *r0_saved_p; //! Shadow residual, in BiCG language.
ColorSpinorField *r_sloppy_saved_p; //! Current residual, in BiCG language.
/**
@brief Internal routine for reliable updates. Made to not conflict with BiCGstab's implementation.
*/
int reliable(double &rNorm, double &maxrx, double &maxrr, const double &r2, const double &delta);
/**
* @brief Internal routine for performing the MR part of BiCGstab-L
*
* @param x_sloppy [out] sloppy accumulator for x
* @param fixed_iteration [in] whether or not this is for a fixed iteration solver
*/
void computeMR(ColorSpinorField &x_sloppy, bool fixed_iteration);
/**
Legacy routines that encapsulate the original pipelined Gram-Schmit.