diff --git a/CMakeLists.txt b/CMakeLists.txt index 7e93a258de..cb1a3ef606 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,7 +239,7 @@ option(QUDA_CTEST_SEP_DSLASH_POLICIES "Test Dslash policies separately in ctest option(QUDA_CTEST_DISABLE_BENCHMARKS "Disable benchmark test" ON) option(QUDA_FAST_COMPILE_REDUCE "enable fast compilation in blas and reduction kernels (single warp per reduction)" OFF) -option(QUDA_FAST_COMPILE_DSLASH "enable fast compilation in dslash kernels (~20% perf impact)" OFF) +option(QUDA_FAST_COMPILE_DSLASH "enable fast compilation in coarse grid dslash kernels (significant perf impact)" OFF) option(QUDA_OPENMP "enable OpenMP" OFF) set(QUDA_CXX_STANDARD diff --git a/ci/docker/Dockerfile.build b/ci/docker/Dockerfile.build index ed21db930c..f8d6aa7ab4 100644 --- a/ci/docker/Dockerfile.build +++ b/ci/docker/Dockerfile.build @@ -8,7 +8,9 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \ build-essential \ cmake \ wget \ - ninja-build && \ + ninja-build \ + git \ + ca-certificates && \ rm -rf /var/lib/apt/lists/* ARG MPICH_VERSION=3.3.2 diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 2c46c23ea9..1d63a37900 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -241,9 +241,9 @@ namespace quda constexpr int M = nSpinBlock * nColor * nVec; #pragma unroll for (int i = 0; i < M; i++) { - vec_t tmp - = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); - memcpy(&out[i], &tmp, sizeof(vec_t)); + auto tmp + = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + memcpy(&out[i], &tmp, sizeof(tmp)); } } }; @@ -1010,11 +1010,14 @@ namespace quda { for (int dim = 0; dim < 4; dim++) { for (int dir = 0; dir < 2; dir++) { - ghost[2 * dim + dir] = comm_dim_partitioned(dim) ? static_cast(ghost_[2 * dim + dir]) : nullptr; - ghost_norm[2 * dim + dir] = !comm_dim_partitioned(dim) ? - nullptr : - reinterpret_cast(static_cast(ghost_[2 * dim + dir]) - + nParity * length_ghost * faceVolumeCB[dim] * sizeof(Float)); + if (comm_dim_partitioned(dim) && ghost_[2 * dim + dir]) { + ghost[2 * dim + dir] = static_cast(ghost_[2 * dim + dir]); + ghost_norm[2 * dim + dir] = reinterpret_cast( + static_cast(ghost_[2 * dim + dir]) + nParity * length_ghost * faceVolumeCB[dim] * sizeof(Float)); + } else { + ghost[2 * dim + dir] = nullptr; + ghost_norm[2 * dim + dir] = nullptr; + } } } } @@ -1023,7 +1026,7 @@ namespace quda { real v[length_ghost]; norm_type nrm - = isFixed::value ? vector_load(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x) : 0.0; + = isFixed::value ? vector_load(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x)[0] : 0.0; #pragma unroll for (int i = 0; i < M; i++) { @@ -1123,16 +1126,9 @@ namespace quda using real = typename mapper::type; using complex = complex; using AllocInt = typename AllocType::type; - using norm_type = float; + using norm_t = float; Float *field = nullptr; - //#define LEGACY_ACCESSOR_NORM // legacy code where norm pointer and offset are stored instead of computed -#ifdef LEGACY_ACCESSOR_NORM - norm_type *norm = nullptr; -#endif AllocInt offset = 0; // offset can be 32-bit or 64-bit -#ifdef LEGACY_ACCESSOR_NORM - AllocInt norm_offset = 0; -#endif int volumeCB = 0; FloatNOrder() = default; @@ -1141,14 +1137,7 @@ namespace quda FloatNOrder(const ColorSpinorField &a, int nFace = 1, Float *buffer = 0, Float **ghost_ = 0) : GhostNOrder(a, nFace, ghost_), field(buffer ? buffer : a.data()), -#ifdef LEGACY_ACCESSOR_NORM - norm(buffer ? reinterpret_cast(reinterpret_cast(buffer) + a.NormOffset()) : - const_cast(reinterpret_cast(a.Norm()))), -#endif offset(a.Bytes() / (2 * sizeof(Float))), -#ifdef LEGACY_ACCESSOR_NORM - norm_offset(a.Bytes() / (2 * sizeof(norm_type))), -#endif volumeCB(a.VolumeCB()) { } @@ -1157,23 +1146,19 @@ namespace quda __device__ __host__ inline void load(complex out[length / 2], int x, int parity = 0) const { real v[length]; -#ifndef LEGACY_ACCESSOR_NORM - auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1); - auto norm = reinterpret_cast(field + volumeCB * (2 * Nc * Ns)); -#endif - norm_type nrm = isFixed::value ? vector_load(norm, x + parity * norm_offset) : 0.0; - + auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t); + norm_t nrm = isFixed::value ? vector_load(field, x + norm_offset)[0] : 0.0; #pragma unroll for (int i = 0; i < M; i++) { // first load from memory - auto vecTmp = vector_load(field + parity * offset, volumeCB * i + x); + auto vecTmp = vector_load(field, parity * offset, volumeCB * i + x); // now copy into output and scale copy_and_scale(v + i * N, vecTmp, nrm); } // now load any remainder if constexpr (Nrem > 0) { - auto vecTmp = vector_load(field + parity * offset + volumeCB * M * N, x); + auto vecTmp = vector_load(field, parity * offset + volumeCB * M * N, x); copy_and_scale(v + M * N, vecTmp, nrm); } @@ -1181,30 +1166,39 @@ namespace quda for (int i = 0; i < length / 2; i++) out[i] = complex(v[2 * i + 0], v[2 * i + 1]); } + __device__ __host__ inline void prefetch(int x, int parity = 0) const + { + auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t); + if constexpr (isFixed::value) prefetch_cache_line(reinterpret_cast(field) + (x + norm_offset)); + +#pragma unroll + for (int i = 0; i < M; i++) prefetch_cache_line(field + (parity * offset + (volumeCB * i + x) * N)); + + // now load any remainder + if constexpr (Nrem > 0) prefetch_cache_line(field + (parity * offset + volumeCB * M * N + x * Nrem)); + } + __device__ __host__ inline void save(const complex in[length / 2], int x, int parity = 0) const { real v[length]; -#ifndef LEGACY_ACCESSOR_NORM - auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1); - auto norm = reinterpret_cast(field + volumeCB * (2 * Nc * Ns)); -#endif + auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t); + #pragma unroll for (int i = 0; i < length / 2; i++) { v[2 * i + 0] = in[i].real(); v[2 * i + 1] = in[i].imag(); } - norm_type scale = 0.0; - norm_type scale_inv = 0.0; + norm_t scale = 0.0; + norm_t scale_inv = 0.0; if constexpr (isFixed::value) { - norm_type max_[length / 2]; + norm_t max_[length / 2]; // two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued) #pragma unroll - for (int i = 0; i < length / 2; i++) - max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2])); + for (int i = 0; i < length / 2; i++) max_[i] = fmaxf(fabsf((norm_t)v[i]), fabsf((norm_t)v[i + length / 2])); #pragma unroll for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale); - norm[x + parity * norm_offset] = scale * fixedInvMaxValue::value; + reinterpret_cast(field)[x + norm_offset] = scale * fixedInvMaxValue::value; scale_inv = fdividef(fixedMaxValue::value, scale); } @@ -1214,14 +1208,14 @@ namespace quda // first do scalar copy converting into storage type copy_and_scale(vecTmp, v + i * N, scale_inv); // second do vectorized copy into memory - vector_store(field + parity * offset, volumeCB * i + x, vecTmp); + vector_store(field, parity * offset, volumeCB * i + x, vecTmp); } if constexpr (Nrem > 0) { array vecTmp; copy_and_scale(vecTmp, v + M * N, scale_inv); // second do vectorized copy into memory - vector_store(field + parity * offset + volumeCB * M * N, x, vecTmp); + vector_store(field, parity * offset + volumeCB * M * N, x, vecTmp); } } diff --git a/include/complex_quda.h b/include/complex_quda.h index 51a4fed2ca..c9ab6557d4 100644 --- a/include/complex_quda.h +++ b/include/complex_quda.h @@ -928,14 +928,14 @@ namespace quda template __host__ __device__ inline complex cmul(const complex &x, const complex &y) { complex rtn = mul2({x.real(), x.real()}, y); - return fma2({x.imag(), x.imag()}, {-y.imag(), y.real()}, rtn); + return fma2({-x.imag(), x.imag()}, {y.imag(), y.real()}, rtn); } template __host__ __device__ inline complex cmac(const complex &x, const complex &y, const complex &z) { complex w = fma2({x.real(), x.real()}, y, z); - return fma2({x.imag(), x.imag()}, {-y.imag(), y.real()}, w); + return fma2({-x.imag(), x.imag()}, {y.imag(), y.real()}, w); } template diff --git a/include/domain_decomposition.h b/include/domain_decomposition.h index 24e653ac37..8ada3ae905 100644 --- a/include/domain_decomposition.h +++ b/include/domain_decomposition.h @@ -39,8 +39,7 @@ namespace quda flags[(int)flag] = true; if ((int)flag == (int)DD::reset) { -#pragma unroll - for (auto i = 0u; i < (int)DD::size; i++) flags[i] = 0; + flags = {}; type = QUDA_DD_NO; } else if ((int)flag >= (int)DD::red_black_type) { type = QUDA_DD_RED_BLACK; diff --git a/include/dslash.h b/include/dslash.h index 8feb23d893..372790f420 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace quda { @@ -70,6 +71,18 @@ namespace quda char tile_str[16]; i32toa(tile_str, Arg::n_src_tile); strcat(aux_base, tile_str); + if constexpr (dslash_double_store()) strcat(aux_base, ",double_store"); + if constexpr (Arg::prefetch_distance > 0) { + strcat(aux_base, ",prefetch="); + i32toa(tile_str, Arg::prefetch_distance); + strcat(aux_base, tile_str); + if constexpr (dslash_prefetch_type() == PrefetchType::THREAD) + strcat(aux_base, ",prefetch=thread"); + else if constexpr (dslash_prefetch_type() == PrefetchType::BULK) + strcat(aux_base, ",prefetch=bulk"); + else if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR) + strcat(aux_base, ",prefetch=tensor"); + } } /** @@ -130,7 +143,7 @@ namespace quda } } - inline void setParam(TuneParam &tp) + template inline void setParam(TuneParam &tp, const GaugeField &U, const GaugeField &L = {}) { // Need to reset ghost pointers prior to every call since the // ghost buffer may have been changed during policy tuning. @@ -173,6 +186,16 @@ namespace quda 0; tp.grid.x += arg.exterior_blocks; } + + if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR && Arg::prefetch_distance > 0) { + Dslash::arg.U.tensor_desc = get_tensor_descriptor(U, tp.block.x); + Dslash::arg.Uback.tensor_desc = get_tensor_descriptor(U.shift(), tp.block.x); + if constexpr (improved) { + assert(!U.empty()); + Dslash::arg.L.tensor_desc = get_tensor_descriptor(L, tp.block.x); + Dslash::arg.Lback.tensor_desc = get_tensor_descriptor(L.shift(), tp.block.x); + } + } } virtual int blockStep() const override { return (arg.shmem & 64) ? 8 : 16; } @@ -219,6 +242,15 @@ namespace quda } } + virtual bool advanceBlockDim(TuneParam ¶m) const override + { + // if TMA is enabled we must keep parity separate in the block (2-d tuning) + if constexpr (dslash_prefetch_tma()) + return TunableKernel2D_base::advanceBlockDim(param); + else + return TunableKernel3D::advanceBlockDim(param); + } + virtual bool advanceTuneParam(TuneParam ¶m) const override { return advanceAux(param) || advanceSharedBytes(param) || advanceBlockDim(param) || advanceSharedCarveOut(param) @@ -268,6 +300,7 @@ namespace quda inline void launch(TuneParam &tp, const qudaStream_t &stream) { tp.set_max_shared_bytes = true; + if (dslash_prefetch_tma() && tp.block.z > 1) errorQuda("Z-dimension block size must be 1 when using TMA"); launch_device( tp, stream, dslash_functor_arg(arg, tp.block.x * tp.grid.x)); } diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index da002550ff..02d2fe2f6c 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -13,6 +13,7 @@ #include #include #include +#include constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE; @@ -20,13 +21,48 @@ constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE; namespace quda { + +#ifdef QUDA_DSLASH_DOUBLE_STORE + constexpr bool dslash_double_store() { return true; } +#else + constexpr bool dslash_double_store() { return false; } +#endif + + constexpr PrefetchType dslash_prefetch_type() + { +#if defined(QUDA_DSLASH_PREFETCH_TYPE_NONE) + return PrefetchType::NONE; +#elif defined(QUDA_DSLASH_PREFETCH_TYPE_THREAD) + return PrefetchType::THREAD; +#elif defined(QUDA_DSLASH_PREFETCH_TYPE_BULK) + return PrefetchType::BULK; +#elif defined(QUDA_DSLASH_PREFETCH_TYPE_TENSOR) + return PrefetchType::TENSOR; +#else +#error "Invalid or missing QUDA_DSLASH_PREFETCH_TYPE" +#endif + return PrefetchType::NONE; + } + +#if defined(NVSHMEM_COMMS) && (defined(QUDA_DSLASH_PREFETCH_TYPE_BULK) || defined(QUDA_DSLASH_PREFETCH_TYPE_TENSOR)) +#error NVSHMEM cannot be used in combination with TMA prefetching at present +#endif + + constexpr bool dslash_prefetch_tma() + { + return (dslash_prefetch_type() == PrefetchType::BULK || dslash_prefetch_type() == PrefetchType::TENSOR); + } + + static_assert(!dslash_prefetch_tma() || dslash_double_store(), + "Cannot use TMA prefetching unless QUDA_DSLASH_DOUBLE_STORE is enabled"); + /** @brief Helper function to determine if we should do halo computation @param[in] dim Dimension we are working on. If dim=-1 (default argument) then we return true if type is any halo kernel. */ - template __host__ __device__ __forceinline__ bool doHalo(int dim = -1) + template __host__ __device__ __forceinline__ constexpr bool doHalo(int dim = -1) { switch (type) { case EXTERIOR_KERNEL_ALL: return true; @@ -44,7 +80,7 @@ namespace quda computation @param[in] dim Dimension we are working on */ - template __host__ __device__ __forceinline__ bool doBulk() + template __host__ __device__ __forceinline__ constexpr bool doBulk() { switch (type) { case EXTERIOR_KERNEL_ALL: @@ -109,6 +145,7 @@ namespace quda if (kernel_type == INTERIOR_KERNEL) { coord.x_cb = idx; + coord.x_cb_0 = (target::block_idx().x - arg.pack_blocks) * target::block_dim().x; if (nDim == 5) coord.X = getCoords5CB(coord, idx, arg.dc.X, arg.X0h, parity, pc_type); else @@ -158,13 +195,68 @@ namespace quda #pragma unroll for (int d = 0; d < nDim; d++) { - coord.in_boundary[1][d] = coord[d] + arg.nFace >= arg.dc.X[d]; - coord.in_boundary[0][d] = coord[d] - arg.nFace < 0; + coord.in_boundary[1][d] = -(coord[d] + arg.nFace >= arg.dc.X[d]); + coord.in_boundary[0][d] = -(coord[d] - arg.nFace < 0); } return coord; } + /** + @brief Compute the checkerboard 1-d index for the nearest + neighbor + @param[in] lattice coordinates + @param[in] mu dimension in which to add 1 + @param[in] dir direction (+1 or -1) + @param[in] arg parameter struct + @return 1-d checkboard index + */ + template + __device__ __host__ inline int getNeighborIndexCB(const Coord &x, int mu, int dir, const Arg &arg) + { + switch (nFace) { + case 1: + switch (dir) { + case +1: // positive direction + switch (mu) { + case 0: return (x.X + 1 - (x.in_boundary[1][0] & arg.X[0])) >> 1; + case 1: return (x.X + arg.X[0] - (x.in_boundary[1][1] & arg.X2X1)) >> 1; + case 2: return (x.X + arg.X2X1 - (x.in_boundary[1][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X + arg.X3X2X1 - (x.in_boundary[1][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X + arg.X4X3X2X1 - (x.in_boundary[1][4] & arg.X5X4X3X2X1)) >> 1; + } + case -1: + switch (mu) { + case 0: return (x.X - 1 + (x.in_boundary[0][0] & arg.X[0])) >> 1; + case 1: return (x.X - arg.X[0] + (x.in_boundary[0][1] & arg.X2X1)) >> 1; + case 2: return (x.X - arg.X2X1 + (x.in_boundary[0][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X - arg.X3X2X1 + (x.in_boundary[0][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X - arg.X4X3X2X1 + (x.in_boundary[0][4] & arg.X5X4X3X2X1)) >> 1; + } + } + case 3: + switch (dir) { + case +1: // positive direction + switch (mu) { + case 0: return (x.X + 3 - (x.in_boundary[1][0] & arg.X[0])) >> 1; + case 1: return (x.X + 3 * arg.X[0] - (x.in_boundary[1][1] & arg.X2X1)) >> 1; + case 2: return (x.X + 3 * arg.X2X1 - (x.in_boundary[1][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X + 3 * arg.X3X2X1 - (x.in_boundary[1][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X + 3 * arg.X4X3X2X1 - (x.in_boundary[1][4] & arg.X5X4X3X2X1)) >> 1; + } + case -1: + switch (mu) { + case 0: return (x.X - 3 + (x.in_boundary[0][0] & arg.X[0])) >> 1; + case 1: return (x.X - 3 * arg.X[0] + (x.in_boundary[0][1] & arg.X2X1)) >> 1; + case 2: return (x.X - 3 * arg.X2X1 + (x.in_boundary[0][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X - 3 * arg.X3X2X1 + (x.in_boundary[0][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X - 3 * arg.X4X3X2X1 + (x.in_boundary[0][4] & arg.X5X4X3X2X1)) >> 1; + } + } + } + return 0; // should never reach here + } + /** @brief Compute whether this thread should be active for updating the a given offsetDim halo. For non-fused halo update kernels @@ -243,7 +335,8 @@ namespace quda static constexpr int n_src_tile = n_src_tile_; // how many RHS per thread static constexpr int max_regs = 0; // by default we don't limit register count static constexpr bool spill_shared = false; // whether a given kernel should use shared memory spilling - + static constexpr int prefetch_distance = 0; // whether we are using prefetching in the dslash + static constexpr PrefetchType prefetch_type = dslash_prefetch_type(); const int parity; // only use this for single parity fields const int nParity; // number of parities we're working on const QudaReconstructType reconstruct; @@ -285,6 +378,7 @@ namespace quda int pack_blocks = 0; // total number of blocks used for packing in the dslash int exterior_dims = 0; // dimension to run in the exterior Dslash int exterior_blocks = 0; + int block_size = 0; DDArg dd_out; DDArg dd_in; @@ -655,6 +749,7 @@ namespace quda static constexpr KernelType kernel_type = kernel_type_; static constexpr int max_regs = Arg::max_regs; static constexpr bool spill_shared = Arg::spill_shared; + static constexpr bool is_dslash = true; Arg arg; dslash_functor_arg(const Arg &arg, unsigned int threads_x) : @@ -685,6 +780,14 @@ namespace quda __forceinline__ __device__ void operator()(int, int s, int parity, bool alive = true) { typename Arg::D dslash(*this); + + if constexpr (dslash_prefetch_tma()) { + // FIXME need warp uniform parity which is not composable with + // NVSHMEM since the latter requires blockDim.y and blockDim.z to + // cover the entire extent + parity = target::block_idx().z; // ensure parity is warp uniform + } + // for full fields set parity from z thread index else use arg setting if (arg.nParity == 1) parity = arg.parity; diff --git a/include/dslash_quda.h b/include/dslash_quda.h index f34a41de1a..4017baa69f 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -19,7 +19,7 @@ namespace quda int_fastdiv X[QUDA_MAX_DIM]; int Ls; - int volume_4d; + int_fastdiv volume_4d; int_fastdiv volume_4d_cb; int_fastdiv face_X[4]; @@ -35,11 +35,7 @@ namespace quda int X2X1; int X3X2X1; int X4X3X2X1; - - int X2X1mX1; - int X3X2X1mX2X1; - int X4X3X2X1mX3X2X1; - int X5X4X3X2X1mX4X3X2X1; + int X5X4X3X2X1; }; /** diff --git a/include/externals/CLI11.hpp b/include/externals/CLI11.hpp index a426c5bae4..9174a58890 100644 --- a/include/externals/CLI11.hpp +++ b/include/externals/CLI11.hpp @@ -63,6 +63,7 @@ #include #include #include +#include // Verbatim copy from CLI/Version.hpp: @@ -2485,7 +2486,7 @@ class AsNumberWithUnit : public Validator { /// "2 EiB" => 2^61 // Units up to exibyte are supported class AsSizeValue : public AsNumberWithUnit { public: - using result_t = uint64_t; + using result_t = std::uint64_t; /// If kb_is_1000 is true, /// interpret 'kb', 'k' as 1000 and 'kib', 'ki' as 1024 diff --git a/include/gauge_field.h b/include/gauge_field.h index c355bd4818..9332b5c1e8 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -1,9 +1,9 @@ #pragma once +#include #include #include #include - #include namespace quda { @@ -147,6 +147,7 @@ namespace quda { class GaugeField : public LatticeField { friend std::ostream &operator<<(std::ostream &output, const GaugeField ¶m); + friend GaugeField shift(const GaugeField &in, int shift); private: /** @@ -193,6 +194,10 @@ namespace quda { double tadpole = 0.0; double fat_link_max = 0.0; + mutable std::unique_ptr shifted + = nullptr; // shifted copy of the gauge field, used for double-store enabled dslash + bool is_shifted = false; // whether this instance is a shifted one + mutable array ghost = {}; // stores the ghost zone of the gauge field (non-native fields only) @@ -647,6 +652,20 @@ namespace quda { } } + /** + @brief Return the shifted gauge field by shift in each + dimension. Shifted field is cached for subsequent reuse. + @param[in] shift value (1 or 3 supported). If no argument + passed the shift is set to Nface. + @return Reference to shifted field + */ + GaugeField &shift(int shift = -1) const; + + /** + @brief Resets the shifted field (if it exists). + */ + void shift_reset() const; + /** * @brief Print the site data * @param[in] parity Parity index @@ -669,6 +688,17 @@ namespace quda { */ void genericPrintMatrix(const GaugeField &a, int dim, int parity, unsigned int x_cb, int rank = 0); + /** + @brief Shift the gauge field by shift in each dimension and store + the resulting shifted field. This is used to move the backwards + links on to this site. The input field must be a padded field + with the ghost pre-exchanged if communications are enabled. + @param[in] in Input shifted field + @param[in] shift value (1 or 3 supported) + @return Shifted field + */ + GaugeField shift(const GaugeField &in, int shift); + /** @brief This is a debugging function, where we cast a gauge field into a spinor field so we can compute its L1 norm. diff --git a/include/gauge_field_order.h b/include/gauge_field_order.h index 827dde5bbf..938d0b4ea0 100644 --- a/include/gauge_field_order.h +++ b/include/gauge_field_order.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace quda { @@ -997,7 +998,7 @@ namespace quda { type) */ template + QudaStaggeredPhase = QUDA_STAGGERED_PHASE_NO, bool = false> struct Reconstruct { using real = typename mapper::type; using complex = complex; @@ -1030,14 +1031,10 @@ namespace quda { __device__ __host__ inline void Unpack(complex out[N / 2], const real in[N], int, int, real, const I *, const int *) const { - if constexpr (isFixed::value) { -#pragma unroll - for (int i = 0; i < N / 2; i++) { out[i] = scale * complex(in[2 * i + 0], in[2 * i + 1]); } - } else { #pragma unroll - for (int i = 0; i < N / 2; i++) { out[i] = complex(in[2 * i + 0], in[2 * i + 1]); } - } + for (int i = 0; i < N / 2; i++) { out[i] = complex(in[2 * i + 0], in[2 * i + 1]); } } + __device__ __host__ inline real getPhase(const complex[]) const { return 0; } }; @@ -1052,36 +1049,40 @@ namespace quda { @param isLastTimeSlide if we're on the last time slice of nodes @param ghostExchange if the field is extended or not (determines indexing type) */ - template - __device__ __host__ inline T timeBoundary(int idx, const I X[QUDA_MAX_DIM], const int R[QUDA_MAX_DIM], - T tBoundary, T scale, int firstTimeSliceBound, int lastTimeSliceBound, bool isFirstTimeSlice, - bool isLastTimeSlice, QudaGhostExchange ghostExchange = QUDA_GHOST_EXCHANGE_NO) - { + template + __device__ __host__ inline T timeBoundary(int idx, const I X[QUDA_MAX_DIM], const int R[QUDA_MAX_DIM], T tBoundary, + T scale, int firstTimeSliceBound, int lastTimeSliceBound, + bool isFirstTimeSlice, bool isLastTimeSlice, + QudaGhostExchange ghostExchange = QUDA_GHOST_EXCHANGE_NO) + { - // MWTODO: should this return tBoundary : scale or tBoundary*scale : scale + // MWTODO: should this return tBoundary : scale or tBoundary*scale : scale - if (ghostExchange_ == QUDA_GHOST_EXCHANGE_PAD - || (ghostExchange_ == QUDA_GHOST_EXCHANGE_INVALID && ghostExchange != QUDA_GHOST_EXCHANGE_EXTENDED)) { - if (idx >= firstTimeSliceBound) { // halo region on the first time slice - return isFirstTimeSlice ? tBoundary : scale; - } else if (idx >= lastTimeSliceBound) { // last link on the last time slice - return isLastTimeSlice ? tBoundary : scale; - } else { - return scale; - } - } else if (ghostExchange_ == QUDA_GHOST_EXCHANGE_EXTENDED - || (ghostExchange_ == QUDA_GHOST_EXCHANGE_INVALID && ghostExchange == QUDA_GHOST_EXCHANGE_EXTENDED)) { - if (idx >= (R[3] - 1) * X[0] * X[1] * X[2] / 2 && idx < R[3] * X[0] * X[1] * X[2] / 2) { - // the boundary condition is on the R[3]-1 time slice - return isFirstTimeSlice ? tBoundary : scale; - } else if (idx >= (X[3] - R[3] - 1) * X[0] * X[1] * X[2] / 2 && idx < (X[3] - R[3]) * X[0] * X[1] * X[2] / 2) { - // the boundary condition lies on the X[3]-R[3]-1 time slice - return isLastTimeSlice ? tBoundary : scale; - } else { - return scale; - } + if (ghostExchange_ == QUDA_GHOST_EXCHANGE_PAD + || (ghostExchange_ == QUDA_GHOST_EXCHANGE_INVALID && ghostExchange != QUDA_GHOST_EXCHANGE_EXTENDED)) { + + if (!shifted && idx >= firstTimeSliceBound) { // halo region on the first time slice + return isFirstTimeSlice ? tBoundary : scale; + } else if (shifted && idx < firstTimeSliceBound) { // shifted link on first time slice + return isFirstTimeSlice ? tBoundary : scale; + } else if (!shifted && idx >= lastTimeSliceBound) { // last link on the last time slice + return isLastTimeSlice ? tBoundary : scale; + } else { + return scale; + } + } else if (ghostExchange_ == QUDA_GHOST_EXCHANGE_EXTENDED + || (ghostExchange_ == QUDA_GHOST_EXCHANGE_INVALID && ghostExchange == QUDA_GHOST_EXCHANGE_EXTENDED)) { + if (idx >= (R[3] - 1) * X[0] * X[1] * X[2] / 2 && idx < R[3] * X[0] * X[1] * X[2] / 2) { + // the boundary condition is on the R[3]-1 time slice + return isFirstTimeSlice ? tBoundary : scale; + } else if (idx >= (X[3] - R[3] - 1) * X[0] * X[1] * X[2] / 2 && idx < (X[3] - R[3]) * X[0] * X[1] * X[2] / 2) { + // the boundary condition lies on the X[3]-R[3]-1 time slice + return isLastTimeSlice ? tBoundary : scale; + } else { + return scale; } - return scale; + } + return scale; } // not actually used - here for reference @@ -1104,8 +1105,8 @@ namespace quda { @tparam ghostExchange_ optional template the ghostExchange type to avoid the run-time overhead */ - template - struct Reconstruct<18, Float, QUDA_RECONSTRUCT_12, ghostExchange_> { + template + struct Reconstruct<18, Float, QUDA_RECONSTRUCT_12, ghostExchange_, phase, shifted> { using real = typename mapper::type; using complex = complex; const real anisotropy; @@ -1119,7 +1120,7 @@ namespace quda { Reconstruct(const GaugeField &u) : anisotropy(u.Anisotropy()), tBoundary(static_cast(u.TBoundary())), - firstTimeSliceBound(u.VolumeCB()), + firstTimeSliceBound(!shifted ? u.VolumeCB() : u.X()[0] * u.X()[1] * u.X()[2] / 2), lastTimeSliceBound((u.X()[3] - 1) * u.X()[0] * u.X()[1] * u.X()[2] / 2), isFirstTimeSlice(comm_coord(3) == 0 ? true : false), isLastTimeSlice(comm_coord(3) == comm_dim(3) - 1 ? true : false), @@ -1145,8 +1146,8 @@ namespace quda { const real u0 = dir < 3 ? anisotropy : - timeBoundary(idx, X, R, tBoundary, static_cast(1.0), firstTimeSliceBound, - lastTimeSliceBound, isFirstTimeSlice, isLastTimeSlice, ghostExchange); + timeBoundary(idx, X, R, tBoundary, static_cast(1.0), firstTimeSliceBound, + lastTimeSliceBound, isFirstTimeSlice, isLastTimeSlice, ghostExchange); // out[6] = u0*conj(out[1]*out[5] - out[2]*out[4]); out[6] = cmul(out[2], out[4]); @@ -1177,8 +1178,8 @@ namespace quda { @tparam ghostExchange_ optional template the ghostExchange type to avoid the run-time overhead */ - template - struct Reconstruct<18, Float, QUDA_RECONSTRUCT_10, ghostExchange_> { + template + struct Reconstruct<18, Float, QUDA_RECONSTRUCT_10, ghostExchange_, phase, shifted> { using real = typename mapper::type; using complex = complex; @@ -1225,8 +1226,8 @@ namespace quda { @tparam ghostExchange_ optional template the ghostExchange type to avoid the run-time overhead */ - template - struct Reconstruct<18, Float, QUDA_RECONSTRUCT_13, ghostExchange_, stag_phase> { + template + struct Reconstruct<18, Float, QUDA_RECONSTRUCT_13, ghostExchange_, stag_phase, shifted> { using real = typename mapper::type; using complex = complex; const Reconstruct<18, Float, QUDA_RECONSTRUCT_12, ghostExchange_> reconstruct_12; @@ -1249,25 +1250,27 @@ namespace quda { out[6] = cmul(out[2], out[4]); out[6] = cmac(out[1], out[5], -out[6]); - out[6] = scale_inv * conj(out[6]); + out[6] = conj(out[6]); out[7] = cmul(out[0], out[5]); out[7] = cmac(out[2], out[3], -out[7]); - out[7] = scale_inv * conj(out[7]); + out[7] = conj(out[7]); out[8] = cmul(out[1], out[3]); out[8] = cmac(out[0], out[4], -out[8]); - out[8] = scale_inv * conj(out[8]); + out[8] = conj(out[8]); if constexpr (stag_phase == QUDA_STAGGERED_PHASE_NO) { // dynamic phasing // Multiply the third row by exp(I*3*phase), since the cross product will end up in a scale factor of exp(-I*2*phase) real cos_sin[2]; sincospi(static_cast(3.0) * phase, &cos_sin[1], &cos_sin[0]); complex A(cos_sin[0], cos_sin[1]); - out[6] = cmul(A, out[6]); - out[7] = cmul(A, out[7]); - out[8] = cmul(A, out[8]); + A *= scale_inv; + out[6] = cmul(out[6], A); + out[7] = cmul(out[7], A); + out[8] = cmul(out[8], A); } else { // phase is +/- 1 so real multiply is sufficient + phase *= scale_inv; out[6] *= phase; out[7] *= phase; out[8] *= phase; @@ -1302,8 +1305,8 @@ namespace quda { @tparam ghostExchange_ optional template the ghostExchange type to avoid the run-time overhead */ - template - struct Reconstruct<18, Float, QUDA_RECONSTRUCT_8, ghostExchange_> { + template + struct Reconstruct<18, Float, QUDA_RECONSTRUCT_8, ghostExchange_, stag_phase, shifted> { using real = typename mapper::type; using complex = complex; const complex anisotropy; // imaginary value stores inverse @@ -1318,7 +1321,7 @@ namespace quda { Reconstruct(const GaugeField &u, real scale = 1.0) : anisotropy(u.Anisotropy() * scale, 1.0 / (u.Anisotropy() * scale)), tBoundary(static_cast(u.TBoundary()) * scale, 1.0 / (static_cast(u.TBoundary()) * scale)), - firstTimeSliceBound(u.VolumeCB()), + firstTimeSliceBound(!shifted ? u.VolumeCB() : u.X()[0] * u.X()[1] * u.X()[2] / 2), lastTimeSliceBound((u.X()[3] - 1) * u.X()[0] * u.X()[1] * u.X()[2] / 2), isFirstTimeSlice(comm_coord(3) == 0 ? true : false), isLastTimeSlice(comm_coord(3) == comm_dim(3) - 1 ? true : false), @@ -1389,29 +1392,31 @@ namespace quda { real r_inv2 = u0_inv * row_sum_inv; { complex A = cmul(conj(out[0]), out[3]); + complex u0A = u0 * A; // out[4] = -(conj(out[6])*conj(out[2]) + u0*A*out[1])*r_inv2; // U11 out[4] = cmul(conj(out[6]), conj(out[2])); - out[4] = cmac(u0 * A, out[1], out[4]); + out[4] = cmac(u0A, out[1], out[4]); out[4] = -r_inv2 * out[4]; // out[5] = (conj(out[6])*conj(out[1]) - u0*A*out[2])*r_inv2; // U12 out[5] = cmul(conj(out[6]), conj(out[1])); - out[5] = cmac(-u0 * A, out[2], out[5]); + out[5] = cmac(-u0A, out[2], out[5]); out[5] = r_inv2 * out[5]; } { complex A = cmul(conj(out[0]), out[6]); + complex u0A = u0 * A; // out[7] = (conj(out[3])*conj(out[2]) - u0*A*out[1])*r_inv2; // U21 out[7] = cmul(conj(out[3]), conj(out[2])); - out[7] = cmac(-u0 * A, out[1], out[7]); + out[7] = cmac(-u0A, out[1], out[7]); out[7] = r_inv2 * out[7]; // out[8] = -(conj(out[3])*conj(out[1]) + u0*A*out[2])*r_inv2; // U12 out[8] = cmul(conj(out[3]), conj(out[1])); - out[8] = cmac(u0 * A, out[2], out[8]); + out[8] = cmac(u0A, out[2], out[8]); out[8] = -r_inv2 * out[8]; } @@ -1433,8 +1438,8 @@ namespace quda { { complex u = dir < 3 ? anisotropy : - timeBoundary(idx, X, R, tBoundary, scale, firstTimeSliceBound, lastTimeSliceBound, - isFirstTimeSlice, isLastTimeSlice, ghostExchange); + timeBoundary(idx, X, R, tBoundary, scale, firstTimeSliceBound, lastTimeSliceBound, + isFirstTimeSlice, isLastTimeSlice, ghostExchange); Unpack(out, in, idx, dir, phase, X, R, scale, u); } @@ -1450,11 +1455,11 @@ namespace quda { @tparam ghostExchange_ optional template the ghostExchange type to avoid the run-time overhead */ - template - struct Reconstruct<18, Float, QUDA_RECONSTRUCT_9, ghostExchange_, stag_phase> { + template + struct Reconstruct<18, Float, QUDA_RECONSTRUCT_9, ghostExchange_, stag_phase, shifted> { using real = typename mapper::type; using complex = complex; - const Reconstruct<18, Float, QUDA_RECONSTRUCT_8, ghostExchange_> reconstruct_8; + const Reconstruct<18, Float, QUDA_RECONSTRUCT_8, ghostExchange_, stag_phase, shifted> reconstruct_8; const real scale; const real scale_inv; @@ -1551,18 +1556,19 @@ namespace quda { template + QudaGhostExchange ghostExchange_ = QUDA_GHOST_EXCHANGE_INVALID, bool use_inphase = false, bool shifted = false> struct FloatNOrder { - using Accessor = FloatNOrder; + using Accessor = FloatNOrder; using store_t = Float; static constexpr int length = length_; using real = typename mapper::type; using complex = complex; typedef typename AllocType::type AllocInt; - Reconstruct reconstruct; + Reconstruct reconstruct; static constexpr int reconLen = recon; static constexpr int hasPhase = (reconLen == 9 || reconLen == 13) ? 1 : 0; + static constexpr bool loadPhase = hasPhase && !(static_phase() && (reconLen == 13 || use_inphase)); static constexpr int N = gauge::get_vector_order(reconLen - hasPhase); static constexpr int M = (reconLen - hasPhase) / N; static constexpr int Nrem = reconLen - hasPhase - M * N; @@ -1580,6 +1586,9 @@ namespace quda { const int geometry; const AllocInt phaseOffset; size_t bytes; + gauge::tensor_desc_t tensor_desc; + const real combined_scale; // Precomputed scale for copy_and_scale: fixedInvMaxValue * reconstruct.scale + const real phase_scale; // Precomputed scale for phase loading: fixedInvMaxValue * 2.0 (or just 2.0 for float) FloatNOrder(const GaugeField &u, Float *gauge_ = 0, Float **ghost_ = 0) : reconstruct(u), @@ -1590,7 +1599,18 @@ namespace quda { stride(u.Stride()), geometry(u.Geometry()), phaseOffset(u.PhaseOffset() / sizeof(Float)), - bytes(u.Bytes()) + bytes(u.Bytes()), + combined_scale([&]() { + if constexpr (recon == 18) { + // QUDA_RECONSTRUCT_NO: combine fixedInvMaxValue with reconstruct.scale + return isFixed::value ? fixedInvMaxValue::value * reconstruct.scale : 1.0; + } else { + // Other reconstruction types: only need fixedInvMaxValue (reconstruct.scale doesn't exist) + return isFixed::value ? fixedInvMaxValue::value : 1.0; + } + }()), + phase_scale(isFixed::value ? fixedInvMaxValue::value * static_cast(2.0) : + static_cast(2.0)) { if (geometry == QUDA_COARSE_GEOMETRY) errorQuda("This accessor does not support coarse-link fields (lacks support for bidirectional ghost zone"); @@ -1612,26 +1632,97 @@ namespace quda { #pragma unroll for (int i = 0; i < M; i++) { // first load from memory - auto vecTmp = vector_load(gauge + parity * offset + dir * (M * N + Nrem) * stride, i * stride + x); - // second do copy converting into register type - copy(tmp + i * N, vecTmp); + auto vecTmp = vector_load(gauge, parity * offset + dir * (M * N + Nrem) * stride, i * stride + x); + // second do copy converting into register type with combined scaling + copy_and_scale(tmp + i * N, vecTmp, combined_scale); } // now load any remainder if constexpr (Nrem > 0) { - auto vecTmp = vector_load(gauge + parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x); - copy(tmp + M * N, vecTmp); + auto vecTmp = vector_load(gauge, parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x); + copy_and_scale(tmp + M * N, vecTmp, combined_scale); } - constexpr bool load_phase = (hasPhase && !(static_phase() && (reconLen == 13 || use_inphase))); - if constexpr (load_phase) { - copy(phase, gauge[parity * offset + phaseOffset + stride * dir + x]); - phase *= static_cast(2.0); + if constexpr (loadPhase) { + if constexpr (isFixed::value) { + copy_and_scale(phase, gauge[parity * offset + phaseOffset + stride * dir + x], phase_scale); + } else { + copy(phase, gauge[parity * offset + phaseOffset + stride * dir + x]); + phase *= static_cast(2.0); + } } reconstruct.Unpack(v, tmp, x, dir, phase, X, R); } + __device__ __host__ inline void raw_load(array &v, int x, int dir, int parity) const + { +#pragma unroll + for (int i = 0; i < M; i++) { + // first load from memory + auto vecTmp = vector_load(gauge, parity * offset + dir * (M * N + Nrem) * stride, i * stride + x); + memcpy(&v[i * N], &vecTmp, sizeof(vecTmp)); + } + + // now load any remainder + if constexpr (Nrem > 0) { + auto vecTmp = vector_load(gauge, parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x); + memcpy(&v[M * N], &vecTmp, sizeof(vecTmp)); + } + + if constexpr (loadPhase) + memcpy(&v[M * N + Nrem], &gauge[parity * offset + phaseOffset + stride * dir + x], sizeof(store_t)); + } + + template __device__ inline void prefetch(int x, int dir, int parity, int block_size = 0) const + { + if constexpr (type == PrefetchType::THREAD) { // use per-thread prefetching +#pragma unroll + for (int i = 0; i < M; i++) + prefetch_cache_line(gauge + (parity * offset + dir * (M * N + Nrem) * stride + (i * stride + x) * N)); + + // now load any remainder + if constexpr (Nrem > 0) + prefetch_cache_line(gauge + (parity * offset + (dir * (M * N + Nrem) + M * N) * stride + x * Nrem)); + + if constexpr (loadPhase) prefetch_cache_line(gauge + (parity * offset + phaseOffset + stride * dir + x)); + } else if constexpr (type == PrefetchType::BULK) { // bulk prefetch + if (block_size == 0) block_size = blockDim.x; + if (target::is_thread_zero()) { +#pragma unroll + for (int i = 0; i < M; i++) + prefetch_cache_bulk(gauge + (parity * offset + dir * (M * N + Nrem) * stride + (i * stride + x) * N), + block_size * N * sizeof(Float)); + + // now load any remainder + if constexpr (Nrem > 0) + prefetch_cache_bulk(gauge + (parity * offset + (dir * (M * N + Nrem) + M * N) * stride + x * Nrem), + block_size * Nrem * sizeof(Float)); + + if constexpr (loadPhase) + prefetch_cache_bulk(gauge + (parity * offset + phaseOffset + stride * dir + x), block_size * sizeof(Float)); + } + } else if constexpr (type == PrefetchType::TENSOR) { // n-d tensor prefetch + if (target::is_thread_zero()) { + prefetch_cache_tensor_5d(tensor_desc.N, x, x / 16, 0, dir, parity); + if constexpr (Nrem > 0) prefetch_cache_tensor_4d(tensor_desc.Nrem, x, x / 16, dir, parity); + if constexpr (loadPhase) prefetch_cache_tensor_4d(tensor_desc.phase, x, x / 16, dir, parity); + } +#if 0 // L1 prefetching is a disabled experiment + } else { // L1 prefetching +#pragma unroll + for (int i = 0; i < M; i++) + prefetch_L1_cache_line(gauge + (parity * offset + dir * (M * N + Nrem) * stride + (i * stride + x) * N)); + + // now load any remainder + if constexpr (Nrem > 0) + prefetch_L1_cache_line(gauge + (parity * offset + (dir * (M * N + Nrem) + M * N) * stride + x * Nrem)); + + if constexpr (loadPhase) prefetch_L1_cache_line(gauge + (parity * offset + phaseOffset + stride * dir + x)); +#endif + } + } + __device__ __host__ inline void save(const complex v[length / 2], int x, int dir, int parity) const { real tmp[reconLen]; @@ -1644,7 +1735,7 @@ namespace quda { #pragma unroll for (int j = 0; j < N; j++) copy(vecTmp[j], tmp[i * N + j]); // second do vectorized copy into memory - vector_store(gauge + parity * offset + dir * (M * N + Nrem) * stride, x + i * stride, vecTmp); + vector_store(gauge, parity * offset + dir * (M * N + Nrem) * stride, x + i * stride, vecTmp); } // now save any remainder @@ -1653,7 +1744,7 @@ namespace quda { #pragma unroll for (int j = 0; j < Nrem; j++) copy(vecTmp[j], tmp[M * N + j]); // second do vectorized copy into memory - vector_store(gauge + parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x, vecTmp); + vector_store(gauge, parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x, vecTmp); } if constexpr (hasPhase) { @@ -1662,6 +1753,29 @@ namespace quda { } } + __device__ __host__ inline void raw_save(const array &v, int x, int dir, int parity) const + { +#pragma unroll + for (int i = 0; i < M; i++) { + array vecTmp; + // first do copy converting into storage type + memcpy(&vecTmp, &v[i * N], sizeof(vecTmp)); + // second do vectorized copy into memory + vector_store(gauge, parity * offset + dir * (M * N + Nrem) * stride, x + i * stride, vecTmp); + } + + // now save any remainder + if constexpr (Nrem > 0) { + array vecTmp; + memcpy(&vecTmp, &v[M * N], sizeof(vecTmp)); + // second do vectorized copy into memory + vector_store(gauge, parity * offset + (dir * (M * N + Nrem) + M * N) * stride, x, vecTmp); + } + + if constexpr (hasPhase) + memcpy(&gauge[parity * offset + phaseOffset + dir * stride + x], &v[M * N + Nrem], sizeof(store_t)); + } + /** @brief This accessor routine returns a gauge_wrapper to this object, allowing us to overload various operators for manipulating at @@ -1690,15 +1804,15 @@ namespace quda { // first do vectorized copy from memory into registers auto vecTmp = vector_load(ghost[dir], (i * 2 + parity) * faceVolumeCB[dir] + x); - // second do copy converting into register type - copy(tmp + i * N, vecTmp); + // second do copy converting into register type with combined scaling + copy_and_scale(tmp + i * N, vecTmp, combined_scale); } // now load any remainder if constexpr (Nrem > 0) { auto vecTmp - = vector_load(ghost[dir] + 2 * faceVolumeCB[dir] * M * N, parity * faceVolumeCB[dir] + x); - copy(tmp + M * N, vecTmp); + = vector_load(ghost[dir], 2 * faceVolumeCB[dir] * M * N, parity * faceVolumeCB[dir] + x); + copy_and_scale(tmp + M * N, vecTmp, combined_scale); } real phase = 0.; @@ -1707,8 +1821,13 @@ namespace quda { // if(stag_phase == QUDA_STAGGERED_PHASE_MILC ) { // phase = inphase < static_cast(0) ? static_cast(-0.5) : static_cast(0.5); // } else { - copy(phase, ghost[dir][2 * faceVolumeCB[dir] * (reconLen - 1) + parity * faceVolumeCB[dir] + x]); - phase *= static_cast(2.0); + if constexpr (isFixed::value) { + copy_and_scale(phase, ghost[dir][2 * faceVolumeCB[dir] * (reconLen - 1) + parity * faceVolumeCB[dir] + x], + phase_scale); + } else { + copy(phase, ghost[dir][2 * faceVolumeCB[dir] * (reconLen - 1) + parity * faceVolumeCB[dir] + x]); + phase *= static_cast(2.0); + } // } } reconstruct.Unpack(v, tmp, x, dir, phase, X, R); @@ -1739,7 +1858,7 @@ namespace quda { #pragma unroll for (int j = 0; j < Nrem; j++) copy(vecTmp[j], tmp[M * N + j]); // second do vectorized copy into memory - vector_store(ghost[dir] + 2 * faceVolumeCB[dir] * M * N, parity * faceVolumeCB[dir] + x, vecTmp); + vector_store(ghost[dir], 2 * faceVolumeCB[dir] * M * N, parity * faceVolumeCB[dir] + x, vecTmp); } if constexpr (hasPhase) { @@ -1790,27 +1909,36 @@ namespace quda { #pragma unroll for (int i = 0; i < M; i++) { // first do vectorized copy from memory - auto vecTmp = vector_load(ghost[dim] + dir * reconLen * 2 * geometry * R[dim] * faceVolumeCB[dim], + auto vecTmp = vector_load(ghost[dim], dir * reconLen * 2 * geometry * R[dim] * faceVolumeCB[dim], ((i * 2 + parity) * geometry + g) * R[dim] * faceVolumeCB[dim] + x); - // second do copy converting into register type - copy(tmp + i * N, vecTmp); + // second do copy converting into register type with combined scaling + copy_and_scale(tmp + i * N, vecTmp, combined_scale); } // now load any remainder if constexpr (Nrem > 0) { auto vecTmp - = vector_load(ghost[dim] + (dir * reconLen + M * N) * 2 * geometry * R[dim] * faceVolumeCB[dim], + = vector_load(ghost[dim], (dir * reconLen + M * N) * 2 * geometry * R[dim] * faceVolumeCB[dim], (parity * geometry + g) * R[dim] * faceVolumeCB[dim] + x); - copy(tmp + M * N, vecTmp); + copy_and_scale(tmp + M * N, vecTmp, combined_scale); } real phase = 0.; - if constexpr (hasPhase) - copy(phase, - ghost[dim][(dir * reconLen + M * N + Nrem) * 2 * geometry * R[dim] * faceVolumeCB[dim] - + (parity * geometry + g) * R[dim] * faceVolumeCB[dim] + x]); + if constexpr (hasPhase) { + if constexpr (isFixed::value) { + copy_and_scale(phase, + ghost[dim][(dir * reconLen + M * N + Nrem) * 2 * geometry * R[dim] * faceVolumeCB[dim] + + (parity * geometry + g) * R[dim] * faceVolumeCB[dim] + x], + phase_scale); + } else { + copy(phase, + ghost[dim][(dir * reconLen + M * N + Nrem) * 2 * geometry * R[dim] * faceVolumeCB[dim] + + (parity * geometry + g) * R[dim] * faceVolumeCB[dim] + x]); + phase *= static_cast(2.0); + } + } // use the extended_idx to determine the boundary condition reconstruct.Unpack(v, tmp, extended_idx, g, 2. * phase, X, R); @@ -1829,7 +1957,7 @@ namespace quda { #pragma unroll for (int j = 0; j < N; j++) copy(vecTmp[j], tmp[i * N + j]); // second do vectorized copy to memory - vector_store(ghost[dim] + dir * reconLen * 2 * geometry * R[dim] * faceVolumeCB[dim], + vector_store(ghost[dim], dir * reconLen * 2 * geometry * R[dim] * faceVolumeCB[dim], ((i * 2 + parity) * geometry + g) * R[dim] * faceVolumeCB[dim] + x, vecTmp); } @@ -1839,7 +1967,7 @@ namespace quda { #pragma unroll for (int j = 0; j < Nrem; j++) copy(vecTmp[j], tmp[M * N + j]); // second do vectorized copy into memory - vector_store(ghost[dim] + (dir * reconLen + M * N) * 2 * geometry * R[dim] * faceVolumeCB[dim], + vector_store(ghost[dim], (dir * reconLen + M * N) * 2 * geometry * R[dim] * faceVolumeCB[dim], (parity * geometry + g) * R[dim] * faceVolumeCB[dim] + x, vecTmp); } @@ -2538,20 +2666,20 @@ namespace quda { template + bool use_inphase = false, QudaGaugeFieldOrder order = QUDA_NATIVE_GAUGE_ORDER, bool shifted = false> struct gauge_mapper { - typedef gauge::FloatNOrder type; + typedef gauge::FloatNOrder type; }; template - struct gauge_mapper { + QudaGhostExchange ghostExchange, bool use_inphase, bool shifted> + struct gauge_mapper { typedef gauge::MILCOrder type; }; template - struct gauge_mapper { + QudaGhostExchange ghostExchange, bool use_inphase, bool shifted> + struct gauge_mapper { typedef gauge::QDPOrder type; }; diff --git a/include/index_helper.cuh b/include/index_helper.cuh index 35ec4bd0e5..7eff25a882 100644 --- a/include/index_helper.cuh +++ b/include/index_helper.cuh @@ -234,47 +234,15 @@ namespace quda { array gx = {}; // nDim global lattice coordinates array gDim = {}; // global lattice dimensions int x_cb; // checkerboard lattice site index + int x_cb_0; // value of x_cb on first thread in block int s; // fifth dimension coord int X; // full lattice site index constexpr const int& operator[](int i) const { return x[i]; } constexpr int& operator[](int i) { return x[i]; } - array_2d in_boundary = {}; + array_2d in_boundary = {}; constexpr int size() const { return nDim; } }; - /** - @brief Compute the checkerboard 1-d index for the nearest - neighbor - @param[in] lattice coordinates - @param[in] mu dimension in which to add 1 - @param[in] dir direction (+1 or -1) - @param[in] arg parameter struct - @return 1-d checkboard index - */ - template - __device__ __host__ inline int getNeighborIndexCB(const Coord &x, int mu, int dir, const Arg &arg) - { - switch (dir) { - case +1: // positive direction - switch (mu) { - case 0: return (x.in_boundary[1][0] ? x.X - (arg.X[0] - 1) : x.X + 1) >> 1; - case 1: return (x.in_boundary[1][1] ? x.X - arg.X2X1mX1 : x.X + arg.X[0]) >> 1; - case 2: return (x.in_boundary[1][2] ? x.X - arg.X3X2X1mX2X1 : x.X + arg.X2X1) >> 1; - case 3: return (x.in_boundary[1][3] ? x.X - arg.X4X3X2X1mX3X2X1 : x.X + arg.X3X2X1) >> 1; - case 4: return (x.in_boundary[1][4] ? x.X - arg.X5X4X3X2X1mX4X3X2X1 : x.X + arg.X4X3X2X1) >> 1; - } - case -1: - switch (mu) { - case 0: return (x.in_boundary[0][0] ? x.X + (arg.X[0] - 1) : x.X - 1) >> 1; - case 1: return (x.in_boundary[0][1] ? x.X + arg.X2X1mX1 : x.X - arg.X[0]) >> 1; - case 2: return (x.in_boundary[0][2] ? x.X + arg.X3X2X1mX2X1 : x.X - arg.X2X1) >> 1; - case 3: return (x.in_boundary[0][3] ? x.X + arg.X4X3X2X1mX3X2X1 : x.X - arg.X3X2X1) >> 1; - case 4: return (x.in_boundary[0][4] ? x.X + arg.X5X4X3X2X1mX4X3X2X1 : x.X - arg.X4X3X2X1) >> 1; - } - } - return 0; // should never reach here - } - /** Compute the 4-d spatial index from the checkerboarded 1-d index at parity parity @@ -839,7 +807,7 @@ namespace quda { // int idx = indexFromFaceIndex<4,QUDA_4D_PC,dim,nFace,0>(ghost_idx, parity, arg); template - constexpr int indexFromFaceIndexStaggered(int dim, int face_num, int face_idx_in, int parity, int nLayers, QudaPCType, const Arg &arg) + __host__ __device__ inline int indexFromFaceIndexStaggered(int dim, int face_num, int face_idx_in, int parity, int nLayers, QudaPCType, const Arg &arg) { const auto *X = arg.dc.X; // grid dimension const auto &V4 = arg.dc.volume_4d; // 4-d volume @@ -854,7 +822,7 @@ namespace quda { int s = face_idx_in / arg.dc.face_XYZT[dim]; int face_idx = face_idx_in - s * arg.dc.face_XYZT[dim]; - int dims[3] = {}; + std::remove_const_t> dims[3] = {}; int d1 = 0; #pragma unroll 4 for (int d2 = 0; d2 < 4; d2++) { // this will evaluate at compile time @@ -898,7 +866,7 @@ namespace quda { } template - constexpr int indexFromFaceIndexStaggered(int face_idx_in, int parity, const Arg &arg) + __host__ __device__ int indexFromFaceIndexStaggered(int face_idx_in, int parity, const Arg &arg) { return indexFromFaceIndexStaggered(dim, face_num, face_idx_in, parity, nLayers, type, arg); } diff --git a/include/kernel_helper.h b/include/kernel_helper.h index 14727c327a..075295f9b6 100644 --- a/include/kernel_helper.h +++ b/include/kernel_helper.h @@ -19,11 +19,14 @@ namespace quda static constexpr bool check_bounds = check_bounds_; static constexpr int max_regs = 0; // by default we don't limit register count static constexpr bool spill_shared = false; // whether a given kernel should use shared memory spilling + static constexpr bool is_dslash = false; // whether the arg is for a dslash (with its nested arg struct) dim3 threads; /** number of active threads required */ + int block_size; /** product of thread block dimensions */ int comms_rank; /** per process value of comm_rank() */ int comms_rank_global; /** per process value comm_rank_global() */ int comms_coord[4]; /** array storing {comm_coord(0), ..., comm_coord(3)} */ int comms_dim[4]; /** array storing {comm_dim(0), ..., comm_dim(3)} */ + int comms_dim_partitioned[4]; /** array storing {comm_dim_partitioned(0), ..., comm_dim_partiitoned(3)} */ constexpr kernel_param() = default; @@ -32,7 +35,9 @@ namespace quda comms_rank(comm_rank()), comms_rank_global(comm_rank_global()), comms_coord {comm_coord(0), comm_coord(1), comm_coord(2), comm_coord(3)}, - comms_dim {comm_dim(0), comm_dim(1), comm_dim(2), comm_dim(3)} + comms_dim {comm_dim(0), comm_dim(1), comm_dim(2), comm_dim(3)}, + comms_dim_partitioned {comm_dim_partitioned(0), comm_dim_partitioned(1), comm_dim_partitioned(2), + comm_dim_partitioned(3)} { } diff --git a/include/kernels/block_orthogonalize.cuh b/include/kernels/block_orthogonalize.cuh index e3e0868c7f..3db52ae6b0 100644 --- a/include/kernels/block_orthogonalize.cuh +++ b/include/kernels/block_orthogonalize.cuh @@ -80,7 +80,7 @@ namespace quda { }; template struct BlockOrtho_Params { - static constexpr int mVec = tile_size(); + static constexpr int mVec = tile_size(); using dot_t = array, mVec>; static constexpr int block_dim = 1; using BlockReduceDot = BlockReduce; @@ -90,7 +90,7 @@ namespace quda { template struct BlockOrtho_ : BlockOrtho_Params::Ops { const Arg &arg; - static constexpr unsigned block_size = Arg::block_size; + static constexpr unsigned block_size = Arg::block_size_cxpr; static constexpr int fineSpin = Arg::fineSpin; static constexpr int spinBlock = (fineSpin == 1) ? 1 : fineSpin / Arg::coarseSpin; // size of spin block static constexpr int nColor = Arg::nColor; diff --git a/include/kernels/dslash_coarse_mma.cuh b/include/kernels/dslash_coarse_mma.cuh index 0a8d5ea9ac..0cc3ac31e5 100644 --- a/include/kernels/dslash_coarse_mma.cuh +++ b/include/kernels/dslash_coarse_mma.cuh @@ -216,7 +216,7 @@ namespace quda // Initialize barrier. All `blockDim.x` threads in block participate. init(bar, blockDim.x * blockDim.y * blockDim.z); // Make initialized barrier visible in async proxy. - cde::fence_proxy_async_shared_cta(); + cuda::ptx::fence_proxy_async(); } // Syncthreads so initialized barrier is visible to all threads. __syncthreads(); diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index 46e0ae876a..ea90177228 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -25,6 +25,7 @@ namespace quda using DomainWall4DArg::threads; using DomainWall4DArg::x; using DomainWall4DArg::xpay; + using DomainWall4DArg::block_size; using F = typename DomainWall4DArg::F; diff --git a/include/kernels/dslash_domain_wall_5d.cuh b/include/kernels/dslash_domain_wall_5d.cuh index 0cb3190293..e1f9171763 100644 --- a/include/kernels/dslash_domain_wall_5d.cuh +++ b/include/kernels/dslash_domain_wall_5d.cuh @@ -25,7 +25,7 @@ namespace quda { // remove the batch dimension from these constants, since these are used for 5-d checkerboard indexing DslashArg::dc.X[4] = in.X(4); - DslashArg::dc.X5X4X3X2X1mX4X3X2X1 = (in.X(4) - 1) * DslashArg::dc.X4X3X2X1; + DslashArg::dc.X5X4X3X2X1 = in.X(4) * DslashArg::dc.X4X3X2X1; } }; diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 2b57d5b0a9..65bfeb6c76 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -37,7 +37,7 @@ namespace quda { static constexpr bool reload = reload_; static constexpr bool spin_project = true; static constexpr bool spinor_direct_load = true; // false means texture load - using F = typename colorspinor_mapper::type; // color spin field order + using F = typename colorspinor_mapper::type; // color spin field order static constexpr bool gauge_direct_load = true; // false means texture load static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_EXTENDED; // gauge field used is an extended one using G = typename gauge_mapper::type; // gauge field order diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index ae46c6a900..efaf33c5d7 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -1,12 +1,10 @@ #pragma once -#include #include #include #include #include -#include -#include // forthe packing kernel +#include // for the packing kernel namespace quda { @@ -33,23 +31,30 @@ namespace quda static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_PAD; static constexpr bool use_inphase = improved_ ? false : true; static constexpr QudaStaggeredPhase phase = phase_; - using GU = typename gauge_mapper::type; - using GL = - typename gauge_mapper::type; + template + using GU = typename gauge_mapper::type; + template + using GL = typename gauge_mapper::type; F out[MAX_MULTI_RHS]; /** output vector field */ F in[MAX_MULTI_RHS]; /** input vector field */ const Ghost halo_pack; /** accessor for writing the halo */ const Ghost halo; /** accessor for reading the halo */ F x[MAX_MULTI_RHS]; /** input vector when doing xpay */ - const GU U; /** the gauge field */ - const GL L; /** the long gauge field */ + mutable GU U; /** the gauge field */ + mutable GU Uback; /** the gauge field */ + mutable GL L; /** the long gauge field */ + mutable GL Lback; /** the long gauge field */ const real a; /** xpay scale factor */ const real tboundary; /** temporal boundary condition */ const bool is_first_time_slice; /** are we on the first (global) time slice */ const bool is_last_time_slice; /** are we on the last (global) time slice */ static constexpr bool improved = improved_; + static constexpr int prefetch_distance = QUDA_DSLASH_PREFETCH_DISTANCE_STAGGERED; + static constexpr int prefetch_distance_l1 = 0; const real dagger_scale; @@ -59,11 +64,15 @@ namespace quda DslashArg < Float, nDim, DDArg, improved ? 3 : 1, n_src_tile > (out, in, halo, U, x, parity, dagger, a == 0.0 ? false : true, spin_project, comm_override), - halo_pack(halo, improved_ ? 3 : 1), halo(halo, improved_ ? 3 : 1), U(U), L(L), a(a), tboundary(U.TBoundary()), - is_first_time_slice(comm_coord(3) == 0 ? true : false), + halo_pack(halo, improved_ ? 3 : 1), halo(halo, improved_ ? 3 : 1), U(U), + Uback(dslash_double_store() ? U.shift(1) : U), L(L), Lback(dslash_double_store() ? L.shift(3) : L), a(a), + tboundary(U.TBoundary()), is_first_time_slice(comm_coord(3) == 0 ? true : false), is_last_time_slice(comm_coord(3) == comm_dim(3) - 1 ? true : false), dagger_scale(dagger ? static_cast(-1.0) : static_cast(1.0)) { + if (!improved && prefetch_distance > 7) + warningQuda("dslash prefetch distance %d is greater than pipeline length for naive staggered", prefetch_distance); + for (auto i = 0u; i < out.size(); i++) { this->out[i] = out[i]; this->in[i] = in[i]; @@ -72,6 +81,69 @@ namespace quda } }; + /** + @brief Prefetch the gauge field into cache. + @param[in] dim The dimension we are presently working on + @param[in] dir The direction we are presently working on (1 = forwards, 0 = backwards) + @param[in] hop The hopping term we are presently working on (0 = 1 - hop, 1 = 3 - hop) + @param[in] coord Coordinates that we are working on with hop-3 boundary conditions evaluated + @param[in] coord1 Copy of coordinates that we are working on with hop-1 boundary conditions evaluated + @param[in] parity Partiry that we are working on + @param[in] arg Paramter struct + */ + template + __device__ __host__ void prefetch(int dim, int dir, int hop, const coord_t &coord, const coord_t &coord1, int parity, + const Arg &arg) + { + int step = 4 * dim + 2 * dir + hop + distance; + if (step >= Arg::improved ? 16 : 8) return; + + // if using a TMA prefetch we need to use block's first coordinate + auto x_cb = dslash_prefetch_tma() ? coord.x_cb_0 : coord.x_cb; + x_cb = (Arg::nDim == 5 ? x_cb % arg.dc.volume_4d_cb : x_cb); + + if constexpr (Arg::improved) { + int dim2 = step / 4; + switch (step % 4) { + case 0: arg.U.template prefetch(x_cb, dim2, parity); break; + case 1: arg.L.template prefetch(x_cb, dim2, parity); break; + case 2: + if constexpr (dslash_double_store()) + arg.Uback.template prefetch(x_cb, dim2, parity); + else + arg.U.template prefetch(getNeighborIndexCB<1>(coord1, dim2, -1, arg.dc), dim2, 1 - parity); + break; + case 3: + if constexpr (dslash_double_store()) + arg.Lback.template prefetch(x_cb, dim2, parity); + else + arg.L.template prefetch(getNeighborIndexCB<3>(coord, dim2, -1, arg.dc), dim2, 1 - parity); + break; + } + } else { + int dim2 = step / 2; + switch (step % 2) { + case 0: arg.U.template prefetch(x_cb, dim2, parity); break; + case 1: + if constexpr (dslash_double_store()) + arg.Uback.template prefetch(x_cb, dim2, parity); + else + arg.U.template prefetch(getNeighborIndexCB<1>(coord1, dim2, -1, arg.dc), dim2, 1 - parity); + break; + } + } + } + + template + __device__ __host__ void prefetch(int dim, int dir, int hop, const coord_t &coord, const coord_t &coord1, int parity, + const Arg &arg) + { + if constexpr (Arg::prefetch_distance_l1 > 0) // L1 prefetch + prefetch<3, Arg::prefetch_distance_l1>(dim, dir, hop, coord, coord1, parity, arg); + if constexpr (Arg::prefetch_distance > 0) // L2 prefetch + prefetch(dim, dir, hop, coord, coord1, parity, arg); + }; + /** @brief Applies the off-diagonal part of the Staggered / Asqtad operator. @@ -90,104 +162,141 @@ namespace quda typedef Matrix, Arg::nColor> Link; const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0; + Coord coord1 = coord; + if constexpr (Arg::improved) { // need to compute 1-hop in_boundary +#pragma unroll + for (int d = 0; d < 4; d++) { + coord1.in_boundary[1][d] = -(coord[d] + 1 >= arg.dc.X[d]); + coord1.in_boundary[0][d] = -(coord[d] - 1 < 0); + } + } + #pragma unroll for (int d = 0; d < 4; d++) { // loop over dimension // standard - forward direction if (arg.dd_in.doHopping(coord, d, +1)) { - const bool ghost = (coord[d] + 1 >= arg.dc.X[d]) && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord1.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); + if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dc.X, d, 1); - const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); + const Link U = dslash_double_store() ? + static_cast(arg.Uback.Ghost(d, ghost_idx, 1 - parity, StaggeredPhase(coord, d, +1, arg))) : + static_cast(arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg))); + #pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_add(U, in, out[s]); } - } else if (doBulk() && !ghost) { - const int fwd_idx = linkIndexP1(coord, arg.dc.X, d); - const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int fwd_idx = getNeighborIndexCB<1>(coord1, d, 1, arg.dc); + const Link U = arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity); - out[s] = mv_add(U, in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity); + out[s] = mv_add(U, in, out[s]); + } } + prefetch(d, 0, 0, coord, coord1, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } // improved - forward direction if (arg.improved && arg.dd_in.doHopping(coord, d, +3)) { - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dc.X, d, arg.nFace); - const Link L = arg.L(d, coord.x_cb, parity); + const Link L = dslash_double_store() ? static_cast(arg.Lback.Ghost(d, ghost_idx, 1 - parity)) : + static_cast(arg.L(d, coord.x_cb, parity)); #pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_add(L, in, out[s]); } - } else if (doBulk() && !ghost) { - const int fwd3_idx = linkIndexP3(coord, arg.dc.X, d); - const Link L = arg.L(d, coord.x_cb, parity); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int fwd3_idx = getNeighborIndexCB<3>(coord, d, 1, arg.dc); + const Link L = arg.L(d, coord.x_cb, parity); #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity); - out[s] = mv_add(L, in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity); + out[s] = mv_add(L, in, out[s]); + } } + prefetch(d, 0, 1, coord, coord1, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } if (arg.dd_in.doHopping(coord, d, -1)) { // Backward gather - compute back offset for spinor and gauge fetch - const bool ghost = (coord[d] - 1 < 0) && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord1.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx2 = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 3) : ghost_idx2; - const Link U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) : - arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg)); + const Link U + = static_cast(arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg))); + #pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_sub(conj(U), in, out[s]); } - } else if (doBulk() && !ghost) { - const int back_idx = linkIndexM1(coord, arg.dc.X, d); - const int gauge_idx = back_idx; - const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : - arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int back_idx = getNeighborIndexCB<1>(coord1, d, -1, arg.dc); + const Link U = dslash_double_store() ? + static_cast(arg.Uback(d, coord.x_cb, parity, StaggeredPhase(coord, d, -1, arg))) : + static_cast(arg.U(d, back_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg))); + #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); - out[s] = mv_sub(conj(U), in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); + out[s] = mv_sub(conj(U), in, out[s]); + } } + prefetch(d, 1, 0, coord, coord1, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } // improved - backward direction if (arg.improved && arg.dd_in.doHopping(coord, d, -3)) { - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); - const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity); + const Link L = static_cast(arg.L.Ghost(d, ghost_idx, 1 - parity)); #pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_sub(conj(L), in, out[s]); } - } else if (doBulk() && !ghost) { - const int back3_idx = linkIndexM3(coord, arg.dc.X, d); - const int gauge_idx = back3_idx; - const Link L = arg.L(d, gauge_idx, 1 - parity); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int back3_idx = getNeighborIndexCB<3>(coord, d, -1, arg.dc); + const Link L = dslash_double_store() ? static_cast(arg.Lback(d, coord.x_cb, parity)) : + static_cast(arg.L(d, back3_idx, 1 - parity)); #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); - out[s] = mv_sub(conj(L), in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); + out[s] = mv_sub(conj(L), in, out[s]); + } } + prefetch(d, 1, 1, coord, coord1, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } + } // nDim } diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 513a034acd..547385c75c 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -63,7 +63,7 @@ namespace quda if (arg.dd_in.doHopping(coord, d, +1)) { const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); constexpr int proj_dir = dagger ? +1 : -1; - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours @@ -101,7 +101,7 @@ namespace quda const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc); const int gauge_idx = back_idx; constexpr int proj_dir = dagger ? -1 : +1; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 8b66ee83e6..75d5ce041c 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -28,7 +28,9 @@ namespace quda static constexpr bool distance_pc = distance_pc_; static constexpr bool gauge_direct_load = false; // false means texture load static constexpr QudaGhostExchange ghost = QUDA_GHOST_EXCHANGE_PAD; - typedef typename gauge_mapper::type G; + template + using G = typename gauge_mapper::type; typedef typename mapper::type real; @@ -37,11 +39,13 @@ namespace quda F x[MAX_MULTI_RHS]; /** input vector set when doing xpay */ Ghost halo_pack; Ghost halo; - const G U; /** the gauge field */ + mutable G U; /** the gauge field */ + mutable G Uback; /** the backwards gauge field */ const real a; /** xpay scale factor - can be -kappa or -kappa^2 */ /** parameters for distance preconditioning */ const real alpha0; const int t0; + static constexpr int prefetch_distance = QUDA_DSLASH_PREFETCH_DISTANCE_WILSON; WilsonArg(cvector_ref &out, cvector_ref &in, const ColorSpinorField &halo, const GaugeField &U, double a, cvector_ref &x, int parity, bool dagger, @@ -51,6 +55,7 @@ namespace quda halo_pack(halo), halo(halo), U(U), + Uback(dslash_double_store() ? U.shift(1) : U), a(a), alpha0(alpha0), t0(t0) @@ -63,6 +68,41 @@ namespace quda } }; + /** + @tparam distance The distance away we are prefetching + @param[in] dim The dimension we are presently working on + @param[in] dir The direction we are presently working on (1 = forwards, 0 = backwards) + @param[in] coord Coordinates that we are working on + @param[in] parity Partiry that we are working on + @param[in] arg Paramter struct + */ + template + __device__ __host__ void prefetch(int dim, int dir, const coord_t &coord, int parity, const Arg &arg) + { + if constexpr (Arg::prefetch_distance == 0) return; + + int step = 2 * dim + dir + Arg::prefetch_distance; + if (step >= 8) return; + + int dim2 = step / 2; + + // if using a bulk prefetch we need to use block's first coordinate + auto x_cb = dslash_prefetch_tma() ? coord.x_cb_0 : coord.x_cb; + x_cb = (Arg::nDim == 5 ? x_cb % arg.dc.volume_4d_cb : x_cb); + + switch (step % 2) { + case 0: arg.U.template prefetch(x_cb, dim2, parity); break; + case 1: + if constexpr (dslash_double_store()) { + arg.Uback.template prefetch(x_cb, dim2, parity); + } else { + int idx = getNeighborIndexCB(coord, dim2, -1, arg.dc); + arg.U.template prefetch(Arg::nDim == 5 ? idx % arg.dc.volume_4d_cb : idx, dim2, 1 - parity); + } + break; + } + } + /** @brief Applies the off-diagonal part of the Wilson operator @@ -102,7 +142,7 @@ namespace quda const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); constexpr int proj_dir = dagger ? +1 : -1; - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours @@ -115,12 +155,16 @@ namespace quda their_spinor_parity); out += fwd_coeff * (U * in).reconstruct(d, proj_dir); - } else if (doBulk() && !ghost) { + } - Link U = arg.U(d, gauge_idx, gauge_parity); - Vector in = arg.in[src_idx](fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + if constexpr (doBulk()) { + if (!ghost) { + Link U = arg.U(d, gauge_idx, gauge_parity); + Vector in = arg.in[src_idx](fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + out += fwd_coeff * (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + } - out += fwd_coeff * (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + prefetch(d, 0, coord, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } @@ -128,10 +172,11 @@ namespace quda if (arg.dd_in.doHopping(coord, d, -1)) { const real bwd_coeff = (d < 3) ? 1.0 : bwd_coeff_3; const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc); - const int gauge_idx = (Arg::nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx); + int gauge_idx = dslash_double_store() ? coord.x_cb : back_idx; + if constexpr (Arg::nDim == 5) gauge_idx = gauge_idx % arg.dc.volume_4d_cb; constexpr int proj_dir = dagger ? -1 : +1; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours @@ -140,17 +185,23 @@ namespace quda idx; const int gauge_ghost_idx = (Arg::nDim == 5 ? ghost_idx % arg.dc.ghostFaceCB[d] : ghost_idx); - Link U = arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity); + Link U = dslash_double_store() ? static_cast(arg.Uback(d, gauge_idx, gauge_parity)) : + static_cast(arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity)); HalfVector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx * arg.Ls + coord.s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out += bwd_coeff * (conj(U) * in).reconstruct(d, proj_dir); - } else if (doBulk() && !ghost) { + } - Link U = arg.U(d, gauge_idx, 1 - gauge_parity); - Vector in = arg.in[src_idx](back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + if constexpr (doBulk()) { + if (!ghost) { + Link U = dslash_double_store() ? static_cast(arg.Uback(d, gauge_idx, gauge_parity)) : + static_cast(arg.U(d, gauge_idx, 1 - gauge_parity)); + Vector in = arg.in[src_idx](back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + out += bwd_coeff * (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + } - out += bwd_coeff * (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + prefetch(d, 1, coord, parity, arg); // prefetch the gauge link Arg::prefetch_distance ahead } } } // nDim diff --git a/include/kernels/extract_gauge_ghost.cuh b/include/kernels/extract_gauge_ghost.cuh index 42fef0b4ae..5ea2cdaa3f 100644 --- a/include/kernels/extract_gauge_ghost.cuh +++ b/include/kernels/extract_gauge_ghost.cuh @@ -24,7 +24,6 @@ namespace quda { int f[nDim][nDim]; bool localParity[nDim]; int faceVolumeCB[nDim]; - int comm_dim[QUDA_MAX_DIM]; const int offset; ExtractGhostArg(const GaugeField &u, Float **Ghost, int offset, uint64_t size) : kernel_param(dim3(size, 1, 1)), @@ -34,7 +33,6 @@ namespace quda { { for (int d=0; d= 2*arg.faceVolumeCB[dim]) return; @@ -128,7 +126,7 @@ namespace quda { int dim = parity_dim % Arg::nDim; // for now we never inject unless we have partitioned in that dimension - if (!arg.comm_dim[dim] && !Arg::extract) return; + if (!arg.comms_dim_partitioned[dim] && !Arg::extract) return; // linear index used for writing into ghost buffer if (X >= 2*arg.faceVolumeCB[dim]) return; diff --git a/include/kernels/gauge_shift.cuh b/include/kernels/gauge_shift.cuh new file mode 100644 index 0000000000..4726258242 --- /dev/null +++ b/include/kernels/gauge_shift.cuh @@ -0,0 +1,89 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace quda +{ + + template + struct GaugeShiftArg : kernel_param<> { + using real = typename mapper::type; + using Link = Matrix, nColor>; + using RawLink = array; + using Gauge = typename gauge_mapper::type; + static constexpr bool verify = verify_; + + int X[4]; // true grid dimensions + Gauge out; + const Gauge in; + int shift; + int volume_cb; + + GaugeShiftArg(GaugeField &out, const GaugeField &in, int shift) : + kernel_param(dim3(in.VolumeCB(), 2, 4)), out(out), in(in), shift(shift), volume_cb(in.VolumeCB()) + { + for (int dir = 0; dir < 4; dir++) X[dir] = in.X()[dir]; + } + }; + + template struct GaugeShift { + const Arg &arg; + constexpr GaugeShift(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity, int dir) + { + byte_array x = {}; + getCoords(x, x_cb, arg.X, parity); + + if constexpr (!Arg::verify) { + typename Arg::RawLink link; + if (x[dir] < arg.shift && arg.comms_dim_partitioned[dir]) { // on boundary so we fetch from ghost + const int ghost_idx = ghostFaceIndexStaggered<0>(x, arg.X, dir, 1); + arg.in.raw_load(link, arg.volume_cb + ghost_idx, dir, 1 - parity); + arg.out.raw_save(link, x_cb, dir, parity); + } else { // simple shift + byte_array dx = {}; + dx[dir] = dx[dir] - arg.shift; + int x_cb_back = linkIndexShift(x, dx, arg.X); + arg.in.raw_load(link, x_cb_back, dir, 1 - parity); + arg.out.raw_save(link, x_cb, dir, parity); + + if (x[dir] >= arg.X[dir] - arg.shift && arg.comms_dim_partitioned[dir]) { // write the ghost + const int ghost_idx = ghostFaceIndexStaggered<1>(x, arg.X, dir, arg.shift); + arg.in.raw_load(link, x_cb, dir, parity); + arg.out.raw_save(link, arg.volume_cb + ghost_idx, dir, 1 - parity); + } + } + } else { + // verify the shifting has worked + using Link = typename Arg::Link; + if (x[dir] < arg.shift && arg.comms_dim_partitioned[dir]) { + const int ghost_idx = ghostFaceIndexStaggered<0>(x, arg.X, dir, 1); + Link in = arg.in(dir, arg.volume_cb + ghost_idx, 1 - parity); + Link out = arg.out(dir, x_cb, parity); + assert(in.L1() == out.L1()); + } else { + byte_array dx = {}; + dx[dir] = dx[dir] - arg.shift; + int x_cb_back = linkIndexShift(x, dx, arg.X); + Link in = arg.in(dir, x_cb_back, 1 - parity); + Link out = arg.out(dir, x_cb, parity); + assert(in.L1() == out.L1()); + + if (x[dir] >= arg.X[dir] - arg.shift && arg.comms_dim_partitioned[dir]) { + const int ghost_idx = ghostFaceIndexStaggered<1>(x, arg.X, dir, arg.shift); + Link in = arg.in(dir, x_cb, parity); + Link out = arg.out.Ghost(dir, ghost_idx, 1 - parity); + assert(in.L1() == out.L1()); + } + } + } + } + }; + +} // namespace quda diff --git a/include/kernels/laplace.cuh b/include/kernels/laplace.cuh index 9d66c2dee4..b1a45a5c85 100644 --- a/include/kernels/laplace.cuh +++ b/include/kernels/laplace.cuh @@ -36,7 +36,7 @@ namespace quda const Ghost halo_pack; /** accessor used for writing the halo field */ const Ghost halo; /** accessor used for reading the halo field */ F x[MAX_MULTI_RHS]; /** input vector field for xpay*/ - const G U; /** the gauge field */ + mutable G U; /** the gauge field */ const real a; /** xpay scale factor - can be -kappa or -kappa^2 */ const real b; /** used by Wuppetal smearing kernel */ int dir; /** The direction from which to omit the derivative */ @@ -86,11 +86,10 @@ namespace quda if (d != dir) { if (arg.dd_in.doHopping(coord, d, +1)) { // Forward gather - compute fwd offset for vector fetch - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { - // const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dc.X, d, 1); const int ghost_idx = ghostFaceIndex<1>(coord, arg.dc.X, d, arg.nFace); const Link U = arg.U(d, coord.x_cb, parity); const Vector in = arg.halo.Ghost(d, 1, ghost_idx + src_idx * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -111,11 +110,10 @@ namespace quda const int back_idx = linkIndexM1(coord, arg.dc.X, d); const int gauge_idx = back_idx; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { - // const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); const int ghost_idx = ghostFaceIndex<0>(coord, arg.dc.X, d, arg.nFace); const Link U = arg.U.Ghost(d, ghost_idx, 1 - parity); diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index 73f7f16b17..a1501a8ea2 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -174,11 +174,7 @@ namespace quda // block all-reduce thread_max using block_reduce_t = cub::BlockReduce; __shared__ typename block_reduce_t::TempStorage temp_storage; -#if CUDA_VERSION >= 12090 float block_max = block_reduce_t(temp_storage).Reduce(thread_max, ::cuda::maximum()); -#else - float block_max = block_reduce_t(temp_storage).Reduce(thread_max, ::cub::Max()); -#endif __shared__ float block_max_all; if (threadIdx.x + blockDim.x * (threadIdx.y + blockDim.y * threadIdx.z) == 0) { diff --git a/include/lattice_field.h b/include/lattice_field.h index da7538d680..3ad257d251 100644 --- a/include/lattice_field.h +++ b/include/lattice_field.h @@ -160,8 +160,9 @@ namespace quda { /** @brief Create the field as specified by the param @param[in] Parameter struct + @param[in] native_gauge Whether the field is a native gauge field */ - void create(const LatticeFieldParam ¶m); + void create(const LatticeFieldParam ¶m, bool is_native_gauge); /** @brief Move the contents of a field to this @@ -500,7 +501,7 @@ namespace quda { @brief Constructor for creating a LatticeField from a LatticeFieldParam @param param Contains the metadata for creating the field */ - LatticeField(const LatticeFieldParam ¶m); + LatticeField(const LatticeFieldParam ¶m, bool is_native_gauge = false); /** @brief Destructor for LatticeField diff --git a/include/quda_define.h.in b/include/quda_define.h.in index 9b6c75f081..98e9177557 100644 --- a/include/quda_define.h.in +++ b/include/quda_define.h.in @@ -168,6 +168,36 @@ #define GPU_DISTANCE_PRECONDITIONING #endif +/** + * @def QUDA_DSLASH_DOUBLE_STORE + * @brief This macro sets whether to use double storage of the gauge + * field to simplify indexing in the Dslash kernels. + */ +#cmakedefine QUDA_DSLASH_DOUBLE_STORE + +/** + * @def QUDA_DSLASH_PREFETCH_TYPE + * @brief This macro sets whether to use + * the TMA for L2 prefetching: + * NONE - no prefetch + * THREAD - per thread prefetch + * BULK - TMA bulk prefetch + * TENSOR - TMA tensor descriptor prefetch + */ +#define QUDA_DSLASH_PREFETCH_TYPE_@QUDA_DSLASH_PREFETCH_TYPE@ + +/** + * @def QUDA_DSLASH_PREFETCH_DISTANCE_WILSON + * @brief This macro sets the prefetch distance for Wilson fermions + */ +#define QUDA_DSLASH_PREFETCH_DISTANCE_WILSON @QUDA_DSLASH_PREFETCH_DISTANCE_WILSON@ + +/** + * @def QUDA_DSLASH_PREFETCH_DISTANCE_STAGGERED + * @brief This macro sets the prefetch distance for staggered fermions + */ +#define QUDA_DSLASH_PREFETCH_DISTANCE_STAGGERED @QUDA_DSLASH_PREFETCH_DISTANCE_STAGGERED@ + #cmakedefine QUDA_MULTIGRID #ifdef QUDA_MULTIGRID /** diff --git a/include/quda_matrix.h b/include/quda_matrix.h index 8eb579dab3..3eef975308 100644 --- a/include/quda_matrix.h +++ b/include/quda_matrix.h @@ -103,7 +103,8 @@ namespace quda { the absolute column sums. @return Compute L1 norm */ - __device__ __host__ inline real L1() { + __device__ __host__ inline real L1() const + { real l1 = 0; #pragma unroll for (int j=0; j __device__ inline T operator()(const T &value_, bool all, const reducer_t &r, const param_t &) { - using warp_reduce_t = cub::WarpReduce; + using warp_reduce_t = cub::WarpReduce; typename warp_reduce_t::TempStorage dummy_storage; warp_reduce_t warp_reduce(dummy_storage); T value = {}; @@ -111,7 +111,7 @@ namespace quda } if (all) { - using warp_scan_t = cub::WarpScan; + using warp_scan_t = cub::WarpScan; typename warp_scan_t::TempStorage dummy_storage; warp_scan_t warp_scan(dummy_storage); value = warp_scan.Broadcast(value, 0); diff --git a/include/targets/cuda/block_reduction_kernel.h b/include/targets/cuda/block_reduction_kernel.h index bf41cde6d3..639501c421 100644 --- a/include/targets/cuda/block_reduction_kernel.h +++ b/include/targets/cuda/block_reduction_kernel.h @@ -61,9 +61,9 @@ namespace quda @tparam block_size x-dimension block-size @param[in] arg Kernel argument */ - template struct BlockKernelArg : Arg_ { + template struct BlockKernelArg : Arg_ { using Arg = Arg_; - static constexpr unsigned int block_size = block_size_; + static constexpr unsigned int block_size_cxpr = block_size; BlockKernelArg(const Arg &arg) : Arg(arg) { } }; @@ -112,7 +112,7 @@ namespace quda */ template