diff --git a/include/quda_milc_interface.h b/include/quda_milc_interface.h index fe04009437..baf87a252a 100644 --- a/include/quda_milc_interface.h +++ b/include/quda_milc_interface.h @@ -187,6 +187,16 @@ extern "C" { int laplaceDim; /** Dimension of Laplacian **/ } QudaTwoLinkQuarkSmearArgs_t; + /** + Options when loading deflation space + **/ + typedef enum QudaMilcEigLoad_s { + QUDA_MILC_EIG_LOAD, /** Load this parity evecs from MILC **/ + QUDA_MILC_EIG_COMPUTE, /** Compute this parity evecs (or load from file via QUDA) **/ + QUDA_MILC_EIG_FROM_OTHER_PARITY, /** Compute this parity evecs from the other parity **/ + QUDA_MILC_INVALID_EIG = QUDA_INVALID_ENUM + } QudaMilcEigLoad; + /** * Optional: Set the MPI Comm Handle if it is not MPI_COMM_WORLD * @@ -379,6 +389,46 @@ extern "C" { double* const final_fermilab_residual, int* num_iters); + /** + * Project the low modes off of a source of given parity. + * + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] source Source vector(s) + * @param[out] source Solution vector(s) + * @param[in] nvec Number of source/solution vectors + * @param[in] n_evec Number of low modes to project off of the source vectors + * @param[in] parity Parity to use + */ + void qudaProject(int external_precision, void **source, void **solution, int nvec, int n_evec, QudaParity parity); + + /** + * Get pointers to QUDA's deflation space objects. + * + * @param[out] evecs Pointer to eigenvectors + * @param[out] evals Pointer to eigenvalues + * @param[in] parity Parity of the deflation space to return + * @param[in] nvecs The number of eigenvectors + */ + void qudaGetDeflationSpace(void **evecs, double *evals, QudaParity parity, int nvecs); + + /** + * Load the deflation space (eigenvalues and eigenvectors) for a particular parity + * which is set in invargs. + * + * @param[in] external_precision Precision of host fields passed to QUDA (2 - double, 1 - single) + * @param[in] quda_precision Precision for QUDA to use (2 - double, 1 - single) + * @param[in] milc_fatlink Fat-link field on the host + * @param[in] milc_longlink Long-link field on the host + * @param[in] mass Quark mass + * @param[in] invargs Struct containing information for the inverter + * @param[in] eigargs Struct containing information for the eigensolver + * @param[in] evecs Evecs coming from MILC + * @param[in] loadtype Whether to load from MILC, from file, compute, or check + */ + void qudaLoadDeflationSpace(int external_precision, int quda_precision, const void *const milc_fatlink, + const void *const milc_longlink, double mass, QudaInvertArgs_t invargs, + QudaEigensolverArgs_t eigargs, void **evecs, QudaMilcEigLoad loadtype); + /** * Solve Ax=b for an improved staggered operator. All fields are fields * passed and returned are host (CPU) field in MILC order. This diff --git a/lib/milc_interface.cpp b/lib/milc_interface.cpp index c70a8a92fe..1f982b60d3 100644 --- a/lib/milc_interface.cpp +++ b/lib/milc_interface.cpp @@ -44,6 +44,10 @@ static const int num_colors = sizeof(colors)/sizeof(uint32_t); #define POP_RANGE #endif +namespace quda +{ + void setDiracEigParam(DiracParam &, QudaInvertParam *, bool, bool); +} static bool initialized = false; #ifdef MULTI_GPU @@ -62,9 +66,8 @@ static bool invalidate_quda_mg = true; static void *df_preconditioner = nullptr; -static void *preserved_deflation_space = nullptr; - -static bool deflation_init = false; +static void *preserved_deflation_space[2] = {nullptr, nullptr}; +static double preserved_evals_mass[2] = {-1.0, -1.0}; using namespace quda; using namespace quda::fermion_force; @@ -112,13 +115,16 @@ void qudaInit(QudaInitArgs_t input) void qudaCleanUpDeflationSpace() { qudamilc_called(__func__); - if (preserved_deflation_space) { - - deflation_space *space = reinterpret_cast(preserved_deflation_space); - logQuda(QUDA_VERBOSE, "Cleaning up deflation space of size %lu\n", space->evecs.size()); - space->evecs.clear(); - space->evals.clear(); - delete space; + for (int p = 0; p < 2; p++) { + if (preserved_deflation_space[p]) { + deflation_space *space = reinterpret_cast(preserved_deflation_space[p]); + logQuda(QUDA_VERBOSE, "Cleaning up parity %d deflation space of size %lu\n", p, space->evecs.size()); + space->evecs.clear(); + space->evals.clear(); + delete space; + preserved_deflation_space[p] = nullptr; + preserved_evals_mass[p] = -1.0; + } } qudamilc_called(__func__); } @@ -1053,6 +1059,7 @@ static void setColorSpinorParams(const int dim[4], QudaPrecision precision, Colo param->siteOrder = QUDA_EVEN_ODD_SITE_ORDER; param->fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; param->gammaBasis = QUDA_DEGRAND_ROSSI_GAMMA_BASIS; // meaningless, but required by the code. + param->pc_type = QUDA_4D_PC; param->create = QUDA_ZERO_FIELD_CREATE; } @@ -1192,6 +1199,303 @@ void qudaMultishiftInvert(int external_precision, int quda_precision, int num_of qudamilc_called(__func__, verbosity); } // qudaMultiShiftInvert +// Project the low modes off of source vector(s) +void qudaProject(int external_precision, void **source, void **solution, int nvec, int n_evec, QudaParity parity) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + logQuda(QUDA_VERBOSE, "Projecting %d low modes out of %d source vectors for parity %s\n", n_evec, nvec, + parity == QUDA_EVEN_PARITY ? "EVEN" : "ODD"); + QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + + // Multiple sweeps of projection to improve precision + int nsweeps = 2; + + // Check inputs + for (int i = 0; i < nvec; i++) + if (!source[i] || !solution[i]) errorQuda("Source or solution vector %d is null!", i); + + // MILC sends pointers to full parity vectors, but QUDA uses single parity vectors + // so for odd parity, need to use offset + int vec_offset = getColorVectorOffset(parity, false, localDim) * host_precision; + + // Device-side deflation space + if (parity != QUDA_EVEN_PARITY && parity != QUDA_ODD_PARITY) errorQuda("Invalid parity %d", parity); + deflation_space *space = reinterpret_cast(preserved_deflation_space[parity]); + if (!space) errorQuda("Failed to get %s parity deflation space!", parity == QUDA_EVEN_PARITY ? "EVEN" : "ODD"); + + // Wrap host vectors + ColorSpinorParam csParam; + setColorSpinorParams(localDim, host_precision, &csParam); + csParam.location = QUDA_CPU_FIELD_LOCATION; + csParam.create = QUDA_REFERENCE_FIELD_CREATE; + std::vector src_h(nvec); + std::vector sol_h(nvec); + for (int i = 0; i < nvec; i++) { + csParam.v = static_cast(static_cast(source[i]) + vec_offset); + src_h[i] = ColorSpinorField(csParam); + csParam.v = static_cast(static_cast(solution[i]) + vec_offset); + sol_h[i] = ColorSpinorField(csParam); + } + + // Setup device side vectors + ColorSpinorParam gpuParam(space->evecs[0]); + gpuParam.create = QUDA_ZERO_FIELD_CREATE; + std::vector src(nvec); + std::vector tmp(nvec); + for (int i = 0; i < nvec; i++) { + tmp[i] = ColorSpinorField(gpuParam); + src[i] = ColorSpinorField(gpuParam); + src[i] = src_h[i]; // Copy host sources to device sources + } + + // Do nsweeps of projection on device + for (int sweep = 0; sweep < nsweeps; sweep++) { + + for (int i = 0; i < nvec; i++) blas::zero(tmp[i]); + + // 1. Take block inner product: (V_i)^dag * src = s_i + std::vector s(n_evec * src.size()); + blas::block::cDotProduct(s, {space->evecs.begin(), space->evecs.begin() + n_evec}, {src.begin(), src.end()}); + + // 2. Build projected component: Sum_i V_i * s_i = tmp + blas::block::caxpy(s, {space->evecs.begin(), space->evecs.begin() + n_evec}, {tmp.begin(), tmp.end()}); + + // 3. Subtract projection in place: src = src - tmp + for (int i = 0; i < nvec; i++) blas::axpy(-1.0, tmp[i], src[i]); + ; + } + + // Copy solution back to host + for (int i = 0; i < nvec; i++) sol_h[i] = src[i]; + + qudamilc_called(__func__, verbosity); +} // qudaProject + +// Get pointers to QUDA's deflation space objects +// Useful for passing eigenvectors and eigenvalues back to MILC +void qudaGetDeflationSpace(void **evecs, double *evals, QudaParity parity, int Nvecs) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + + // Device-side deflation space + if (parity != QUDA_EVEN_PARITY && parity != QUDA_ODD_PARITY) errorQuda("Invalid parity %d", parity); + deflation_space *space = reinterpret_cast(preserved_deflation_space[parity]); + if (!space) errorQuda("Failed to get %s parity deflation space!", parity == QUDA_EVEN_PARITY ? "EVEN" : "ODD"); + if (static_cast(Nvecs) > space->evecs.size()) + errorQuda("Requested %d eigenvectors, but deflation space has only %lu", Nvecs, space->evecs.size()); + + // Copy eigenvectors if requested + if (evecs) { + // Set up host fields + ColorSpinorParam csParam(space->evecs[0]); + csParam.location = QUDA_CPU_FIELD_LOCATION; + csParam.create = QUDA_REFERENCE_FIELD_CREATE; + csParam.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER; + std::vector host_evecs(Nvecs); + for (int i = 0; i < Nvecs; i++) { + csParam.v = evecs[i]; + host_evecs[i] = ColorSpinorField(csParam); + host_evecs[i] = space->evecs[i]; // Copy to host + } + } + + // Copy eigenvalues if requested + if (evals) + for (int i = 0; i < Nvecs; i++) evals[i] = space->evals[i].real(); + + qudamilc_called(__func__, verbosity); +} // qudaGetDeflationSpace + +// Load single parity deflation space with eigenvectors generated from eigensolve, loaded from file, +// passed from MILC, or generated from other parity eigenvectors +void qudaLoadDeflationSpace(int external_precision, int quda_precision, const void *const fatlink, + const void *const longlink, double mass, QudaInvertArgs_t inv_args, + QudaEigensolverArgs_t eigargs, void **evecs, QudaMilcEigLoad load_type) +{ + static const QudaVerbosity verbosity = getVerbosity(); + qudamilc_called(__func__, verbosity); + + QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; + QudaPrecision device_precision_sloppy; + switch (inv_args.mixed_precision) { + case 2: device_precision_sloppy = QUDA_HALF_PRECISION; break; + case 1: device_precision_sloppy = QUDA_SINGLE_PRECISION; break; + default: device_precision_sloppy = device_precision; + } + QudaParity parity = inv_args.evenodd; + if (parity != QUDA_EVEN_PARITY && parity != QUDA_ODD_PARITY) errorQuda("Invalid parity %d", parity); + QudaParity other_parity = parity == QUDA_EVEN_PARITY ? QUDA_ODD_PARITY : QUDA_EVEN_PARITY; + double epsilon = device_precision == QUDA_DOUBLE_PRECISION ? __DBL_EPSILON__ : __FLT_EPSILON__; + int n_evecs = eigargs.n_conv; + + // Load gauge fields if not done yet + if (invalidate_quda_gauge || !create_quda_gauge) { + QudaGaugeParam fat_param = newQudaGaugeParam(); + QudaGaugeParam long_param = newQudaGaugeParam(); + setGaugeParams(fat_param, long_param, longlink, localDim, host_precision, device_precision, device_precision_sloppy, + inv_args.tadpole, inv_args.naik_epsilon); + loadGaugeQuda(const_cast(fatlink), &fat_param); + if (longlink != nullptr) loadGaugeQuda(const_cast(longlink), &long_param); + invalidate_quda_gauge = false; + } + + // Load deflation space + if (load_type == QUDA_MILC_EIG_COMPUTE) { + // Main deflation space is obtained by calling the deflatable inverter with dummy source + // Incoming inv_args can have inv_args.max_iter=1 + logQuda(QUDA_VERBOSE, "Computing deflation space (or loading from file) for parity %d and mass %e\n", parity, mass); + + double final_residual, final_fermilab_residual; + int num_iters = 0; + + ColorSpinorParam csParam; + setColorSpinorParams(localDim, host_precision, &csParam); + csParam.location = QUDA_CPU_FIELD_LOCATION; + csParam.siteSubset = QUDA_FULL_SITE_SUBSET; // qudaInvertDeflatable expects full-parity vectors + csParam.x[0] *= 2; + ColorSpinorField source(csParam); + ColorSpinorField solution(csParam); + source.Source(QUDA_POINT_SOURCE, inv_args.evenodd, 0, 0); // using dummy point source + + qudaInvertDeflatable(external_precision, quda_precision, mass, inv_args, eigargs, 1.0, 0.0, fatlink, longlink, + static_cast(source.data()), static_cast(solution.data()), &final_residual, + &final_fermilab_residual, &num_iters); + + } else if (load_type == QUDA_MILC_EIG_FROM_OTHER_PARITY) { + logQuda(QUDA_VERBOSE, "Computing deflation space for parity %d from parity %d\n", parity, other_parity); + double other_parity_mass = preserved_evals_mass[other_parity]; + + // Get preserved other parity deflation space + deflation_space *other_parity_space = reinterpret_cast(preserved_deflation_space[other_parity]); + if (!other_parity_space) + errorQuda("Failed to get %s parity deflation space!", parity == QUDA_EVEN_PARITY ? "ODD" : "EVEN"); + if (other_parity_space->evecs.size() < static_cast(n_evecs)) + errorQuda("Other parity deflation space too small!"); + + // Setup new deflation space + ColorSpinorParam gpuParam(other_parity_space->evecs[0]); + deflation_space *space = new deflation_space; + space->svd = false; + resize(space->evecs, n_evecs, gpuParam); + space->evals.resize(n_evecs, 0.0); + + // Create Dirac operator + QudaInvertParam invertParam = newQudaInvertParam(); + setInvertParams(host_precision, device_precision, device_precision_sloppy, mass, 1.0, 0.0, inv_args.max_iter, 1e-1, + parity, verbosity, QUDA_CG_INVERTER, &invertParam); + invertParam.cuda_prec_eigensolver = eigargs.prec_eigensolver; + DiracParam diracEigParam; + setDiracEigParam(diracEigParam, &invertParam, true, false); + Dirac *dEig = Dirac::create(diracEigParam); + + // Temp vector on GPU + gpuParam.create = QUDA_ZERO_FIELD_CREATE; + ColorSpinorField temp(gpuParam); + + Complex n_unit(-1.0, 0.0); + + for (int i = 0; i < n_evecs; i++) { + + // Compute other parity eigenvector v_o = i*D_oe*v_e/\lambda_e + dEig->Dslash(temp, other_parity_space->evecs[i], parity); + auto norm2 = blas::norm2(temp); + blas::ax(1.0 / sqrt(norm2), temp); + space->evecs[i] = temp; + + // Compute eigenvalues, lambda_i = v_i^dag A v_i / (v_i^dag * v_i) + dEig->M(temp, space->evecs[i]); + auto eval = other_parity_space->evals[i]; // re-use eigenvalues by default + if (fabs(mass - other_parity_mass) > epsilon) { // recompute eigenvalues if mass doesn't match + auto vtAv = blas::cDotProduct(space->evecs[i], temp); + auto v2 = blas::norm2(space->evecs[i]); + eval = vtAv / sqrt(v2); + } + space->evals[i] = eval; + + // res^2 = |\lambda*v - A*v| + auto res = blas::caxpbyNorm(eval, space->evecs[i], n_unit, temp); + logQuda(QUDA_VERBOSE, "Eval[%04d] = (%+.16e,%+.16e), Residual = %+.16e\n", i, eval.real(), eval.imag(), + sqrt(res[0])); + } + delete dEig; + + // Preserve deflation space + preserved_deflation_space[parity] = space; + preserved_evals_mass[parity] = mass; + + } else if (load_type == QUDA_MILC_EIG_LOAD) { + + logQuda(QUDA_VERBOSE, "Loading deflation space of size %d for parity %d and mass %e\n", n_evecs, parity, mass); + + if (!evecs) errorQuda("qudaLoadDeflationSpace called with load_type QUDA_MILC_EIG_LOAD but evecs is null!"); + + QudaInvertParam invertParam = newQudaInvertParam(); + setInvertParams(host_precision, device_precision, device_precision_sloppy, mass, 1.0, 0.0, inv_args.max_iter, 1e-1, + parity, verbosity, QUDA_CG_INVERTER, &invertParam); + invertParam.cuda_prec_eigensolver = eigargs.prec_eigensolver; + ColorSpinorParam csParam; + setColorSpinorParams(localDim, host_precision, &csParam); + ColorSpinorParam gpuParam(csParam, invertParam, QUDA_CUDA_FIELD_LOCATION); + + // Setup deflation space + deflation_space *space = new deflation_space; + space->svd = false; + resize(space->evecs, n_evecs, gpuParam); + space->evals.resize(n_evecs, 0.0); + + // Create Dirac operator + DiracParam diracEigParam; + setDiracEigParam(diracEigParam, &invertParam, true, false); + Dirac *dEig = Dirac::create(diracEigParam); + + // Temp vector on GPU + gpuParam.create = QUDA_ZERO_FIELD_CREATE; + ColorSpinorField temp(gpuParam); + + Complex n_unit(-1.0, 0.0); + + // MILC sends pointer to full parity evecs, but QUDA uses single parity vectors + // so for odd parity, need to use offset + int evec_offset = getColorVectorOffset(parity, false, localDim) * host_precision; + + const lat_dim_t dims = {localDim[0], localDim[1], localDim[2], localDim[3]}; + for (int i = 0; i < n_evecs; i++) { + + // Copy each evec to host-side spinor and then to device-side deflation space + void *evec_ptr = static_cast(evecs[i]) + evec_offset; + ColorSpinorParam cpuParam(evec_ptr, invertParam, dims, true, QUDA_CPU_FIELD_LOCATION); + ColorSpinorField in_evec(cpuParam); + space->evecs[i] = in_evec; + + // Compute eigenvalue, lambda_i = v_i^dag A v_i / (v_i^dag * v_i) + dEig->M(temp, space->evecs[i]); + auto vtAv = blas::cDotProduct(space->evecs[i], temp); + auto v2 = blas::norm2(space->evecs[i]); + auto eval = vtAv / sqrt(v2); + space->evals[i] = eval; + + // Compute residual, res^2 = |\lambda*v - A*v| + auto res = blas::caxpbyNorm(eval, space->evecs[i], n_unit, temp); + logQuda(QUDA_SUMMARIZE, "Eval[%04d] = (%+.16e,%+.16e), Residual = %+.16e\n", i, eval.real(), eval.imag(), + sqrt(res[0])); + } + + delete dEig; + + // Preserve deflation space + preserved_deflation_space[parity] = space; + preserved_evals_mass[parity] = mass; + + } else { + errorQuda("Unrecognized load_type"); + } + + qudamilc_called(__func__, verbosity); +} // qudaLoadDeflationSpace + // Wrapper function for qudaInvertDeflatable to maintain backward compatibility with old(er) MILC void qudaInvert(int external_precision, int quda_precision, double mass, QudaInvertArgs_t inv_args, double target_residual, double target_fermilab_residual, const void *const fatlink, @@ -1242,12 +1546,10 @@ void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv QudaPrecision host_precision = (external_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; QudaPrecision device_precision = (quda_precision == 2) ? QUDA_DOUBLE_PRECISION : QUDA_SINGLE_PRECISION; QudaPrecision device_precision_sloppy = device_precision; - QudaGaugeParam fat_param = newQudaGaugeParam(); QudaGaugeParam long_param = newQudaGaugeParam(); setGaugeParams(fat_param, long_param, longlink, localDim, host_precision, device_precision, device_precision_sloppy, inv_args.tadpole, inv_args.naik_epsilon); - QudaInvertParam invertParam = newQudaInvertParam(); QudaParity local_parity = inv_args.evenodd; @@ -1255,10 +1557,8 @@ void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv setInvertParams(host_precision, device_precision, device_precision_sloppy, 0.0, 0, 0, 0, 0.0, local_parity, verbosity, QUDA_CG_INVERTER, &invertParam); - ColorSpinorParam csParam; setColorSpinorParams(localDim, host_precision, &csParam); - // dirty hack to invalidate the cached gauge field without breaking interface compatability if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda(); @@ -1269,13 +1569,10 @@ void qudaDslash(int external_precision, int quda_precision, QudaInvertArgs_t inv } if (longlink == nullptr) invertParam.dslash_type = QUDA_STAGGERED_DSLASH; - int src_offset = getColorVectorOffset(other_parity, false, localDim); int dst_offset = getColorVectorOffset(local_parity, false, localDim); - dslashQuda(static_cast(dst) + dst_offset * host_precision, static_cast(src) + src_offset * host_precision, &invertParam, local_parity); - if (!create_quda_gauge) invalidateGaugeQuda(); qudamilc_called(__func__, verbosity); @@ -1451,6 +1748,8 @@ void qudaInvertMsrcDeflatable(int external_precision, int quda_precision, double inv_args.tadpole, inv_args.naik_epsilon); QudaParity local_parity = inv_args.evenodd; + QudaParity other_parity = local_parity == QUDA_EVEN_PARITY ? QUDA_ODD_PARITY : QUDA_EVEN_PARITY; + if (local_parity != QUDA_EVEN_PARITY && local_parity != QUDA_ODD_PARITY) errorQuda("Invalid parity %d", local_parity); const double reliable_delta = 1e-1; QudaInvertParam invertParam = newQudaInvertParam(); @@ -1459,30 +1758,34 @@ void qudaInvertMsrcDeflatable(int external_precision, int quda_precision, double QUDA_CG_INVERTER, &invertParam); invertParam.num_src = num_src; - // Deflation for even parity solves + // Deflation parameters invertParam.eig_param = (qep.n_ev_deflate > 0) ? &qep : nullptr; - if (qep.n_ev_deflate > 0 && local_parity != QUDA_EVEN_PARITY) - errorQuda("MILC interface deflation currently only supports even parity solves."); - - if (eig_args.vec_in_parity != QUDA_EVEN_PARITY) - errorQuda("MILC interface deflation currently only supports even parity eigenvectors."); - invertParam.tol_restart = eig_args.tol_restart; - - // Eigensolver precision invertParam.cuda_prec_eigensolver = eig_args.prec_eigensolver; - // Preserve deflation space - if (invertParam.eig_param && qep.preserve_deflation) { - if (deflation_init) { - if (!preserved_deflation_space) errorQuda("Unexpected nullptr for preserved deflation space"); - qep.preserve_deflation_space = preserved_deflation_space; + // Deflation space + if (invertParam.eig_param && qep.preserve_deflation) { // if want deflation and use preserved space + qep.preserve_deflation_space = preserved_deflation_space[local_parity]; + if (!qep.preserve_deflation_space) { // if does not exist yet + // Check if other parity space exists + // If so, construct this parity deflation space from other parity deflation space + // Else, this is skipped and the deflation space is constructed via eigensolve during the call to the inverter + if (preserved_deflation_space[other_parity]) { + qudaLoadDeflationSpace(external_precision, quda_precision, fatlink, longlink, mass, inv_args, eig_args, nullptr, + QUDA_MILC_EIG_FROM_OTHER_PARITY); + // This parity deflation space should now exist + qep.preserve_deflation_space = preserved_deflation_space[local_parity]; + if (!qep.preserve_deflation_space) errorQuda("Failed to load deflation space!"); + } + } + // Check that preserved eigenvalues are for this mass + double epsilon = device_precision == QUDA_DOUBLE_PRECISION ? __DBL_EPSILON__ : __FLT_EPSILON__; + if (fabs(mass - preserved_evals_mass[local_parity]) > epsilon) { + logQuda(QUDA_VERBOSE, "Resetting eigenvalues to mass %e\n", invertParam.mass); + qep.preserve_evals = QUDA_BOOLEAN_FALSE; } } - ColorSpinorParam csParam; - setColorSpinorParams(localDim, host_precision, &csParam); - // dirty hack to invalidate the cached gauge field without breaking interface compatability if (*num_iters == -1 || !canReuseResidentGauge(&invertParam)) invalidateGaugeQuda(); @@ -1506,9 +1809,10 @@ void qudaInvertMsrcDeflatable(int external_precision, int quda_precision, double host_free(sln_pointer); host_free(src_pointer); + // Preserve deflation space if (invertParam.eig_param && qep.preserve_deflation) { - preserved_deflation_space = qep.preserve_deflation_space; - deflation_init = true; // signal that we have deflation space preserved + preserved_deflation_space[local_parity] = qep.preserve_deflation_space; + preserved_evals_mass[local_parity] = mass; } // The conventions for num_iters, final_residual, and final_fermilab_residual are taken from the