Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
124 commits
Select commit Hold shift + click to select a range
63b7ff4
Initial support for prefetching (over fetching) added to load instruc…
maddyscientist Sep 19, 2025
191105b
Fix for half precision
maddyscientist Oct 1, 2025
5b41229
Apply some missing OMP parallelization to host functions
maddyscientist Oct 1, 2025
a2efb44
Fix for fine-grained accessor vector loads
maddyscientist Oct 1, 2025
c815076
Add prefetching instructions for CUDA
maddyscientist Oct 1, 2025
177c18b
Optimizaiton of neighbor indexing for dslash kernels: use bitwise ins…
maddyscientist Oct 1, 2025
eae953d
Add support for creating a backward gauge field
maddyscientist Oct 3, 2025
2540a1b
Some small improvedments to shift(GaugeField) function
maddyscientist Oct 7, 2025
e686437
Gauge shift should encode shift value in aux_string
maddyscientist Oct 7, 2025
676c643
Add support for experimental double storage of gauge fields - disable…
maddyscientist Oct 7, 2025
9c2025b
Fix some issues with gauge shift: fix single-GPU builds and add half/…
maddyscientist Oct 20, 2025
721fbd5
make doBulk and doHalo constexpr
maddyscientist Oct 20, 2025
02a4cb9
Add target::is_thread_zero and target::is_lane_zero helper functions …
maddyscientist Oct 21, 2025
33b5f2f
Expose prefetching instructions
maddyscientist Oct 21, 2025
ccf7a55
Add prefetching support to gauge and colorspinor fields
maddyscientist Oct 21, 2025
0642f63
Add L2 gauge-field prefetching support to both Wilson and staggered d…
maddyscientist Oct 21, 2025
72a001f
QUDA_DSLASH_DOUBLE_STORE is now a CMake parameter
maddyscientist Oct 23, 2025
02e7bc3
Add TMA prefetch support for Wilson and staggered fermions (enabled w…
maddyscientist Oct 23, 2025
7bb5cdc
Add target::uniform helper which is used to create warp-uniform varia…
maddyscientist Oct 23, 2025
f42a507
Fix typo in last commit
maddyscientist Oct 23, 2025
e2df25f
Fix bug with non-double-store staggered dslash
maddyscientist Oct 27, 2025
3010aa6
Fix bug with parity setting
maddyscientist Oct 27, 2025
acfaf5b
Fix bulk prefetch of phase
maddyscientist Oct 27, 2025
67f8ce4
Add 3-d and 4-d TMA prefetch instructions
maddyscientist Oct 28, 2025
946bed0
first version of tensor descriptor TMA prefetch - almost certainly buggy
maddyscientist Oct 28, 2025
d772d5f
Fix some warnings and set Uback tensor descriptor for wilson dslash
maddyscientist Oct 28, 2025
60894ec
Add 5-d tensor prefetch instruction to CUDA. Introduce 3-operand var…
maddyscientist Nov 3, 2025
9910869
colorspinor::FloatNOrder load/save functions use 3-operand vector_loa…
maddyscientist Nov 3, 2025
b9a4d5f
Continued improvements to tensor TMA prefetch variant and gauge::Floa…
maddyscientist Nov 3, 2025
23992e0
Guard TMA tensor descriptor creation with __COMPUTE_CAPABILITY__ >= 900
maddyscientist Nov 4, 2025
f0f9afd
Optimization for fixed point gauge field load with QUDA_RECONSTRUCT_N…
maddyscientist Nov 4, 2025
cfaa705
Optimization of fixed-point phase rescaling
maddyscientist Nov 5, 2025
17d349c
Small optimziation to recon-8 unpack, reduces reconstruct by 4 multip…
maddyscientist Nov 5, 2025
a5abce8
Fix backward hopping ghost boundary check in staggered dslash
maddyscientist Nov 5, 2025
c265884
Fix UBSAN error: avoid pointer arithmetic on null pointers
maddyscientist Nov 5, 2025
aee623d
Optimize vector_load/vector_store in gauge_field_order.h to reduce 64…
maddyscientist Nov 5, 2025
6cfc18a
Fix double-store dslash kernels when we have T partitioning - boundar…
maddyscientist Nov 5, 2025
168f097
Fix performance when using double-store gauge field: shifted gauge fi…
maddyscientist Nov 5, 2025
f11bd84
Dslash prefetch should distinguish in the aux string
maddyscientist Nov 5, 2025
a2a9b24
Added experimental optimization: replace parity * offset with bitmask…
maddyscientist Nov 6, 2025
7d17452
Optimization for staggered packing kernels: ensure we do division by …
maddyscientist Nov 8, 2025
27b725d
Optimize scale_inv multiplication in gauge field reconstruction
maddyscientist Nov 8, 2025
2e12a2c
Optimize the alternate path for i2f: with a pre-computed shift consta…
maddyscientist Nov 10, 2025
b67b9fb
Merge origin/feature/prefetch2
maddyscientist Nov 10, 2025
abed9ac
Revert "Added experimental optimization: replace parity * offset with…
maddyscientist Nov 10, 2025
50cc09a
Optimize FFMA2 issuance
maddyscientist Nov 16, 2025
4c9fa83
Add experiment with L1 prefetching for staggered dslash
maddyscientist Nov 17, 2025
9daba3f
No bank conflicts when doing L1 prefetch
maddyscientist Nov 20, 2025
8427323
Fix last commit
maddyscientist Nov 20, 2025
daa5a4f
Disable L1 prefetch experiment on in dslash_staggered
maddyscientist Nov 20, 2025
4b0600a
Fix 32-byte alignment when gauge field is padded
maddyscientist Dec 4, 2025
bbd8ac6
Fix a double4 compiler conflict
maddyscientist Dec 9, 2025
1ed2db1
Fix conflict between block_size definitions
maddyscientist Dec 9, 2025
9de5021
Forbid NVSHMEM and TMA prefetching. Fix autotuner so that only valid…
maddyscientist Dec 9, 2025
30ae502
Fix ambiguity from multi-inheritance with fused DWF kernel
maddyscientist Dec 11, 2025
79934bb
Cleanup of abstraction of TMA to allow for clean building on modern a…
maddyscientist Dec 11, 2025
573d0be
Merge branch 'develop' of github.com:lattice/quda into feature/prefetch2
maddyscientist Dec 11, 2025
04b4fae
We should only be aligning the stride with native gauge fields
maddyscientist Dec 12, 2025
0cf1286
Remove FMA optimied I2F, as it introduces floating point rounding tha…
maddyscientist Dec 13, 2025
aaa629d
We only ever need to resize the pad when creating a gauge field from …
maddyscientist Dec 13, 2025
5653947
Tweak block CG tolerance for staggered eigensovler. Laplace eigensol…
maddyscientist Dec 15, 2025
c5cd669
Fix issue with MRHS Shamir DWF operator (pre-computed constant should…
maddyscientist Dec 16, 2025
20a70e4
Fix warning
maddyscientist Dec 16, 2025
74dd488
Fix bug in mdw_dslash5_tensor_core (was ignorant of the reworked acce…
maddyscientist Dec 16, 2025
b2e6e88
Minor optimization mdw_dslash5_tensor_core.cuh and fix quarter precision
maddyscientist Dec 17, 2025
9b5545f
Reduce carve-out autotuner overhead - default carve out step size is …
maddyscientist Dec 17, 2025
d7568e6
Backwards gauge tensor descriptor copy only done if double store enabled
maddyscientist Dec 17, 2025
c92f3cd
Hopefully fix compiler warning
maddyscientist Dec 18, 2025
35da04f
Fix HIP compilation
maddyscientist Dec 18, 2025
6041ec6
Always use ::cuda::maximum() now that we install our own CCCL
maddyscientist Dec 18, 2025
982f41b
Always use ::cuda::maximum() now that we install our own CCCL
maddyscientist Dec 18, 2025
60a746b
Update cub block interfaces
maddyscientist Dec 18, 2025
4918c98
Fix HIP load_store.h
maddyscientist Dec 18, 2025
af2be33
Fix compilation warning with CUDA clang
maddyscientist Dec 18, 2025
02baeaa
Add missing target_device.h
maddyscientist Dec 18, 2025
4b8352c
Fix clang warning
maddyscientist Dec 18, 2025
13a192b
Fix HIP function call
maddyscientist Dec 18, 2025
274cbad
Fix TMA instruction exposure
maddyscientist Dec 18, 2025
89e8886
Fix clang warning
maddyscientist Dec 19, 2025
866a389
Fix clang error
maddyscientist Dec 19, 2025
63b97b9
Fix another clang error
maddyscientist Dec 19, 2025
bcfaa50
Hopefully the last clang error
maddyscientist Dec 19, 2025
b95f9b4
I2F is encoded in half precision fields
maddyscientist Jan 8, 2026
1b73643
Remove LEGACY_ACCESSOR_NORM path from colorspinor::FloatNOrder, and o…
maddyscientist Jan 9, 2026
510b0a2
Use CCCL 3.1.4 instead of latest main branch commit
maddyscientist Jan 9, 2026
55ee7cc
Add some clarifying comments
maddyscientist Jan 9, 2026
9d21752
Fix compiler warning in domain_decomposition.h
maddyscientist Jan 9, 2026
8d04ac1
Add prefetching support for native staggered
maddyscientist Jan 10, 2026
ca2a85a
Remove stray debug asserts
maddyscientist Jan 13, 2026
0bc3ad3
Small clean up to tune_key
maddyscientist Jan 21, 2026
44b9000
tensor descriptor cache should work as expected now
maddyscientist Jan 22, 2026
96a3912
CMake will error out if TMA prefetch is requested but double-store is…
maddyscientist Jan 22, 2026
6360e16
Small cleanup to Wilson dslash
maddyscientist Jan 27, 2026
48e870b
indexfromFaceIndexStaggered should not be constexpr
maddyscientist Jan 27, 2026
051dd43
Fix compilation issue tripping up some CI
maddyscientist Jan 27, 2026
16a787c
Add 2-d TMA prefetch accessors
maddyscientist Jan 27, 2026
32fd0c3
Add run-time launch check when TMA is enabled to ensure parity is blo…
maddyscientist Jan 28, 2026
9fb3260
Cleanup of staggered dslash kernel
maddyscientist Jan 28, 2026
cc6e837
Add FloatNOrder raw_load and raw_save functions
maddyscientist Jan 28, 2026
3229363
Gauge shift now operates on raw packed elements
maddyscientist Jan 28, 2026
35e734a
Matrix::L1/L2/Linf method should be const qualified
maddyscientist Jan 29, 2026
a5055cc
Fix printing bug with LatticeField
maddyscientist Jan 29, 2026
e38501a
Add kernel_param::comms_dim_partitioned which mirrors comm_dim_partit…
maddyscientist Jan 29, 2026
a8d4a0a
Gause shift kernel now fills in the ghost region of the shifted field…
maddyscientist Jan 30, 2026
73f46af
When double-store is enabled, when doing the halo update always read …
maddyscientist Jan 30, 2026
37cfc7b
Fix bug with staggered dslash test where partitioning was being reset…
maddyscientist Jan 30, 2026
e223bfa
Selecting the type of prefetching to use is now more verbose.
maddyscientist Feb 3, 2026
ea36ced
Runtime warning if dslash prefetch distance exceeds max for naive sta…
maddyscientist Feb 3, 2026
3b25ff5
Fix ROCm compilation
maddyscientist Feb 3, 2026
9b83fde
Make HIP shared memory helpers match CUDA versions
maddyscientist Feb 3, 2026
709b7f9
Blackwell now defaults to using BULK TMA prefetching with a prefetch …
maddyscientist Feb 4, 2026
305884e
Signficant cleanup of TENSOR variant of prefetching. Descriptor not …
maddyscientist Feb 4, 2026
06413d0
Fix CI
maddyscientist Feb 4, 2026
dd77fc0
Fix type with twisted mass
maddyscientist Feb 4, 2026
a265269
Increase TuneKey::aux_n to prevent buffer overflow
maddyscientist Feb 4, 2026
f92570e
value to reference - fixes clang compilation issue
maddyscientist Feb 4, 2026
3ada421
Add git to docker file for CSCS
maddyscientist Feb 5, 2026
2125574
Fix deprecation warning with recent CUDA 13.1 regarding NVML temperat…
maddyscientist Feb 6, 2026
951a3ee
Make the NVML temperature query more robust for the change in interface
maddyscientist Feb 6, 2026
3c8ed1a
Fix CLI11 for modern compilers
maddyscientist Feb 10, 2026
8c7ba4d
Temporary change of default prefetch type on sm100 while doing some b…
weinbe2 Mar 3, 2026
a510234
Fix bug in gauge shift when writing its halo. Add some sanity checks…
maddyscientist Mar 12, 2026
d460006
Merge branch 'feature/prefetch2' of github.com:lattice/quda into feat…
maddyscientist Mar 12, 2026
b0f2a86
Revert "Temporary change of default prefetch type on sm100 while doin…
maddyscientist Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ option(QUDA_CTEST_SEP_DSLASH_POLICIES "Test Dslash policies separately in ctest
option(QUDA_CTEST_DISABLE_BENCHMARKS "Disable benchmark test" ON)

option(QUDA_FAST_COMPILE_REDUCE "enable fast compilation in blas and reduction kernels (single warp per reduction)" OFF)
option(QUDA_FAST_COMPILE_DSLASH "enable fast compilation in dslash kernels (~20% perf impact)" OFF)
option(QUDA_FAST_COMPILE_DSLASH "enable fast compilation in coarse grid dslash kernels (significant perf impact)" OFF)

option(QUDA_OPENMP "enable OpenMP" OFF)
set(QUDA_CXX_STANDARD
Expand Down
4 changes: 3 additions & 1 deletion ci/docker/Dockerfile.build
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ RUN apt-get update -qq && apt-get install -qq -y --no-install-recommends \
build-essential \
cmake \
wget \
ninja-build && \
ninja-build \
git \
ca-certificates && \
rm -rf /var/lib/apt/lists/*

ARG MPICH_VERSION=3.3.2
Expand Down
82 changes: 38 additions & 44 deletions include/color_spinor_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ namespace quda
constexpr int M = nSpinBlock * nColor * nVec;
#pragma unroll
for (int i = 0; i < M; i++) {
vec_t tmp
= vector_load<vec_t>(reinterpret_cast<const vec_t *>(in + parity * offset_cb), x_cb * N + chi * M + i);
memcpy(&out[i], &tmp, sizeof(vec_t));
auto tmp
= vector_load<Float, 2>(reinterpret_cast<const vec_t *>(in + parity * offset_cb), x_cb * N + chi * M + i);
memcpy(&out[i], &tmp, sizeof(tmp));
}
}
};
Expand Down Expand Up @@ -1010,11 +1010,14 @@ namespace quda
{
for (int dim = 0; dim < 4; dim++) {
for (int dir = 0; dir < 2; dir++) {
ghost[2 * dim + dir] = comm_dim_partitioned(dim) ? static_cast<Float *>(ghost_[2 * dim + dir]) : nullptr;
ghost_norm[2 * dim + dir] = !comm_dim_partitioned(dim) ?
nullptr :
reinterpret_cast<norm_type *>(static_cast<char *>(ghost_[2 * dim + dir])
+ nParity * length_ghost * faceVolumeCB[dim] * sizeof(Float));
if (comm_dim_partitioned(dim) && ghost_[2 * dim + dir]) {
ghost[2 * dim + dir] = static_cast<Float *>(ghost_[2 * dim + dir]);
ghost_norm[2 * dim + dir] = reinterpret_cast<norm_type *>(
static_cast<char *>(ghost_[2 * dim + dir]) + nParity * length_ghost * faceVolumeCB[dim] * sizeof(Float));
} else {
ghost[2 * dim + dir] = nullptr;
ghost_norm[2 * dim + dir] = nullptr;
}
}
}
}
Expand All @@ -1023,7 +1026,7 @@ namespace quda
{
real v[length_ghost];
norm_type nrm
= isFixed<Float>::value ? vector_load<float>(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x) : 0.0;
= isFixed<Float>::value ? vector_load<float, 1>(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x)[0] : 0.0;

#pragma unroll
for (int i = 0; i < M; i++) {
Expand Down Expand Up @@ -1123,16 +1126,9 @@ namespace quda
using real = typename mapper<Float>::type;
using complex = complex<real>;
using AllocInt = typename AllocType<huge_alloc>::type;
using norm_type = float;
using norm_t = float;
Float *field = nullptr;
//#define LEGACY_ACCESSOR_NORM // legacy code where norm pointer and offset are stored instead of computed
#ifdef LEGACY_ACCESSOR_NORM
norm_type *norm = nullptr;
#endif
AllocInt offset = 0; // offset can be 32-bit or 64-bit
#ifdef LEGACY_ACCESSOR_NORM
AllocInt norm_offset = 0;
#endif
int volumeCB = 0;

FloatNOrder() = default;
Expand All @@ -1141,14 +1137,7 @@ namespace quda
FloatNOrder(const ColorSpinorField &a, int nFace = 1, Float *buffer = 0, Float **ghost_ = 0) :
GhostNOrder(a, nFace, ghost_),
field(buffer ? buffer : a.data<Float *>()),
#ifdef LEGACY_ACCESSOR_NORM
norm(buffer ? reinterpret_cast<norm_type *>(reinterpret_cast<char *>(buffer) + a.NormOffset()) :
const_cast<norm_type *>(reinterpret_cast<const norm_type *>(a.Norm()))),
#endif
offset(a.Bytes() / (2 * sizeof(Float))),
#ifdef LEGACY_ACCESSOR_NORM
norm_offset(a.Bytes() / (2 * sizeof(norm_type))),
#endif
volumeCB(a.VolumeCB())
{
}
Expand All @@ -1157,54 +1146,59 @@ namespace quda
__device__ __host__ inline void load(complex out[length / 2], int x, int parity = 0) const
{
real v[length];
#ifndef LEGACY_ACCESSOR_NORM
auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1);
auto norm = reinterpret_cast<float *>(field + volumeCB * (2 * Nc * Ns));
#endif
norm_type nrm = isFixed<Float>::value ? vector_load<float>(norm, x + parity * norm_offset) : 0.0;

auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t);
norm_t nrm = isFixed<Float>::value ? vector_load<norm_t, 1>(field, x + norm_offset)[0] : 0.0;
#pragma unroll
for (int i = 0; i < M; i++) {
// first load from memory
auto vecTmp = vector_load<Float, N>(field + parity * offset, volumeCB * i + x);
auto vecTmp = vector_load<Float, N>(field, parity * offset, volumeCB * i + x);
// now copy into output and scale
copy_and_scale(v + i * N, vecTmp, nrm);
}

// now load any remainder
if constexpr (Nrem > 0) {
auto vecTmp = vector_load<Float, Nrem>(field + parity * offset + volumeCB * M * N, x);
auto vecTmp = vector_load<Float, Nrem>(field, parity * offset + volumeCB * M * N, x);
copy_and_scale(v + M * N, vecTmp, nrm);
}

#pragma unroll
for (int i = 0; i < length / 2; i++) out[i] = complex(v[2 * i + 0], v[2 * i + 1]);
}

__device__ __host__ inline void prefetch(int x, int parity = 0) const
{
auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t);
if constexpr (isFixed<Float>::value) prefetch_cache_line(reinterpret_cast<norm_t *>(field) + (x + norm_offset));

#pragma unroll
for (int i = 0; i < M; i++) prefetch_cache_line(field + (parity * offset + (volumeCB * i + x) * N));

// now load any remainder
if constexpr (Nrem > 0) prefetch_cache_line(field + (parity * offset + volumeCB * M * N + x * Nrem));
}

__device__ __host__ inline void save(const complex in[length / 2], int x, int parity = 0) const
{
real v[length];
#ifndef LEGACY_ACCESSOR_NORM
auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1);
auto norm = reinterpret_cast<float *>(field + volumeCB * (2 * Nc * Ns));
#endif
auto norm_offset = (volumeCB * 2 * Nc * Ns + parity * offset) * sizeof(Float) / sizeof(norm_t);

#pragma unroll
for (int i = 0; i < length / 2; i++) {
v[2 * i + 0] = in[i].real();
v[2 * i + 1] = in[i].imag();
}

norm_type scale = 0.0;
norm_type scale_inv = 0.0;
norm_t scale = 0.0;
norm_t scale_inv = 0.0;
if constexpr (isFixed<Float>::value) {
norm_type max_[length / 2];
norm_t max_[length / 2];
// two-pass to increase ILP (assumes length divisible by two, e.g. complex-valued)
#pragma unroll
for (int i = 0; i < length / 2; i++)
max_[i] = fmaxf(fabsf((norm_type)v[i]), fabsf((norm_type)v[i + length / 2]));
for (int i = 0; i < length / 2; i++) max_[i] = fmaxf(fabsf((norm_t)v[i]), fabsf((norm_t)v[i + length / 2]));
#pragma unroll
for (int i = 0; i < length / 2; i++) scale = fmaxf(max_[i], scale);
norm[x + parity * norm_offset] = scale * fixedInvMaxValue<Float>::value;
reinterpret_cast<norm_t *>(field)[x + norm_offset] = scale * fixedInvMaxValue<Float>::value;
scale_inv = fdividef(fixedMaxValue<Float>::value, scale);
}

Expand All @@ -1214,14 +1208,14 @@ namespace quda
// first do scalar copy converting into storage type
copy_and_scale<Float, real, N>(vecTmp, v + i * N, scale_inv);
// second do vectorized copy into memory
vector_store(field + parity * offset, volumeCB * i + x, vecTmp);
vector_store(field, parity * offset, volumeCB * i + x, vecTmp);
}

if constexpr (Nrem > 0) {
array<Float, Nrem> vecTmp;
copy_and_scale<Float, real, Nrem>(vecTmp, v + M * N, scale_inv);
// second do vectorized copy into memory
vector_store(field + parity * offset + volumeCB * M * N, x, vecTmp);
vector_store(field, parity * offset + volumeCB * M * N, x, vecTmp);
}
}

Expand Down
4 changes: 2 additions & 2 deletions include/complex_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -928,14 +928,14 @@ namespace quda
template <typename real> __host__ __device__ inline complex<real> cmul(const complex<real> &x, const complex<real> &y)
{
complex<real> rtn = mul2({x.real(), x.real()}, y);
return fma2({x.imag(), x.imag()}, {-y.imag(), y.real()}, rtn);
return fma2({-x.imag(), x.imag()}, {y.imag(), y.real()}, rtn);
}

template <typename real>
__host__ __device__ inline complex<real> cmac(const complex<real> &x, const complex<real> &y, const complex<real> &z)
{
complex<real> w = fma2({x.real(), x.real()}, y, z);
return fma2({x.imag(), x.imag()}, {-y.imag(), y.real()}, w);
return fma2({-x.imag(), x.imag()}, {y.imag(), y.real()}, w);
}

template <typename T1, typename T2, typename T3>
Expand Down
3 changes: 1 addition & 2 deletions include/domain_decomposition.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ namespace quda
flags[(int)flag] = true;

if ((int)flag == (int)DD::reset) {
#pragma unroll
for (auto i = 0u; i < (int)DD::size; i++) flags[i] = 0;
flags = {};
type = QUDA_DD_NO;
} else if ((int)flag >= (int)DD::red_black_type) {
type = QUDA_DD_RED_BLACK;
Expand Down
35 changes: 34 additions & 1 deletion include/dslash.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <tunable_nd.h>
#include <instantiate.h>
#include <instantiate_dslash.h>
#include <tma_helper.hpp>

namespace quda
{
Expand Down Expand Up @@ -70,6 +71,18 @@ namespace quda
char tile_str[16];
i32toa(tile_str, Arg::n_src_tile);
strcat(aux_base, tile_str);
if constexpr (dslash_double_store()) strcat(aux_base, ",double_store");
if constexpr (Arg::prefetch_distance > 0) {
strcat(aux_base, ",prefetch=");
i32toa(tile_str, Arg::prefetch_distance);
strcat(aux_base, tile_str);
if constexpr (dslash_prefetch_type() == PrefetchType::THREAD)
strcat(aux_base, ",prefetch=thread");
else if constexpr (dslash_prefetch_type() == PrefetchType::BULK)
strcat(aux_base, ",prefetch=bulk");
else if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR)
strcat(aux_base, ",prefetch=tensor");
}
}

/**
Expand Down Expand Up @@ -130,7 +143,7 @@ namespace quda
}
}

inline void setParam(TuneParam &tp)
template <bool improved = false> inline void setParam(TuneParam &tp, const GaugeField &U, const GaugeField &L = {})
{
// Need to reset ghost pointers prior to every call since the
// ghost buffer may have been changed during policy tuning.
Expand Down Expand Up @@ -173,6 +186,16 @@ namespace quda
0;
tp.grid.x += arg.exterior_blocks;
}

if constexpr (dslash_prefetch_type() == PrefetchType::TENSOR && Arg::prefetch_distance > 0) {
Dslash::arg.U.tensor_desc = get_tensor_descriptor(U, tp.block.x);
Dslash::arg.Uback.tensor_desc = get_tensor_descriptor(U.shift(), tp.block.x);
if constexpr (improved) {
assert(!U.empty());
Dslash::arg.L.tensor_desc = get_tensor_descriptor(L, tp.block.x);
Dslash::arg.Lback.tensor_desc = get_tensor_descriptor(L.shift(), tp.block.x);
}
}
}

virtual int blockStep() const override { return (arg.shmem & 64) ? 8 : 16; }
Expand Down Expand Up @@ -219,6 +242,15 @@ namespace quda
}
}

virtual bool advanceBlockDim(TuneParam &param) const override
{
// if TMA is enabled we must keep parity separate in the block (2-d tuning)
if constexpr (dslash_prefetch_tma())
return TunableKernel2D_base<false>::advanceBlockDim(param);
else
return TunableKernel3D::advanceBlockDim(param);
}

virtual bool advanceTuneParam(TuneParam &param) const override
{
return advanceAux(param) || advanceSharedBytes(param) || advanceBlockDim(param) || advanceSharedCarveOut(param)
Expand Down Expand Up @@ -268,6 +300,7 @@ namespace quda
inline void launch(TuneParam &tp, const qudaStream_t &stream)
{
tp.set_max_shared_bytes = true;
if (dslash_prefetch_tma() && tp.block.z > 1) errorQuda("Z-dimension block size must be 1 when using TMA");
launch_device<dslash_functor>(
tp, stream, dslash_functor_arg<D, P, dagger, xpay, kernel_type, Arg>(arg, tp.block.x * tp.grid.x));
}
Expand Down
Loading
Loading