From 63b7ff4c7b55a0810eb8b886ab2d42caa4ab49e9 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Fri, 19 Sep 2025 16:42:56 -0700 Subject: [PATCH 001/121] Initial support for prefetching (over fetching) added to load instructions for CUDA --- include/targets/cuda/inline_ptx.h | 370 ++++++++++++++++++++++++--- include/targets/cuda/load_store.h | 61 +++-- include/targets/generic/load_store.h | 16 +- 3 files changed, 394 insertions(+), 53 deletions(-) diff --git a/include/targets/cuda/inline_ptx.h b/include/targets/cuda/inline_ptx.h index fa29eee35b..3bbccf2ab5 100644 --- a/include/targets/cuda/inline_ptx.h +++ b/include/targets/cuda/inline_ptx.h @@ -18,119 +18,425 @@ namespace quda { // If you're bored... // http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st - __device__ inline void load_streaming_double4(double4 &a, const double4 *addr) +// Helper macro for prefetch size validation +#define VALIDATE_PREFETCH_SIZE(prefetch_size) \ + static_assert(prefetch_size == 0 || prefetch_size == 64 || prefetch_size == 128 || prefetch_size == 256, \ + "prefetch_size must be 0, 64, 128, or 256") + + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_streaming_double4(double4 &a, const double4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + double x, y, z, w; - asm("ld.cs.global.v4.f64 {%0, %1, %2, %3}, [%4+0];" : "=d"(x), "=d"(y), "=d"(z), "=d"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain streaming load, no prefetch hint + asm volatile("ld.global.cs.v4.f64 {%0, %1, %2, %3}, [%4];\n" : "=d"(x), "=d"(y), "=d"(z), "=d"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cs.L2::64B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cs.L2::128B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cs.L2::256B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_streaming_double2(double2 &a, const double2* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_streaming_double2(double2 &a, const double2 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + double x, y; - asm("ld.cs.global.v2.f64 {%0, %1}, [%2+0];" : "=d"(x), "=d"(y) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain streaming load, no prefetch hint + asm volatile("ld.global.cs.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cs.L2::64B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cs.L2::128B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cs.L2::256B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } + a.x = x; a.y = y; } - __device__ inline void load_streaming_float8(float8 &v, const float8 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_streaming_float8(float8 &v, const float8 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y, z, w, a, b, c, d; - asm("ld.cs.global.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8+0];" - : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) - : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain streaming load, no prefetch hint + asm volatile("ld.global.cs.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cs.L2::64B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cs.L2::128B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cs.L2::256B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } + v = {{x, y, z, w}, {a, b, c, d}}; } - __device__ inline void load_streaming_float4(float4 &a, const float4* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_streaming_float4(float4 &a, const float4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y, z, w; - asm("ld.cs.global.v4.f32 {%0, %1, %2, %3}, [%4+0];" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain streaming load, no prefetch hint + asm volatile("ld.global.cs.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cs.L2::64B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cs.L2::128B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cs.L2::256B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_cached_short4(short4 &a, const short4 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_short4(short4 &a, const short4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + short x, y, z, w; - asm("ld.ca.global.v4.s16 {%0, %1, %2, %3}, [%4+0];" : "=h"(x), "=h"(y), "=h"(z), "=h"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v4.s16 {%0, %1, %2, %3}, [%4];\n" : "=h"(x), "=h"(y), "=h"(z), "=h"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_cached_short2(short2 &a, const short2 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_short2(short2 &a, const short2 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + short x, y; - asm("ld.ca.global.v2.s16 {%0, %1}, [%2+0];" : "=h"(x), "=h"(y) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } + a.x = x; a.y = y; } - __device__ inline void load_global_short4(short4 &a, const short4 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_global_short4(short4 &a, const short4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + short x, y, z, w; - asm("ld.cg.global.v4.s16 {%0, %1, %2, %3}, [%4+0];" : "=h"(x), "=h"(y), "=h"(z), "=h"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain global load, no prefetch hint + asm volatile("ld.global.cg.v4.s16 {%0, %1, %2, %3}, [%4];\n" : "=h"(x), "=h"(y), "=h"(z), "=h"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cg.L2::64B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cg.L2::128B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cg.L2::256B.v4.s16 {%0, %1, %2, %3}, [%4];\n" + : "=h"(x), "=h"(y), "=h"(z), "=h"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_global_short2(short2 &a, const short2 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_global_short2(short2 &a, const short2 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + short x, y; - asm("ld.cg.global.v2.s16 {%0, %1}, [%2+0];" : "=h"(x), "=h"(y) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain global load, no prefetch hint + asm volatile("ld.global.cg.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cg.L2::64B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cg.L2::128B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cg.L2::256B.v2.s16 {%0, %1}, [%2];\n" : "=h"(x), "=h"(y) : "l"(addr)); + } + a.x = x; a.y = y; } - __device__ inline void load_global_float4(float4 &a, const float4* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_global_float4(float4 &a, const float4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y, z, w; - asm("ld.cg.global.v4.f32 {%0, %1, %2, %3}, [%4+0];" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain global load, no prefetch hint + asm volatile("ld.global.cg.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.cg.L2::64B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.cg.L2::128B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.cg.L2::256B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_cached_float4(float4 &a, const float4* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_float4(float4 &a, const float4 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y, z, w; - asm("ld.ca.global.v4.f32 {%0, %1, %2, %3}, [%4+0];" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(x), "=f"(y), "=f"(z), "=f"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_cached_float8(float8 &v, const float8 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_float8(float8 &v, const float8 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y, z, w, a, b, c, d; - asm("ld.ca.global.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8+0];" - : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) - : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(x), "=f"(y), "=f"(z), "=f"(w), "=f"(a), "=f"(b), "=f"(c), "=f"(d) + : "l"(addr)); + } + v = {{x, y, z, w}, {a, b, c, d}}; } - __device__ inline void load_cached_float2(float2 &a, const float2* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_float2(float2 &a, const float2 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + float x, y; - asm("ld.ca.global.v2.f32 {%0, %1}, [%2+0];" : "=f"(x), "=f"(y) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v2.f32 {%0, %1}, [%2];\n" : "=f"(x), "=f"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v2.f32 {%0, %1}, [%2];\n" : "=f"(x), "=f"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v2.f32 {%0, %1}, [%2];\n" : "=f"(x), "=f"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v2.f32 {%0, %1}, [%2];\n" : "=f"(x), "=f"(y) : "l"(addr)); + } + a.x = x; a.y = y; } - __device__ inline void load_cached_double4(double4 &a, const double4 *addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_float(float &a, const float *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + + float x; + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.f32 {%0}, [%1];\n" : "=f"(x) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.f32 {%0}, [%1];\n" : "=f"(x) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.f32 {%0}, [%1];\n" : "=f"(x) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.f32 {%0}, [%1];\n" : "=f"(x) : "l"(addr)); + } + + a = x; + } + + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_double4(double4 &a, const double4 *addr) + { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + double x, y, z, w; - asm("ld.ca.global.v4.f64 {%0, %1, %2, %3}, [%4+0];" : "=d"(x), "=d"(y), "=d"(z), "=d"(w) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v4.f64 {%0, %1, %2, %3}, [%4];\n" : "=d"(x), "=d"(y), "=d"(z), "=d"(w) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v4.f64 {%0, %1, %2, %3}, [%4];\n" + : "=d"(x), "=d"(y), "=d"(z), "=d"(w) + : "l"(addr)); + } + a.x = x; a.y = y; a.z = z; a.w = w; } - __device__ inline void load_cached_double2(double2 &a, const double2* addr) + // Valid values for prefetch_size: 0 (no prefetch), 64, 128, 256 + // Note: 256B prefetch requires SM 80+. For older architectures, 256B -> 128B + template __device__ inline void load_cached_double2(double2 &a, const double2 *addr) { + VALIDATE_PREFETCH_SIZE(prefetch_size); + constexpr size_t prefetch_ = __COMPUTE_CAPABILITY__ < 800 ? 0 : prefetch_size; + double x, y; - asm("ld.ca.global.v2.f64 {%0, %1}, [%2+0];" : "=d"(x), "=d"(y) : __PTR(addr)); + + if constexpr (prefetch_ == 0) { + // Plain cached load, no prefetch hint + asm volatile("ld.global.ca.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 64) { + asm volatile("ld.global.ca.L2::64B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 128) { + asm volatile("ld.global.ca.L2::128B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } else if constexpr (prefetch_ == 256) { + asm volatile("ld.global.ca.L2::256B.v2.f64 {%0, %1}, [%2];\n" : "=d"(x), "=d"(y) : "l"(addr)); + } + a.x = x; a.y = y; } diff --git a/include/targets/cuda/load_store.h b/include/targets/cuda/load_store.h index 9f2a51d0b8..29b2e50be3 100644 --- a/include/targets/cuda/load_store.h +++ b/include/targets/cuda/load_store.h @@ -15,53 +15,82 @@ namespace quda // pre-declaration of vector_load that we wish to specialize template struct vector_load_impl; + // pre-declaration of the prefetch type + template struct prefetch_t; + // CUDA specializations of the vector_load template <> struct vector_load_impl { - template __device__ inline void operator()(T &value, const void *ptr, int idx) - { + template + __device__ inline void operator()(T &value, const void *ptr, int idx, const prefetch_t &) { value = reinterpret_cast(ptr)[idx]; } - __device__ inline void operator()(float4 &value, const void *ptr, int idx) + template + __device__ inline void operator()(float4 &value, const void *ptr, int idx, const prefetch_t &) + { + load_cached_float4(value, reinterpret_cast(ptr) + idx); + } + + template + __device__ inline void operator()(float2 &value, const void *ptr, int idx, const prefetch_t &) { - load_cached_float4(value, reinterpret_cast(ptr) + idx); + load_cached_float2(value, reinterpret_cast(ptr) + idx); } - __device__ inline void operator()(float2 &value, const void *ptr, int idx) + template + __device__ inline void operator()(float &value, const void *ptr, int idx, const prefetch_t &) { - load_cached_float2(value, reinterpret_cast(ptr) + idx); + load_cached_float(value, reinterpret_cast(ptr) + idx); } #if __COMPUTE_CAPABILITY__ >= 1000 - __device__ inline void operator()(double4 &value, const void *ptr, int idx) + template + __device__ inline void operator()(double4 &value, const void *ptr, int idx, const prefetch_t &) { - load_cached_double4(value, reinterpret_cast(ptr) + idx); + load_cached_double4(value, reinterpret_cast(ptr) + idx); } - __device__ inline void operator()(float8 &value, const void *ptr, int idx) + template + __device__ inline void operator()(float8 &value, const void *ptr, int idx, const prefetch_t &) { - load_cached_float8(value, reinterpret_cast(ptr) + idx); + load_cached_float8(value, reinterpret_cast(ptr) + idx); } #endif - __device__ inline void operator()(double2 &value, const void *ptr, int idx) + template + __device__ inline void operator()(double2 &value, const void *ptr, int idx, const prefetch_t &) + { + load_cached_double2(value, reinterpret_cast(ptr) + idx); + } + + template + __device__ inline void operator()(short2 &value, const void *ptr, int idx, const prefetch_t &prefetch) + { + load_cached_short2(value, reinterpret_cast(ptr) + idx); + } + + template + __device__ inline void operator()(short4 &value, const void *ptr, int idx, const prefetch_t &prefetch) { - load_cached_double2(value, reinterpret_cast(ptr) + idx); + load_cached_short4(value, reinterpret_cast(ptr) + idx); } - __device__ inline void operator()(short8 &value, const void *ptr, int idx) + template + __device__ inline void operator()(short8 &value, const void *ptr, int idx, const prefetch_t &prefetch) { float4 tmp; - operator()(tmp, ptr, idx); + operator()(tmp, ptr, idx, prefetch); memcpy(&value, &tmp, sizeof(float4)); } - __device__ inline void operator()(char8 &value, const void *ptr, int idx) + template + __device__ inline void operator()(char8 &value, const void *ptr, int idx, const prefetch_t &prefetch) { float2 tmp; - operator()(tmp, ptr, idx); + operator()(tmp, ptr, idx, prefetch); memcpy(&value, &tmp, sizeof(float2)); } + }; // pre-declaration of vector_store that we wish to specialize diff --git a/include/targets/generic/load_store.h b/include/targets/generic/load_store.h index 3239aeaefc..93b847a4db 100644 --- a/include/targets/generic/load_store.h +++ b/include/targets/generic/load_store.h @@ -5,28 +5,34 @@ namespace quda { + template struct prefetch_t { + static constexpr int size = prefetch; + }; + /** @brief Non-specialized load operation */ template struct vector_load_impl { - template __device__ __host__ inline void operator()(T &value, const void *ptr, int idx) + template + __device__ __host__ inline void operator()(T &value, const void *ptr, int idx, const prefetch_t &) { value = reinterpret_cast(ptr)[idx]; } }; - template __device__ __host__ inline vector_t vector_load(const void *ptr, int idx) + template + __device__ __host__ inline vector_t vector_load_internal(const void *ptr, int idx) { vector_t value; - target::dispatch(value, ptr, idx); + target::dispatch(value, ptr, idx, prefetch_t()); return value; } - template + template __device__ __host__ inline array vector_load(const void *ptr, int idx) { using vector_t = typename VectorType::type; - auto value_v = vector_load(ptr, idx); + auto value_v = vector_load_internal(ptr, idx); array value_a; static_assert(sizeof(value_a) == sizeof(value_v), "array type and vector type are different sizes"); memcpy(&value_a, &value_v, sizeof(vector_t)); From 191105b34d6e1d4774d8bec8eb34139fca673850 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 30 Sep 2025 21:24:17 -0700 Subject: [PATCH 002/121] Fix for half precision --- include/color_spinor_field_order.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 46ad849079..0420d349c3 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -1023,7 +1023,7 @@ namespace quda { real v[length_ghost]; norm_type nrm - = isFixed::value ? vector_load(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x) : 0.0; + = isFixed::value ? vector_load(ghost_norm[2 * dim + dir], parity * faceVolumeCB[dim] + x)[0] : 0.0; #pragma unroll for (int i = 0; i < M; i++) { @@ -1161,7 +1161,7 @@ namespace quda auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1); auto norm = reinterpret_cast(field + volumeCB * (2 * Nc * Ns)); #endif - norm_type nrm = isFixed::value ? vector_load(norm, x + parity * norm_offset) : 0.0; + norm_type nrm = isFixed::value ? vector_load(norm, x + parity * norm_offset)[0] : 0.0; #pragma unroll for (int i = 0; i < M; i++) { From 5b41229f7272b2bfb5b512d488d458395c192cfc Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 30 Sep 2025 21:24:40 -0700 Subject: [PATCH 003/121] Apply some missing OMP parallelization to host functions --- tests/utils/gauge_utils.cpp | 2 ++ tests/utils/host_utils.cpp | 1 + tests/utils/staggered_gauge_utils.cpp | 2 ++ tests/utils/staggered_host_utils.cpp | 2 ++ 4 files changed, 7 insertions(+) diff --git a/tests/utils/gauge_utils.cpp b/tests/utils/gauge_utils.cpp index 962dfb16ad..3b2dd007fd 100644 --- a/tests/utils/gauge_utils.cpp +++ b/tests/utils/gauge_utils.cpp @@ -433,6 +433,7 @@ template struct ApplyRandomU1Phase { auto gauge = reinterpret_cast(gauge_); for (int dir = 0; dir < 4; dir++) { +#pragma omp parallel for for (int i = 0; i < Vh; i++) { for (int parity = 0; parity < 2; parity++) { // create a random phase @@ -493,6 +494,7 @@ template struct ConstructRandomMatrixGaugeField { }; for (int dir = 0; dir < 4; dir++) { +#pragma omp parallel for for (int i = 0; i < Vh; i++) { for (int parity = 0; parity < 2; parity++) { real_t *link = gauge[dir] + (parity * Vh + i) * gauge_site_size; diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index 2a31748a7a..7b8807fcd1 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -181,6 +181,7 @@ void constructHostCloverField(void *clover, void *, QudaInvertParam &inv_param) template struct ConstructCloverField { void operator()(void *res, double norm, double diag) { +#pragma omp parallel for for (auto i = 0lu; i < static_cast(Vh); i++) { for (auto parity = 0lu; parity < 2lu; parity++) { auto clover_matrix = reinterpret_cast(res) + 72 * (parity * Vh + i); diff --git a/tests/utils/staggered_gauge_utils.cpp b/tests/utils/staggered_gauge_utils.cpp index c95dd87036..ab2a4c1346 100644 --- a/tests/utils/staggered_gauge_utils.cpp +++ b/tests/utils/staggered_gauge_utils.cpp @@ -61,6 +61,7 @@ void constructFatLongGaugeField(void *const *fatlink, void *const *longlink, Gau constructRandomGaugeField(longlink, param, precision, dslash_type); // incorporate non-trivial phase into long links for (int dir = 0; dir < 4; ++dir) { +#pragma omp parallel for for (int i = 0; i < Vh; ++i) { for (int parity = 0; parity < 2; parity++) { double phase = random_uniform_host(i, parity, 0, 2 * M_PI); @@ -93,6 +94,7 @@ void constructFatLongGaugeField(void *const *fatlink, void *const *longlink, Gau // incorporate non-trivial phase into long links for (int dir = 0; dir < 4; ++dir) { +#pragma omp parallel for for (int i = 0; i < Vh; ++i) { for (int parity = 0; parity < 2; parity++) { double phase = random_uniform_host(i, parity, 0, 2 * M_PI); diff --git a/tests/utils/staggered_host_utils.cpp b/tests/utils/staggered_host_utils.cpp index 1c5340661a..47884ca593 100644 --- a/tests/utils/staggered_host_utils.cpp +++ b/tests/utils/staggered_host_utils.cpp @@ -186,6 +186,7 @@ void computeTwoLinkCPU(void **twolink, su3_matrix **sitelinkEx) for (int dir = 0; dir < 4; ++dir) E[dir] = Z[dir] + 4; const int extended_volume = E[3] * E[2] * E[1] * E[0]; +#pragma omp parallel for for (int t = 0; t < Z[3]; ++t) { for (int z = 0; z < Z[2]; ++z) { for (int y = 0; y < Z[1]; ++y) { @@ -698,6 +699,7 @@ void constructStaggeredTestSpinorParam(quda::ColorSpinorParam *cs_param, const Q // data reordering routines template void reorderQDPtoMILC(Out *milc_out, In **qdp_in, int V, int siteSize) { +#pragma omp parallel for for (int i = 0; i < V; i++) { for (int dir = 0; dir < 4; dir++) { for (int j = 0; j < siteSize; j++) { From a2efb44f5d9008a039f07b2a7dc7e56adc7c4d89 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 30 Sep 2025 21:56:01 -0700 Subject: [PATCH 004/121] Fix for fine-grained accessor vector loads --- include/color_spinor_field_order.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 0420d349c3..008c25db8b 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -241,9 +241,9 @@ namespace quda constexpr int M = nSpinBlock * nColor * nVec; #pragma unroll for (int i = 0; i < M; i++) { - vec_t tmp - = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); - memcpy(&out[i], &tmp, sizeof(vec_t)); + auto tmp + = vector_load(reinterpret_cast(in + parity * offset_cb), x_cb * N + chi * M + i); + memcpy(&out[i], &tmp, sizeof(tmp)); } } }; From c815076f33b3dc137d7984318758bca6ea279681 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 30 Sep 2025 21:56:49 -0700 Subject: [PATCH 005/121] Add prefetching instructions for CUDA --- include/targets/cuda/inline_ptx.h | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/include/targets/cuda/inline_ptx.h b/include/targets/cuda/inline_ptx.h index 3bbccf2ab5..adf92c4720 100644 --- a/include/targets/cuda/inline_ptx.h +++ b/include/targets/cuda/inline_ptx.h @@ -476,4 +476,13 @@ namespace quda { asm("st.cs.global.v2.s16 [%0+0], {%1, %2};" :: __PTR(addr), "h"(x), "h"(y)); } + __device__ __forceinline__ void prefetch_L1(const void *p) { asm volatile("prefetch.global.L1 [%0];" ::"l"(p)); } + + __device__ __forceinline__ void prefetch_L2(const void *p) { asm volatile("prefetch.global.L2 [%0];" ::"l"(p)); } + + __device__ __forceinline__ void prefetch_tma(const void *p, size_t bytes) + { + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(p), "r"(static_cast(bytes))); + } + } // namespace quda From 177c18ba203354b25e99c4e60714726aea0d923e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Wed, 1 Oct 2025 09:55:19 -0700 Subject: [PATCH 006/121] Optimizaiton of neighbor indexing for dslash kernels: use bitwise instead of logic operations when computing the neighboring index; this is branch free and less operations --- include/dslash_helper.cuh | 59 ++++++++++++++++++- include/dslash_quda.h | 1 + include/index_helper.cuh | 35 +---------- include/kernels/dslash_staggered.cuh | 21 +++++-- .../dslash_twisted_mass_preconditioned.cuh | 4 +- include/kernels/dslash_wilson.cuh | 4 +- include/kernels/laplace.cuh | 4 +- lib/color_spinor_field.cpp | 1 + 8 files changed, 81 insertions(+), 48 deletions(-) diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 834b59425c..1714dd64d5 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -158,13 +158,68 @@ namespace quda #pragma unroll for (int d = 0; d < nDim; d++) { - coord.in_boundary[1][d] = coord[d] + arg.nFace >= arg.dc.X[d]; - coord.in_boundary[0][d] = coord[d] - arg.nFace < 0; + coord.in_boundary[1][d] = -(coord[d] + arg.nFace >= arg.dc.X[d]); + coord.in_boundary[0][d] = -(coord[d] - arg.nFace < 0); } return coord; } + /** + @brief Compute the checkerboard 1-d index for the nearest + neighbor + @param[in] lattice coordinates + @param[in] mu dimension in which to add 1 + @param[in] dir direction (+1 or -1) + @param[in] arg parameter struct + @return 1-d checkboard index + */ + template + __device__ __host__ inline int getNeighborIndexCB(const Coord &x, int mu, int dir, const Arg &arg) + { + switch (nFace) { + case 1: + switch (dir) { + case +1: // positive direction + switch (mu) { + case 0: return (x.X + 1 - (x.in_boundary[1][0] & arg.X[0])) >> 1; + case 1: return (x.X + arg.X[0] - (x.in_boundary[1][1] & arg.X2X1)) >> 1; + case 2: return (x.X + arg.X2X1 - (x.in_boundary[1][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X + arg.X3X2X1 - (x.in_boundary[1][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X + arg.X4X3X2X1 - (x.in_boundary[1][4] & arg.X5X4X3X2X1)) >> 1; + } + case -1: + switch (mu) { + case 0: return (x.X - 1 + (x.in_boundary[0][0] & arg.X[0])) >> 1; + case 1: return (x.X - arg.X[0] + (x.in_boundary[0][1] & arg.X2X1)) >> 1; + case 2: return (x.X - arg.X2X1 + (x.in_boundary[0][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X - arg.X3X2X1 + (x.in_boundary[0][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X - arg.X4X3X2X1 + (x.in_boundary[0][4] & arg.X5X4X3X2X1)) >> 1; + } + } + case 3: + switch (dir) { + case +1: // positive direction + switch (mu) { + case 0: return (x.X + 3 - (x.in_boundary[1][0] & arg.X[0])) >> 1; + case 1: return (x.X + 3 * arg.X[0] - (x.in_boundary[1][1] & arg.X2X1)) >> 1; + case 2: return (x.X + 3 * arg.X2X1 - (x.in_boundary[1][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X + 3 * arg.X3X2X1 - (x.in_boundary[1][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X + 3 * arg.X4X3X2X1 - (x.in_boundary[1][4] & arg.X5X4X3X2X1)) >> 1; + } + case -1: + switch (mu) { + case 0: return (x.X - 3 + (x.in_boundary[0][0] & arg.X[0])) >> 1; + case 1: return (x.X - 3 * arg.X[0] + (x.in_boundary[0][1] & arg.X2X1)) >> 1; + case 2: return (x.X - 3 * arg.X2X1 + (x.in_boundary[0][2] & arg.X3X2X1)) >> 1; + case 3: return (x.X - 3 * arg.X3X2X1 + (x.in_boundary[0][3] & arg.X4X3X2X1)) >> 1; + case 4: return (x.X - 3 * arg.X4X3X2X1 + (x.in_boundary[0][4] & arg.X5X4X3X2X1)) >> 1; + } + } + } + return 0; // should never reach here + } + /** @brief Compute whether this thread should be active for updating the a given offsetDim halo. For non-fused halo update kernels diff --git a/include/dslash_quda.h b/include/dslash_quda.h index f34a41de1a..5091e28674 100644 --- a/include/dslash_quda.h +++ b/include/dslash_quda.h @@ -35,6 +35,7 @@ namespace quda int X2X1; int X3X2X1; int X4X3X2X1; + int X5X4X3X2X1; int X2X1mX1; int X3X2X1mX2X1; diff --git a/include/index_helper.cuh b/include/index_helper.cuh index db58c0daed..5ea718aa8c 100644 --- a/include/index_helper.cuh +++ b/include/index_helper.cuh @@ -238,43 +238,10 @@ namespace quda { int X; // full lattice site index constexpr const int& operator[](int i) const { return x[i]; } constexpr int& operator[](int i) { return x[i]; } - array_2d in_boundary = {}; + array_2d in_boundary = {}; constexpr int size() const { return nDim; } }; - /** - @brief Compute the checkerboard 1-d index for the nearest - neighbor - @param[in] lattice coordinates - @param[in] mu dimension in which to add 1 - @param[in] dir direction (+1 or -1) - @param[in] arg parameter struct - @return 1-d checkboard index - */ - template - __device__ __host__ inline int getNeighborIndexCB(const Coord &x, int mu, int dir, const Arg &arg) - { - switch (dir) { - case +1: // positive direction - switch (mu) { - case 0: return (x.in_boundary[1][0] ? x.X - (arg.X[0] - 1) : x.X + 1) >> 1; - case 1: return (x.in_boundary[1][1] ? x.X - arg.X2X1mX1 : x.X + arg.X[0]) >> 1; - case 2: return (x.in_boundary[1][2] ? x.X - arg.X3X2X1mX2X1 : x.X + arg.X2X1) >> 1; - case 3: return (x.in_boundary[1][3] ? x.X - arg.X4X3X2X1mX3X2X1 : x.X + arg.X3X2X1) >> 1; - case 4: return (x.in_boundary[1][4] ? x.X - arg.X5X4X3X2X1mX4X3X2X1 : x.X + arg.X4X3X2X1) >> 1; - } - case -1: - switch (mu) { - case 0: return (x.in_boundary[0][0] ? x.X + (arg.X[0] - 1) : x.X - 1) >> 1; - case 1: return (x.in_boundary[0][1] ? x.X + arg.X2X1mX1 : x.X - arg.X[0]) >> 1; - case 2: return (x.in_boundary[0][2] ? x.X + arg.X3X2X1mX2X1 : x.X - arg.X2X1) >> 1; - case 3: return (x.in_boundary[0][3] ? x.X + arg.X4X3X2X1mX3X2X1 : x.X - arg.X3X2X1) >> 1; - case 4: return (x.in_boundary[0][4] ? x.X + arg.X5X4X3X2X1mX4X3X2X1 : x.X - arg.X4X3X2X1) >> 1; - } - } - return 0; // should never reach here - } - /** Compute the 4-d spatial index from the checkerboarded 1-d index at parity parity diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index ae46c6a900..27bd23d62f 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -90,6 +90,15 @@ namespace quda typedef Matrix, Arg::nColor> Link; const int their_spinor_parity = (arg.nParity == 2) ? 1 - parity : 0; + Coord coord1 = coord; + if constexpr (arg.improved) { // need to compute 1-hop in_boundary +#pragma unroll + for (int d = 0; d < 4; d++) { + coord1.in_boundary[1][d] = -(coord[d] + 1 >= arg.dc.X[d]); + coord1.in_boundary[0][d] = -(coord[d] - 1 < 0); + } + } + #pragma unroll for (int d = 0; d < 4; d++) { // loop over dimension @@ -105,7 +114,7 @@ namespace quda out[s] = mv_add(U, in, out[s]); } } else if (doBulk() && !ghost) { - const int fwd_idx = linkIndexP1(coord, arg.dc.X, d); + const int fwd_idx = getNeighborIndexCB<1>(coord1, d, 1, arg.dc); const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); #pragma unroll for (auto s = 0; s < n_src_tile; s++) { @@ -117,7 +126,7 @@ namespace quda // improved - forward direction if (arg.improved && arg.dd_in.doHopping(coord, d, +3)) { - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dc.X, d, arg.nFace); const Link L = arg.L(d, coord.x_cb, parity); @@ -128,7 +137,7 @@ namespace quda out[s] = mv_add(L, in, out[s]); } } else if (doBulk() && !ghost) { - const int fwd3_idx = linkIndexP3(coord, arg.dc.X, d); + const int fwd3_idx = getNeighborIndexCB<3>(coord, d, 1, arg.dc); const Link L = arg.L(d, coord.x_cb, parity); #pragma unroll for (auto s = 0; s < n_src_tile; s++) { @@ -153,7 +162,7 @@ namespace quda out[s] = mv_sub(conj(U), in, out[s]); } } else if (doBulk() && !ghost) { - const int back_idx = linkIndexM1(coord, arg.dc.X, d); + const int back_idx = getNeighborIndexCB<1>(coord1, d, -1, arg.dc); const int gauge_idx = back_idx; const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); @@ -167,7 +176,7 @@ namespace quda // improved - backward direction if (arg.improved && arg.dd_in.doHopping(coord, d, -3)) { - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity); @@ -178,7 +187,7 @@ namespace quda out[s] = mv_sub(conj(L), in, out[s]); } } else if (doBulk() && !ghost) { - const int back3_idx = linkIndexM3(coord, arg.dc.X, d); + const int back3_idx = getNeighborIndexCB<3>(coord, d, -1, arg.dc); const int gauge_idx = back3_idx; const Link L = arg.L(d, gauge_idx, 1 - parity); #pragma unroll diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 513a034acd..547385c75c 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -63,7 +63,7 @@ namespace quda if (arg.dd_in.doHopping(coord, d, +1)) { const int fwd_idx = getNeighborIndexCB(coord, d, +1, arg.dc); constexpr int proj_dir = dagger ? +1 : -1; - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours @@ -101,7 +101,7 @@ namespace quda const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc); const int gauge_idx = back_idx; constexpr int proj_dir = dagger ? -1 : +1; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 8b66ee83e6..0b1fb49fb8 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -102,7 +102,7 @@ namespace quda const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); constexpr int proj_dir = dagger ? +1 : -1; - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours @@ -131,7 +131,7 @@ namespace quda const int gauge_idx = (Arg::nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx); constexpr int proj_dir = dagger ? -1 : +1; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { // we need to compute the face index if we are updating a face that isn't ours diff --git a/include/kernels/laplace.cuh b/include/kernels/laplace.cuh index 9d66c2dee4..b45ac9774f 100644 --- a/include/kernels/laplace.cuh +++ b/include/kernels/laplace.cuh @@ -86,7 +86,7 @@ namespace quda if (d != dir) { if (arg.dd_in.doHopping(coord, d, +1)) { // Forward gather - compute fwd offset for vector fetch - const bool ghost = coord.in_boundary[1][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { @@ -111,7 +111,7 @@ namespace quda const int back_idx = linkIndexM1(coord, arg.dc.X, d); const int gauge_idx = back_idx; - const bool ghost = coord.in_boundary[0][d] && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { diff --git a/lib/color_spinor_field.cpp b/lib/color_spinor_field.cpp index 442ef13ab1..5288a52a97 100644 --- a/lib/color_spinor_field.cpp +++ b/lib/color_spinor_field.cpp @@ -384,6 +384,7 @@ namespace quda dc.X2X1 = X[1] * X[0]; dc.X3X2X1 = X[2] * X[1] * X[0]; dc.X4X3X2X1 = X[3] * X[2] * X[1] * X[0]; + dc.X5X4X3X2X1 = X[4] * X[3] * X[2] * X[1] * X[0]; dc.X2X1mX1 = (X[1] - 1) * X[0]; dc.X3X2X1mX2X1 = (X[2] - 1) * X[1] * X[0]; dc.X4X3X2X1mX3X2X1 = (X[3] - 1) * X[2] * X[1] * X[0]; From eae953d5c5725df215002d6975ff0f73e9e89be6 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Thu, 2 Oct 2025 23:38:47 -0700 Subject: [PATCH 007/121] Add support for creating a backward gauge field --- include/gauge_field.h | 11 +++++++ include/kernels/gauge_shift.cuh | 56 +++++++++++++++++++++++++++++++++ lib/CMakeLists.txt | 1 + lib/gauge_shift.cu | 41 ++++++++++++++++++++++++ 4 files changed, 109 insertions(+) create mode 100644 include/kernels/gauge_shift.cuh create mode 100644 lib/gauge_shift.cu diff --git a/include/gauge_field.h b/include/gauge_field.h index c355bd4818..c830938b58 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -669,6 +669,17 @@ namespace quda { */ void genericPrintMatrix(const GaugeField &a, int dim, int parity, unsigned int x_cb, int rank = 0); + /** + @brief Shift the gauge field by shift in each dimension and store + the resulting shifted field. This is used to move the backwards + links on to this site. The input field must be a padded field + with the ghost pre-exchanged if communications are enabled. + @param[out] out Output shifted field + @param[in] in Input shifted field + @param[in] shift value (1 or 3 supported) + */ + void shift(GaugeField &out, const GaugeField &in, int shift); + /** @brief This is a debugging function, where we cast a gauge field into a spinor field so we can compute its L1 norm. diff --git a/include/kernels/gauge_shift.cuh b/include/kernels/gauge_shift.cuh new file mode 100644 index 0000000000..abe369f439 --- /dev/null +++ b/include/kernels/gauge_shift.cuh @@ -0,0 +1,56 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace quda +{ + + template struct GaugeShiftArg : kernel_param<> { + using real = typename mapper::type; + using Link = Matrix, nColor>; + using Gauge = typename gauge_mapper::type; + + int X[4]; // true grid dimensions + Gauge out; + const Gauge in; + int shift; + + GaugeShiftArg(GaugeField &out, const GaugeField &in, int shift) : + kernel_param(dim3(in.VolumeCB(), 2, 4)), out(out), in(in), shift(shift) + { + for (int dir = 0; dir < 4; dir++) X[dir] = in.X()[dir]; + } + }; + + template struct GaugeShift { + const Arg &arg; + constexpr GaugeShift(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity, int dir) + { + using real = typename Arg::real; + using Link = typename Arg::Link; + + byte_array x = {}; + getCoords(x, x_cb, arg.X, parity); + + if (x[dir] < arg.shift && arg.comms_dim[dir]) { // on the boundary so we need to fetch from the ghost zone + const int ghost_idx = ghostFaceIndex<0, 4>(x, arg.X, dir, arg.shift); + Link U = arg.in.Ghost(dir, ghost_idx, 1 - parity); + arg.out(dir, x_cb, parity) = U; + } else { // simple shift + byte_array dx = {}; + dx[dir] = dx[dir] - arg.shift; + int x_cb_back = linkIndexShift(x, dx, arg.X); + Link U = arg.in(dir, x_cb_back, 1 - parity); + arg.out(dir, x_cb, parity) = U; + } + } + }; + +} // namespace quda diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 65f33a6772..9a614e09f5 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -28,6 +28,7 @@ set (QUDA_OBJS gauge_phase.cu timer.cpp solver.cpp inv_bicgstab_quda.cpp inv_cg_quda.cpp inv_bicgstabl_quda.cpp inv_multi_cg_quda.cpp inv_eigcg_quda.cpp gauge_ape.cu + gauge_shift.cu gauge_stout.cu gauge_hyp.cu gauge_wilson_flow.cu gauge_plaq.cu gauge_plaqrect.cu gauge_laplace.cpp gauge_observable.cpp inv_cgnr.cpp inv_cgne.cpp diff --git a/lib/gauge_shift.cu b/lib/gauge_shift.cu new file mode 100644 index 0000000000..8339508eb3 --- /dev/null +++ b/lib/gauge_shift.cu @@ -0,0 +1,41 @@ +#include +#include +#include +#include + +namespace quda +{ + + template class GaugeShifter : public TunableKernel3D + { + GaugeField &out; + const GaugeField ∈ + int shift; + unsigned int minThreads() const { return in.VolumeCB(); } + + public: + GaugeShifter(GaugeField &out, const GaugeField &in, int shift) : + TunableKernel3D(in, 2, 4), out(out), in(in), shift(shift) + { + assert(shift == 1 || shift == 3); + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } + + long long bytes() const { return out.Bytes() + in.Bytes(); } + }; + + void shift(GaugeField &out, const GaugeField &in, int shift) + { + getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + instantiate(out, in, shift); + getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + } + +} // namespace quda From 2540a1bdc785d3aa4ef00b3962d545d1b74dde1e Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 6 Oct 2025 21:44:49 -0700 Subject: [PATCH 008/121] Some small improvedments to shift(GaugeField) function --- include/gauge_field.h | 2 +- lib/gauge_shift.cu | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/gauge_field.h b/include/gauge_field.h index c830938b58..a4f3d0b590 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -678,7 +678,7 @@ namespace quda { @param[in] in Input shifted field @param[in] shift value (1 or 3 supported) */ - void shift(GaugeField &out, const GaugeField &in, int shift); + GaugeField shift(const GaugeField &in, int shift); /** @brief This is a debugging function, where we cast a gauge field diff --git a/lib/gauge_shift.cu b/lib/gauge_shift.cu index 8339508eb3..652ef8aa50 100644 --- a/lib/gauge_shift.cu +++ b/lib/gauge_shift.cu @@ -31,11 +31,19 @@ namespace quda long long bytes() const { return out.Bytes() + in.Bytes(); } }; - void shift(GaugeField &out, const GaugeField &in, int shift) + GaugeField shift(const GaugeField &in, int shift) { getProfile().TPSTART(QUDA_PROFILE_COMPUTE); + if (in.GhostExchange() == QUDA_GHOST_EXCHANGE_EXTENDED) + errorQuda("Extended ghost exchange not supported"); + if (in.GhostExchange() == QUDA_GHOST_EXCHANGE_NO && comm_partitioned()) + errorQuda("comm_dim_partition() == true requires we have GhostExchange = QUDA_GHOST_EXCHANGE_PAD"); + GaugeFieldParam param(in); + param.create = QUDA_NULL_FIELD_CREATE; + GaugeField out(param); instantiate(out, in, shift); getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); + return out; } } // namespace quda From e686437443b89e75c49508cfa152c7d07e693cfb Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 6 Oct 2025 22:06:56 -0700 Subject: [PATCH 009/121] Gauge shift should encode shift value in aux_string --- lib/gauge_shift.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/gauge_shift.cu b/lib/gauge_shift.cu index 652ef8aa50..6a81de0246 100644 --- a/lib/gauge_shift.cu +++ b/lib/gauge_shift.cu @@ -18,6 +18,10 @@ namespace quda TunableKernel3D(in, 2, 4), out(out), in(in), shift(shift) { assert(shift == 1 || shift == 3); + strcat(aux, ",shift="); + char shift_str[16]; + u32toa(shift_str, shift); + strcat(aux, shift_str); apply(device::get_default_stream()); } From 676c643e73a82c293d298e6367deeed97400d879 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 6 Oct 2025 22:35:56 -0700 Subject: [PATCH 010/121] Add support for experimental double storage of gauge fields - disabled by default --- include/dslash.h | 6 ++++++ include/kernels/dslash_staggered.cuh | 29 ++++++++++++++++++++++++---- include/kernels/dslash_wilson.cuh | 18 +++++++++++++++-- lib/dslash_improved_staggered.hpp | 12 ++++++++++-- lib/dslash_staggered.hpp | 9 +++++++-- lib/dslash_wilson.hpp | 9 ++++++++- 6 files changed, 72 insertions(+), 11 deletions(-) diff --git a/include/dslash.h b/include/dslash.h index 8feb23d893..3e0906810d 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -9,6 +9,9 @@ #include #include +// enable experimental double store of gauge fields +//#define QUDA_DSLASH_DOUBLE_STORE + namespace quda { @@ -70,6 +73,9 @@ namespace quda char tile_str[16]; i32toa(tile_str, Arg::n_src_tile); strcat(aux_base, tile_str); +#ifdef QUDA_DSLASH_DOUBLE_STORE + strcat(aux_base, ",double_store"); +#endif } /** diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index 27bd23d62f..ebb55b9fff 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -43,7 +43,9 @@ namespace quda const Ghost halo; /** accessor for reading the halo */ F x[MAX_MULTI_RHS]; /** input vector when doing xpay */ const GU U; /** the gauge field */ + const GU Uback; /** the gauge field */ const GL L; /** the long gauge field */ + const GL Lback; /** the long gauge field */ const real a; /** xpay scale factor */ const real tboundary; /** temporal boundary condition */ @@ -54,13 +56,14 @@ namespace quda const real dagger_scale; StaggeredArg(cvector_ref &out, cvector_ref &in, - const ColorSpinorField &halo, const GaugeField &U, const GaugeField &L, double a, - cvector_ref &x, int parity, bool dagger, const int *comm_override) : + const ColorSpinorField &halo, const GaugeField &U, const GaugeField &Uback, const GaugeField &L, + const GaugeField &Lback, double a, cvector_ref &x, int parity, bool dagger, + const int *comm_override) : DslashArg < Float, nDim, DDArg, improved ? 3 : 1, n_src_tile > (out, in, halo, U, x, parity, dagger, a == 0.0 ? false : true, spin_project, comm_override), - halo_pack(halo, improved_ ? 3 : 1), halo(halo, improved_ ? 3 : 1), U(U), L(L), a(a), tboundary(U.TBoundary()), - is_first_time_slice(comm_coord(3) == 0 ? true : false), + halo_pack(halo, improved_ ? 3 : 1), halo(halo, improved_ ? 3 : 1), U(U), Uback(Uback), L(L), Lback(Lback), a(a), + tboundary(U.TBoundary()), is_first_time_slice(comm_coord(3) == 0 ? true : false), is_last_time_slice(comm_coord(3) == comm_dim(3) - 1 ? true : false), dagger_scale(dagger ? static_cast(-1.0) : static_cast(1.0)) { @@ -154,8 +157,13 @@ namespace quda if (doHalo(d) && ghost) { const int ghost_idx2 = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); const int ghost_idx = arg.improved ? ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 3) : ghost_idx2; +#ifdef QUDA_DSLASH_DOUBLE_STORE + const Link U = arg.improved ? arg.Uback(d, coord.x_cb, parity) : + arg.Uback(d, coord.x_cb, parity, StaggeredPhase(coord, d, -1, arg)); +#else const Link U = arg.improved ? arg.U.Ghost(d, ghost_idx2, 1 - parity) : arg.U.Ghost(d, ghost_idx2, 1 - parity, StaggeredPhase(coord, d, -1, arg)); +#endif #pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); @@ -163,9 +171,14 @@ namespace quda } } else if (doBulk() && !ghost) { const int back_idx = getNeighborIndexCB<1>(coord1, d, -1, arg.dc); +#ifdef QUDA_DSLASH_DOUBLE_STORE + const Link U = arg.improved ? arg.Uback(d, coord.x_cb, parity) : + arg.Uback(d, coord.x_cb, parity, StaggeredPhase(coord, d, -1, arg)); +#else const int gauge_idx = back_idx; const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); +#endif #pragma unroll for (auto s = 0; s < n_src_tile; s++) { Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); @@ -179,7 +192,11 @@ namespace quda const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); +#ifdef QUDA_DSLASH_DOUBLE_STORE + const Link L = arg.Lback(d, coord.x_cb, parity); +#else const Link L = arg.L.Ghost(d, ghost_idx, 1 - parity); +#endif #pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in @@ -188,8 +205,12 @@ namespace quda } } else if (doBulk() && !ghost) { const int back3_idx = getNeighborIndexCB<3>(coord, d, -1, arg.dc); +#ifdef QUDA_DSLASH_DOUBLE_STORE + const Link L = arg.Lback(d, coord.x_cb, parity); +#else const int gauge_idx = back3_idx; const Link L = arg.L(d, gauge_idx, 1 - parity); +#endif #pragma unroll for (auto s = 0; s < n_src_tile; s++) { const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 0b1fb49fb8..04aa4f50fe 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -38,19 +38,21 @@ namespace quda Ghost halo_pack; Ghost halo; const G U; /** the gauge field */ + const G Uback; /** the backwards gauge field */ const real a; /** xpay scale factor - can be -kappa or -kappa^2 */ /** parameters for distance preconditioning */ const real alpha0; const int t0; WilsonArg(cvector_ref &out, cvector_ref &in, const ColorSpinorField &halo, - const GaugeField &U, double a, cvector_ref &x, int parity, bool dagger, - const int *comm_override, double alpha0 = 0.0, int t0 = -1) : + const GaugeField &U, const GaugeField &Uback, double a, cvector_ref &x, + int parity, bool dagger, const int *comm_override, double alpha0 = 0.0, int t0 = -1) : DslashArg(out, in, halo, U, x, parity, dagger, a != 0.0 ? true : false, spin_project, comm_override), halo_pack(halo), halo(halo), U(U), + Uback(Uback), a(a), alpha0(alpha0), t0(t0) @@ -128,7 +130,11 @@ namespace quda if (arg.dd_in.doHopping(coord, d, -1)) { const real bwd_coeff = (d < 3) ? 1.0 : bwd_coeff_3; const int back_idx = getNeighborIndexCB(coord, d, -1, arg.dc); +#ifdef QUDA_DSLASH_DOUBLE_STORE + const int gauge_idx = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); +#else const int gauge_idx = (Arg::nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx); +#endif constexpr int proj_dir = dagger ? -1 : +1; const bool ghost = coord.in_boundary[0][d] & isActive(active, thread_dim, d, coord, arg); @@ -140,14 +146,22 @@ namespace quda idx; const int gauge_ghost_idx = (Arg::nDim == 5 ? ghost_idx % arg.dc.ghostFaceCB[d] : ghost_idx); +#ifdef QUDA_DSLASH_DOUBLE_STORE + Link U = arg.Uback(d, gauge_idx, gauge_parity); +#else Link U = arg.U.Ghost(d, gauge_ghost_idx, 1 - gauge_parity); +#endif HalfVector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx * arg.Ls + coord.s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out += bwd_coeff * (conj(U) * in).reconstruct(d, proj_dir); } else if (doBulk() && !ghost) { +#ifdef QUDA_DSLASH_DOUBLE_STORE + Link U = arg.Uback(d, gauge_idx, gauge_parity); +#else Link U = arg.U(d, gauge_idx, 1 - gauge_parity); +#endif Vector in = arg.in[src_idx](back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); out += bwd_coeff * (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); diff --git a/lib/dslash_improved_staggered.hpp b/lib/dslash_improved_staggered.hpp index 60588fd75e..07b6396ee8 100644 --- a/lib/dslash_improved_staggered.hpp +++ b/lib/dslash_improved_staggered.hpp @@ -153,8 +153,16 @@ namespace quda constexpr bool improved = true; constexpr QudaReconstructType recon_u = QUDA_RECONSTRUCT_NO; auto halo = ColorSpinorField::create_comms_batch(in, 3); - StaggeredArg arg(out, in, halo, U, L, a, x, parity, - dagger, comm_override); + +#ifdef QUDA_DSLASH_DOUBLE_STORE + GaugeField Uback = shift(U, 1); + GaugeField Lback = shift(L, 3); +#else + const GaugeField &Uback = U; + const GaugeField &Lback = L; +#endif + StaggeredArg arg(out, in, halo, U, Uback, L, Lback, a, x, + parity, dagger, comm_override); Staggered staggered(arg, out, in, halo, L); dslash::DslashPolicyTune policy(staggered, in, halo, profile); } diff --git a/lib/dslash_staggered.hpp b/lib/dslash_staggered.hpp index 51a15c9ae4..874fe731e5 100644 --- a/lib/dslash_staggered.hpp +++ b/lib/dslash_staggered.hpp @@ -49,12 +49,17 @@ namespace quda constexpr int nDim = 4; constexpr bool improved = false; auto halo = ColorSpinorField::create_comms_batch(in); +#ifdef QUDA_DSLASH_DOUBLE_STORE + GaugeField Uback = shift(U, 1); +#else + const GaugeField &Uback = shift(U, 1); +#endif if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_MILC || (U.LinkType() == QUDA_GENERAL_LINKS && U.Reconstruct() == QUDA_RECONSTRUCT_NO)) { if constexpr (is_enabled()) { StaggeredArg arg( - out, in, halo, U, U, a, x, parity, dagger, comm_override); + out, in, halo, U, Uback, U, Uback, a, x, parity, dagger, comm_override); Staggered staggered(arg, out, in, halo); dslash::DslashPolicyTune policy(staggered, in, halo, profile); @@ -64,7 +69,7 @@ namespace quda } else if (U.StaggeredPhase() == QUDA_STAGGERED_PHASE_TIFR) { if constexpr (is_enabled()) { StaggeredArg arg( - out, in, halo, U, U, a, x, parity, dagger, comm_override); + out, in, halo, U, Uback, U, Uback, a, x, parity, dagger, comm_override); Staggered staggered(arg, out, in, halo); dslash::DslashPolicyTune policy(staggered, in, halo, profile); diff --git a/lib/dslash_wilson.hpp b/lib/dslash_wilson.hpp index 80086142e4..c1fc823d3e 100644 --- a/lib/dslash_wilson.hpp +++ b/lib/dslash_wilson.hpp @@ -43,7 +43,14 @@ namespace quda { constexpr int nDim = 4; auto halo = ColorSpinorField::create_comms_batch(in); - WilsonArg arg(out, in, halo, U, a, x, parity, dagger, + +#ifdef QUDA_DSLASH_DOUBLE_STORE + GaugeField Uback = shift(U, 1); +#else + const GaugeField &Uback = U; +#endif + + WilsonArg arg(out, in, halo, U, Uback, a, x, parity, dagger, comm_override, alpha0, t0); Wilson wilson(arg, out, in, halo); dslash::DslashPolicyTune policy(wilson, in, halo, profile); From 9c2025b8cb5cec48ee3109c57a3380daa2dd857b Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 20 Oct 2025 16:09:01 -0700 Subject: [PATCH 011/121] Fix some issues with gauge shift: fix single-GPU builds and add half/quarter precision support --- include/kernels/gauge_shift.cuh | 2 +- lib/gauge_shift.cu | 22 ++++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/include/kernels/gauge_shift.cuh b/include/kernels/gauge_shift.cuh index abe369f439..dced98e50c 100644 --- a/include/kernels/gauge_shift.cuh +++ b/include/kernels/gauge_shift.cuh @@ -39,7 +39,7 @@ namespace quda byte_array x = {}; getCoords(x, x_cb, arg.X, parity); - if (x[dir] < arg.shift && arg.comms_dim[dir]) { // on the boundary so we need to fetch from the ghost zone + if (x[dir] < arg.shift && arg.comms_dim[dir] > 1) { // on the boundary so we need to fetch from the ghost zone const int ghost_idx = ghostFaceIndex<0, 4>(x, arg.X, dir, arg.shift); Link U = arg.in.Ghost(dir, ghost_idx, 1 - parity); arg.out(dir, x_cb, parity) = U; diff --git a/lib/gauge_shift.cu b/lib/gauge_shift.cu index 6a81de0246..cc5997e4ee 100644 --- a/lib/gauge_shift.cu +++ b/lib/gauge_shift.cu @@ -6,7 +6,7 @@ namespace quda { - template class GaugeShifter : public TunableKernel3D + template class GaugeShifter : public TunableKernel3D { GaugeField &out; const GaugeField ∈ @@ -28,8 +28,22 @@ namespace quda void apply(const qudaStream_t &stream) { TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); - GaugeShiftArg arg(out, in, shift); - launch(tp, stream, arg); + if (in.Reconstruct() == QUDA_RECONSTRUCT_NO) { + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } else if (in.Reconstruct() == QUDA_RECONSTRUCT_13) { + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } else if (in.Reconstruct() == QUDA_RECONSTRUCT_12) { + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } else if (in.Reconstruct() == QUDA_RECONSTRUCT_9) { + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } else if (in.Reconstruct() == QUDA_RECONSTRUCT_8) { + GaugeShiftArg arg(out, in, shift); + launch(tp, stream, arg); + } } long long bytes() const { return out.Bytes() + in.Bytes(); } @@ -45,7 +59,7 @@ namespace quda GaugeFieldParam param(in); param.create = QUDA_NULL_FIELD_CREATE; GaugeField out(param); - instantiate(out, in, shift); + instantiate(out, in, shift); getProfile().TPSTOP(QUDA_PROFILE_COMPUTE); return out; } From 721fbd523d6a75b5326c5e028c0814a729159aaf Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 20 Oct 2025 16:29:52 -0700 Subject: [PATCH 012/121] make doBulk and doHalo constexpr --- include/dslash_helper.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 1714dd64d5..0a207e376e 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -26,7 +26,7 @@ namespace quda @param[in] dim Dimension we are working on. If dim=-1 (default argument) then we return true if type is any halo kernel. */ - template __host__ __device__ __forceinline__ bool doHalo(int dim = -1) + template __host__ __device__ __forceinline__ constexpr bool doHalo(int dim = -1) { switch (type) { case EXTERIOR_KERNEL_ALL: return true; @@ -44,7 +44,7 @@ namespace quda computation @param[in] dim Dimension we are working on */ - template __host__ __device__ __forceinline__ bool doBulk() + template __host__ __device__ __forceinline__ constexpr bool doBulk() { switch (type) { case EXTERIOR_KERNEL_ALL: From 02a4cb9cd3e65fd37fee0748f44fe272615ad07f Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Mon, 20 Oct 2025 17:30:00 -0700 Subject: [PATCH 013/121] Add target::is_thread_zero and target::is_lane_zero helper functions for executing single-thread regions of code. On CUDA install the latest version of CCCL via CPM since we need some new features --- include/targets/cuda/target_device.h | 51 ++++++++++++++++++++++++++++ include/targets/hip/target_device.h | 10 ++++++ lib/targets/cuda/target_cuda.cmake | 7 ++++ 3 files changed, 68 insertions(+) diff --git a/include/targets/cuda/target_device.h b/include/targets/cuda/target_device.h index ee7c646172..ac3b47dfdc 100644 --- a/include/targets/cuda/target_device.h +++ b/include/targets/cuda/target_device.h @@ -7,6 +7,8 @@ #include #endif +#include + #if defined(__CUDACC__) || defined(_NVHPC_CUDA) || (defined(__clang__) && defined(__CUDA__)) #define QUDA_CUDA_CC #endif @@ -171,6 +173,55 @@ namespace quda } } + template struct is_thread_zero_impl { + template bool operator()(const T &) { return true; } + }; + +#ifdef QUDA_CUDA_CC + template <> struct is_thread_zero_impl { + template __device__ bool operator()(const T &) + { + unsigned int tid = thread_idx_linear(); + unsigned int warp_id = tid / 32; + unsigned int uniform_warp_id = __shfl_sync(0xFFFFFFFF, warp_id, 0); // Broadcast from lane 0 + // unsigned int uniform_warp_id = __reduce_min_sync(~0, warp_id == 0); perhaps faster on sm_100 + return (uniform_warp_id == 0 && cuda::ptx::elect_sync(0xFFFFFFFF)); + } + }; +#endif + + /** + @brief Return true only for a single thread in a thread block. + This function assumes all warps in the thread block are + converged. Note that the single thread that is returned is not + necessarily thread 0 in the thread block. + @tparam dim The dimension of the thread block + @return true for a single thread in the thread block, else + false + */ + template __device__ __host__ inline bool is_thread_zero() + { + return target::dispatch(std::integral_constant()); + } + + template struct is_lane_zero_impl { + bool operator()() { return true; } + }; +#ifdef QUDA_CUDA_CC + template <> struct is_lane_zero_impl { + __device__ bool operator()() { return cuda::ptx::elect_sync(0xFFFFFFFF); } + }; +#endif + + /** + @brief Return true only for a single lane in a warp. + This function assumes the warp is converged. + Note that the single thread that is returned is not + necessarily lane 0 in the warp. + @return true for a single lane in the warp, else false + */ + __device__ __host__ inline bool is_lane_zero() { return target::dispatch(); } + template constexpr bool vectorize() { #ifdef QUDA_VECTORIZE_SINGLE diff --git a/include/targets/hip/target_device.h b/include/targets/hip/target_device.h index 897c9bdae1..4075604cf2 100644 --- a/include/targets/hip/target_device.h +++ b/include/targets/hip/target_device.h @@ -135,6 +135,16 @@ namespace quda } } + template __device__ __host__ inline bool is_thread_zero() + { + return thread_idx_linear() == 0; + } + + template __device__ __host__ inline bool is_lane_zero() + { + return (thread_idx_linear<3>() % 64) == 0; // switch this to warp_size + } + } // namespace target namespace device diff --git a/lib/targets/cuda/target_cuda.cmake b/lib/targets/cuda/target_cuda.cmake index 0c5fcb46a3..db6d52ed0a 100644 --- a/lib/targets/cuda/target_cuda.cmake +++ b/lib/targets/cuda/target_cuda.cmake @@ -419,6 +419,13 @@ if(CUDAToolkit_FOUND) target_link_libraries(quda INTERFACE CUDA::cudart_static) endif() +CPMAddPackage( + NAME CCCL + GITHUB_REPOSITORY nvidia/cccl + GIT_TAG main # Fetches the latest commit on the main branch +) +target_link_libraries(quda PRIVATE CCCL::CCCL) + # nvshmem enabled parts need SEPARABLE_COMPILATION ... if(QUDA_NVSHMEM) list(APPEND QUDA_DSLASH_OBJS dslash_constant_arg.cu) From 33b5f2f739fc477c45155770744e649f321f6ade Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 21 Oct 2025 12:48:57 -0700 Subject: [PATCH 014/121] Expose prefetching instructions --- include/targets/cuda/load_store.h | 18 ++++++++++++++++++ include/targets/generic/load_store.h | 15 +++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/include/targets/cuda/load_store.h b/include/targets/cuda/load_store.h index 29b2e50be3..161c93cbe5 100644 --- a/include/targets/cuda/load_store.h +++ b/include/targets/cuda/load_store.h @@ -156,6 +156,24 @@ namespace quda } }; + // pre-declaration of the prefetch_cache that we wish to specialize + template struct prefetch_cache_line_imp; + + // CUDA specialization of the prefetch_cache that uses inline ptx + template <> struct prefetch_cache_line_imp { + __device__ inline void operator()(const void *p) { prefetch_L2(p); } + }; + + // pre-declaration of the prefetch_cache that we wish to specialize + template struct prefetch_cache_bulk_imp; + +#if __COMPUTE_CAPABILITY__ >= 900 + // CUDA specialization of the prefetch_cache_bulk that uses TMA (requires Hopper+) + template <> struct prefetch_cache_bulk_imp { + __device__ inline void operator()(const void *p, size_t bytes) { prefetch_tma(p, bytes); } + }; +#endif + } // namespace quda #include "../generic/load_store.h" diff --git a/include/targets/generic/load_store.h b/include/targets/generic/load_store.h index 93b847a4db..8254509e74 100644 --- a/include/targets/generic/load_store.h +++ b/include/targets/generic/load_store.h @@ -64,4 +64,19 @@ namespace quda vector_store(ptr, idx, value_v); } + template struct prefetch_cache_line_imp { + __device__ __host__ inline void operator()(const void *) { } + }; + + __device__ __host__ inline void prefetch_cache_line(const void *p) { target::dispatch(p); } + + template struct prefetch_cache_bulk_imp { + __device__ __host__ inline void operator()(const void *, size_t) { } + }; + + __device__ __host__ inline void prefetch_cache_bulk(const void *p, size_t bytes) + { + target::dispatch(p, bytes); + } + } // namespace quda From ccf7a552a7d12ebf33cc5066f76998584946c158 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 21 Oct 2025 12:52:51 -0700 Subject: [PATCH 015/121] Add prefetching support to gauge and colorspinor fields --- include/color_spinor_field_order.h | 15 ++++++++++++++ include/gauge_field_order.h | 33 ++++++++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/include/color_spinor_field_order.h b/include/color_spinor_field_order.h index 008c25db8b..a227e5a829 100644 --- a/include/color_spinor_field_order.h +++ b/include/color_spinor_field_order.h @@ -1181,6 +1181,21 @@ namespace quda for (int i = 0; i < length / 2; i++) out[i] = complex(v[2 * i + 0], v[2 * i + 1]); } + __device__ __host__ inline void prefetch(int x, int parity = 0) const + { +#ifndef LEGACY_ACCESSOR_NORM + auto norm_offset = offset / (sizeof(Float) < sizeof(float) ? sizeof(norm_type) / sizeof(Float) : 1); + auto norm = reinterpret_cast(field + volumeCB * (2 * Nc * Ns)); +#endif + if constexpr (isFixed::value) prefetch_cache_line(norm + x + parity * norm_offset); + +#pragma unroll + for (int i = 0; i < M; i++) prefetch_cache_line(field + parity * offset + (volumeCB * i + x) * N); + + // now load any remainder + if constexpr (Nrem > 0) prefetch_cache_line(field + parity * offset + volumeCB * M * N + x * Nrem); + } + __device__ __host__ inline void save(const complex in[length / 2], int x, int parity = 0) const { real v[length]; diff --git a/include/gauge_field_order.h b/include/gauge_field_order.h index 4561f1f21f..7f8dc28ac5 100644 --- a/include/gauge_field_order.h +++ b/include/gauge_field_order.h @@ -1632,6 +1632,39 @@ namespace quda { reconstruct.Unpack(v, tmp, x, dir, phase, X, R); } + __device__ inline void prefetch(int x, int dir, int parity) const + { +#pragma unroll + for (int i = 0; i < M; i++) + prefetch_cache_line(gauge + parity * offset + dir * (M * N + Nrem) * stride + (i * stride + x) * N); + + // now load any remainder + if constexpr (Nrem > 0) + prefetch_cache_line(gauge + parity * offset + (dir * (M * N + Nrem) + M * N) * stride + x * Nrem); + + constexpr bool load_phase = (hasPhase && !(static_phase() && (reconLen == 13 || use_inphase))); + if constexpr (load_phase) prefetch_cache_line(gauge + parity * offset + phaseOffset + stride * dir + x); + } + + __device__ inline void prefetch_bulk(int x, int dir, int parity, int block_size) const + { + if (target::is_thread_zero()) { +#pragma unroll + for (int i = 0; i < M; i++) + prefetch_cache_bulk(gauge + parity * offset + dir * (M * N + Nrem) * stride + (i * stride + x) * N, + block_size * N * sizeof(Float)); + + // now load any remainder + if constexpr (Nrem > 0) + prefetch_cache_bulk(gauge + parity * offset + (dir * (M * N + Nrem) + M * N) * stride + x * Nrem, + block_size * Nrem * sizeof(Float)); + + constexpr bool load_phase = (hasPhase && !(static_phase() && (reconLen == 13 || use_inphase))); + if constexpr (load_phase) + prefetch_cache_bulk(gauge + parity * offset + phaseOffset + stride * dir, block_size * sizeof(Float)); + } + } + __device__ __host__ inline void save(const complex v[length / 2], int x, int dir, int parity) const { real tmp[reconLen]; From 0642f638bb761ed991dff43664d731c93a76fa41 Mon Sep 17 00:00:00 2001 From: maddyscientist Date: Tue, 21 Oct 2025 14:50:34 -0700 Subject: [PATCH 016/121] Add L2 gauge-field prefetching support to both Wilson and staggered dslash kernels. Disabled by default (set with with Arg::prefetch_distance parameter), and TMA prefetch will be added in next push --- include/dslash.h | 5 ++ include/dslash_helper.cuh | 4 + include/index_helper.cuh | 1 + include/kernel_helper.h | 2 + include/kernels/dslash_staggered.cuh | 124 +++++++++++++++++++------- include/kernels/dslash_wilson.cuh | 62 +++++++++++-- include/targets/cuda/tunable_kernel.h | 8 +- include/targets/hip/tunable_kernel.h | 13 ++- 8 files changed, 176 insertions(+), 43 deletions(-) diff --git a/include/dslash.h b/include/dslash.h index 3e0906810d..d34d83d42a 100644 --- a/include/dslash.h +++ b/include/dslash.h @@ -76,6 +76,11 @@ namespace quda #ifdef QUDA_DSLASH_DOUBLE_STORE strcat(aux_base, ",double_store"); #endif + if constexpr (Arg::prefetch_distance > 0) { + strcat(aux_base, ",prefetch="); + i32toa(tile_str, Arg::prefetch_distance); + strcat(aux_base, tile_str); + } } /** diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 0a207e376e..314c27c43c 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -109,6 +109,7 @@ namespace quda if (kernel_type == INTERIOR_KERNEL) { coord.x_cb = idx; + coord.x_cb_0 = (target::block_idx().x - arg.pack_blocks) * target::block_dim().x; if (nDim == 5) coord.X = getCoords5CB(coord, idx, arg.dc.X, arg.X0h, parity, pc_type); else @@ -298,6 +299,7 @@ namespace quda static constexpr int n_src_tile = n_src_tile_; // how many RHS per thread static constexpr int max_regs = 0; // by default we don't limit register count static constexpr bool spill_shared = false; // whether a given kernel should use shared memory spilling + static constexpr int prefetch_distance = 0; // whether we are using prefetching in the dslash const int parity; // only use this for single parity fields const int nParity; // number of parities we're working on @@ -340,6 +342,7 @@ namespace quda int pack_blocks = 0; // total number of blocks used for packing in the dslash int exterior_dims = 0; // dimension to run in the exterior Dslash int exterior_blocks = 0; + int block_size = 0; DDArg dd_out; DDArg dd_in; @@ -707,6 +710,7 @@ namespace quda static constexpr KernelType kernel_type = kernel_type_; static constexpr int max_regs = Arg::max_regs; static constexpr bool spill_shared = Arg::spill_shared; + static constexpr bool is_dslash = true; Arg arg; dslash_functor_arg(const Arg &arg, unsigned int threads_x) : diff --git a/include/index_helper.cuh b/include/index_helper.cuh index 5ea718aa8c..c27215ce4e 100644 --- a/include/index_helper.cuh +++ b/include/index_helper.cuh @@ -234,6 +234,7 @@ namespace quda { array gx = {}; // nDim global lattice coordinates array gDim = {}; // global lattice dimensions int x_cb; // checkerboard lattice site index + int x_cb_0; // value of x_cb on first thread in block int s; // fifth dimension coord int X; // full lattice site index constexpr const int& operator[](int i) const { return x[i]; } diff --git a/include/kernel_helper.h b/include/kernel_helper.h index 14727c327a..bf8fd17d2a 100644 --- a/include/kernel_helper.h +++ b/include/kernel_helper.h @@ -19,7 +19,9 @@ namespace quda static constexpr bool check_bounds = check_bounds_; static constexpr int max_regs = 0; // by default we don't limit register count static constexpr bool spill_shared = false; // whether a given kernel should use shared memory spilling + static constexpr bool is_dslash = false; // whether the arg is for a dslash (with its nested arg struct) dim3 threads; /** number of active threads required */ + int block_size; /** product of thread block dimensions */ int comms_rank; /** per process value of comm_rank() */ int comms_rank_global; /** per process value comm_rank_global() */ int comms_coord[4]; /** array storing {comm_coord(0), ..., comm_coord(3)} */ diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index ebb55b9fff..fd383c0d3f 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -52,6 +52,7 @@ namespace quda const bool is_first_time_slice; /** are we on the first (global) time slice */ const bool is_last_time_slice; /** are we on the last (global) time slice */ static constexpr bool improved = improved_; + static constexpr int prefetch_distance = 0; const real dagger_scale; @@ -75,6 +76,43 @@ namespace quda } }; + /** + @brief Prefetch the gauge field into cache. + @param[in] dim The dimension we are presently working on + @param[in] dir The direction we are presently working on (1 = forwards, 0 = backwards) + @param[in] hop The hopping term we are presently working on (0 = 1 - hop, 1 = 3 - hop) + @param[in] coord Coordinates that we are working on with hop-3 boundary conditions evaluated + @param[in] coord1 Copy of coordinates that we are working on with hop-1 boundary conditions evaluated + @param[in] parity Partiry that we are working on + @param[in] arg Paramter struct + */ + template + __device__ __host__ void prefetch(int dim, int dir, int hop, const coord_t &coord, const coord_t &coord1, int parity, + const Arg &arg) + { + if constexpr (arg.prefetch_distance == 0) return; + + if constexpr (arg.improved) { + int step = 4 * dim + 2 * dir + hop + arg.prefetch_distance; + if (step >= 16) return; + + // for TMA use arg.block_size and coord.x_cb_0 + // also should have warp uniform parity + int dim2 = step / 4; + switch (step % 4) { + case 0: arg.U.prefetch(coord.x_cb, dim2, parity); break; + case 1: arg.L.prefetch(coord.x_cb, dim2, parity); break; +#ifdef QUDA_DSLASH_DOUBLE_STORE + case 2: arg.Uback.prefetch(coord.x_cb, dim2, parity); break; + case 3: arg.Lback.prefetch(coord.x_cb, dim2, parity); break; +#else + case 2: arg.U.prefetch(getNeighborIndexCB<1>(coord1, dim2, -1, arg.dc), dim2, 1 - parity); break; + case 3: arg.L.prefetch(getNeighborIndexCB<3>(coord, dim2, -1, arg.dc), dim2, 1 - parity); break; +#endif + } + } + } + /** @brief Applies the off-diagonal part of the Staggered / Asqtad operator. @@ -107,7 +145,8 @@ namespace quda // standard - forward direction if (arg.dd_in.doHopping(coord, d, +1)) { - const bool ghost = (coord[d] + 1 >= arg.dc.X[d]) && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord1.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); + if (doHalo(d) && ghost) { const int ghost_idx = ghostFaceIndexStaggered<1>(coord, arg.dc.X, d, 1); const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); @@ -116,14 +155,20 @@ namespace quda Vector in = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_add(U, in, out[s]); } - } else if (doBulk() && !ghost) { - const int fwd_idx = getNeighborIndexCB<1>(coord1, d, 1, arg.dc); - const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int fwd_idx = getNeighborIndexCB<1>(coord1, d, 1, arg.dc); + const Link U = arg.improved ? arg.U(d, coord.x_cb, parity) : + arg.U(d, coord.x_cb, parity, StaggeredPhase(coord, d, +1, arg)); #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity); - out[s] = mv_add(U, in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + Vector in = arg.in[src_idx + s](fwd_idx, their_spinor_parity); + out[s] = mv_add(U, in, out[s]); + } } + prefetch(d, 0, 0, coord, coord1, parity, arg); } } @@ -139,20 +184,25 @@ namespace quda = arg.halo.Ghost(d, 1, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_add(L, in, out[s]); } - } else if (doBulk() && !ghost) { - const int fwd3_idx = getNeighborIndexCB<3>(coord, d, 1, arg.dc); - const Link L = arg.L(d, coord.x_cb, parity); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int fwd3_idx = getNeighborIndexCB<3>(coord, d, 1, arg.dc); + const Link L = arg.L(d, coord.x_cb, parity); #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity); - out[s] = mv_add(L, in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + const Vector in = arg.in[src_idx + s](fwd3_idx, their_spinor_parity); + out[s] = mv_add(L, in, out[s]); + } } + prefetch(d, 0, 1, coord, coord1, parity, arg); } } if (arg.dd_in.doHopping(coord, d, -1)) { // Backward gather - compute back offset for spinor and gauge fetch - const bool ghost = (coord[d] - 1 < 0) && isActive(active, thread_dim, d, coord, arg); + const bool ghost = coord1.in_boundary[1][d] & isActive(active, thread_dim, d, coord, arg); if (doHalo(d) && ghost) { const int ghost_idx2 = ghostFaceIndexStaggered<0>(coord, arg.dc.X, d, 1); @@ -169,21 +219,26 @@ namespace quda Vector in = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_sub(conj(U), in, out[s]); } - } else if (doBulk() && !ghost) { - const int back_idx = getNeighborIndexCB<1>(coord1, d, -1, arg.dc); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int back_idx = getNeighborIndexCB<1>(coord1, d, -1, arg.dc); #ifdef QUDA_DSLASH_DOUBLE_STORE - const Link U = arg.improved ? arg.Uback(d, coord.x_cb, parity) : - arg.Uback(d, coord.x_cb, parity, StaggeredPhase(coord, d, -1, arg)); + const Link U = arg.improved ? arg.Uback(d, coord.x_cb, parity) : + arg.Uback(d, coord.x_cb, parity, StaggeredPhase(coord, d, -1, arg)); #else - const int gauge_idx = back_idx; - const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : - arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); + const int gauge_idx = back_idx; + const Link U = arg.improved ? arg.U(d, gauge_idx, 1 - parity) : + arg.U(d, gauge_idx, 1 - parity, StaggeredPhase(coord, d, -1, arg)); #endif #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); - out[s] = mv_sub(conj(U), in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + Vector in = arg.in[src_idx + s](back_idx, their_spinor_parity); + out[s] = mv_sub(conj(U), in, out[s]); + } } + prefetch(d, 1, 0, coord, coord1, parity, arg); } } @@ -203,19 +258,24 @@ namespace quda = arg.halo.Ghost(d, 0, ghost_idx + (src_idx + s) * arg.dc.ghostFaceCB[d], their_spinor_parity); out[s] = mv_sub(conj(L), in, out[s]); } - } else if (doBulk() && !ghost) { - const int back3_idx = getNeighborIndexCB<3>(coord, d, -1, arg.dc); + } + + if constexpr (doBulk()) { + if (!ghost) { + const int back3_idx = getNeighborIndexCB<3>(coord, d, -1, arg.dc); #ifdef QUDA_DSLASH_DOUBLE_STORE - const Link L = arg.Lback(d, coord.x_cb, parity); + const Link L = arg.Lback(d, coord.x_cb, parity); #else - const int gauge_idx = back3_idx; - const Link L = arg.L(d, gauge_idx, 1 - parity); + const int gauge_idx = back3_idx; + const Link L = arg.L(d, gauge_idx, 1 - parity); #endif #pragma unroll - for (auto s = 0; s < n_src_tile; s++) { - const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); - out[s] = mv_sub(conj(L), in, out[s]); + for (auto s = 0; s < n_src_tile; s++) { + const Vector in = arg.in[src_idx + s](back3_idx, their_spinor_parity); + out[s] = mv_sub(conj(L), in, out[s]); + } } + prefetch(d, 1, 1, coord, coord1, parity, arg); } } } // nDim diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 04aa4f50fe..3a937394d5 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -43,6 +43,7 @@ namespace quda /** parameters for distance preconditioning */ const real alpha0; const int t0; + static constexpr int prefetch_distance = 0; WilsonArg(cvector_ref &out, cvector_ref &in, const ColorSpinorField &halo, const GaugeField &U, const GaugeField &Uback, double a, cvector_ref &x, @@ -65,6 +66,41 @@ namespace quda } }; + /** + @tparam distance The distance away we are prefetching + @param[in] dim The dimension we are presently working on + @param[in] dir The direction we are presently working on (1 = forwards, 0 = backwards) + @param[in] coord Coordinates that we are working on + @param[in] parity Partiry that we are working on + @param[in] arg Paramter struct + */ + template + __device__ __host__ void prefetch(int dim, int dir, const coord_t &coord, int parity, const Arg &arg) + { + if constexpr (arg.prefetch_distance == 0) return; + + int step = 2 * dim + dir + arg.prefetch_distance; + if (step >= 8) return; + + // for TMA use arg.block_size + int dim2 = step / 2; + // need warp uniform variants of these and parity + const int x_cb = (Arg::nDim == 5 ? coord.x_cb % arg.dc.volume_4d_cb : coord.x_cb); + + switch (step % 2) { + case 0: arg.U.prefetch(x_cb, dim2, parity); break; +#ifdef QUDA_DSLASH_DOUBLE_STORE + case 1: arg.Uback.prefetch(x_cb, dim2, parity); break; +#else + case 1: { + const int back_idx = getNeighborIndexCB(coord, dim2, -1, arg.dc); + const int idx1 = (Arg::nDim == 5 ? back_idx % arg.dc.volume_4d_cb : back_idx); + arg.U.prefetch(idx1, dim2, 1 - parity); + } break; +#endif + } + } + /** @brief Applies the off-diagonal part of the Wilson operator @@ -117,12 +153,16 @@ namespace quda their_spinor_parity); out += fwd_coeff * (U * in).reconstruct(d, proj_dir); - } else if (doBulk() && !ghost) { + } - Link U = arg.U(d, gauge_idx, gauge_parity); - Vector in = arg.in[src_idx](fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + if constexpr (doBulk()) { + if (!ghost) { + Link U = arg.U(d, gauge_idx, gauge_parity); + Vector in = arg.in[src_idx](fwd_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + out += fwd_coeff * (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + } - out += fwd_coeff * (U * in.project(d, proj_dir)).reconstruct(d, proj_dir); + prefetch(d, 0, coord, parity, arg); } } @@ -155,16 +195,20 @@ namespace quda their_spinor_parity); out += bwd_coeff * (conj(U) * in).reconstruct(d, proj_dir); - } else if (doBulk() && !ghost) { + } + if (doBulk()) { + if (!ghost) { #ifdef QUDA_DSLASH_DOUBLE_STORE - Link U = arg.Uback(d, gauge_idx, gauge_parity); + Link U = arg.Uback(d, gauge_idx, gauge_parity); #else - Link U = arg.U(d, gauge_idx, 1 - gauge_parity); + Link U = arg.U(d, gauge_idx, 1 - gauge_parity); #endif - Vector in = arg.in[src_idx](back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + Vector in = arg.in[src_idx](back_idx + coord.s * arg.dc.volume_4d_cb, their_spinor_parity); + out += bwd_coeff * (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + } - out += bwd_coeff * (conj(U) * in.project(d, proj_dir)).reconstruct(d, proj_dir); + prefetch(d, 1, coord, parity, arg); } } } // nDim diff --git a/include/targets/cuda/tunable_kernel.h b/include/targets/cuda/tunable_kernel.h index 55219f5f93..46b599254e 100644 --- a/include/targets/cuda/tunable_kernel.h +++ b/include/targets/cuda/tunable_kernel.h @@ -57,6 +57,8 @@ namespace quda launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { checkSharedBytes(tp, arg); + const_cast(arg).block_size = tp.block.x * tp.block.y * tp.block.z; + if constexpr (Arg::is_dslash) const_cast(arg).arg.block_size = arg.block_size; #ifdef JITIFY launch_error = launch_jitify(kernel.name, tp, stream, arg); #else @@ -66,7 +68,7 @@ namespace quda return launch_error; } - template void check_arg_size(Arg&) + template void check_arg_size(Arg &) { static_assert(sizeof(Arg) <= device::max_constant_size(), "Parameter struct is greater than max constant size"); } @@ -76,6 +78,8 @@ namespace quda launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { checkSharedBytes(tp, arg); + const_cast(arg).block_size = tp.block.x * tp.block.y * tp.block.z; + if constexpr (Arg::is_dslash) const_cast(arg).arg.block_size = arg.block_size; #ifdef JITIFY // note we do the copy to constant memory after the kernel has been compiled in launch_jitify launch_error = launch_jitify(kernel.name, tp, stream, arg); @@ -99,6 +103,8 @@ namespace quda void launch_cuda(const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) const { checkSharedBytes(tp, arg); + const_cast(arg).block_size = tp.block.x * tp.block.y * tp.block.z; + if constexpr (Arg::is_dslash) const_cast(arg).arg.block_size = arg.block_size; constexpr bool grid_stride = false; const_cast(this)->launch_device(KERNEL(raw_kernel), tp, stream, arg); } diff --git a/include/targets/hip/tunable_kernel.h b/include/targets/hip/tunable_kernel.h index 5447eeb25b..bb8b08f56b 100644 --- a/include/targets/hip/tunable_kernel.h +++ b/include/targets/hip/tunable_kernel.h @@ -54,17 +54,26 @@ namespace quda launch_device(const kernel_t &kernel, const TuneParam &tp, const qudaStream_t &stream, const Arg &arg) { checkSharedBytes(tp, arg); + const_cast(arg).block_size = tp.block.x * tp.block.y * tp.block.z; + if constexpr (Arg::is_dslash) const_cast(arg).arg.block_size = arg.block_size; setMaxActiveBlocks(kernel, tp); launch_error = qudaLaunchKernel(kernel, tp, stream, static_cast(&arg)); return launch_error; } + template void check_arg_size(Arg &) + { + static_assert(sizeof(Arg) <= device::max_constant_size(), "Parameter struct is greater than max constant size"); + } + template