diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 42432041a8b01..cab8092e30260 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -103,6 +103,7 @@ cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention ke option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF) cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF) option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF) +option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF) option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) @@ -125,6 +126,7 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) +option(onnxruntime_DUMP_TENSOR "Dump tensor inside kernel." OFF) cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF) # When loading a delay loaded DLL, Windows searches the main EXE's folder first. @@ -627,7 +629,6 @@ else() check_cxx_compiler_flag(-Wparentheses HAS_PARENTHESES) check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32) check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING) - check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING) check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE) check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE) @@ -774,8 +775,13 @@ if (onnxruntime_USE_CUDA) endif() if (onnxruntime_QUICK_BUILD) - message( STATUS "Quick build mode: Flash attention limited to fp16 only") - list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1) + message( STATUS "Quick build mode: Flash attention limited to head dimension 128 only") + list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1) + endif() + + if (onnxruntime_USE_INT4_KV_CACHE) + message( STATUS "Enable int4 kv cache for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1) endif() endif() @@ -1433,6 +1439,9 @@ if (Git_FOUND) if (onnxruntime_QUICK_BUILD) string(APPEND ORT_BUILD_INFO "quick-build=1, ") endif() + if (onnxruntime_USE_INT4_KV_CACHE) + string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ") + endif() endif() string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}") configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h) @@ -1446,6 +1455,8 @@ if (onnxruntime_USE_CUDA) find_package(CUDAToolkit REQUIRED) if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8) + add_definitions("-DENABLE_BF16") + message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_BF16 flag") add_definitions("-DENABLE_FP8") message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag") endif() @@ -1779,6 +1790,10 @@ if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS) add_compile_definitions(DEBUG_NODE_INPUTS_OUTPUTS) endif() +if (onnxruntime_DUMP_TENSOR) + add_compile_definitions(DUMP_TENSOR_LEVEL=1) +endif() + if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS) if (NOT CMAKE_SYSTEM_NAME STREQUAL "Linux") message(FATAL_ERROR "External custom operator schemas feature is only supported on Linux") diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f77a5dd78fcc5..f6efcb3fad6a9 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -28,12 +28,10 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS # Quick build mode: Filter flash attention kernels for faster development iteration. # - We keep only hdim128 fp16 flash attention kernels in quick build mode. # - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256). -# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels. # If new head dimensions are added or removed, update this list to match the supported set. if(onnxruntime_QUICK_BUILD) message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels") list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)") - list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16") endif() file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index b7b68ff324d9f..149ecdb969bd5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2520,15 +2520,26 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.GroupQueryAttention** - Group Query Self/Cross Attention. + Group Query Self/Cross Attention with KV Cache Quantization Support. - *Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv. - Supports different number of heads for q and kv for CPU and CUDA. - Only supports causal and local attention. - Supports rotary position embedding for CPU and CUDA. - Supports packed input for CPU and CUDA. - Supports continuous decoding for batch_size == 1 for CPU and CUDA. + This operator implements causal grouped-query attention with past state (KV cache) support. + It also supports optional float8, int8 or int4 quantization for the KV cache to reduce memory footprint. + **Cache Format:** + The past and present KV cache tensors are expected in a BNSH format: `(batch_size, num_heads, cache_sequence_length, head_size)`, where `cache_sequence_length` is the length of the cached key/value sequences, or the maximum sequence length when past and present buffer sharing is used. + + **Quantization:** + When quantization is enabled, `past_key` and `past_value` inputs can be of type `float8e4m3fn`, `uint8` or `int8`. The corresponding `k_scale` and `v_scale` tensors must be provided. + The operator will output `present_key` and `present_value` in same format as the `past_key` and `past_value`. + + For 4-bit quantization, the data type is uint8 where each byte contains two 4-bit values. The bit width of quantized KV cache can be set using `kv_cache_bit_width` attribute. + + The shapes of the k_scale, v_scale tensors shall be broadcastable to present_key shape. + + **Quantization Modes (`k_quant_type`, `v_quant_type` attributes):** + - **"NONE"**: No quantization. + - **"PER_TENSOR"**: A single scale for the entire tensor. Scale example shape: `[1]`. + - **"PER_CHANNEL"**: A scale for each channel. Scale example shape: `[1, num_heads_k, 1, head_size]`. #### Version @@ -2539,6 +2550,10 @@ This version of the operator has been available since version 1 of the 'com.micr
do_rotary : int
Whether to use rotary position embedding. Default value is 0.
+
k_quant_type : string
+
Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
+
kv_cache_bit_width : int
+
Bit width of quantized KV cache. Supported values are 8 and 4.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
@@ -2555,9 +2570,11 @@ This version of the operator has been available since version 1 of the 'com.micr
Use a smooth factor in softmax.
softcap : float
Softcap value for attention weights. Default value is 0.
+
v_quant_type : string
+
Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.
-#### Inputs (7 - 12) +#### Inputs (7 - 14)
query : T
@@ -2566,9 +2583,9 @@ This version of the operator has been available since version 1 of the 'com.micr
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
past_key (optional) : T
+
past_key (optional) : T_CACHE
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
-
past_value (optional) : T
+
past_value (optional) : T_CACHE
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
seqlens_k : M
1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
@@ -2584,6 +2601,10 @@ This version of the operator has been available since version 1 of the 'com.micr
additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
head_sink (optional) : T
1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.
+
k_scale (optional) : T_KV_SCALE
+
Scale tensor for past_key.
+
v_scale (optional) : T_KV_SCALE
+
Scale tensor for past_value.
#### Outputs (3 - 4) @@ -2591,9 +2612,9 @@ This version of the operator has been available since version 1 of the 'com.micr
output : T
3D output tensor with shape (batch_size, sequence_length, hidden_size)
-
present_key : T
+
present_key : T_CACHE
present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
-
present_value : T
+
present_value : T_CACHE
present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.
output_qk (optional) : T
Values of QK matrix multiplication, either before or after softmax normalization
@@ -2604,6 +2625,10 @@ This version of the operator has been available since version 1 of the 'com.micr
T : tensor(float16), tensor(bfloat16), tensor(float)
Constrain input and output to float tensors.
+
T_CACHE : tensor(float), tensor(float16), tensor(bfloat16), tensor(uint8), tensor(int8), tensor(float8e4m3fn)
+
Constrain KV cache types.
+
T_KV_SCALE : tensor(float)
+
Constrain KV cache scale types.
M : tensor(int32)
Constrain mask to int tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 08840c623b709..0230f2866fcb4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -577,7 +577,7 @@ Do not modify directly.* |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| @@ -1003,7 +1003,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)
**T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)
**T_KV_SCALE** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -1484,7 +1484,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T_CACHE**
*in* past_value:**T_CACHE**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*in* position_ids:**tensor(int64)**
*in* attention_bias:**T**
*in* head_sink:**T**
*in* k_scale:**T_KV_SCALE**
*in* v_scale:**T_KV_SCALE**
*out* output:**T**
*out* present_key:**T_CACHE**
*out* present_value:**T_CACHE**
*out* output_qk:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 80d374d3f0b25..e886adac03f27 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -59,6 +59,13 @@ enum class QKOutputType : int { AFTER_SOFTMAX = 2 }; +// Enum to define quantization granularity. +enum class KVQuantizationType : int { + NONE = 0, + PER_TENSOR = 1, + PER_CHANNEL = 2, +}; + constexpr bool LAYOUT_BSNH = false; constexpr bool LAYOUT_BNSH = true; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index 4ad11dce7e093..9a123e80adc18 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -96,6 +96,11 @@ struct GroupQueryAttentionParameters : AttentionParameters { AttentionQkvFormat past_kv_format; int zeros_count; int* zero_ptr; + + // Quantization parameters for KV cache + KVQuantizationType k_quant_type = KVQuantizationType::NONE; + KVQuantizationType v_quant_type = KVQuantizationType::NONE; + int kv_cache_bit_width = 0; }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index eb1560ac8e341..0fe37ae68c012 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -70,7 +70,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { seqlens_k, total_seqlen_tensor, scale_, - softcap_)); + softcap_, + 0)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, attention_bias, diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f7c54ad456925..1515adb1fc6ff 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -97,7 +97,7 @@ Status Check_QKV(const T* packed_qkv, const T* value, const int num_heads, const } template -Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, +Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_num_heads, int head_size, int kv_cache_bit_width, int& past_sequence_length) { const auto& past_key_dims = past_key->Shape().GetDims(); const auto& past_value_dims = past_value->Shape().GetDims(); @@ -141,15 +141,18 @@ Status CheckPast(const T* past_key, const T* past_value, int batch_size, int kv_ // We assume all sequence in past kv are right-padded to max or past sequence length past_sequence_length = static_cast(past_key_dims[2]); - if (past_key_dims[3] != head_size) { + // For 4-bit quantized KV cache, actual dimension is head_size / 2 because 2 nibbles are packed into one byte. + // Note that we have checked that head_size is a multiple of 8 in Check_QKV. + int packed_head_size = (kv_cache_bit_width == 4) ? (head_size / 2) : head_size; + if (past_key_dims[3] != packed_head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); + past_key_dims[3], " expected ", packed_head_size); } - if (past_value_dims[3] != head_size) { + if (past_value_dims[3] != packed_head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); + past_value_dims[3], " expected ", packed_head_size); } return Status::OK(); } @@ -203,7 +206,8 @@ Status CheckInputs(const T* query, const T* seqlens_k, const T* total_seqlen, float scale, - float softcap) { + float softcap, + int kv_cache_bit_width) { // Note: Here S* is seqlen_past_kv_cache, S+ is seqlen_present_kv_cache // past_key : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr // past_value : (B, N_k, S*, H) or (B, N_k, S+, H) or nullptr @@ -221,6 +225,18 @@ Status CheckInputs(const T* query, num_heads % kv_num_heads); } +#ifdef USE_INT4_KV_CACHE + if (kv_cache_bit_width != 0 && kv_cache_bit_width != 4 && kv_cache_bit_width != 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_cache_bit_width must be 0, 4 or 8. Got kv_cache_bit_width == ", kv_cache_bit_width); + } +#else + if (kv_cache_bit_width != 0 && kv_cache_bit_width != 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "kv_cache_bit_width must be 0 or 8. Got kv_cache_bit_width == ", kv_cache_bit_width); + } +#endif + int batch_size = 0; int sequence_length = 0; int q_hidden_size = 0; @@ -239,7 +255,7 @@ Status CheckInputs(const T* query, // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { - ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, past_sequence_length)); + ORT_RETURN_IF_ERROR(CheckPast(past_key, past_value, batch_size, kv_num_heads, head_size, kv_cache_bit_width, past_sequence_length)); } else if (past_key != nullptr || past_value != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'past_key' and 'past_value' shall be both present or both absent."); @@ -329,12 +345,13 @@ Status CheckInputs(const T* query, const T* total_seqlen, float scale, float softcap, + int kv_cache_bit_width, int max_threads_per_block) { if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, scale, softcap, kv_cache_bit_width); } template diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h index 415612582ee4b..9669e0b8622c0 100644 --- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h +++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h @@ -46,11 +46,11 @@ #define DUMP_TENSOR_D(...) #endif -#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG) -#define DEBUG_PRINTF(fmt, ...) \ +#if (defined(__GNUC__) || defined(__clang__)) && (DUMP_TENSOR_LEVEL > 0) +#define DUMP_PRINTF(fmt, ...) \ std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__) #else -#define DEBUG_PRINTF(fmt, ...) \ - do { \ +#define DUMP_PRINTF(fmt, ...) \ + do { \ } while (0) #endif diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 1622bb6622412..486bf05bd86d5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -145,18 +145,21 @@ struct PackedMultiHeadAttentionData { bool use_memory_efficient_attention; }; -template +template struct GroupQueryAttentionData { // Input Tensors const T* query = nullptr; const T* key = nullptr; const T* value = nullptr; - const T* past_key = nullptr; - const T* past_value = nullptr; + const U* past_key = nullptr; + const U* past_value = nullptr; const T* cos_cache = nullptr; const T* sin_cache = nullptr; const T* head_sink = nullptr; + const float* k_scale = nullptr; + const float* v_scale = nullptr; + // Total sequence length for each batch. It has shape [batch_size]. int* total_seq_lens = nullptr; @@ -186,13 +189,18 @@ struct GroupQueryAttentionData { // Output Tensors T* output = nullptr; - void* present_key = nullptr; - void* present_value = nullptr; + U* present_key = nullptr; + U* present_value = nullptr; // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; bool use_flash_attention_fast_decode = false; + bool use_xqa = false; + + // XQA buffer + void* xqa_buffer = nullptr; + size_t xqa_buffer_bytes = 0; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 83f94a31d1786..7aed9fe10afbd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -145,12 +145,6 @@ bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_hea template bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k) { #ifdef ORT_QUICK_BUILD - // In quick build mode, only fp16 flash attention is built - constexpr bool is_bf16 = std::is_same::value; - if (is_bf16) { - return false; - } - if (head_size != 128) { return false; } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h index eb1d6501a80f6..0e6c733313f06 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -66,14 +66,6 @@ #define LOCAL_SWITCH BOOL_SWITCH #endif -#ifdef ORT_QUICK_BUILD -// Quick build mode: only fp16 kernels are compiled -#define FP16_SWITCH(COND, ...) \ - [&] { \ - using elem_type = cutlass::half_t; \ - return __VA_ARGS__(); \ - }() -#else #define FP16_SWITCH(COND, ...) \ [&] { \ if (COND) { \ @@ -84,7 +76,6 @@ return __VA_ARGS__(); \ } \ }() -#endif #ifdef ORT_QUICK_BUILD // Quick build mode: only hdim128 kernels are compiled diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 29ef660e562e0..39154ca395fc1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -1,15 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" #include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include "contrib_ops/cuda/bert/xqa/xqa_loader.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "contrib_ops/cpu/utils/debug_macros.h" @@ -21,23 +24,49 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - GroupQueryAttention, \ - kMSDomain, \ - 1, \ - T, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \ - .MayInplace(3, 1) \ - .MayInplace(4, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 6), \ - GroupQueryAttention); - -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) +namespace { +// Map string attribute to quantization type enum +KVQuantizationType StringToKVQuantizationType(std::string s) { + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::toupper(c); }); + if (s == "NONE") { + return KVQuantizationType::NONE; + } + if (s == "PER_TENSOR") { + return KVQuantizationType::PER_TENSOR; + } + + if (s == "PER_CHANNEL") { + return KVQuantizationType::PER_CHANNEL; + } + return KVQuantizationType::NONE; +} +} // namespace + +#define REGISTER_KERNEL_TYPED(T, U) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T##_##U, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T_CACHE", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T_KV_SCALE", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", {DataTypeImpl::GetTensorType()}) \ + .MayInplace(3, 1) /* past_key and present_key */ \ + .MayInplace(4, 2) /* past_value and present_value */ \ + .InputMemoryType(OrtMemTypeCPUInput, 6), /* total_sequence_length */ \ + GroupQueryAttention); + +REGISTER_KERNEL_TYPED(MLFloat16, MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16, BFloat16) +REGISTER_KERNEL_TYPED(MLFloat16, int8_t) +REGISTER_KERNEL_TYPED(BFloat16, int8_t) +#ifdef USE_INT4_KV_CACHE +REGISTER_KERNEL_TYPED(MLFloat16, uint8_t) +REGISTER_KERNEL_TYPED(BFloat16, uint8_t) +#endif constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; @@ -51,8 +80,8 @@ constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; // - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData // - Flash Attention and Memory Efficient Attention backends // -template -GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : CudaKernel(info) { int64_t num_heads = 0; int64_t kv_num_heads = 0; @@ -69,6 +98,14 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) softcap_ = info.GetAttrOrDefault("softcap", 0.0f); use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + k_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("k_quant_type", "NONE")); + v_quant_type_ = StringToKVQuantizationType(info.GetAttrOrDefault("v_quant_type", "NONE")); + kv_cache_bit_width_ = static_cast(info.GetAttrOrDefault("kv_cache_bit_width", 0)); + + bool is_quantized = (k_quant_type_ != KVQuantizationType::NONE || v_quant_type_ != KVQuantizationType::NONE); + int default_enable_xqa = is_quantized ? 1 : 0; + enable_xqa_ = (std::is_same_v || std::is_same_v) && ParseEnvironmentVariableWithDefault("ORT_ENABLE_XQA", default_enable_xqa) != 0; + kernel_options_ = this->GetAttentionKernelOptions(); disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); @@ -99,8 +136,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) // 9. position_ids (Tensor) - Position indices for RoPE // 10. attention_bias (Tensor) - Not supported in this kernel // 11. head_sink (Tensor) - Attention sink for GPT-OSS -template -Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* query = context->Input(0); const Tensor* key = context->Input(1); const Tensor* value = context->Input(2); @@ -119,6 +156,35 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* position_ids = context->Input(9); const Tensor* attention_bias = context->Input(10); const Tensor* head_sink = context->Input(11); + const Tensor* k_scale = context->Input(12); + const Tensor* v_scale = context->Input(13); + + if (k_quant_type_ != KVQuantizationType::NONE) { + if (k_scale == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "k_scale must be provided when k_quant_type is not NONE"); + } + + if (k_scale->DataType() != DataTypeImpl::GetType()) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "k_scale must be float tensor"); + } + } + + if (v_quant_type_ != KVQuantizationType::NONE) { + if (v_scale == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "v_scale must be provided when v_quant_type is not NONE"); + } + if (v_scale->DataType() != DataTypeImpl::GetType()) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "v_scale must be float tensor"); + } + } if (attention_bias != nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -127,8 +193,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; - typedef typename ToCudaType::MappedType CudaT; - GroupQueryAttentionData data; + + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; + GroupQueryAttentionData data; ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, key, @@ -144,6 +212,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { total_seqlen, scale_, softcap_, + kv_cache_bit_width_, device_prop.maxThreadsPerBlock)); ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, @@ -155,6 +224,9 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); + parameters.k_quant_type = k_quant_type_; + parameters.v_quant_type = v_quant_type_; + parameters.kv_cache_bit_width = kv_cache_bit_width_; parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; @@ -176,8 +248,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, output_shape); // Set up present KV output shapes + // For 4-bit quantization, we pack two 4-bit values into one uint8 byte. + // Therefore, the dense head size in the tensor shape is halved (rounded up). + int dense_head_size = (parameters.kv_cache_bit_width == 4) ? (parameters.head_size + 1) / 2 : parameters.head_size; std::vector present_dims = { - parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, parameters.head_size}; + parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, dense_head_size}; TensorShape present_shape(present_dims); context->Output(1, present_shape); // present_key @@ -203,24 +278,103 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); // Handle Past/Present pointers - data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); - data.present_key = reinterpret_cast(context->Output(1)->MutableData()); - data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); - data.present_value = reinterpret_cast(context->Output(2)->MutableData()); + data.k_scale = k_scale == nullptr ? nullptr : reinterpret_cast(k_scale->DataRaw()); + data.v_scale = v_scale == nullptr ? nullptr : reinterpret_cast(v_scale->DataRaw()); + + data.past_key = (past_key == nullptr) ? nullptr : reinterpret_cast(past_key->Data()); + data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); + + data.present_key = reinterpret_cast(context->Output(1)->MutableData()); + data.present_value = reinterpret_cast(context->Output(2)->MutableData()); // Compute past_present_share_buffer early since it's needed for flash attention path selection. // This compares the final pointer values after quantization handling. parameters.past_present_share_buffer = (data.past_key == data.present_key); + bool is_inputs_quantized = (k_quant_type_ != KVQuantizationType::NONE) || (v_quant_type_ != KVQuantizationType::NONE); + + // Allocate XQA scratch if needed (only for Flash Decoding path) + IAllocatorUniquePtr xqa_scratch_buffer; + // Check conditions to enable XQA (Extreme Query Attention) kernel for optimized decoding. + // XQA is a highly optimized kernel for generation phase (seq_len=1). + // Constraints: + // 1. Compute Capability SM 8.0+ (Ampere or newer). + // 2. Not the first prompt (decoding phase). + // 3. Sequence length is 1. + // 4. Past and Present KV cache share the same buffer (required for XQA specific memory access). + // 5. No Softcap (XQA doesn't support softcap). + // 6. Standard Softmax (no smooth softmax). + // 7. No local window attention (global attention only). + if (enable_xqa_ && + (device_prop.major >= 8) && + !parameters.is_first_prompt && + parameters.sequence_length == 1 && + parameters.past_present_share_buffer && + parameters.softcap == 0.0f && + !parameters.use_smooth_softmax && + parameters.local_window_size == -1) { + int group_size = parameters.num_heads / parameters.kv_num_heads; + + bool is_int8_quantized_supported = (k_quant_type_ == KVQuantizationType::PER_TENSOR && + v_quant_type_ == KVQuantizationType::PER_TENSOR && + data.k_scale == data.v_scale && // XQA requires k_scale and v_scale to be the same. Here requires k_scale and v_scale are same tensor. + parameters.kv_cache_bit_width == 8 && + (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && + (group_size == 4 || group_size == 8 || group_size == 16 || group_size == 32)); + + bool is_non_quantized_supported = !is_inputs_quantized && + (parameters.head_size == 256 || parameters.head_size == 128 || parameters.head_size == 64) && + (64 % group_size == 0); + + data.use_xqa = (is_non_quantized_supported || is_int8_quantized_supported); + + if (data.use_xqa) { + size_t xqa_internal_bytes = onnxruntime::contrib::cuda::GetXQAScratchSize( + GetDeviceProp(), + parameters.batch_size, + parameters.num_heads, + parameters.kv_num_heads, + parameters.head_size, + parameters.seqlen_present_kv_cache, + parameters.k_quant_type != KVQuantizationType::NONE ? XqaQuantType::kInt8 : XqaQuantType::kNone, + std::is_same::value); + assert(xqa_internal_bytes > 0); + // Calculate additional scratch needed for manual RoPE/Append in ExtremeDecoding + size_t xqa_total_bytes = xqa_internal_bytes; + if (parameters.do_rotary) { + // 1. Q_rotated buffer: B * N * H * sizeof(T) (if rotary) + // 2. K_rotated buffer: B * Nk * H * sizeof(T) (if rotary) + size_t element_size = sizeof(CudaT); + size_t q_bytes = parameters.batch_size * parameters.num_heads * parameters.head_size * element_size; + size_t k_bytes = parameters.batch_size * parameters.kv_num_heads * parameters.head_size * element_size; + q_bytes = (q_bytes + 255) / 256 * 256; + k_bytes = (k_bytes + 255) / 256 * 256; + xqa_total_bytes += q_bytes + k_bytes; + } + + xqa_scratch_buffer = this->GetScratchBuffer(xqa_total_bytes, context->GetComputeStream()); + data.xqa_buffer = xqa_scratch_buffer.get(); + data.xqa_buffer_bytes = xqa_internal_bytes; + + if (parameters.do_rotary) { + data.qkv_buffer = reinterpret_cast(reinterpret_cast(data.xqa_buffer) + xqa_internal_bytes); + } + } + } + + // Compute past_present_share_buffer early since it's needed for flash attention path selection. + // This compares the final pointer values after quantization handling. + #if USE_FLASH_ATTENTION - bool use_flash_attention = !disable_flash_attention_ && + bool use_flash_attention = !data.use_xqa && + !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, parameters.head_size, parameters.num_heads, parameters.kv_num_heads); data.use_flash_attention = use_flash_attention; - data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer; + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer && !is_inputs_quantized; if (use_flash_attention) { // Allocate Flash specific buffers (Softmax LSE, Accum) @@ -245,6 +399,16 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); + auto cuda_stream = static_cast(context->GetComputeStream()->GetHandle()); + if (softmax_lse_accum_bytes > 0) { + // Initialize to 0 is fine because Flash kernel will write -inf to it if needed. + // However, the standard Flash kernel often doesn't zero it globally. + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(softmax_lse_accum_buffer.get(), 0, softmax_lse_accum_bytes, cuda_stream)); + } + if (out_accum_bytes > 0) { + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(out_accum_buffer.get(), 0, out_accum_bytes, cuda_stream)); + } + data.softmax_lse = reinterpret_cast(softmax_lse_buffer.get()); data.softmax_lse_accum = reinterpret_cast(softmax_lse_accum_buffer.get()); data.out_accum = reinterpret_cast(out_accum_buffer.get()); @@ -281,7 +445,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { } #if USE_MEMORY_EFFICIENT_ATTENTION - if (!data.use_flash_attention) { + if (!data.use_xqa && !data.use_flash_attention) { // Fall back to memory efficient attention. int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = @@ -313,6 +477,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { // This ensures allocation logic stays in sync with kernel usage auto buffer_req = GQABufferRequirements::Compute( parameters, + data.use_xqa, data.use_flash_attention, data.use_flash_attention_fast_decode, data.use_memory_efficient_attention); @@ -351,11 +516,29 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.head_sink = reinterpret_cast(head_sink->Data()); } - cublasHandle_t cublas = GetCublasHandle(context); +#if DUMP_TENSOR_LEVEL > 0 + DUMP_TENSOR_INIT(); + // Dump Scales + if (data.k_scale) { + if (parameters.k_quant_type == KVQuantizationType::PER_TENSOR) { + DUMP_TENSOR("k_scale", data.k_scale, 1, 1); + } else if (parameters.k_quant_type == KVQuantizationType::PER_CHANNEL) { + DUMP_TENSOR("k_scale", data.k_scale, parameters.kv_num_heads, 1, parameters.head_size); + } + } + if (data.v_scale) { + if (parameters.v_quant_type == KVQuantizationType::PER_TENSOR) { + DUMP_TENSOR("v_scale", data.v_scale, 1, 1); + } else if (parameters.v_quant_type == KVQuantizationType::PER_CHANNEL) { + DUMP_TENSOR("v_scale", data.v_scale, parameters.kv_num_heads, 1, parameters.head_size); + } + } +#endif - ORT_RETURN_IF_ERROR(QkvToContext( - device_prop, cublas, context->GetComputeStream(), parameters, data)); + cublasHandle_t cublas = GetCublasHandle(context); + ORT_RETURN_IF_ERROR((QkvToContext( + device_prop, cublas, context->GetComputeStream(), parameters, data))); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 2536da9fe1379..75a613724e746 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -15,7 +15,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GroupQueryAttention final : public CudaKernel { public: GroupQueryAttention(const OpKernelInfo& info); @@ -35,6 +35,11 @@ class GroupQueryAttention final : public CudaKernel { bool disable_flash_attention_; bool disable_memory_efficient_attention_; bool disable_flash_decode_; + bool enable_xqa_; + + KVQuantizationType k_quant_type_; + KVQuantizationType v_quant_type_; + int kv_cache_bit_width_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 0b6da63b31af6..e0fa993db29bd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -42,11 +42,14 @@ limitations under the License. #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cpu/bert/attention_common.h" #include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" +#include "contrib_ops/cuda/bert/group_query_attention_qdq.cuh" +#include "contrib_ops/cuda/bert/xqa/xqa_loader.h" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include "contrib_ops/cuda/bert/rotary_common.cuh" #include "contrib_ops/cuda/bert/transformer_common.h" #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_type_conversion.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -74,44 +77,43 @@ namespace cuda { // 3. Ensuring synchronization between past and present KV caches when necessary. // 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches. // 5. Returning strict Q, K, V pointers ready for the core attention operation. -template +template Status PrepareQKV( cudaStream_t stream, const int max_threads_per_block, const GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, - const T*& q, - const T*& k, - const T*& v) { + GroupQueryAttentionData& data, + const T*& q) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - using CudaT = typename ToCudaType::MappedType; - CudaT* q_out = data.qkv_buffer; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaU; + CudaT* q_out = reinterpret_cast(data.qkv_buffer); if (!parameters.is_packed_qkv && !parameters.do_rotary) { q_out = nullptr; } - CudaT* k_final_ptr = reinterpret_cast(data.present_key); - CudaT* v_final_ptr = reinterpret_cast(data.present_value); - int final_max_seqlen = parameters.seqlen_present_kv_cache; - bool final_is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + CudaT* k = reinterpret_cast(data.present_key); + CudaT* v = reinterpret_cast(data.present_value); + int max_cache_length = parameters.seqlen_present_kv_cache; + bool is_cache_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); if (!parameters.past_present_share_buffer) { - size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * final_max_seqlen * head_size * sizeof(CudaT); + size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * max_cache_length * head_size * sizeof(CudaU); CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream)); CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); } + // Copy past KV to present KV if needed if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { - bool is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); - if (is_bnsh) { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaT); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaT); + if (is_cache_bnsh) { + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaU); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaU); size_t width = src_pitch; size_t height = (size_t)batch_size * kv_num_heads; @@ -120,8 +122,8 @@ Status PrepareQKV( CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream)); } else { - size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaT); - size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaT); + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaU); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaU); size_t width = src_pitch; size_t height = (size_t)batch_size; @@ -132,17 +134,17 @@ Status PrepareQKV( } } - ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppendKV( + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), - q_out, k_final_ptr, v_final_ptr, + q_out, k, v, data.k_scale, data.v_scale, num_heads, kv_num_heads, head_size, sequence_length, batch_size, - final_max_seqlen, data.past_seq_lens, + max_cache_length, data.past_seq_lens, reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, - final_is_bnsh, + is_cache_bnsh, parameters.k_quant_type, parameters.kv_cache_bit_width, stream, max_threads_per_block)); if (q_out != nullptr) { @@ -150,17 +152,16 @@ Status PrepareQKV( } else { q = reinterpret_cast(data.query); } - k = reinterpret_cast(k_final_ptr); - v = reinterpret_cast(v_final_ptr); + return Status::OK(); } ////////// Auxiliary Kernels for KV prep // Concat new to past in present. Supports past BSNH or past BNSH -template +template Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, cudaStream_t stream, @@ -190,12 +191,12 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, is_bsnh, data.past_seq_lens, data.total_seq_lens, - data.past_key, - data.past_value, + reinterpret_cast(data.past_key), + reinterpret_cast(data.past_value), reinterpret_cast(new_key), reinterpret_cast(new_value), - data.present_key, - data.present_value, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), stream, max_threads_per_block, past_only, @@ -207,9 +208,9 @@ Status LaunchConcatNewToPastKVHelper(GroupQueryAttentionParameters& parameters, } // Concat new to kv buffer in place -template +template Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, const void* new_key, const void* new_value, bool is_new_kv_bnsh_format, @@ -230,8 +231,8 @@ Status LaunchConcatKVInPlace(GroupQueryAttentionParameters& parameters, parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), - data.present_key, - data.present_value, + reinterpret_cast(data.present_key), + reinterpret_cast(data.present_value), is_past_kv_bnsh_format, is_new_kv_bnsh_format, stream, @@ -553,13 +554,105 @@ Status LaunchGetSequenceLengths( } // Trace function for debugging -#define ORT_GQA_TRACE(func_name) \ - DEBUG_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \ - func_name, \ - static_cast(parameters.is_packed_qkv), \ - static_cast(parameters.is_first_prompt), \ - static_cast(parameters.is_subsequent_prompt), \ - static_cast(parameters.past_present_share_buffer)); +#define ORT_GQA_TRACE(func_name) \ + DUMP_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \ + func_name, \ + static_cast(parameters.is_packed_qkv), \ + static_cast(parameters.is_first_prompt), \ + static_cast(parameters.is_subsequent_prompt), \ + static_cast(parameters.past_present_share_buffer)); + +////////// Kernels (supports right padding but not left padding) +// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. +// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. +template +Status ExtremeDecoding( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + ORT_GQA_TRACE("ExtremeDecoding"); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + // const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + // bool is_causal = parameters.is_unidirectional; + // bool is_bf16 = std::is_same::value; + + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + bool past_bsnh = (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + + // Ultimate Fused Preprocessing: Unpack, RoPE Q, RoPE K, Quantize K/V, Append K/V + // This replaces all manual steps (Rotate Q, Rotate K, Quantize, StridedCopy) + CudaT* q_rot_ptr = reinterpret_cast(data.qkv_buffer); + const CudaT* q_input_for_xqa = q_rot_ptr; + if (q_rot_ptr == nullptr) { + q_input_for_xqa = reinterpret_cast(data.query); + } + + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_rot_ptr, // unpacked_q (can be null if !do_rotary) + data.present_key, + data.present_value, + data.k_scale, + data.v_scale, + num_heads, + kv_num_heads, + head_size, + sequence_length, + batch_size, + parameters.seqlen_present_kv_cache, // max_seqlen (capacity) + data.past_seq_lens, + data.cos_cache, + data.sin_cache, + parameters.do_rotary ? parameters.rotary_dim : 0, + data.position_ids, + parameters.rotary_interleaved, + !past_bsnh, // is_cache_bnsh + parameters.k_quant_type, + parameters.kv_cache_bit_width, + stream, + device_prop.maxThreadsPerBlock)); + + // Determine workspace size for XQA + void* xqa_workspace = data.xqa_buffer; + size_t xqa_workspace_size = data.xqa_buffer_bytes; + + // 5. Launch XQA + Status status = onnxruntime::contrib::cuda::LaunchXQAKernel( + device_prop, + stream, + q_input_for_xqa, + data.present_key, + data.present_value, + data.output, + batch_size, + num_heads, + kv_num_heads, + parameters.head_size, + parameters.seqlen_present_kv_cache, // max_seq_len (Capacity) + scale, + past_bsnh, + data.past_seq_lens, + data.k_scale, // kv_cache_scale + // Map KVQuantizationType (0=NONE, 1=TENSOR, 2=CHANNEL) to XqaQuantType (0=FP16/BF16, 1=INT8, 2=FP8) + (parameters.k_quant_type == KVQuantizationType::NONE) ? onnxruntime::contrib::cuda::XqaQuantType::kNone : onnxruntime::contrib::cuda::XqaQuantType::kInt8, + xqa_workspace, + xqa_workspace_size); + + // If XQA launch fails, debugging info + + return status; +} ////////// Kernels (supports right padding but not left padding) // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. @@ -568,12 +661,12 @@ Status LaunchGetSequenceLengths( // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. -template +template Status FlashDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); @@ -587,7 +680,7 @@ Status FlashDecoding( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value || std::is_same::value; void* query = reinterpret_cast(const_cast(data.query)); void* key; @@ -613,6 +706,9 @@ Status FlashDecoding( bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + DUMP_PRINTF("[FlashDecoding] key=%p, value=%p, present_key=%p, present_value=%p, seqlens_k=%p, is_packed_qkv=%d", + key, value, present_key, present_value, seqlens_k, static_cast(parameters.is_packed_qkv)); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, @@ -629,12 +725,12 @@ Status FlashDecoding( // Use extra kernel(s) for unpacking, rotary and kv append. // Flash attention is used for attention only. -template +template Status FlashAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; @@ -646,21 +742,14 @@ Status FlashAttention( AttentionQkvFormat past_kv_format = parameters.past_kv_format; bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; bool is_causal = parameters.is_unidirectional; - bool is_bf16 = std::is_same::value; + bool is_bf16 = std::is_same::value || std::is_same::value; DUMP_TENSOR_INIT(); const T* q_prep = nullptr; - const T* k_prep = nullptr; - const T* v_prep = nullptr; - ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); void* query = const_cast(q_prep); - (void)k_prep; // Key/value are now processed by PrepareQKV - (void)v_prep; - - bool use_packed_for_fa = false; - void* present_key = data.present_key; void* present_value = data.present_value; @@ -677,6 +766,9 @@ Status FlashAttention( // Use padded seq lens for first prompt since mha_fwd_kvcache assumes uniform seqlen_q. int* seq_lens = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens; + // After PrepareQKV, the input for flash attention is no longer packed. + constexpr bool is_packed_qkv_for_flash = false; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, kernel_new_k, kernel_new_v, @@ -689,22 +781,214 @@ Status FlashAttention( parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size - 1, - parameters.rotary_interleaved, use_packed_for_fa, 0, 1)); + parameters.rotary_interleaved, is_packed_qkv_for_flash, 0, 1)); DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); return Status::OK(); } + +// Fallback path for decoding quantized kv cache, when XQA is not usable (due to softcap, window, etc.) +// We dequantize the cache and run standard Flash Attention. +template +Status DequantizeFlashAttentionFallback( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + assert(!parameters.is_first_prompt); // Only support first prompt for this function. + assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); + + ORT_GQA_TRACE("DequantizeFlashAttentionFallback"); + + // We need to dequantize the entire KV cache (present_key/value) into a float/half buffer (data.qkv_buffer). + // Layout in qkv_buffer: [Q (rotated)] [K_dequantized] [V_dequantized] + typedef typename onnxruntime::cuda::OrtToCudaType::type CudaT; + CudaT* q_rot = reinterpret_cast(data.qkv_buffer); + size_t q_elements = static_cast(parameters.batch_size) * parameters.sequence_length * parameters.num_heads * parameters.head_size; + size_t k_elements = static_cast(parameters.batch_size) * parameters.seqlen_present_kv_cache * parameters.kv_num_heads * parameters.head_size; + CudaT* k_dequant = q_rot + q_elements; + CudaT* v_dequant = k_dequant + k_elements; + + // Step 1: Update Quantized Cache + // We can use LaunchUnpackRoPEQuantizeAppend to unpack new QKV, apply RoPE, and append to quantized cache. + // This will also put rotated Q into q_rot. + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_rot, data.present_key, data.present_value, data.k_scale, data.v_scale, + parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, parameters.batch_size, + parameters.seqlen_present_kv_cache, data.past_seq_lens, + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, + (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH), + parameters.k_quant_type, parameters.kv_cache_bit_width, + stream, device_prop.maxThreadsPerBlock)); + + // Step 2: Dequantize Entire Cache + // We now have the updated quantized cache in data.present_key/value. We need to dequantize it to k_dequant/v_dequant. + bool is_bsnh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + + if (parameters.kv_cache_bit_width == 8) { + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, 8, parameters.k_quant_type, is_bsnh))); + + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, 8, parameters.v_quant_type, is_bsnh))); +#ifdef USE_INT4_KV_CACHE + } else if (parameters.kv_cache_bit_width == 4) { + // Int4 support if needed + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, k_dequant, reinterpret_cast(data.present_key), data.k_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, 4, parameters.k_quant_type, is_bsnh))); + + ORT_RETURN_IF_ERROR((LaunchDequantizeKV( + stream, v_dequant, reinterpret_cast(data.present_value), data.v_scale, + nullptr, parameters.batch_size, parameters.kv_num_heads, parameters.seqlen_present_kv_cache, + parameters.head_size, 4, parameters.v_quant_type, is_bsnh))); +#endif + } + + // Step 3: Run Flash Attention on dequantized k/v + bool is_causal = parameters.is_unidirectional; + bool is_bf16 = std::is_same::value || std::is_same::value; + + // Use the total_seq_lens here since k_dequant/v_dequant has both past and new tokens. + void* seqlens_k_ptr = const_cast(reinterpret_cast(data.total_seq_lens)); + int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1; + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, q_rot, k_dequant, v_dequant, nullptr /*new K*/, nullptr /*new V*/, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k_ptr, nullptr /*cos_cache*/, nullptr /*sin_cache*/, + /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, reinterpret_cast(const_cast(data.head_sink)), /*block_table*/ nullptr, + parameters.batch_size, parameters.num_heads, parameters.kv_num_heads, parameters.head_size, parameters.sequence_length, + parameters.seqlen_present_kv_cache, parameters.sequence_length, 0 /*rotary_dim = 0 as it is already rotated*/, + scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, is_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + local_window_size, parameters.rotary_interleaved, false, + 0, 1)); + + return Status::OK(); +} + +// Use Flash Attention for float key and value, then quantize key/value to int8 to save to k/v cache. +template +Status FlashAttentionAndQuantizeKV( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + float scale) { + assert(parameters.is_first_prompt); // Only support first prompt for this function. + assert(parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE); + + const int max_threads_per_block = device_prop.maxThreadsPerBlock; + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_num_heads = parameters.kv_num_heads; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + + ORT_GQA_TRACE("FlashAttentionAndQuantizeKV"); + + bool past_bsnh = parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + + size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; + size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; + + using CudaT = typename onnxruntime::cuda::OrtToCudaType::type; + CudaT* q_final = reinterpret_cast(data.qkv_buffer); + + // For FlashAttentionAndQuantizeKV, we need float K and V for attention. + // We'll write them to qkv_buffer. + CudaT* k_final = q_final + q_elements; + CudaT* v_final = k_final + k_elements; + + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppend( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_final, k_final, v_final, nullptr, nullptr, + num_heads, kv_num_heads, head_size, sequence_length, batch_size, + sequence_length, data.past_seq_lens, + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, + false, // BSNH for scratch + KVQuantizationType::NONE, + 0, // bit_width is 0 since we are not quantizing here. + stream, max_threads_per_block)); + + // 2. Run Float Flash Attention + bool is_causal = parameters.is_unidirectional; + bool is_bf16 = std::is_same::value || std::is_same::value; + + int local_window_size = parameters.local_window_size > 0 ? parameters.local_window_size - 1 : -1; + + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, q_final, k_final, v_final, data.output, + reinterpret_cast(data.softmax_lse), + batch_size, num_heads, kv_num_heads, head_size, sequence_length, sequence_length, + scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, + parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), + true, // kv_bsnh = true (BSNH) + local_window_size)); + + // 3. Quantize K and V to present cache + if (parameters.k_quant_type != KVQuantizationType::NONE) { + if (parameters.kv_cache_bit_width == 8) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.k_quant_type, true, past_bsnh))); +#ifdef USE_INT4_KV_CACHE + } else if (parameters.kv_cache_bit_width == 4) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_key), reinterpret_cast(k_final), data.k_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 4, parameters.k_quant_type, true, past_bsnh))); +#endif + } + } + + if (parameters.v_quant_type != KVQuantizationType::NONE) { + if (parameters.kv_cache_bit_width == 8) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 8, parameters.v_quant_type, true, past_bsnh))); +#ifdef USE_INT4_KV_CACHE + } else if (parameters.kv_cache_bit_width == 4) { + ORT_RETURN_IF_ERROR((LaunchQuantizeKV( + stream, reinterpret_cast(data.present_value), reinterpret_cast(v_final), data.v_scale, + nullptr, data.total_seq_lens, batch_size, kv_num_heads, sequence_length, parameters.seqlen_present_kv_cache, + head_size, 4, parameters.v_quant_type, true, past_bsnh))); +#endif + } + } + + return Status::OK(); +} #endif #if USE_MEMORY_EFFICIENT_ATTENTION -template +template Status EfficientAttention( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data, + GroupQueryAttentionData& data, float scale) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int batch_size = parameters.batch_size; @@ -718,15 +1002,12 @@ Status EfficientAttention( ORT_GQA_TRACE("EfficientAttention"); const T* q_prep = nullptr; - const T* k_prep = nullptr; - const T* v_prep = nullptr; - ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); + ORT_RETURN_IF_ERROR((PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep))); const void* query = reinterpret_cast(q_prep); - const void* key = reinterpret_cast(k_prep); - const void* value = reinterpret_cast(v_prep); - - const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + const void* key; + const void* value; + const bool is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; if (num_heads == kv_num_heads) { // Use present kv directly if not grouped key = reinterpret_cast(data.present_key); @@ -737,15 +1018,16 @@ Status EfficientAttention( float2* v_buff = reinterpret_cast(data.v); const float2* k_og = reinterpret_cast(data.present_key); const float2* v_og = reinterpret_cast(data.present_value); + ORT_RETURN_IF_ERROR(LaunchUngroup(parameters, k_buff, v_buff, k_og, v_og, present_sequence_length, - present_sequence_length, is_bsnh, stream, max_threads_per_block)); + present_sequence_length, is_kv_bsnh, stream, max_threads_per_block)); key = reinterpret_cast(data.k); value = reinterpret_cast(data.v); } MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; - p.is_bf16 = std::is_same::value; + p.is_bf16 = std::is_same::value || std::is_same::value; p.is_half = !p.is_bf16 && (sizeof(T) == 2); p.batch_size = batch_size; p.num_heads = num_heads; @@ -764,7 +1046,7 @@ Status EfficientAttention( p.key = key; p.value = value; p.attn_bias = nullptr; - p.is_kv_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + p.is_kv_bsnh = is_kv_bsnh; p.output = data.output; p.workspace = MemoryEfficientAttentionParams::need_workspace(p.v_head_size, sizeof(T) == sizeof(float)) ? data.fmha_buffer @@ -781,15 +1063,18 @@ Status EfficientAttention( ////////// API Functions -template +template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& /*cublas*/, Stream* ort_stream, GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data) { + GroupQueryAttentionData& data) { auto stream = static_cast(ort_stream->GetHandle()); const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(parameters.head_size)) : parameters.scale; + if (data.use_xqa) { + return ExtremeDecoding(device_prop, stream, parameters, data, scale); + } #if USE_FLASH_ATTENTION if (data.use_flash_attention_fast_decode) { @@ -797,6 +1082,14 @@ Status QkvToContext( } if (data.use_flash_attention) { + if (parameters.k_quant_type != KVQuantizationType::NONE || parameters.v_quant_type != KVQuantizationType::NONE) { + if (parameters.is_first_prompt) { + return FlashAttentionAndQuantizeKV(device_prop, stream, parameters, data, scale); + } else { + return DequantizeFlashAttentionFallback(device_prop, stream, parameters, data, scale); + } + } + return FlashAttention(device_prop, stream, parameters, data, scale); } #endif @@ -810,24 +1103,62 @@ Status QkvToContext( return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unfused Group Query Attention not implemented yet."); } -template struct GroupQueryAttentionData; -template struct GroupQueryAttentionData; +template struct GroupQueryAttentionData; +template struct GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>; +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template Status QkvToContext<__nv_bfloat16, __nv_bfloat16>( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData<__nv_bfloat16, __nv_bfloat16>& data); + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData<__nv_bfloat16, int8_t>; -template Status QkvToContext( +template Status QkvToContext<__nv_bfloat16, int8_t>( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); + GroupQueryAttentionData<__nv_bfloat16, int8_t>& data); -template Status QkvToContext( +template struct GroupQueryAttentionData; + +template Status QkvToContext( + const cudaDeviceProp& device_prop, + cublasHandle_t& cublas, + Stream* ort_stream, + contrib::GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data); + +template struct GroupQueryAttentionData<__nv_bfloat16, uint8_t>; + +template Status QkvToContext<__nv_bfloat16, uint8_t>( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); + GroupQueryAttentionData<__nv_bfloat16, uint8_t>& data); template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +template Status LaunchUnpackQKV<__nv_bfloat16, LAYOUT_BNSH>(const __nv_bfloat16* packed_qkv, __nv_bfloat16* unpacked_q, __nv_bfloat16* unpacked_k, __nv_bfloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); + +// BFloat16 variant is used in sparse attention. template Status LaunchUnpackQKV(const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 4ad71c5003e0e..8cd4b44b9832e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -15,13 +15,13 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* stream, contrib::GroupQueryAttentionParameters& parameters, - GroupQueryAttentionData& data); + GroupQueryAttentionData& data); template Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, @@ -46,6 +46,7 @@ struct GQABufferRequirements { template static GQABufferRequirements Compute( const GroupQueryAttentionParameters& params, + bool use_xqa, bool use_flash_attention, bool use_flash_attention_fast_decode, bool use_memory_efficient_attention) { @@ -66,6 +67,15 @@ struct GQABufferRequirements { const size_t k_elements = batch_size * seq_len * kv_num_heads * head_size; const size_t v_elements = k_elements; + if (use_xqa) { + if (params.do_rotary || params.is_packed_qkv) { + // XQA need scratch for rotated/unpacked Q. + // RoPE K is written directly to cache by the fused kernel. + req.qkv_buffer_bytes = elem_size * q_elements; + } + return req; + } + if (use_flash_attention) { // Flash Attention path: // qkv_buffer is used for: @@ -75,11 +85,25 @@ struct GQABufferRequirements { // Logic: // - we generally only need Q buffer (for rotary Q) if we can write K/V directly to cache/output. - if (params.do_rotary || params.is_packed_qkv) { - // Just Q buffer needed for rotation/unpacking. - // K and V are written directly to present_key/value (unpacked/rotated/quantized/appended). + bool is_quantized = params.k_quant_type != KVQuantizationType::NONE || + params.v_quant_type != KVQuantizationType::NONE; + + if (is_quantized) { + if (!params.is_first_prompt) { + // Decoding fallback: need full cache scratch for dequantization + // We need space for Q (rotated) + K (dequantized full) + V (dequantized full) + // Q is sequence_length (1), K/V are seqlen_present_kv_cache (Capacity) + const size_t k_elements_full = batch_size * static_cast(params.seqlen_present_kv_cache) * kv_num_heads * head_size; + // Align to 256 bytes for good measure + size_t total_bytes = elem_size * (q_elements + 2 * k_elements_full) + 256; + req.qkv_buffer_bytes = total_bytes; + } else { + req.qkv_buffer_bytes = elem_size * (q_elements + k_elements + v_elements); + } + } else if (params.do_rotary || params.is_packed_qkv) { req.qkv_buffer_bytes = elem_size * q_elements; } + } else if (use_memory_efficient_attention) { // Memory Efficient Attention path: // - qkv_buffer: for unpacking packed QKV or Q rotation @@ -110,14 +134,16 @@ Status LaunchGetSequenceLengths( const int max_threads_per_block); template -Status LaunchUnpackRoPEAppendKV( +Status LaunchUnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, T* k_cache, T* v_cache, + T* unpacked_q, void* k_cache, void* v_cache, + const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, - const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block); + const bool is_cache_bnsh, const KVQuantizationType k_quant_type, + const int bit_width, cudaStream_t stream, const int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh new file mode 100644 index 0000000000000..3aa9d6d96789a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qdq.cuh @@ -0,0 +1,385 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// Enable quantized KV cache support for INT8/INT4 +#define KV_QUANT_SUPPORTED 1 + +#include + +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/rotary_common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Constants for quantization bounds +constexpr int kInt4Min = -8; +constexpr int kInt4Max = 7; +constexpr int kInt8Min = -128; +constexpr int kInt8Max = 127; +constexpr int kInt4ZeroPacked = 0x88; // (0 + 8) | ((0 + 8) << 4) for INT4 zero padding +constexpr int kThreadsPerBlock = 256; + +template +struct TypeConverter { + __device__ static float to_float(T val) { return static_cast(val); } +}; + +template <> +struct TypeConverter { + __device__ static float to_float(half val) { return __half2float(val); } +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + __device__ static float to_float(__nv_bfloat16 val) { return __bfloat162float(val); } +}; + +// ============================================================================ +// KV Cache Quantization/Dequantization Kernels +// ============================================================================ +// +// This file implements symmetric quantization for KV cache in GroupQueryAttention. +// Supports INT4 and INT8 with PER_TENSOR and PER_CHANNEL quantization modes. +// +// QUANTIZATION SCHEME: +// ------------------- +// INT4: Symmetric signed quantization +// - Range: [-8, 7] (signed 4-bit) +// - Formula: q = clamp(round(x / scale), -8, 7) +// - Rounding: Round-to-nearest (rintf) +// - Saturation: Clamp to [-8, 7] +// +// INT8: Symmetric signed quantization +// - Range: [-128, 127] (signed 8-bit) +// - Formula: q = clamp(round(x / scale), -128, 127) +// - Rounding: Round-to-nearest (rintf) +// - Saturation: Clamp to [-128, 127] +// +// BIT PACKING (INT4 only): +// ----------------------- +// Storage format: uint8_t, 2 values per byte +// packed_byte = ((q0 + 8) & 0x0F) | (((q1 + 8) & 0x0F) << 4) +// +// Where: +// - q0 (even element) → low nibble (bits 0-3) +// - q1 (odd element) → high nibble (bits 4-7) +// - +8 bias converts signed [-8, 7] to unsigned [0, 15] +// +// For odd head_size, last element q0 is paired with q1 = 0. +// +// SCALE TENSOR FORMAT: +// ------------------- +// Scales are always FP16/BF16 (type T), never quantized. +// +// PER_TENSOR: scale[0] - single scale for entire cache +// PER_CHANNEL: scale[head_idx * head_size + elem_idx] - one scale per channel +// +// MEMORY LAYOUT: +// ------------- +// Cache: BNSH (batch, num_heads, sequence_length, head_size) +// INT4: (head_size + 1) / 2 bytes per head +// INT8: head_size bytes per head +// ============================================================================ + +// Dequantization Kernel: Converts Quantized (Int8/Int4) KV cache back to Floating Point (T). +// Iterates over every individual element with one thread per element. +template +__global__ void DequantizeKernel(T* dequantized_data, + const T_QUANT* quantized_data, + const T_SCALE* scale, const int* past_seq_lens, + int batch_size, int num_heads, + int cache_sequence_length, + int head_size, int bit_width, + KVQuantizationType quant_type, + bool is_input_bsnh) { + int64_t total_elements = static_cast(batch_size) * num_heads * cache_sequence_length * head_size; + // For BIT_WIDTH=4, each T_QUANT (uint8) holds 2 elements. + int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size; + + for (int64_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; i < total_elements; + i += static_cast(blockDim.x) * gridDim.x) { + int h = static_cast(i % head_size); + int s = static_cast((i / head_size) % cache_sequence_length); + int n = static_cast((i / (head_size * cache_sequence_length)) % num_heads); + int b = static_cast((i / (num_heads * head_size * cache_sequence_length))); + + // Correctly identify padding in the past_kv cache. + // In the decoding case, `seqlens` contains `past_len + new_len - 1`. + // We need the actual past_len to mask the padding correctly. + if (past_seq_lens != nullptr) { + // For a given batch entry `b`, the actual length of the past sequence is `past_seq_lens[b]`. + // If `s` (the current sequence index) is beyond this length, it's padding and should be zeroed. + if (s >= past_seq_lens[b]) { + dequantized_data[i] = static_cast(0.0f); + continue; + } + } + + float scale_val = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale_val = static_cast(scale[0]); + } else { // PER_CHANNEL + int64_t scale_idx = static_cast(n) * head_size + h; + scale_val = static_cast(scale[scale_idx]); + } + float quantized_float; + int64_t input_idx = static_cast(b) * num_heads * cache_sequence_length * elements_per_head_packed + + static_cast(n) * cache_sequence_length * elements_per_head_packed + + static_cast(s) * elements_per_head_packed + + (bit_width == 4 ? h / 2 : h); + + if (is_input_bsnh) { + input_idx = static_cast(b) * cache_sequence_length * num_heads * elements_per_head_packed + + static_cast(s) * num_heads * elements_per_head_packed + + static_cast(n) * elements_per_head_packed + + (bit_width == 4 ? h / 2 : h); + } + + if (bit_width == 8) { + quantized_float = static_cast( + reinterpret_cast(quantized_data)[input_idx]); +#ifdef USE_INT4_KV_CACHE + } else if (bit_width == 4) { + const uint8_t packed_val = + reinterpret_cast(quantized_data)[input_idx]; + quantized_float = (h % 2 == 0) + ? static_cast((packed_val & 0x0F) - 8) + : static_cast((packed_val >> 4) - 8); +#endif + } + + dequantized_data[i] = static_cast(quantized_float * scale_val); + } +} + +template +Status LaunchDequantizeKV(cudaStream_t stream, T* dequantized_data, + const T_QUANT* quantized_data, const T_SCALE* scale, + const int* past_seq_lens, int batch_size, int num_heads, + int cache_sequence_length, + int head_size, int bit_width, + KVQuantizationType quant_type, + bool is_input_bsnh) { + if (cache_sequence_length == 0) return Status::OK(); + + // Output buffer uses cache_sequence_length stride + int64_t total_elements = static_cast(batch_size) * num_heads * cache_sequence_length * head_size; + const int blocks = static_cast((total_elements + kThreadsPerBlock - 1) / kThreadsPerBlock); + DequantizeKernel<<>>( + dequantized_data, quantized_data, scale, past_seq_lens, + batch_size, num_heads, cache_sequence_length, + head_size, bit_width, quant_type, is_input_bsnh); + + return CUDA_CALL(cudaGetLastError()); +} + +// Quantization Kernel: Converts Floating Point (T) cache to Quantized (Int8/Int4) values. +// Note: This kernel is used to quantize a full input tensor, e.g. during graph initialization +// or fallback paths. The main prompt path uses the fused UnpackRoPEAppend kernel. +template +__global__ void QuantizeKernel(T_QUANT* quantized_data, + const T* dequantized_data, const T_SCALE* scale, + const int* past_seq_lens, + const int* total_seq_lens, + int total_packed_elements, + int input_sequence_length, + int cache_sequence_length, int num_heads, int head_size, + int bit_width, KVQuantizationType quant_type, + bool is_input_bsnh, + bool is_output_bsnh) { + // elements_per_head_packed is the number of BYTES occupied by head_size elements. + int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size; + + for (int64_t i = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; i < total_packed_elements; + i += static_cast(blockDim.x) * gridDim.x) { + int h_packed = static_cast(i % elements_per_head_packed); + int s = static_cast((i / elements_per_head_packed) % cache_sequence_length); + int n = static_cast((i / (elements_per_head_packed * cache_sequence_length)) % num_heads); + int b = static_cast(i / (num_heads * elements_per_head_packed * cache_sequence_length)); + + // If past_seq_lens is provided, skip the past data to preserve it. + // This is useful when we are appending new data to an existing quantized cache (shared buffer). + if (past_seq_lens != nullptr) { + if (s < past_seq_lens[b]) { + continue; + } + } + + // Zero out padding in the present_kv cache. + // `total_seq_lens` provides the total valid sequence length for each batch item. + // If the current sequence index `s` is in the padded region, write zero. + int total_valid_len_b = total_seq_lens[b]; + if (s >= total_valid_len_b) { + if (bit_width == 8) { + int64_t out_idx = i; + if (is_output_bsnh) { + int64_t b_idx = b; + int64_t n_idx = n; + int64_t s_idx = s; + int64_t h_idx = i % elements_per_head_packed; + out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + + s_idx * num_heads * elements_per_head_packed + + n_idx * elements_per_head_packed + + h_idx; + } + reinterpret_cast(quantized_data)[out_idx] = 0; +#ifdef USE_INT4_KV_CACHE + } else if (bit_width == 4) { // INT4 + // With packed iteration, each thread handles one byte (2 values). + // Since we are in the padding region, write a zero byte. + // For BNSH/BSNH output, we need to calculate correct index. + // Memory Safety: + // We iterate up to `total_packed_elements` which matches the allocated buffer size + // (batch_size * num_heads * cache_sequence_length * elements_per_head_packed). + // Since `h_idx` comes from `i % elements_per_head_packed`, `out_idx` is guaranteed + // to be within the buffer bounds. Writing kInt4ZeroPacked is safe. + int64_t out_idx = i; + if (is_output_bsnh) { + int64_t b_idx = b; + int64_t n_idx = n; + int64_t s_idx = s; + int64_t h_idx = i % elements_per_head_packed; + out_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + + s_idx * num_heads * elements_per_head_packed + + n_idx * elements_per_head_packed + + h_idx; + } + // INT4 uses +8 bias, so zero values pack to 0x88 + reinterpret_cast(quantized_data)[out_idx] = kInt4ZeroPacked; +#endif + } + continue; + } + + int64_t output_idx = i; + if (is_output_bsnh) { + int64_t b_idx = b; + int64_t n_idx = n; + int64_t s_idx = s; + int64_t h_idx = i % elements_per_head_packed; + output_idx = b_idx * cache_sequence_length * num_heads * elements_per_head_packed + + s_idx * num_heads * elements_per_head_packed + + n_idx * elements_per_head_packed + + h_idx; + } + + if (bit_width == 8) { + int h = h_packed; + float scale_val = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale_val = static_cast(scale[0]); + } else { // PER_CHANNEL + int scale_idx = n * head_size + h; + scale_val = static_cast(scale[scale_idx]); + } + + float inv_scale = (scale_val == 0.0f) ? 0.0f : 1.0f / scale_val; + int64_t flattened_input_idx = is_input_bsnh ? ((int64_t)b * input_sequence_length * num_heads * head_size + + (int64_t)s * num_heads * head_size + + (int64_t)n * head_size + + h) + : ((int64_t)b * num_heads * input_sequence_length * head_size + + (int64_t)n * input_sequence_length * head_size + + (int64_t)s * head_size + + h); + float val_float = static_cast(dequantized_data[flattened_input_idx]) * inv_scale; + + int32_t val_int32 = static_cast(rintf(val_float)); + reinterpret_cast(quantized_data)[output_idx] = + static_cast(max(kInt8Min, min(kInt8Max, val_int32))); +#ifdef USE_INT4_KV_CACHE + } else if (bit_width == 4) { + int h0 = h_packed * 2; + int h1 = h0 + 1; + + // Compute first nibble + float scale0 = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale0 = static_cast(scale[0]); + } else { + scale0 = static_cast(scale[n * head_size + h0]); + } + float inv_scale0 = (scale0 == 0.0f) ? 0.0f : 1.0f / scale0; + + int64_t input_idx0 = is_input_bsnh ? ((int64_t)b * input_sequence_length * num_heads * head_size + + (int64_t)s * num_heads * head_size + + (int64_t)n * head_size + + h0) + : ((int64_t)b * num_heads * input_sequence_length * head_size + + (int64_t)n * input_sequence_length * head_size + + (int64_t)s * head_size + + h0); + float val0 = static_cast(dequantized_data[input_idx0]) * inv_scale0; + int8_t q0 = static_cast(max(static_cast(kInt4Min), min(static_cast(kInt4Max), rintf(val0)))); + + // Compute second nibble if within head_size + int8_t q1 = 0; // Default to 0 (value 0) if padded + if (h1 < head_size) { + float scale1 = 1.0f; + if (quant_type == KVQuantizationType::PER_TENSOR) { + scale1 = static_cast(scale[0]); + } else { + scale1 = static_cast(scale[n * head_size + h1]); + } + float inv_scale1 = (scale1 == 0.0f) ? 0.0f : 1.0f / scale1; + + int64_t input_idx1 = is_input_bsnh ? ((int64_t)b * input_sequence_length * num_heads * head_size + + (int64_t)s * num_heads * head_size + + (int64_t)n * head_size + + h1) + : ((int64_t)b * num_heads * input_sequence_length * head_size + + (int64_t)n * input_sequence_length * head_size + + (int64_t)s * head_size + + h1); + float val1 = static_cast(dequantized_data[input_idx1]) * inv_scale1; + q1 = static_cast(max(static_cast(kInt4Min), min(static_cast(kInt4Max), rintf(val1)))); + } else { + // Padding for odd head_size + q1 = 0; + } + + // Pack two 4-bit values into one byte with +8 bias to convert to unsigned [0,15] + // Low nibble: q0 (even element), High nibble: q1 (odd element) + uint8_t packed = ((q0 + 8) & 0x0F) | (((q1 + 8) & 0x0F) << 4); + reinterpret_cast(quantized_data)[output_idx] = packed; +#endif + } + } +} + +template +Status LaunchQuantizeKV(cudaStream_t stream, T_QUANT* quantized_data, + const T* dequantized_data, const T_SCALE* scale, + const int* past_seq_lens, + const int* total_seq_lens, + int batch_size, int num_heads, + int input_sequence_length, int cache_sequence_length, int head_size, int bit_width, + KVQuantizationType quant_type, + bool is_input_bsnh, + bool is_output_bsnh) { + assert(total_seq_lens != nullptr); + if (cache_sequence_length == 0) return Status::OK(); + + int elements_per_head_packed = (bit_width == 4) ? (head_size + 1) / 2 : head_size; + int total_packed_elements = batch_size * num_heads * cache_sequence_length * elements_per_head_packed; + + int blocks = (total_packed_elements + kThreadsPerBlock - 1) / kThreadsPerBlock; + + QuantizeKernel<<>>( + quantized_data, dequantized_data, scale, past_seq_lens, total_seq_lens, total_packed_elements, + input_sequence_length, cache_sequence_length, num_heads, head_size, bit_width, quant_type, is_input_bsnh, is_output_bsnh); + + return CUDA_CALL(cudaGetLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh index ddf24aff27442..d5c95be316a1f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -16,28 +16,32 @@ namespace onnxruntime { namespace contrib { namespace cuda { -// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache +// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache + Quantize if needed // -// OPTIMIZATION: This version uses Shared Memory to store the current head being processed. -// Shared memory allows RoPE dispatcher to access paired elements in non-interleaved mode -// (element i pairs with i ± rotary_dim/2) without global memory gathers. +// This kernel performs the following: +// 1. Unpacks Q, K, V from input tensor(s). The input can be a single packed QKV tensor +// or three separate Q, K, V tensors. +// 2. Applies Rotary Positional Embedding (RoPE) to Q and K if rotary_dim > 0. +// 3. Appends K and V to the KV cache at the correct sequence index (past_seq_len + s). +// - Performs on-the-fly quantization (Int8 or Int4) if configured (BIT_WIDTH < 16). +// - Supports both BNSH and BSNH layouts for the KV cache. +// 4. Writes the rotated Q back to global memory (unpacked_q) for the subsequent attention kernel. // -// Alignment Note: This kernel assumes that base pointers (packed_qkv, query, etc.) -// are 16-byte aligned and that head_size is a multiple of elements_per_thread. -// -// Grid Layout: -// blockIdx.x: sequence index (s) -> Max 2^31-1 (Supports very long context) -// blockIdx.y: head index (head_idx) -> Max 65535 -// blockIdx.z: batch index (b) -> Max 65535 -template +// Template Parameters: +// - T: The floating point type (half or BFloat16). +// - BIT_WIDTH: The bit width for KV cache quantization (16=none, 8=Int8, 4=Int4). +// - MAX_HEAD_SIZE: Maximum supported head size, used for shared memory allocation. +template __global__ void UnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, T* unpacked_q, - T* k_cache, - T* v_cache, + void* k_cache, + void* v_cache, + const float* k_scale, + const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, @@ -49,14 +53,20 @@ __global__ void UnpackRoPEAppend( const int rotary_dim, const int64_t* position_ids, const bool interleaved, - const bool is_cache_bnsh) { + const bool is_cache_bnsh, + const bool per_channel) { using LoadT = float4; constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); + // Determine grid coordinates: + // - s: current sequence index (within the new tokens batch) + // - head_idx: global head index (0 to num_heads + 2*kv_num_heads - 1) + // - b: batch index const int s = blockIdx.x; const int head_idx = blockIdx.y; const int b = blockIdx.z; const int tid = threadIdx.x; + // h: the starting channel index for this thread (multiple elements per thread via LoadT) const int h = tid * elements_per_thread; // Guard work with 'valid' instead of early return to ensure all threads reach __syncthreads() @@ -64,16 +74,16 @@ __global__ void UnpackRoPEAppend( const int q_hidden = num_heads * head_size; const int k_hidden = kv_num_heads * head_size; - const int sequence_length = gridDim.x; + const int sequence_length = gridDim.x; // Number of new tokens in this launch __shared__ T shared_head[MAX_HEAD_SIZE]; - // Determine Head Type and Offset within hidden dimension + // Determine Head Type and Offset within the packed hidden dimension [Q, K, V] enum HeadType { QUERY, KEY, VALUE }; HeadType head_type; - int n; // Index within its specific type + int n; // Index relative to its specific type int offset_in_hidden; if (head_idx < num_heads) { @@ -91,7 +101,7 @@ __global__ void UnpackRoPEAppend( } // 1. Load data into Registers - T vals[elements_per_thread]; + alignas(16) T vals[elements_per_thread]; if (valid) { if (packed_qkv != nullptr) { const int64_t packed_idx = static_cast(b) * sequence_length * d + @@ -118,8 +128,9 @@ __global__ void UnpackRoPEAppend( } } - // 2. Process RoPE - // Optimization: Only use shared memory for non-interleaved mode + // 2. Process RoPE (Rotary Positional Embedding) + // Non-interleaved RoPE requires full head visibility to pair channels (h, h + d/2). + // We use shared memory as a staging buffer to allow any thread to access its pair. const bool is_qk = (head_type == QUERY || head_type == KEY); if (valid && rotary_dim > 0 && is_qk && !interleaved) { T* shared_ptr = &shared_head[h]; @@ -127,12 +138,13 @@ __global__ void UnpackRoPEAppend( } // CRITICAL: Barrier must be outside the 'if(valid)' and 'if(is_qk)' blocks - // to ensure every thread in the block participates. + // to ensure every thread in the block participates and shared memory is ready. __syncthreads(); if (valid && rotary_dim > 0 && is_qk) { const int past_seq_len = past_seq_lens[b]; const int64_t pos_base = static_cast(b) * sequence_length; + // Calculate global position for RoPE: use position_ids if provided, else rely on past_seq_len. int pos_id = (position_ids != nullptr) ? static_cast(position_ids[pos_base + s]) : (past_seq_len + s); const int h_idx = h / elements_per_thread; @@ -145,44 +157,136 @@ __global__ void UnpackRoPEAppend( 0); } - // 3. Store results back to Global Memory + // 3. Store results back to Global Memory (Unpacked Q and Quantized KV Cache) if (valid) { if (head_type == QUERY) { if (unpacked_q != nullptr) { + // Store rotated Q to global memory for the Attention kernel const int64_t q_out_idx = static_cast(b) * sequence_length * q_hidden + static_cast(s) * q_hidden + static_cast(n) * head_size + h; reinterpret_cast(unpacked_q)[q_out_idx / elements_per_thread] = *reinterpret_cast(vals); } } else { + // Store K or V into the KV cache at index (past_seqlen + s) const int cache_s = past_seq_lens[b] + s; if (cache_s < max_seqlen) { - T* cache_ptr = (head_type == KEY) ? k_cache : v_cache; + void* cache_ptr = (head_type == KEY) ? k_cache : v_cache; if (cache_ptr != nullptr) { - int64_t cache_idx = is_cache_bnsh ? (static_cast(b) * kv_num_heads * max_seqlen * head_size + static_cast(n) * max_seqlen * head_size + static_cast(cache_s) * head_size + h) : (static_cast(b) * max_seqlen * kv_num_heads * head_size + static_cast(cache_s) * kv_num_heads * head_size + static_cast(n) * head_size + h); - reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); + int64_t cache_idx; + if (is_cache_bnsh) { + // BNSH layout: [Batch, NumHeads, SeqLen, HeadSize] + // Note: For BIT_WIDTH=4, head_size refers to the number of UNPACKED elements. + // stride_s is the number of bytes occupied by head_size elements. + const int64_t stride_s = (BIT_WIDTH == 4) ? (head_size / 2) : head_size; + const int64_t stride_n = max_seqlen * stride_s; + const int64_t stride_b = kv_num_heads * stride_n; + cache_idx = static_cast(b) * stride_b + + static_cast(n) * stride_n + + static_cast(cache_s) * stride_s + + (BIT_WIDTH == 4 ? h / 2 : h); + } else { + // BSNH layout: [Batch, SeqLen, NumHeads, HeadSize] + const int64_t stride_n = (BIT_WIDTH == 4) ? (head_size / 2) : head_size; + const int64_t stride_s = kv_num_heads * stride_n; + const int64_t stride_b = max_seqlen * stride_s; + cache_idx = static_cast(b) * stride_b + + static_cast(cache_s) * stride_s + + static_cast(n) * stride_n + + (BIT_WIDTH == 4 ? h / 2 : h); + } + + if constexpr (BIT_WIDTH == 16 || BIT_WIDTH == 32) { + // No quantization: direct store + reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); + } else if constexpr (BIT_WIDTH == 8) { + // Int8 Quantization: 1 element per byte + const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; + uint64_t packed = 0; + for (int i = 0; i < elements_per_thread; ++i) { + float sc = per_channel ? scale_buffer[n * head_size + h + i] : scale_buffer[0]; + float inv_s = (sc == 0.0f) ? 0.0f : 1.0f / sc; + int8_t q = static_cast(max(-128.0f, min(127.0f, rintf(static_cast(vals[i]) * inv_s)))); + packed |= (static_cast(static_cast(q)) << (i * 8)); + } + // Store 8 elements (8 bytes) at once + reinterpret_cast(cache_ptr)[cache_idx / 8] = packed; + } else if constexpr (BIT_WIDTH == 4) { + // Int4 Quantization: 2 elements per byte + constexpr float kInt4Min = -8.0f; + constexpr float kInt4Max = 7.0f; + const float* scale_buffer = (head_type == KEY) ? k_scale : v_scale; + uint32_t packed = 0; + for (int i = 0; i < 4; ++i) { + // Elements are paired as (0,1), (2,3), etc. into single bytes. + float s0 = per_channel ? scale_buffer[n * head_size + h + i * 2] : scale_buffer[0]; + float s1 = per_channel ? scale_buffer[n * head_size + h + i * 2 + 1] : scale_buffer[0]; + int8_t q0 = static_cast(max(kInt4Min, min(kInt4Max, rintf(static_cast(vals[i * 2]) * (s0 == 0 ? 0 : 1.0f / s0))))); + int8_t q1 = static_cast(max(kInt4Min, min(kInt4Max, rintf(static_cast(vals[i * 2 + 1]) * (s1 == 0 ? 0 : 1.0f / s1))))); + uint8_t p = ((q0 + 8) & 0x0F) | (((q1 + 8) & 0x0F) << 4); + packed |= (static_cast(p) << (i * 8)); + } + // Store 8 elements (4 bytes) at once + reinterpret_cast(cache_ptr)[cache_idx / 4] = packed; + } } } } } } +// Internal dispatcher that selects the appropriate template specialization based on head_size. +// MAX_HEAD_SIZE is used to optimize shared memory usage and kernel performance. +template +Status DispatchUnpackRoPEAppendHeadSize( + const dim3& grid, const dim3& block, cudaStream_t stream, + const T* packed_qkv, const T* query, const T* key, const T* value, + T* unpacked_q, void* k_cache, void* v_cache, + const float* k_scale, const float* v_scale, + const int num_heads, const int kv_num_heads, const int head_size, const int d, + const int max_seqlen, const int* past_seq_lens, + const T* cos_cache, const T* sin_cache, const int rotary_dim, + const int64_t* position_ids, const bool interleaved, const bool is_cache_bnsh, const bool per_channel) { + if (head_size <= 64) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + } else if (head_size <= 128) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + } else if (head_size <= 256) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, k_scale, v_scale, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (256)."); + } + return CUDA_CALL(cudaGetLastError()); +} + +// Public entry point to launch the Unpack+RoPE+Append kernel. +// Handles parameter validation, grid/block sizing, and bit-width dispatching. template -Status LaunchUnpackRoPEAppendKV( +Status LaunchUnpackRoPEAppend( const T* packed_qkv, const T* query, const T* key, const T* value, - T* unpacked_q, T* k_cache, T* v_cache, + T* unpacked_q, void* k_cache, void* v_cache, + const float* k_scale, const float* v_scale, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, const int max_seqlen, const int* past_seq_lens, const T* cos_cache, const T* sin_cache, const int rotary_dim, const int64_t* position_ids, const bool interleaved, - const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block) { + const bool is_cache_bnsh, const KVQuantizationType k_quant_type, + const int bit_width, cudaStream_t stream, const int max_threads_per_block) { constexpr int elements_per_vector = sizeof(float4) / sizeof(T); if (head_size % elements_per_vector != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by vector size (16 bytes)."); } - // rotary_dim <= head_size check to prevent out-of-bounds in shared memory if (rotary_dim > head_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "rotary_dim (", rotary_dim, ") cannot exceed head_size (", head_size, ")."); } @@ -209,40 +313,30 @@ Status LaunchUnpackRoPEAppendKV( const dim3 grid(sequence_length, total_heads, batch_size); const dim3 block(threads_per_block); - // Dynamic dispatch for MAX_HEAD_SIZE templates to improve occupancy for common LLM head sizes - if (head_size <= 64) { - UnpackRoPEAppend<<>>( - packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, - num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); - } else if (head_size <= 128) { - UnpackRoPEAppend<<>>( - packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, - num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); - } else if (head_size <= 256) { - UnpackRoPEAppend<<>>( - packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, - num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, - cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (256)."); + bool per_channel = (k_quant_type == KVQuantizationType::PER_CHANNEL); + + if (bit_width == 0) { + return DispatchUnpackRoPEAppendHeadSize( + grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); + } else if (bit_width == 8) { + return DispatchUnpackRoPEAppendHeadSize( + grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); +#ifdef USE_INT4_KV_CACHE + } else if (bit_width == 4) { + return DispatchUnpackRoPEAppendHeadSize( + grid, block, stream, packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + k_scale, v_scale, num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh, per_channel); +#endif } - return CUDA_CALL(cudaGetLastError()); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported bit_width (", bit_width, ") for GQA quantization."); } -// Explicit template instantiations -template Status LaunchUnpackRoPEAppendKV( - const half*, const half*, const half*, const half*, half*, half*, half*, - int, int, int, int, int, int, const int*, const half*, const half*, int, const int64_t*, bool, bool, - cudaStream_t, int); - -template Status LaunchUnpackRoPEAppendKV( - const BFloat16*, const BFloat16*, const BFloat16*, const BFloat16*, BFloat16*, BFloat16*, BFloat16*, - int, int, int, int, int, int, const int*, const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, - cudaStream_t, int); - } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh b/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh index 1cab81e83b2ef..0f05ad4687962 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_common.cuh @@ -390,6 +390,102 @@ struct RotaryDispatcher { } }; +// ============================================================================ +// Specialization: float2 + __nv_bfloat16 +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float2& val, const float2* cos_cache, const float2* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float2* new_kv_base, const int64_t in_offset) { + if (2 * h_idx * 2 >= rotary_dim) return; + + using namespace onnxruntime::cuda; + __nv_bfloat162* v_ptr = reinterpret_cast<__nv_bfloat162*>(&val); + __nv_bfloat162 v0 = v_ptr[0]; + __nv_bfloat162 v1 = v_ptr[1]; + const __nv_bfloat162* cos_ptr = reinterpret_cast(cos_cache); + const __nv_bfloat162* sin_ptr = reinterpret_cast(sin_cache); + int half_rot = rotary_dim / 2; + + if (interleaved) { + int f0 = 2 * h_idx; + int cs0 = pos_id * half_rot + f0; + + __nv_bfloat162 c_pair = cos_ptr[cs0 / 2]; + __nv_bfloat162 s_pair = sin_ptr[cs0 / 2]; + + float c0f = __bfloat162float(c_pair.x); + float s0f = __bfloat162float(s_pair.x); + float e0x = __bfloat162float(v0.x); + float e0y = __bfloat162float(v0.y); + v0.x = __float2bfloat16(e0x * c0f - e0y * s0f); + v0.y = __float2bfloat16(e0x * s0f + e0y * c0f); + + float c1f = __bfloat162float(c_pair.y); + float s1f = __bfloat162float(s_pair.y); + float e1x = __bfloat162float(v1.x); + float e1y = __bfloat162float(v1.y); + v1.x = __float2bfloat16(e1x * c1f - e1y * s1f); + v1.y = __float2bfloat16(e1x * s1f + e1y * c1f); + + } else { + const __nv_bfloat16* kv_ptr = reinterpret_cast(new_kv_base); + int base_idx = 4 * h_idx; + int64_t scalar_in_offset = in_offset * 4; + + auto rotate_element_bf16 = [&](int idx, __nv_bfloat16& val) { + if (idx >= rotary_dim) return; + int pair_idx = (idx < half_rot) ? (idx + half_rot) : (idx - half_rot); + float sign = (idx < half_rot) ? -1.0f : 1.0f; + int cos_idx = idx % half_rot; + int cs_idx = pos_id * half_rot + cos_idx; + + __nv_bfloat16 c_val = reinterpret_cast(cos_ptr)[cs_idx]; + __nv_bfloat16 s_val = reinterpret_cast(sin_ptr)[cs_idx]; + + float val_f = __bfloat162float(val); + float pair_f = __bfloat162float(kv_ptr[scalar_in_offset + pair_idx]); + float cf = __bfloat162float(c_val); + float sf = __bfloat162float(s_val); + + val = __float2bfloat16(val_f * cf + sign * pair_f * sf); + }; + + rotate_element_bf16(base_idx, v0.x); + rotate_element_bf16(base_idx + 1, v0.y); + rotate_element_bf16(base_idx + 2, v1.x); + rotate_element_bf16(base_idx + 3, v1.y); + } + v_ptr[0] = v0; + v_ptr[1] = v1; + } +}; + +// ============================================================================ +// Specialization: float4 + __nv_bfloat16 +// ============================================================================ +template <> +struct RotaryDispatcher { + __device__ static void apply(float4& val, const float4* cos_cache, const float4* sin_cache, + const int rotary_dim, const int h_idx, const int pos_id, + const bool interleaved, const float4* new_kv_base, const int64_t in_offset) { + float2 p1 = make_float2(val.x, val.y); + float2 p2 = make_float2(val.z, val.w); + const float2* c = reinterpret_cast(cos_cache); + const float2* s = reinterpret_cast(sin_cache); + const float2* b = reinterpret_cast(new_kv_base); + + RotaryDispatcher::apply(p1, c, s, rotary_dim, h_idx * 2, pos_id, interleaved, b, in_offset * 2); + RotaryDispatcher::apply(p2, c, s, rotary_dim, h_idx * 2 + 1, pos_id, interleaved, b, in_offset * 2); + + val.x = p1.x; + val.y = p1.y; + val.z = p2.x; + val.w = p2.y; + } +}; + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/RefChecker.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/RefChecker.cuh new file mode 100644 index 0000000000000..b5348848dac9b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/RefChecker.cuh @@ -0,0 +1,95 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "utils.cuh" +#include +#include +#include +#include +#include +#include + +struct RefChecker { + half q[8][32][32]; + half k[8][4][64][32]; + float qk[4][32][64]; + float tileRowMax[4][32]; + half x[4][32][64]; + half v[8][4][32][64]; + float tileRowSum[4][32]; + float acc1PerStep[4][32][256]; + half out[32][256]; + + void init() { +#define INIT_MEMBER(member) initMember(member, #member) + INIT_MEMBER(q); + INIT_MEMBER(k); + INIT_MEMBER(qk); + INIT_MEMBER(tileRowMax); + INIT_MEMBER(x); + INIT_MEMBER(v); + INIT_MEMBER(tileRowSum); + INIT_MEMBER(acc1PerStep); + INIT_MEMBER(out); +#undef INIT_MEMBER + } + + private: + template + void initMember(T& dst, char const* varName); +}; + +template +std::enable_if_t, float> || std::is_same_v, half>, std::string> +makeFileName(T (&dst)[d0][d1][d2][d3], char const* varName) { + std::stringstream ss; + ss << varName << '_' << d0 << 'x' << d1 << 'x' << d2 << 'x' << d3 << '_' + << (std::is_same_v, float> ? "f32" : "f16") << ".bin"; + return ss.str(); +} + +template +std::enable_if_t, float> || std::is_same_v, half>, std::string> +makeFileName(T (&dst)[d0][d1][d2], char const* varName) { + std::stringstream ss; + ss << varName << '_' << d0 << 'x' << d1 << 'x' << d2 << '_' + << (std::is_same_v, float> ? "f32" : "f16") << ".bin"; + return ss.str(); +} + +template +std::enable_if_t, float> || std::is_same_v, half>, std::string> +makeFileName(T (&dst)[d0][d1], char const* varName) { + std::stringstream ss; + ss << varName << '_' << d0 << 'x' << d1 << '_' << (std::is_same_v, float> ? "f32" : "f16") + << ".bin"; + return ss.str(); +} + +template +void RefChecker::initMember(T& dst, char const* varName) { + std::string const filename = makeFileName(dst, varName); + printf("loading %s\n", filename.c_str()); + namespace fs = std::filesystem; + assert(fs::exists(filename)); + assert(fs::file_size(filename) == sizeof(dst)); + std::ifstream fin(filename, std::ios::binary); + fin.read(reinterpret_cast(&dst), sizeof(dst)); + assert(fin); +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/barriers.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/barriers.cuh new file mode 100644 index 0000000000000..79a9ee4b5cdb7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/barriers.cuh @@ -0,0 +1,409 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "defines.h" +#if !USE_CUSTOM_BARRIER +#include +using CtaBarrier = cuda::barrier; +#else + +#ifndef __CUDACC__ +#include +#endif + +#if __CUDACC_VER_MAJOR__ < 12 +#define STR_REL_CTA "" +#define STR_ACQ_CTA "" +#else +#define STR_REL_CTA ".release.cta" +#define STR_ACQ_CTA ".acquire.cta" +#endif + +enum class Scope : uint32_t { + CTA = 0, + CGA = 1, +}; + +enum class ArriveOrder : uint32_t { + RELEASE = 0, + RELAXED = 1, +}; + +enum class ArrivalToken : uint64_t { +}; + +template +class MBarrier // rename this to MBarrier +{ + public: + using ArrivalToken = ::ArrivalToken; + static constexpr Scope defaultScope = defaultScope_; + using arrival_token = ArrivalToken; + + __device__ inline MBarrier(uint32_t count) { + assert(count > 0); + asm volatile("mbarrier.init.b64 [%0], %1;\n" ::"l"(addr()), "r"(count) : "memory"); + } + + __device__ ~MBarrier() { + asm volatile("mbarrier.inval.b64 [%0];\n" ::"l"(addr()) : "memory"); + } + + template + __device__ inline mha::conditional_t arrive(uint32_t update = 1) { + ArrivalToken token{}; +#if __CUDA_ARCH__ >= 900 + if constexpr (scope == Scope::CTA) { + switch (order) { + case ArriveOrder::RELEASE: + asm volatile("mbarrier.arrive.release.cta.b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(update) + : "memory"); + break; + case ArriveOrder::RELAXED: + asm volatile("mbarrier.arrive.relaxed.cta.b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(update) + : "memory"); + break; + } + return token; + } else { + static_assert(scope == Scope::CGA); + switch (order) { + case ArriveOrder::RELEASE: + asm volatile("mbarrier.arrive.release.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(update) + : "memory"); + break; + case ArriveOrder::RELAXED: + asm volatile("mbarrier.arrive.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(update) + : "memory"); + break; + } + return; + } +#else + static_assert(scope == Scope::CTA && order == ArriveOrder::RELEASE); + if (update > 1) { + asm volatile("mbarrier.arrive.noComplete" STR_REL_CTA ".b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(update - 1U) + : "memory"); + [[maybe_unused]] ArrivalToken refToken; + asm volatile("mbarrier.arrive" STR_REL_CTA ".b64 %0, [%1];\n" : "=l"(refToken) : "l"(addr()) : "memory"); + assert(token == refToken); + return token; + } else { + asm volatile("mbarrier.arrive" STR_REL_CTA ".b64 %0, [%1];\n" : "=l"(token) : "l"(addr()) : "memory"); + return token; + } +#endif + } + + __device__ inline bool isLocal() const { + uint32_t addrCtaRank{}; + asm("getctarank.u64 %0, %1;\n" : "=r"(addrCtaRank) : "l"(addr())); + uint32_t ctaRank{}; + asm("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(ctaRank)); + return addrCtaRank == ctaRank; + } + + __device__ inline void remoteArrive(uint32_t update = 1) { +#if __CUDA_ARCH__ >= 900 + assert(!isLocal()); + asm volatile("mbarrier.arrive.release.cluster.shared::cluster.b64 _, [%0], %1;\n" + : + : "l"(__cvta_generic_to_shared(&mBar)), "r"(update) + : "memory"); +#else + asm volatile("trap;\n"); +#endif + } + + template + __device__ inline mha::conditional_t arrive_tx_relaxed(uint32_t txCount) { +#if __CUDA_ARCH__ >= 900 + if constexpr (scope == Scope::CTA) { + ArrivalToken token{}; + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(txCount) + : "memory"); + return token; + } else { + asm volatile("mbarrier.arrive.expect_tx.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount) + : "memory"); + return; + } +#else + asm volatile("trap;\n"); +#endif + } + + template + __device__ inline mha::conditional_t arrive_tx( + uint32_t txCount, uint32_t arriveCount = 1) { +#if __CUDA_ARCH__ >= 900 + if (arriveCount == 1) { + if constexpr (scope == Scope::CTA) { + ArrivalToken token{}; + switch (order) { + case ArriveOrder::RELEASE: + asm volatile("mbarrier.arrive.expect_tx.release.cta.b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(txCount) + : "memory"); + break; + case ArriveOrder::RELAXED: + asm volatile("mbarrier.arrive.expect_tx.relaxed.cta.b64 %0, [%1], %2;\n" + : "=l"(token) + : "l"(addr()), "r"(txCount) + : "memory"); + break; + } + return token; + } else { + static_assert(scope == Scope::CGA); + switch (order) { + case ArriveOrder::RELEASE: + asm volatile( + "mbarrier.arrive.expect_tx.release.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount) + : "memory"); + break; + case ArriveOrder::RELAXED: + asm volatile( + "mbarrier.arrive.expect_tx.relaxed.cluster.b64 _, [%0], %1;\n" ::"l"(addr()), "r"(txCount) + : "memory"); + break; + } + return; + } + } else { + if constexpr (scope == Scope::CTA) { + asm volatile("mbarrier.expect_tx.relaxed.cta.b64 [%0], %1;\n" ::"l"(addr()), "r"(txCount) : "memory"); + } else { + asm volatile("mbarrier.expect_tx.relaxed.cluster.b64 [%0], %1;\n" ::"l"(addr()), "r"(txCount) + : "memory"); + } + return arrive(arriveCount); + } +#else + asm volatile("trap;\n"); +#endif + } + + template + __device__ inline bool test_wait(ArrivalToken&& token) { + uint32_t ready{}; + if constexpr (scope == Scope::CGA) { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.test_wait.acquire.cluster.b64 ready, [%1], %2;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "l"(token) + : "memory"); + } else { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.test_wait" STR_ACQ_CTA + ".b64 ready, [%1], %2;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "l"(token) + : "memory"); + } + return ready != 0; + } + + template + __device__ inline bool test_wait_parity(bool parity) { + uint32_t ready{}; + if constexpr (scope == Scope::CGA) { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.test_wait.parity.acquire.cluster.b64 ready, [%1], %2;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "r"(uint32_t{parity}) + : "memory"); + } else { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.test_wait.parity" STR_ACQ_CTA + ".b64 ready, [%1], %2;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "r"(uint32_t{parity}) + : "memory"); + } + return ready != 0; + } +#if __CUDA_ARCH__ >= 900 + template + __device__ inline bool try_wait(ArrivalToken&& token) { + uint32_t ready{}; + if constexpr (scope == Scope::CGA) { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.try_wait.acquire.cluster.b64 ready, [%1], %2, %3;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "l"(token), "n"(kSUSPEND_TIME_HINT) + : "memory"); + } else { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.try_wait.acquire.cta.b64 ready, [%1], %2, %3;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "l"(token), "n"(kSUSPEND_TIME_HINT) + : "memory"); + } + return ready != 0; + } + + template + __device__ inline bool try_wait_parity(bool parity) { + uint32_t ready{}; + if constexpr (scope == Scope::CGA) { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.try_wait.parity.acquire.cluster.b64 ready, [%1], %2, %3;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "r"(uint32_t{parity}), "n"(kSUSPEND_TIME_HINT) + : "memory"); + } else { + asm volatile( + "{\n" + ".reg .pred ready;\n" + "mbarrier.try_wait.parity.acquire.cta.b64 ready, [%1], %2, %3;\n" + "selp.b32 %0, 1, 0, ready;\n" + "}\n" + : "=r"(ready) + : "l"(addr()), "r"(uint32_t{parity}), "n"(kSUSPEND_TIME_HINT) + : "memory"); + } + return ready != 0; + } +#endif + template + __device__ inline void wait(ArrivalToken&& token) { +#if __CUDA_ARCH__ >= 900 + poll([&]() { return try_wait(ArrivalToken{token}); }); +#else + poll([&]() { return test_wait(ArrivalToken{token}); }); +#endif + } + + // starting from `parity = false`. + template + __device__ inline void wait_parity(bool parity) { +#if __CUDA_ARCH__ >= 900 + poll([&]() { return try_wait_parity(parity); }); +#else + poll([&]() { return test_wait_parity(parity); }); +#endif + } + + template + __device__ inline mha::enable_if_t arrive_and_wait(uint32_t update = 1) { + wait(arrive(update)); + } + + private: + __device__ inline uint64_t addr() const { + return reinterpret_cast(&mBar); + } + + template + __device__ inline static void poll(F&& func) { + if constexpr (funcSupportsBlocking) { + while (!func()) { + } + } else { + float sleepDuration = 0.125F; + while (!func()) { + __nanosleep(uint32_t(sleepDuration)); + sleepDuration = sleepDuration * 1.25F + 0.F; + } + } + } + + public: + static constexpr uint32_t kSUSPEND_TIME_HINT = 0xFFFFFFFFU; + + private: + uint64_t mBar; +}; + +template +__device__ inline void init(MBarrier* bar, uint32_t count) { + new (bar) MBarrier{count}; +} + +using CtaBarrier = MBarrier; +using CgaBarrier = MBarrier; + +template +__device__ inline constexpr bool toParity(uint32_t i) { + return i % (nbBars * 2) / nbBars; +} + +class NamedBarrier { + public: + __device__ inline NamedBarrier(uint32_t idxBar, uint32_t arriveCount) + : mName{idxBar}, mArriveCount{arriveCount} { + assert(idxBar < 16 && arriveCount % 32 == 0); + } + + __device__ inline void arrive() const { + asm volatile("barrier.cta.arrive %0, %1;\n" ::"r"(mName), "r"(mArriveCount) : "memory"); + } + + __device__ inline void arrive_and_wait() const { + asm volatile("barrier.cta.sync %0, %1;\n" ::"r"(mName), "r"(mArriveCount) : "memory"); + } + + private: + uint32_t const mName; + uint32_t const mArriveCount; +}; + +__device__ inline void namedBarSync(uint32_t idxBar, uint32_t arriveCount) { + NamedBarrier bar{idxBar, arriveCount}; + bar.arrive_and_wait(); +} +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/cuda_hint.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/cuda_hint.cuh new file mode 100644 index 0000000000000..5350ea9ce12d2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/cuda_hint.cuh @@ -0,0 +1,44 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "platform.h" + +#if IS_IN_IDE_PARSER + +#ifndef __CUDACC__ +#define __CUDACC__ 1 +#endif + +#ifndef __CUDA_ARCH__ +#define __CUDA_ARCH__ 900 +#endif + +#ifndef __CUDACC_VER_MAJOR__ +#define __CUDACC_VER_MAJOR__ 12 +#endif +#ifndef __CUDACC_VER_MINOR__ +#define __CUDACC_VER_MINOR__ 9 +#endif + +#if __CUDA_ARCH__ == 900 +#ifndef __CUDA_ARCH_FEAT_SM90_ALL +#define __CUDA_ARCH_FEAT_SM90_ALL +#endif +#endif + +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/defines.h b/onnxruntime/contrib_ops/cuda/bert/xqa/defines.h new file mode 100644 index 0000000000000..3c1e9d1c96434 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/defines.h @@ -0,0 +1,213 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "mha_stdheaders.cuh" + +#define STATIC_NB_K_HEADS 0 +#if STATIC_NB_K_HEADS +#define NB_K_HEADS 2 +#endif + +// allowed values are multiples of 16 in range [16, 256] +#ifndef HEAD_ELEMS +#define HEAD_ELEMS 128 +#endif + +// nbQHeads / nbKHeads for MQA/GQA +#ifndef HEAD_GRP_SIZE +#define HEAD_GRP_SIZE 8 +#endif + +#define IS_MLA (HEAD_GRP_SIZE == 128 && HEAD_ELEMS == 576) + +#if IS_MLA +#define INPUT_ELEM __nv_fp8_e4m3 +#define INPUT_ELEM2 __nv_fp8x2_e4m3 +#define HEAD_ELEMS_V 512 +#else +// 1 means fp16 and 0 means bf16 input/output +#ifndef INPUT_FP16 +#define INPUT_FP16 1 +#endif + +// Don't modify +#if INPUT_FP16 +#define INPUT_ELEM half +#define INPUT_ELEM2 half2 +#else +#define INPUT_ELEM __nv_bfloat16 +#define INPUT_ELEM2 __nv_bfloat162 +#endif +#endif + +// For beam search. Allowed values: 1, 4 +#ifndef BEAM_WIDTH +#define BEAM_WIDTH 1 +#endif + +#ifndef SPEC_DEC +#define SPEC_DEC 0 +#endif + +// M_TILESIZE: Number of query tokens processed per CTA in M dimension. +// For decoding (S=1), this equals group_size. Allowed values: 8, 16, 32 +#ifndef M_TILESIZE +#define M_TILESIZE 32 +#endif + +#if SPEC_DEC +using MaskType = uint32_t; +#endif + +// Enables SWAP AB optimization for speculative decoding when using a small, fixed Q_SEQ_LEN. +// NOTE: Requires a uniform input sequence length for the entire batch. +#ifdef SPEC_Q_SEQ_LEN +static_assert(SPEC_DEC, "SPEC_Q_SEQ_LEN should only be used when SPEC_DEC is enabled."); +#endif + +// 0: half/bf16 based on INPUT_FP16; 1: int8_t; 2: __nv_fp8_e4m3 +#ifndef CACHE_ELEM_ENUM +#define CACHE_ELEM_ENUM 0 +#endif + +// don't modify +#define USE_KV_CACHE true + +// don't modify +#ifndef ALLOW_MULTI_BLOCK_MODE +#define ALLOW_MULTI_BLOCK_MODE true +#endif + +// For paged KV cache. Allowed values: 0, 16, 32, 64, 128 +// 0 means contiguous KV cache (non-paged). +#ifndef TOKENS_PER_PAGE +#define TOKENS_PER_PAGE 0 +#endif + +// don't modify +#ifndef USE_PAGED_KV_CACHE +#define USE_PAGED_KV_CACHE (TOKENS_PER_PAGE > 0) +#endif + +// Paged KV Cache Format +// 0 - XQA Original +// 1 - separate K and V cache pools, each with layout (batch, seq_len, head, head_elem) for VLLM/SGLang +#ifdef USE_PAGED_KV_CACHE +#ifndef PAGED_KV_CACHE_LAYOUT +#define PAGED_KV_CACHE_LAYOUT 0 +#endif +#endif + +// don't modify +#define USE_BEAM_SEARCH (BEAM_WIDTH > 1) + +#if CACHE_ELEM_ENUM == 0 +#define PRAGMA_UNROLL_FP16_ONLY _Pragma("unroll") +#else +#define PRAGMA_UNROLL_FP16_ONLY _Pragma("unroll(1)") +#endif + +// good for short sequence length but bad for long sequence length. Only for mha.cu. +#ifndef SHORT_SEQ_OPT +#define SHORT_SEQ_OPT 1 +#endif + +#ifndef SLIDING_WINDOW +#define SLIDING_WINDOW 0 +#endif + +#ifndef SKIP_SOFTMAX_ATTN +#define SKIP_SOFTMAX_ATTN 0 +#endif + +#ifndef SKIP_SOFTMAX_ATTN_BLOCK_STATS +#define SKIP_SOFTMAX_ATTN_BLOCK_STATS 0 +#endif + +#ifndef SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE +#define SKIP_SOFTMAX_ATTN_FIX_THRESHOLD_GREATER_THAN_ONE 1 +#endif + +// 0 - no PDL +// 1 - naive PDL +// 2 - aggressive PDL (implemented only in mha_sm90.cu for now) +#ifndef ENABLE_PDL +#define ENABLE_PDL 2 +#endif + +#ifndef USE_INPUT_KV +#define USE_INPUT_KV 0 +#endif + +#if USE_INPUT_KV +// 0 - no RoPE +// 1 - NEOX style +// 2 - GPTJ style +#ifndef ROPE_STYLE +#define ROPE_STYLE 0 +#endif + +#if SPEC_DEC +#error "SPEC_DEC is not supported for USE_INPUT_KV" +#endif +#endif + +// Output element type: +// 0 - input element type +// 1 - KV cache element type +#ifndef LOW_PREC_OUTPUT +#define LOW_PREC_OUTPUT 0 +#endif + +#if LOW_PREC_OUTPUT +static_assert(CACHE_ELEM_ENUM != 0); +#endif + +// true should be better if warpTile.x * cacheElemSize < 128. otherwise use false. +#define GRP_LOAD_V (CACHE_ELEM_ENUM != 0) || (HEAD_ELEMS == 256 && USE_PAGED_KV_CACHE && BEAM_WIDTH > 1) + +// use custom barrier for NVRTC to avoid pulling in many headers +#ifndef USE_CUSTOM_BARRIER +#define USE_CUSTOM_BARRIER 1 +#endif + +#ifndef OPTIMIZE_FOR_LATENCY +#define OPTIMIZE_FOR_LATENCY 1 +#endif + +#ifndef IS_SPEC_DEC_TREE +#define IS_SPEC_DEC_TREE 1 // by default SPEC_DEC expect tree-based draft token structure +#endif + +#define DBG_BATCH_SIZE 2 +#define DBG_SEQ_LEN 256 * 4 + 3 +#define DBG_NB_CTAS_PER_SEQ 8 + +#include +#include +template +using ElemType = mha::conditional_t>>; + +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// But this optimization is not numerically stable when (x * log2e - bias) is computed with FMA and x is too large. For +// this reason, don't set safeInitRowMax with a huge absolute value. +#define SAFE_INIT_ROW_MAX (-1e+5F) diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/gmma.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/gmma.cuh new file mode 100644 index 0000000000000..a93a5dceabf13 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/gmma.cuh @@ -0,0 +1,164 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#include "utils.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { + +enum class SwizzleMode : uint64_t { + kNONE = 0, + k128 = 1, + k64 = 2, + k32 = 3 +}; + +struct MatDesc { + uint64_t addr : 16; + uint64_t dimKOffset : 16; + uint64_t dimMNOffset : 16; + uint64_t pad0 : 1; + uint64_t baseOffset : 3; + uint64_t pad1 : 10; + SwizzleMode swizzle : 2; + + enum class Raw : uint64_t { + }; + + [[nodiscard]] __device__ inline MatDesc withAddr(void const* data) const { + MatDesc ret = *this; + ret.addr = encode(__cvta_generic_to_shared(data)); + return ret; + } + + static __device__ inline uint32_t encode(uint32_t val) { + return (val & 0x3FFFFU) >> 4; + } + + __device__ inline bool operator==(MatDesc const& other) const { + return raw() == other.raw(); + } + + __device__ inline Raw const& raw() const { + static_assert(sizeof(MatDesc) == 8); + return reinterpret_cast(*this); + } + + static __device__ inline MatDesc fromRaw(Raw const& raw) { + return reinterpret_cast(raw); + } +}; + +static_assert(sizeof(MatDesc) == 8); + +[[nodiscard]] __device__ inline MatDesc::Raw addAddr(MatDesc::Raw base, void const* data) { + assert((uint32_t(__cvta_generic_to_shared(data)) & ~0x3FFFFU) == 0); + MatDesc::Raw ret = base; + auto& u32x2 = reinterpret_cast(ret); + u32x2[0] += static_cast(__cvta_generic_to_shared(data)) >> 4; + return ret; +} + +__device__ inline MatDesc makeMatDesc(void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, + void const* patternStartAddr, SwizzleMode swizzleMode) { + uint32_t const patternAddr = __cvta_generic_to_shared(patternStartAddr); + uint32_t const baseAlign = [&]() -> uint32_t { + switch (swizzleMode) { + case SwizzleMode::kNONE: + return 1; + case SwizzleMode::k128: + return 1024; + case SwizzleMode::k64: + return 512; + case SwizzleMode::k32: + return 256; + } + asm volatile("trap;\n"); + return 0; + }(); + assert(__cvta_generic_to_shared(data) % baseAlign == 0); + uint32_t const baseOffset = ((patternAddr % baseAlign == 0) ? 0U : ((patternAddr >> 0x7) & 0x7)); + return MatDesc{ + /*addr=*/MatDesc::encode(__cvta_generic_to_shared(data)), + /*dimKOffset=*/MatDesc::encode(dimKByteOffset), + /*dimMNOffset=*/MatDesc::encode(dimMNByteOffset), + /*pad0=*/0, + /*baseOffset=*/baseOffset, + /*pad1=*/0, + /*swizzle=*/swizzleMode, + }; +} + +__device__ inline MatDesc makeMatDesc( + void const* data, uint32_t dimKByteOffset, uint32_t dimMNByteOffset, SwizzleMode swizzleMode) { + return makeMatDesc(data, dimKByteOffset, dimMNByteOffset, data, swizzleMode); +} + +inline constexpr uint32_t instM = 64; + +template +inline constexpr uint32_t instK = 32 / sizeof(MathElem); + +inline constexpr uint32_t instNBase = 8; + +// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N +// acc is used as both input and output. +template +__device__ void mma_async_shmA( + float (&acc)[exactDiv(n, instNBase)][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal); +template +__device__ void mma_async_regA( + float (&acc)[exactDiv(n, instNBase)][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal); + +__device__ inline void fence() { + asm volatile("wgmma.fence.sync.aligned;\n"); +} + +__device__ inline void commit_group() { + asm volatile("wgmma.commit_group.sync.aligned;\n"); +} + +template +__device__ inline void wait_group() { + asm volatile("wgmma.wait_group.sync.aligned %0\n; " ::"n"(targetNbInFlightGroups)); +} + +template +constexpr SwizzleMode getSwizzleMode(Array2D const&) { + constexpr auto rowBytes = Array2D::rowBytes; + if constexpr (!swizzle) { + return SwizzleMode::kNONE; + } + if constexpr (rowBytes % 128 == 0) { + return SwizzleMode::k128; + } else if constexpr (rowBytes == 64) { + return SwizzleMode::k64; + } else { + static_assert(rowBytes == 32); + return SwizzleMode::k32; + } +} +} // namespace gmma + +#include "gmma_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/gmma_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/gmma_impl.cuh new file mode 100644 index 0000000000000..dbb7686d0c884 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/gmma_impl.cuh @@ -0,0 +1,4198 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +namespace gmma { +// cog template. Do code generation with: pip install cogapp; cog -r $filename + +// clang-format off +/*[[[cog +import cog +reg_list = lambda beg,end: ", ".join([f"%{i}" for i in range(beg, end)]) +acc_placeholder = lambda n: "{%s}" % reg_list(0, n//2) +acc_registers = lambda n: "\n , ".join([f'"+f"(acc[{i}][0][0]), "+f"(acc[{i}][0][1]), "+f"(acc[{i}][1][0]), "+f"(acc[{i}][1][1])' for i in range(n//8)]) +ptx_eol = "\\n" +n_list = [8, 16, 24, 32, 64, 128, 256] +for n in n_list: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} + +template<> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, {n}, false, false>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k32.f32.e4m3.e4m3{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1;{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + +for n in n_list: + for transA in [0, 1]: + for transB in [0, 1]: + for t,s in [('half', 'f16'), ('__nv_bfloat16', 'bf16')]: + cog.outl(f''' +template<> +__device__ inline void mma_async_shmA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], MatDesc::Raw descA, +MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "%{n//2},{ptx_eol}" //a-desc + "%{n//2+1},{ptx_eol}" //b-desc + "%{n//2+2}, 1, 1, {transA}, {transB};{ptx_eol}" + : {acc_registers(n)} + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') + if transA == 0: + cog.outl(f''' +template<> +__device__ inline void mma_async_regA<{t}, {n}, {transA}, {transB}>(float(&acc)[{n//8}][2][2], uint32_t +const(&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) +{{ + if (accHasVal) {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(true)); + }} + else {{ + asm volatile("wgmma.mma_async.sync.aligned.m64n{n}k16.f32.{s}.{s}{ptx_eol}" + "{acc_placeholder(n)},{ptx_eol}" // d + "{{{reg_list(n//2, n//2 + 4)}}},{ptx_eol}" //a + "%{n//2+4},{ptx_eol}" //b-desc + "%{n//2+5}, 1, 1, {transB};{ptx_eol}" + : {acc_registers(n)} + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), "l"(reinterpret_cast(descB)), "n"(false)); + }} +}} +''') +]]]*/ +// clang-format on + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 8, false, false>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 8, false, false>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 16, false, false>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 16, false, false>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 24, false, false>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 24, false, false>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 32, false, false>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 32, false, false>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 64, false, false>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 64, false, false>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 128, false, false>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_fp8_e4m3, 256, false, false>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 0>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 0>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 0, 1>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 8, 0, 1>( + float (&acc)[1][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "{%4, %5, %6, %7},\n" // a + "%8,\n" // b-desc + "%9, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 0>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 8, 1, 1>( + float (&acc)[1][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3},\n" // d + "%4,\n" // a-desc + "%5,\n" // b-desc + "%6, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 0>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 0>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 0, 1>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 16, 0, 1>( + float (&acc)[2][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "{%8, %9, %10, %11},\n" // a + "%12,\n" // b-desc + "%13, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 0>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 16, 1, 1>( + float (&acc)[2][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7},\n" // d + "%8,\n" // a-desc + "%9,\n" // b-desc + "%10, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 0>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 0>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 0, 1>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 24, 0, 1>( + float (&acc)[3][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "{%12, %13, %14, %15},\n" // a + "%16,\n" // b-desc + "%17, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 0>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 24, 1, 1>( + float (&acc)[3][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11},\n" // d + "%12,\n" // a-desc + "%13,\n" // b-desc + "%14, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 0>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 0>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 0, 1>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 32, 0, 1>( + float (&acc)[4][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "{%16, %17, %18, %19},\n" // a + "%20,\n" // b-desc + "%21, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 0>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 32, 1, 1>( + float (&acc)[4][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15},\n" // d + "%16,\n" // a-desc + "%17,\n" // b-desc + "%18, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 0>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 0>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 0, 1>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 64, 0, 1>( + float (&acc)[8][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "{%32, %33, %34, %35},\n" // a + "%36,\n" // b-desc + "%37, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 0>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 64, 1, 1>( + float (&acc)[8][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31},\n" // d + "%32,\n" // a-desc + "%33,\n" // b-desc + "%34, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 0>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 0>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 0, 1>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 128, 0, 1>( + float (&acc)[16][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "{%64, %65, %66, %67},\n" // a + "%68,\n" // b-desc + "%69, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 0>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 128, 1, 1>( + float (&acc)[16][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63},\n" // d + "%64,\n" // a-desc + "%65,\n" // b-desc + "%66, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 0>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 0>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 0, 1>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 0, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_regA<__nv_bfloat16, 256, 0, 1>( + float (&acc)[32][2][2], uint32_t const (&a)[2][2][1], MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "{%128, %129, %130, %131},\n" // a + "%132,\n" // b-desc + "%133, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "r"(a[0][0][0]), "r"(a[0][1][0]), "r"(a[1][0][0]), "r"(a[1][1][0]), + "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 0>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 0;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +template <> +__device__ inline void mma_async_shmA<__nv_bfloat16, 256, 1, 1>( + float (&acc)[32][2][2], MatDesc::Raw descA, MatDesc::Raw descB, bool accHasVal) { + if (accHasVal) { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(true)); + } else { + asm volatile( + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16\n" + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, " + "%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, " + "%44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, " + "%65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, " + "%86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, " + "%106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, " + "%123, %124, %125, %126, %127},\n" // d + "%128,\n" // a-desc + "%129,\n" // b-desc + "%130, 1, 1, 1, 1;\n" + : "+f"(acc[0][0][0]), "+f"(acc[0][0][1]), "+f"(acc[0][1][0]), "+f"(acc[0][1][1]), "+f"(acc[1][0][0]), + "+f"(acc[1][0][1]), "+f"(acc[1][1][0]), "+f"(acc[1][1][1]), "+f"(acc[2][0][0]), "+f"(acc[2][0][1]), + "+f"(acc[2][1][0]), "+f"(acc[2][1][1]), "+f"(acc[3][0][0]), "+f"(acc[3][0][1]), "+f"(acc[3][1][0]), + "+f"(acc[3][1][1]), "+f"(acc[4][0][0]), "+f"(acc[4][0][1]), "+f"(acc[4][1][0]), "+f"(acc[4][1][1]), + "+f"(acc[5][0][0]), "+f"(acc[5][0][1]), "+f"(acc[5][1][0]), "+f"(acc[5][1][1]), "+f"(acc[6][0][0]), + "+f"(acc[6][0][1]), "+f"(acc[6][1][0]), "+f"(acc[6][1][1]), "+f"(acc[7][0][0]), "+f"(acc[7][0][1]), + "+f"(acc[7][1][0]), "+f"(acc[7][1][1]), "+f"(acc[8][0][0]), "+f"(acc[8][0][1]), "+f"(acc[8][1][0]), + "+f"(acc[8][1][1]), "+f"(acc[9][0][0]), "+f"(acc[9][0][1]), "+f"(acc[9][1][0]), "+f"(acc[9][1][1]), + "+f"(acc[10][0][0]), "+f"(acc[10][0][1]), "+f"(acc[10][1][0]), "+f"(acc[10][1][1]), "+f"(acc[11][0][0]), + "+f"(acc[11][0][1]), "+f"(acc[11][1][0]), "+f"(acc[11][1][1]), "+f"(acc[12][0][0]), "+f"(acc[12][0][1]), + "+f"(acc[12][1][0]), "+f"(acc[12][1][1]), "+f"(acc[13][0][0]), "+f"(acc[13][0][1]), "+f"(acc[13][1][0]), + "+f"(acc[13][1][1]), "+f"(acc[14][0][0]), "+f"(acc[14][0][1]), "+f"(acc[14][1][0]), "+f"(acc[14][1][1]), + "+f"(acc[15][0][0]), "+f"(acc[15][0][1]), "+f"(acc[15][1][0]), "+f"(acc[15][1][1]), "+f"(acc[16][0][0]), + "+f"(acc[16][0][1]), "+f"(acc[16][1][0]), "+f"(acc[16][1][1]), "+f"(acc[17][0][0]), "+f"(acc[17][0][1]), + "+f"(acc[17][1][0]), "+f"(acc[17][1][1]), "+f"(acc[18][0][0]), "+f"(acc[18][0][1]), "+f"(acc[18][1][0]), + "+f"(acc[18][1][1]), "+f"(acc[19][0][0]), "+f"(acc[19][0][1]), "+f"(acc[19][1][0]), "+f"(acc[19][1][1]), + "+f"(acc[20][0][0]), "+f"(acc[20][0][1]), "+f"(acc[20][1][0]), "+f"(acc[20][1][1]), "+f"(acc[21][0][0]), + "+f"(acc[21][0][1]), "+f"(acc[21][1][0]), "+f"(acc[21][1][1]), "+f"(acc[22][0][0]), "+f"(acc[22][0][1]), + "+f"(acc[22][1][0]), "+f"(acc[22][1][1]), "+f"(acc[23][0][0]), "+f"(acc[23][0][1]), "+f"(acc[23][1][0]), + "+f"(acc[23][1][1]), "+f"(acc[24][0][0]), "+f"(acc[24][0][1]), "+f"(acc[24][1][0]), "+f"(acc[24][1][1]), + "+f"(acc[25][0][0]), "+f"(acc[25][0][1]), "+f"(acc[25][1][0]), "+f"(acc[25][1][1]), "+f"(acc[26][0][0]), + "+f"(acc[26][0][1]), "+f"(acc[26][1][0]), "+f"(acc[26][1][1]), "+f"(acc[27][0][0]), "+f"(acc[27][0][1]), + "+f"(acc[27][1][0]), "+f"(acc[27][1][1]), "+f"(acc[28][0][0]), "+f"(acc[28][0][1]), "+f"(acc[28][1][0]), + "+f"(acc[28][1][1]), "+f"(acc[29][0][0]), "+f"(acc[29][0][1]), "+f"(acc[29][1][0]), "+f"(acc[29][1][1]), + "+f"(acc[30][0][0]), "+f"(acc[30][0][1]), "+f"(acc[30][1][0]), "+f"(acc[30][1][1]), "+f"(acc[31][0][0]), + "+f"(acc[31][0][1]), "+f"(acc[31][1][0]), "+f"(acc[31][1][1]) + : "l"(reinterpret_cast(descA)), "l"(reinterpret_cast(descB)), "n"(false)); + } +} + +//[[[end]]] +} // namespace gmma diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/hostUtils.h b/onnxruntime/contrib_ops/cuda/bert/xqa/hostUtils.h new file mode 100644 index 0000000000000..9b9c342671dd0 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/hostUtils.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +inline cudaLaunchConfig_t makeLaunchConfig( + dim3 const& gridDim, dim3 const& ctaDim, size_t dynShmBytes, cudaStream_t stream, bool usePDL) { + static cudaLaunchAttribute pdlAttr; + pdlAttr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + pdlAttr.val.programmaticStreamSerializationAllowed = (usePDL ? 1 : 0); + + cudaLaunchConfig_t cfg{gridDim, ctaDim, dynShmBytes, stream, &pdlAttr, 1}; + return cfg; +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/ldgsts.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/ldgsts.cuh new file mode 100644 index 0000000000000..26390ad30056f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/ldgsts.cuh @@ -0,0 +1,64 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_hint.cuh" +#ifndef __CUDACC__ +#include +#endif +#include "barriers.cuh" + +namespace ldgsts { +// @fixme: prefetch makes it slower on sm_86. Try on other platforms. +template +__device__ inline void copyAsync( + void* dst, void const* src, uint32_t srcSize = size) // srcSize == 0 means filling with zeros. +{ + static_assert(size == 4 || size == 8 || size == 16); + if constexpr (size == 16) { + // asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], 16, %2;\n" :: + // "l"(__cvta_generic_to_shared(dst)), "l"(src), "r"(srcSize)); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16, %2;\n" ::"l"(__cvta_generic_to_shared(dst)), "l"(src), + "r"(srcSize)); + } else { + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"l"(__cvta_generic_to_shared(dst)), "l"(src), + "n"(size), "r"(srcSize)); + } +} + +__device__ inline void commitGroup() { + asm volatile("cp.async.commit_group;\n"); +} + +// wait until only targetNbInFlightGroups groups are still in-flight. +template +__device__ inline void waitGroup() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(targetNbInFlightGroups)); +} + +// noInc = false: increase expected arrive count, in additional to increasing arrive count +// noInc = true: increases arrive count but does not modify expected arrive count +__device__ inline void barArrive(CtaBarrier& bar, bool noInc = false) { + if (noInc) { + asm volatile("cp.async.mbarrier.arrive.noinc.shared.b64 [%0];\n" ::"l"(__cvta_generic_to_shared(&bar))); + } else { + asm volatile("cp.async.mbarrier.arrive.shared.b64 [%0];\n" ::"l"(__cvta_generic_to_shared(&bar))); + } +} + +} // namespace ldgsts diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h new file mode 100644 index 0000000000000..5aa78aa242306 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha.h @@ -0,0 +1,141 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MHA_H_COMMON +#define MHA_H_COMMON +#ifndef __CUDACC__ +#include +#endif +#include "defines.h" +#include "utils.h" +#if SPEC_DEC +#include "specDec.h" +#endif + +using CacheElem = ElemType; +// These depend on HEAD_ELEMS which is Global +constexpr uint32_t validElemsPerHead = HEAD_ELEMS; +constexpr bool isMLA = IS_MLA; +static_assert((isMLA || validElemsPerHead <= 256) && (sizeof(CacheElem) * validElemsPerHead) % 16 == 0); +constexpr uint32_t headElems = validElemsPerHead <= 64 ? 64 : (validElemsPerHead <= 128 ? 128 : (isMLA ? 576 : 256)); +static_assert(headElems == 64 || headElems == 128 || headElems == 256 || headElems == 576, "not implemented"); +constexpr uint32_t beamWidth = BEAM_WIDTH; + +#if SPEC_DEC +__device__ constexpr uint32_t rowsPerBlock = M_TILESIZE; +#endif + +inline constexpr bool useSpecDec = SPEC_DEC; + +using InputElem = INPUT_ELEM; +using InputElem2 = INPUT_ELEM2; +#if !(SPEC_DEC) +constexpr uint32_t inputSeqLen = 1; // speculative decoding if > 1 +#endif + +constexpr bool useKVCache = USE_KV_CACHE; + +using SeqLenDataType = uint32_t; +#endif + +// Dependent definitions +#ifndef MHA_H_DEPENDENT +#define MHA_H_DEPENDENT +// This depends on HEAD_GRP_SIZE macro which changes per namespace +constexpr uint32_t headGrpSize = HEAD_GRP_SIZE; +#endif + +// Common Part 2 +#ifndef MHA_H_COMMON_2 +#define MHA_H_COMMON_2 + +constexpr bool usePagedKVCache = USE_PAGED_KV_CACHE; +constexpr uint32_t tokensPerPage = TOKENS_PER_PAGE; + +using IOHead = Vec; +using InputHead = IOHead; +using GMemCacheHead = Vec; + +constexpr uint32_t validElemsPerKHead = validElemsPerHead; +constexpr bool lowPrecOutput = LOW_PREC_OUTPUT; + +#if IS_MLA +constexpr uint32_t validElemsPerVHead = 512; +static_assert(lowPrecOutput == false); +using OutputHead = Vec<__nv_bfloat16, validElemsPerVHead>; +#else +constexpr uint32_t validElemsPerVHead = validElemsPerHead; +using OutputHead = mha::conditional_t; +#endif +using OutputElem = OutputHead::Elem; + +using PaddedInputHead = Vec; +using PaddedCacheHead = Vec; + +// impl detail, may be moved to mha.cu/mha_sm90.cu +constexpr bool isHeadPadded = (validElemsPerHead != headElems); + +constexpr bool useInputKV = USE_INPUT_KV; + +using GMemKVCacheHead = mha::conditional_t; + +using KVCachePageIndex = int32_t; // shape: KVCacheHead[nbKHeads][tokensPerPage]. Page index in the global pool of pages + +constexpr bool allowSlidingWindow = SLIDING_WINDOW; + +struct BeamSearchParams { + uint32_t const* __restrict__ indices; // shape: [batchSize][beamWidth][capacity] + uint32_t capacity; + uint32_t const* __restrict__ ctxLenList; // shape: [batchSize][beamWidth]. Should be [batchSize] but we have to + // match trt-llm API. +}; + +// function declarations removed to prevent ambiguity with namespaces +// they are defined in mha_impl.cuh included in xqa_loader.cu + +#if STATIC_NB_K_HEADS +constexpr uint32_t nbKHeads = NB_K_HEADS; + +constexpr uint32_t nbVHeads = nbKHeads; +constexpr uint32_t nbQHeads = nbKHeads * headGrpSize; +constexpr uint32_t nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; +#endif +constexpr uint32_t cacheElemSize = sizeof(CacheElem); +constexpr uint32_t inputElemSize = sizeof(InputElem); +constexpr uint32_t outputElemSize = sizeof(OutputElem); + +constexpr uint32_t ioHeadBytes = sizeof(IOHead); +constexpr uint32_t gmemCacheHeadBytes = sizeof(GMemCacheHead); + +constexpr uint32_t paddedInputHeadBytes = sizeof(PaddedInputHead); +constexpr uint32_t paddedCacheHeadBytes = sizeof(PaddedCacheHead); + +constexpr uint32_t allowMultiBlockMode = ALLOW_MULTI_BLOCK_MODE; + +enum class XQAKernelType : int32_t { + kAMPERE_WARP_SPECIALIZED = 0, + kHOPPER_WARP_SPECIALIZED = 1, + kSM120_MLA = 2 +}; + +#ifdef GENERATE_CUBIN +#define CUBIN_EXPORT extern "C" +#else +#define CUBIN_EXPORT static +#endif + +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mhaUtils.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/mhaUtils.cuh new file mode 100644 index 0000000000000..ac37a2d1097ab --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mhaUtils.cuh @@ -0,0 +1,435 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "ldgsts.cuh" +#include "mha.h" +#include "utils.cuh" + +// for beam search +template +struct IndexedHeadPtrImpl { + static_assert(tokensPerPage != 0 && nbPages != 0); + uint32_t const* indices; // values are in range [0, beamWidth) + Head* pool; + Vec const* pageIndices; + uint32_t nbKHeads; + uint32_t offset; // applied onto pool + pointers + + __device__ inline Head& operator[](uint32_t i) const { + return *(*this + i); + } + + __device__ inline Head* operator+(uint32_t i) const { + assert(indices[i] < beamWidth); + assert(nbPages == 1 || offset % tokensPerPage == 0); + auto const pageIdx = pageIndices[indices[i]][nbPages == 1 ? 0U : i / tokensPerPage]; + return pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage); + } +}; + +template +struct IndexedHeadPtrImpl { + uint32_t const* indices; // values are in range [0, beamWidth) + Head* pointer; + uint32_t offset; + uint32_t beamStride; + + __device__ inline Head& operator[](uint32_t i) const { + return *(*this + i); + } + + __device__ inline Head* operator+(uint32_t i) const { + assert(indices[i] < beamWidth); + return pointer + (beamStride * indices[i] + offset + i); + } +}; + +template +using IndexedHeadPtr = IndexedHeadPtrImpl; + +// for beamWidth = 1 +template +struct HeadPtr { + static_assert(tokensPerPage != 0 && nbPages != 0); + Head* pool; + Vec pageIndices; + uint32_t nbKHeads; + uint32_t offset; // offset inside the first page. + + __device__ inline Head& operator[](uint32_t i) const { + return *(*this + i); + } + + __device__ inline Head* operator+(uint32_t i) const { +#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE + auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage]; + return (pageIdx & (1U << 31)) + ? nullptr + : pool + (tokensPerPage * nbKHeads * pageIdx + offset + (i % tokensPerPage) * nbKHeads); +#else + assert(nbPages == 1 || offset % tokensPerPage == 0); + auto const pageIdx = pageIndices[nbPages == 1 ? 0U : i / tokensPerPage]; + return (pageIdx & (1U << 31)) ? nullptr + : pool + (tokensPerPage * nbKHeads * pageIdx + offset + i % tokensPerPage); +#endif + } +}; + +template +struct HeadPtr : TinyPtr { +}; + +// template +// #if BEAM_WIDTH == 1 +// using SrcHeadPtr = TinyPtr; +// #else +// using SrcHeadPtr = IndexedHeadPtr; +// #endif + +// @fixme: give evict first hint for last part. +template +__device__ inline void copyPartialHeadsAsync( + Warp const& warp, Array2D& dst, + uint32_t dstHeadOffset, SrcHeadPtr const& src, uint32_t idxPart, uint32_t nbAvailHeads = maxNbCopiedHeads, + LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) { + static_assert(maxNbCopiedHeads <= dstNbHeads); + assert(idxPart < nbPartsPerHead); + assert(dstHeadOffset + maxNbCopiedHeads <= dstNbHeads); + assert(sizeof(Head) * (src.offset + maxNbCopiedHeads) <= (1ULL << 32)); + assert(!isFull || nbAvailHeads >= maxNbCopiedHeads); + constexpr uint32_t headBytes = sizeof(Head); + constexpr uint32_t partBytes = exactDiv(headBytes, nbPartsPerHead); + constexpr uint32_t warpLdBytes = partBytes * maxNbCopiedHeads; + constexpr uint32_t thrdLdBytes = exactDiv(warpLdBytes, warp_size); + assertIsPowerOf2(); + static_assert(thrdLdBytes >= grainBytes); + // a segment is responsible for loading one partial head collaboratively + constexpr uint32_t thrdsPerSeg = exactDiv(partBytes, grainBytes); + static_assert(thrdsPerSeg > 0 && thrdsPerSeg <= warp_size); + assertIsPowerOf2(); + assert(__shfl_sync(0xFU << (laneId() / 4 * 4), src.offset, 0, 4) == src.offset); + auto const warpLane = laneId(); + uint32_t const segIdx = warpLane / thrdsPerSeg; + uint32_t const segLane = warpLane % thrdsPerSeg; + constexpr uint32_t partsPerWarpInst = exactDiv(grainBytes * warp_size, partBytes); +#pragma unroll + for (uint32_t i = 0; i < thrdLdBytes / grainBytes; i++) { + uint32_t const idxHeadLocal = partsPerWarpInst * i + segIdx; + assert(idxHeadLocal < maxNbCopiedHeads); + bool const isHeadInBound = isFull || (idxHeadLocal < nbAvailHeads); + constexpr uint32_t grainsPerPart = exactDiv(partBytes, grainBytes); + using SrcHead = mha::decay_t; + constexpr uint32_t nbValidGrains = exactDiv(sizeof(SrcHead), grainBytes); + uint32_t const idxGrainInsideHead = grainsPerPart * idxPart + segLane; + bool const isGrainInBound = (!isHeadPadded || idxGrainInsideHead < nbValidGrains); + SrcHead const* const pSrcHead = src + localHeadIdxMap(idxHeadLocal); + bool const isValidPage = (pSrcHead != nullptr); + LdGrain const* const pSrc = reinterpret_cast(pSrcHead) + idxGrainInsideHead; + LdGrain* const pDst = &dst.template at(dstHeadOffset + idxHeadLocal, segLane); + assert(!hasBankConflict(pDst)); + ldgsts::copyAsync(pDst, pSrc, isValidPage && isHeadInBound && isGrainInBound ? grainBytes : 0u); + } +} + +template +__device__ inline void copyHeadsAsync( + uint32_t idxWarp, Array2D& dst, SrcHeadPtr const& src, + uint32_t nbAvailHeads = maxNbCopiedHeads, LocalHeadIdxMap&& localHeadIdxMap = [](uint32_t x) { return x; }) { + assert(idxWarp < nbWarps); + Warp const& warp = this_warp(); + constexpr uint32_t maxNbHeadsPerWarp = exactDiv(maxNbCopiedHeads, nbWarps); + uint32_t const dstHeadOffset = maxNbHeadsPerWarp * idxWarp; + uint32_t const warpNbAvailHeads = (dstHeadOffset < nbAvailHeads ? nbAvailHeads - dstHeadOffset : 0); + constexpr uint32_t idxPart = 0; + copyPartialHeadsAsync(warp, dst, dstHeadOffset, src, + idxPart, warpNbAvailHeads, [&](uint32_t x) { return localHeadIdxMap(dstHeadOffset + x); }); +} + +template +__device__ inline void copyGrains( + uint32_t idxWarp, LdGrain* dst, LdGrain const* src, uint32_t totalNbGrains = maxTotalNbGrains) { + assert((isFull && totalNbGrains == maxTotalNbGrains) || (!isFull && totalNbGrains <= maxTotalNbGrains)); + constexpr uint32_t nbThrds = warp_size * nbWarps; + uint32_t const tid = warp_size * idxWarp + laneId(); +// copy output to scratch +#pragma unroll + for (uint32_t i = 0; i < divUp(maxTotalNbGrains, nbThrds); i++) { + uint32_t const idx = nbThrds * i + tid; + if (!(isFull && maxTotalNbGrains % nbThrds == 0) && idx >= totalNbGrains) { + break; + } + if constexpr (isAsync) { + ldgsts::copyAsync(&dst[idx], &src[idx], grainBytes); + } else { + dst[idx] = src[idx]; + } + } +} + +// with ldmatrix, what we load for fp8 cache is T0:{e0,e1,e2,e3}; T1:{e4, e5, e6, e7}; T2:{e8,e9,e10,e11}; T3:{e12, e13, +// e14, e15}; When casted to fp16, it will be T0:{e0, e1}; T1{e4, e5};... | T0:{e2, e3}; T1{e6, e7}; ... We need to +// reorder Q to match that order. isFwd=false to revert the reorder. +template +__device__ inline void reorder16bQHeadsToMatch8bKCache(uint32_t idxWarp, Array2D& qHeads) { + assert(idxWarp < nbWarps); + constexpr uint32_t nbWarpIters = exactDiv(exactDiv(cols, 2) * rows, warp_size); // warps * iters + constexpr uint32_t nbWorkingWarps = mha::min(nbWarps, nbWarpIters); + if (idxWarp >= nbWorkingWarps) { + return; + } + static_assert(cols % 2 == 0); + uint32_t const tid = warp_size * idxWarp + laneId(); + constexpr uint32_t iterCols = exactDiv(warp_size * nbWorkingWarps, rows) * 2; + static_assert(cols % iterCols == 0, "fix this by reducing nbWorkingWarps, or use divUp and add runtime check"); + constexpr uint32_t nbIters = exactDiv(cols, iterCols); + static_assert(nbIters == exactDiv(nbWarpIters, nbWorkingWarps)); + uint32_t const r = tid % rows; + uint32_t const cInit = tid / rows * 2; +#pragma unroll + for (uint32_t n = 0; n < nbIters; n++) { + uint32_t const c = cInit + iterCols * n; + LdGrain const src[2] = { + qHeads.template at(r, c), + qHeads.template at(r, c + 1), + }; + auto const& s = reinterpret_cast const&>(src); + if constexpr (isFwd) { + qHeads.template at(r, c) = LdGrain{s[0], s[2], s[4], s[6]}; + qHeads.template at(r, c + 1) = LdGrain{s[1], s[3], s[5], s[7]}; + } else { + qHeads.template at(r, c) = LdGrain{s[0], s[4], s[1], s[5]}; + qHeads.template at(r, c + 1) = LdGrain{s[2], s[6], s[3], s[7]}; + } + } +} + +template +struct KVCacheList; + +template <> +struct KVCacheList { +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM; + GMemCacheHead* vCacheVLLM; +#else + GMemKVCacheHead* pool; +#endif + KVCachePageIndex const* kvCachePageList; // shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. + SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility) + uint32_t maxNbPagesPerSeq; +}; + +template <> +struct KVCacheList { + GMemKVCacheHead* kData; // shape: KVCacheHead[batchSize][beamWidth][nbKHeads][capacity] + GMemKVCacheHead* vData; // shape: KVCacheHead[batchSize][beamWidth][nbKHeads][capacity] + SeqLenDataType const* seqLenList; // shape: [batchSize][beamWidth] (for compatibility) + uint32_t capacity; + bool isBSNH; + uint32_t extraSeqLen; +}; + +__device__ inline uint32_t getSeqLen(uint32_t const* seqLenList, uint32_t idxReq, uint32_t extraSeqLen) { + uint64_t cachePolicy; + asm("createpolicy.fractional.L2::evict_last.b64 %0;\n" : "=l"(cachePolicy)); + uint32_t len; + asm("ld.global.nc.L1::evict_last.L2::cache_hint.L2::256B.b32 %0, [%1], %2;\n" + : "=r"(len) + : "l"(&seqLenList[idxReq * beamWidth]), "l"(cachePolicy)); + for (uint32_t i = 0; i < beamWidth; i++) { + assert(len == seqLenList[idxReq * beamWidth + i]); + } + return len + extraSeqLen; +} + +template +__device__ inline uint32_t getCacheSeqLen(KVCacheList const& cacheList, uint32_t idxReq) { + return getSeqLen(cacheList.seqLenList, idxReq, cacheList.extraSeqLen); +} + +__device__ inline uint32_t getCtxCacheSeqLen(BeamSearchParams const& beamSearchParams, uint32_t idxReq) { + return getSeqLen(beamSearchParams.ctxLenList, idxReq, 0); +} + +template +__device__ inline Vec getPage(KVCacheList const& cacheList, bool isK, + uint32_t idxReq, uint32_t idxBeam, uint32_t idxPageBeg, uint32_t nbPages) { + auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq; + Vec ret; +#pragma unroll + for (uint32_t i = 0; i < nbLoadedPages; i++) { + uint32_t const idxPage = idxPageBeg + i; +#if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE + ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[maxNbPagesPerSeq * idxReq + idxPage] : kBAD_PAGE_INDEX); +#else + ret[i] = (idxPage < nbPages ? cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq + 2 * maxNbPagesPerSeq * idxBeam + maxNbPagesPerSeq * (isK ? 0U : 1U) + idxPage] + : kBAD_PAGE_INDEX); +#endif + } + return ret; +} + +template +__device__ inline void loadPagesForBeamSearchAsync(uint32_t idxWarp, + Vec, beamWidth>& dst, KVCacheList const& cacheList, bool isK, + uint32_t idxReq, uint32_t idxPageBeg, uint32_t nbPages) { + assert(idxWarp < nbWarps); + auto const maxNbPagesPerSeq = cacheList.maxNbPagesPerSeq; + static_assert(beamWidth < warp_size); + auto const tid = warp_size * idxWarp + laneId(); + auto const idxBeam = tid / nbLoadedPages; + auto const idxLoadedPage = tid % nbLoadedPages; + static_assert(warp_size * nbWarps >= beamWidth * nbLoadedPages); + if (idxBeam < beamWidth) { + constexpr uint32_t nbBytes = sizeof(KVCachePageIndex); + uint32_t const idxPage = idxPageBeg + idxLoadedPage; + ldgsts::copyAsync(&dst[idxBeam][idxLoadedPage], + &cacheList.kvCachePageList[beamWidth * 2 * maxNbPagesPerSeq * idxReq + 2 * maxNbPagesPerSeq * idxBeam + (isK ? 0U : maxNbPagesPerSeq) + idxPage], + idxPage < nbPages ? nbBytes : 0U); + } +} + +template +__device__ inline void loadIndicesForBeamSearchAsync(uint32_t idxWarp, Vec& dst, + BeamSearchParams const& params, uint32_t idxReq, uint32_t idxBeam, uint32_t uniformSeqOffset, uint32_t seqLen) { + constexpr uint32_t nbThreads = warp_size * nbWarps; + // constexpr uint32_t indicesPerInst = mha::min(exactDiv(grainBytes, sizeof(uint32_t)), divUp(length, nbThreads)); + // // @fixme: std::bit_ceil on length + constexpr uint32_t indicesPerInst = 1U; // to handle unaligned case. + constexpr uint32_t bytesPerInst = sizeof(uint32_t) * indicesPerInst; + assertIsPowerOf2(); + uint32_t const capacity = params.capacity; + uint32_t const srcOffset = (idxReq * beamWidth + idxBeam) * capacity + uniformSeqOffset; + uint32_t const tid = warp_size * idxWarp + laneId(); + constexpr uint32_t indicesPerIter = indicesPerInst * nbThreads; +#pragma unroll + for (uint32_t i = 0; i < length / indicesPerIter; i++) { + uint32_t const idx = indicesPerIter * i + indicesPerInst * tid; + ldgsts::copyAsync(&dst[idx], ¶ms.indices[srcOffset + idx], + (isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0); + } + if constexpr (length % indicesPerIter != 0) { + uint32_t const idx = indicesPerIter * (length / indicesPerIter) + indicesPerInst * tid; + if (idx < length) { + ldgsts::copyAsync(&dst[idx], ¶ms.indices[srcOffset + idx], + (isFullTile || uniformSeqOffset + idx < seqLen) ? bytesPerInst : 0); + } + } +} + +__device__ inline InputElem2 float2ToInputElem2(float2 src) { + InputElem2 dst; + if constexpr (mha::is_same_v) { + reinterpret_cast(dst) = __float22half2_rn(src); + return dst; + } else if constexpr (mha::is_same_v) { + reinterpret_cast(dst) = __float22bfloat162_rn(src); + return dst; + } else if constexpr (mha::is_same_v) { + reinterpret_cast<__nv_fp8x2_e4m3&>(dst) = __nv_fp8x2_e4m3{src}; + return dst; + } else { + trap(); + } +} + +template +using TokenOrNone = RealTypeOrNone; + +template +__device__ inline TokenOrNone arrive(CtaBarrier* pBarrier) { + if constexpr (real) { + return pBarrier->arrive(); + } else { + assert(pBarrier == nullptr); + return None{}; + } +} + +template +__device__ inline void wait(CtaBarrier* pBarrier, TokenOrNone&& token) { + if constexpr (real) { + pBarrier->wait(mha::move(token)); + } else { + assert(pBarrier == nullptr); + __syncwarp(); + } +} + +template +__device__ inline bool test_wait(CtaBarrier* pBarrier, TokenOrNone&& token) { + if constexpr (real) { + uint32_t complete; + asm volatile( + "{\n" + ".reg .pred complete;\n" + "mbarrier.test_wait.acquire.cta.shared::cta.b64 complete, [%1], %2;\n" + "selp.b32 %0, 1, 0, complete;\n}\n" + : "=r"(complete) + : "l"(__cvta_generic_to_shared(pBarrier)), "l"(token)); + return bool(complete); + } else { + return false; + } +} + +template +using ParityOrNone = RealTypeOrNone; + +template +__device__ inline void wait_parity(CtaBarrier* pBarrier, ParityOrNone parity) { + assert(real == (pBarrier != nullptr)); + if constexpr (real) { + pBarrier->wait_parity(parity); + } else { + __syncwarp(); + } +} + +template +__device__ inline bool test_wait_parity(CtaBarrier* pBarrier, ParityOrNone parity) { + assert(real == (pBarrier != nullptr)); + if constexpr (real) { +#if USE_CUSTOM_BARRIER + return pBarrier->test_wait_parity(parity); +#else + return pBarrier->try_wait_parity_for(parity, cuda::std::chrono::nanoseconds(0)); +#endif + } else { + return false; + } +} + +template +__device__ inline ParityOrNone& flip(ParityOrNone& flip) { + if constexpr (real) { + flip = !flip; + } + return flip; +} + +template +__device__ inline ParityOrNone getAndFlip(ParityOrNone& flag) { + ParityOrNone const ret = flag; + if constexpr (real) { + flag = !flag; + } + return ret; +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha_components.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_components.cuh new file mode 100644 index 0000000000000..f17f326ef9a66 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_components.cuh @@ -0,0 +1,172 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "mma.cuh" +#include "utils.cuh" + +using InstAcc = Array2D; + +template +using WarpAccT = Array2D; + +template +__device__ inline void applyMask( + Warp const& warp, Array2D& acc, uint32_t validColBeg, uint32_t validColEnd) { + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + uint32_t const col = 8 * n + InstAcc::cols * idxInQuad + j; + if (col >= validColBeg && col < validColEnd) { + continue; + } +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + acc(m, n)(i, j) = mha::numeric_limits::lowest(); + } + } + } + } +} + +template +using QuadRegRowMaxT = Vec; // data is replicated across 4 threads in a MMA quad. +template +using ThrdRegRowMaxT = Vec; // unlike QuadRegRowMax, not replicated. +template +using UniformRescaleMaskT = Vec; // uniform and stored in UR +inline constexpr uint32_t quadPerWarp = warp_size / 4; + +// idxMat8 is the reduced row index in 8-row unit. +template +__device__ inline float replicateValForQuad(Warp const& warp, Vec const& src, uint32_t idxMat8) { + assertWarpConverged(); + uint32_t const i = idxMat8 / 4; + uint32_t const j = idxMat8 % 4; + return __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4); +} + +template +__device__ inline QuadRegRowMaxT replicateForQuad(Warp const& warp, Vec const& src) { + assertWarpConverged(); + QuadRegRowMaxT dst{}; +#pragma unroll + for (uint32_t i = 0; i < src.size; i++) { +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { + dst[i * 4 + j] = __shfl_sync(~0U, src[i], quadPerWarp * j + laneId() / 4); + assert(__float_as_int(dst[i * 4 + j]) == __float_as_int(replicateValForQuad(warp, src, i * 4 + j))); + } + } + return dst; +} + +template +__device__ inline ThrdRegRowMaxT dedupFromQuad(Warp const& warp, Vec const& src) { +#ifndef NDEBUG + for (uint32_t i = 0; i < src.size; i++) { + assert(__float_as_int(src[i]) == __float_as_int(__shfl_sync(~0U, src[i], laneId() / 4 * 4))); + } +#endif + ThrdRegRowMaxT dst{}; + uint32_t const lane = laneId(); + uint32_t const idxMat = lane / 8; + uint32_t const idxRow = lane % 8; +#pragma unroll + for (uint32_t i = 0; i < dst.size; i++) { +#pragma unroll + for (uint32_t j = 0; j < 4; j++) { + float const val = __shfl_sync(~0U, src[i * 4 + j], 4 * idxRow); + if (idxMat == j) { + dst[i] = val; + } + } + } +#ifndef NDEBUG // refcheck + QuadRegRowMaxT rep = replicateForQuad(warp, dst); +#pragma unroll + for (uint32_t i = 0; i < n; i++) { + assert(__float_as_int(src[i]) == __float_as_int(rep[i])); + __syncwarp(); + } +#endif + return dst; +} + +template +__device__ inline ThrdRegRowMaxT computeRowSumF8( + Warp const& warp, Array2D, exactDiv(tileM, 16), exactDiv(tileN, 16)> const& src) { + using WarpAcc = WarpAccT; + WarpAcc acc{}; + Vec<__nv_fp8x2_e4m3, 2> const bWord = {__nv_fp8x2_e4m3{float2{1, 1}}, __nv_fp8x2_e4m3{float2{1, 1}}}; + uint32_t const b[2][1] = {reinterpret_cast(bWord), reinterpret_cast(bWord)}; +#pragma unroll + for (uint32_t i = 0; i < WarpAcc::rows; i++) { +#pragma unroll + for (uint32_t k = 0; k < exactDiv(src.cols, 2); k++) { + mma<__nv_fp8_e4m3>(reinterpret_cast(acc(i, 0)), + reinterpret_cast(src(i, k * 2)), b); + } + } + QuadRegRowMaxT rowSum; + for (uint32_t i = 0; i < WarpAcc::rows; i++) { + for (uint32_t m = 0; m < InstAcc::rows; m++) { +#ifndef NDEBUG + assert(__float_as_int(acc(i, 0)(m, 0)) == __float_as_int(acc(i, 0)(m, 1))); + assert(__float_as_int(acc(i, 0)(m, 0)) == __float_as_int(__shfl_sync(~0U, acc(i, 0)(m, 0), laneId() / 4 * 4))); +#endif + rowSum[i * InstAcc::rows + m] = acc(i, 0)(m, 0); + } + } + return dedupFromQuad(warp, rowSum); +} + +template +__device__ inline ThrdRegRowMaxT computeRowSumF32(Warp const& warp, WarpAccT const& src) { + QuadRegRowMaxT rowSum{}; +#pragma unroll + for (uint32_t n = 0; n < src.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < src.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + if (n == 0 && j == 0) { + rowSum[m * InstAcc::rows + i] = src(m, n)(i, j); + } else { + rowSum[m * InstAcc::rows + i] += src(m, n)(i, j); + } + } + } + } + } + uint32_t const lane = laneId(); +#pragma unroll + for (uint32_t mask = 2; mask != 0; mask /= 2) { +#pragma unroll + for (uint32_t i = 0; i < rowSum.size; i++) { + rowSum[i] += __shfl_xor_sync(~0U, rowSum[i], mask); + } + } + return dedupFromQuad(warp, rowSum); +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_impl.cuh new file mode 100644 index 0000000000000..810c33adb5364 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_impl.cuh @@ -0,0 +1,2606 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "cuda_hint.cuh" +#include "defines.h" +#if !(IS_MLA) +#include "ldgsts.cuh" +#include "mha.h" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" + +#include +#include +#ifndef GENERATE_CUBIN +#include "hostUtils.h" +#include +#ifndef NDEBUG +#include +#endif +#endif + +// There are 4 ways to pass ctaRowMax backward from gemm1 warps to gemm0 warps: +// 1. Protect with xFwdBarriers+xBwdBarriers. This way, ctaRowMax is available to gemm0 warps together with x tiles and +// warpRowMax/warpRowSum. But ctaRowMax is required before warp tile online softmax, while the other buffers is needed +// only after online softmax. So xBwdBarriers wait will need to be moved before online softmax. +// 2. Similar to approach 1, but we add an additional register copy of ctaRowMax in gemm0 warps. It's loaded from smem +// ctaRowMax after warp tile online softmax, so the current warp tile can't use it. But we can pass it to next +// iteration so softmax of next tile can use it. The update will be delayed by 1 more iteration and we need one or two +// more registers. Alternatively, put the extra copy in shared memory, so we have double buffer for ctaRowMax. +// 3. Protected with dedicated backward barriers (xFwdBarriers + ctaRowmaxBwdBarriers). Then we don't have drawbacks of +// 1 or 2, but we need extra smem barriers and extra arrive/wait instructions. +// 4. No protection, just use volatile read/write. This approach gives most timely update and has lowest cost, but the +// result is non-deterministic up to an small numeric error. +// #define CTA_ROW_MAX_BACKWARD_METHOD 4 +// 1 is 8% slower than 4. 2/3 are 10% slower than 4. +#define CTA_ROW_MAX_BACKWARD_METHOD 1 + +static_assert(inputElemSize >= cacheElemSize); + +constexpr uint32_t cacheElemsPerGrain = exactDiv(grainBytes, cacheElemSize); +constexpr uint32_t inputElemsPerGrain = exactDiv(grainBytes, inputElemSize); +constexpr bool enableMicroFastPath = false; + +// x: horizontal stacking for cta horizontal tile size +// y: vertical stacking for cta vertical tile size +// z: must be 2 for warp specialization. +constexpr uint3 ctaShapeInWarps = {4, 1, 2}; + +static_assert(ctaShapeInWarps.z == 2); // for warp specialization +constexpr uint32_t nbWarpsPerCta = ctaShapeInWarps.x * ctaShapeInWarps.y * ctaShapeInWarps.z; +constexpr uint32_t ctaSize = warp_size * nbWarpsPerCta; + +#if SPEC_DEC +// Use 32 row size +constexpr uint32_t nbValidRows = rowsPerBlock; +static_assert(nbValidRows <= 32u); +#else +constexpr uint32_t nbValidRows = headGrpSize * beamWidth; +#endif +constexpr uint2 warpTile = {64, roundUp(nbValidRows, 16U)}; +static_assert(nbValidRows <= warpTile.y); + +constexpr uint32_t gemm1WarpsPerGrp = exactDiv(headElems, warpTile.x); +constexpr uint32_t gemm1NbWarpGrps = exactDiv(ctaShapeInWarps.x, gemm1WarpsPerGrp); // warp groups split along seqLen dim. + +constexpr uint2 ctaTile = {warpTile.x * ctaShapeInWarps.x, // if .x is greater than headSize, then gemm1 uses split-K + warpTile.y* ctaShapeInWarps.y}; + +constexpr uint32_t cvtExpansion = exactDiv(inputElemSize, cacheElemSize); + +#ifndef __CUDA_ARCH__ +constexpr uint32_t preferedKHeadPartBytes = 64; +__constant__ constexpr uint32_t cacheVTileSeqLen = 32; +#else +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 +constexpr uint32_t preferedKHeadPartBytes = 64; +__constant__ constexpr uint32_t cacheVTileSeqLen = 32; +#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 || __CUDA_ARCH__ == 900 +constexpr uint32_t preferedKHeadPartBytes = 128; +__constant__ constexpr uint32_t cacheVTileSeqLen = 64; +#else +// Safe default for older or unknown architectures +constexpr uint32_t preferedKHeadPartBytes = 64; +__constant__ constexpr uint32_t cacheVTileSeqLen = 32; +#endif +#endif +constexpr uint32_t kHeadPartBytes = mha::min(preferedKHeadPartBytes, paddedCacheHeadBytes); +// constexpr uint32_t cacheElemsPerKHeadPart = exactDiv(kHeadPartBytes, cacheElemSize); + +constexpr bool persistentQ = paddedInputHeadBytes * ctaTile.y <= (16u << 10); +static_assert(persistentQ); +constexpr uint32_t qHeadPartBytes = persistentQ ? paddedInputHeadBytes : kHeadPartBytes; +[[maybe_unused]] constexpr uint32_t qHeadPartElems = exactDiv(qHeadPartBytes, inputElemSize); + +constexpr uint32_t nbPartsPerCacheKHead = exactDiv(paddedCacheHeadBytes, kHeadPartBytes); +[[maybe_unused]] constexpr uint32_t nbPartsPerInputKHead = exactDiv(paddedInputHeadBytes, kHeadPartBytes); +constexpr uint32_t nbPartsPerInputQHead = exactDiv(paddedInputHeadBytes, qHeadPartBytes); + +// false - each warp load V tiles independent of each other; true - all warps in a warp group load V tiles together. +// @fixme: when true, and nbVBuffers is only 2, we need to sync all warps in a group after finishing using a buffer and +// before refill it with prefetch data. We may need at least 3. +constexpr bool grpLoadV = GRP_LOAD_V; + +// number of shared memory buffers for latency hiding +constexpr uint32_t nbQBuffers = mha::min(nbPartsPerInputQHead, 2u); // for latency hiding +constexpr uint32_t nbKBuffers = 2; // for latency hiding +constexpr uint32_t nbVBuffers = 2; // @fixme: H100 SXM need more in-flight requests. may need to increase this. +constexpr uint32_t nbXBuffers = 1; + +__device__ inline uint3 getWarpIdx(Warp const& warp = this_warp()) { + return uint3{ctaShapeInWarps.x == 1 ? 0 : makeWarpUniform(warp, threadIdx.x / warp_size), + ctaShapeInWarps.y == 1 ? 0 : makeWarpUniform(warp, threadIdx.y), + ctaShapeInWarps.z == 1 ? 0 : makeWarpUniform(warp, threadIdx.z)}; +} + +__device__ inline uint32_t gemm1WarpGrpIdx(uint32_t warpIdxX) { + return gemm1NbWarpGrps == 1 ? 0 : warpIdxX / gemm1WarpsPerGrp; +} + +__device__ inline uint32_t gemm1WarpIdxInGrp(uint32_t warpIdxX) { + return gemm1WarpsPerGrp == 1 ? 0 : (gemm1NbWarpGrps == 1 ? warpIdxX : warpIdxX % gemm1WarpsPerGrp); +} + +constexpr uint32_t instM = 16; +[[maybe_unused]] constexpr uint32_t instN = 8; +// constexpr uint32_t instK = 16; + +using QuadRegRowMax = QuadRegRowMaxT; // data is replicated across 4 threads in a MMA quad. +using ThrdRegRowMax = ThrdRegRowMaxT; // unlike QuadRegRowMax, not replicated. +using UniformRescaleMask = UniformRescaleMaskT; // uniform and stored in UR + +__device__ inline bool any(UniformRescaleMask const& x) { + uint32_t val = 0U; +#pragma unroll + for (uint32_t i = 0; i < x.size; i++) { + uint32_t word = x[i]; + constexpr uint32_t wordBits = 32; + if (warpTile.y % wordBits != 0 && i + 1 == x.size) { + constexpr uint32_t validBits = warpTile.y % wordBits; + word &= ((1U << validBits) - 1); + } + val |= word; + } + return val != 0; +} + +#ifndef NDEBUG +__device__ inline void printRowMax(ThrdRegRowMax const& src) { + for (uint32_t i = 0; i < warp_size * src.size; i++) { + if (laneId() == i % warp_size) { + printf("%f%s", src[i / warp_size], i == 31 ? "\n" : " "); + } + __syncwarp(); + } +} + +__device__ inline void printRowMax(QuadRegRowMax const& src) { + for (uint32_t i = 0; i < src.size / 4; i++) { + for (uint32_t j = 0; j < 8; j++) { + if (laneId() == 4 * j) { + for (uint32_t k = 0; k < 4; k++) { + printf("%f%s", src[i * 4 + k], i == 31 ? "\n" : " "); + } + } + __syncwarp(); + } + } +} +#endif + +struct alignas(16) SMemWarpRowMax { + __device__ inline float const& operator[](uint32_t idxRow) const { + assert(idxRow < ThrdRegRowMax::size * warp_size); + uint32_t const idxInstM8 = idxRow / quadPerWarp; + return data[ThrdRegRowMax::size == 1 ? 0 : idxInstM8 / 4][idxRow % quadPerWarp][idxInstM8 % 4]; + } + + __device__ inline float& operator[](uint32_t idxRow) { + return const_cast(static_cast(*this)[idxRow]); + } + + // When data is register, data is replicate across 4 threads in a quad. + template + __device__ inline QuadRegRowMax const loadToRegForQuad(Warp const& warp) const { + uint32_t const idxQuad = laneId() / 4; + QuadRegRowMax result; +#pragma unroll + for (uint32_t i = 0; i < divUp(warpTile.y, quadPerWarp * 4); i++) { + auto const& src = data[i][idxQuad]; + auto& dst = reinterpret_cast(result[4 * i]); + if constexpr (asVolatile) { + asm volatile("ld.volatile.shared.v4.f32 {%0, %1, %2, %3}, [%4];\n" + : "=f"(dst[0]), "=f"(dst[1]), "=f"(dst[2]), "=f"(dst[3]) + : "l"(__cvta_generic_to_shared(&src))); + } else { + reinterpret_cast(dst) = reinterpret_cast(src); + } + } + return result; + } + + template + __device__ inline ThrdRegRowMax const loadToReg(Warp const& warp) const { + ThrdRegRowMax result; +#pragma unroll + for (uint32_t i = 0; i < result.size; i++) { + auto const& src = this->operator[](warp_size * i + laneId()); + float& dst = result[i]; + if constexpr (asVolatile) { + dst = static_cast(src); + // asm volatile("ld.volatile.shared.f32 %0, [%1];\n" + // : "=f"(dst) : "l"(__cvta_generic_to_shared(&src))); + } else { + dst = src; + } + } + return result; + } + + template + __device__ inline void storeFromReg(Warp const& warp, QuadRegRowMax const& regData) { + for (uint32_t i = 0; i < regData.size; i++) { + assert(regData[i] == __shfl_sync(0xFU << (laneId() / 4 * 4), regData[i], 0, 4)); + } + if (laneId() % 4 != 0) { + return; + } + uint32_t const idxQuad = laneId() / 4; +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + auto& dst = data[i][idxQuad]; + auto const& src = reinterpret_cast(regData[4 * i]); + if constexpr (asVolatile) { + asm volatile( + "st.volatile.shared.v4.f32 [%0], {%1, %2, %3, %4};\n" ::"l"(__cvta_generic_to_shared(&dst)), + "f"(src[0]), "f"(src[1]), "f"(src[2]), "f"(src[3])); + } else { + reinterpret_cast(dst) = reinterpret_cast(src); + } + } + } + + template + __device__ inline void storeFromReg(Warp const& warp, ThrdRegRowMax const& regData) { +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + auto& dst = this->operator[](warp_size * i + laneId()); + assert(!hasBankConflict(&dst)); + float const src = regData[i]; + if constexpr (asVolatile) { + static_cast(dst) = src; + } else { + dst = src; + } + } + } + + __device__ inline void atomicMaxUpdate(Warp const& warp, ThrdRegRowMax const& regData) { +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + auto& dst = this->operator[](warp_size * i + laneId()); + assert(!hasBankConflict(&dst)); + float const src = regData[i]; + atomicMax(&dst, src); + } + } + + float data[ThrdRegRowMax::size][quadPerWarp][4]; +}; + +// cacheVTileSeqLen may be smaller than x cols, so we need multiple v tiles per X tile. +constexpr uint32_t nbCacheVTilesPerXTile = exactDiv(warpTile.x, cacheVTileSeqLen); + +[[maybe_unused]] constexpr uint32_t nbWarpGrpsPerXTile = mha::min(nbCacheVTilesPerXTile, gemm1NbWarpGrps); + +#if USE_PAGED_KV_CACHE +constexpr uint32_t nbPagesPerWarpTile = (warpTile.x <= tokensPerPage ? 1U : exactDiv(warpTile.x, tokensPerPage)); +using KCachePageIndices = Vec; +constexpr uint32_t nbPagesPerVTile = (cacheVTileSeqLen <= tokensPerPage ? 1 : exactDiv(cacheVTileSeqLen, tokensPerPage)); +using VCachePageIndices = Vec; +#endif + +static_assert(ctaShapeInWarps.y == 1); + +struct alignas(128) SharedMem { + using QSmemBuffer = Array2D; + using KSmemBuffer = Array2D; + using XSmemBuffer = Array2D; + using VSmemBuffer = Array2D; + + QSmemBuffer q[ctaShapeInWarps.y][nbQBuffers]; + KSmemBuffer k[ctaShapeInWarps.x][nbKBuffers]; + XSmemBuffer x[ctaShapeInWarps.y][ctaShapeInWarps.x]; + static_assert(nbXBuffers == 1); + VSmemBuffer v[gemm1NbWarpGrps][grpLoadV ? 1 : gemm1WarpsPerGrp][nbVBuffers]; + + SMemWarpRowMax warpRowMax[ctaShapeInWarps.y][ctaShapeInWarps.x]; // the max used when computing this->x + SMemWarpRowMax warpRowSum[ctaShapeInWarps.y][ctaShapeInWarps.x]; // the row sum of gemm0 output + +#if CTA_ROW_MAX_BACKWARD_METHOD == 1 || CTA_ROW_MAX_BACKWARD_METHOD == 2 || CTA_ROW_MAX_BACKWARD_METHOD == 3 + // protected with xFwdBarriers+xBwdBarriers for CTA_ROW_MAX_BACKWARD_METHOD 1 or 2, and with + // xFwdBarriers+ctaRowMaxBwdBarriers for 3. Cannot reuse warpRowMax because a gemm1 warp is not sure whether other + // gemm1 warps have finished using it, unless we want to pay extra sync. + SMemWarpRowMax ctaRowMax[ctaShapeInWarps.y][ctaShapeInWarps.x]; +#elif CTA_ROW_MAX_BACKWARD_METHOD == 4 + SMemWarpRowMax ctaRowMax[ctaShapeInWarps.y]; // just a hint, no strict protection required if you don't care about + // non-deterministic output (up to a small numeric error) +#endif + +#if BEAM_WIDTH > 1 + Vec gemm0CacheIndir[ctaShapeInWarps.x]; + Vec gemm1CacheIndir[grpLoadV ? gemm1NbWarpGrps : ctaShapeInWarps.x]; +#if USE_PAGED_KV_CACHE + Vec kCachePages[ctaShapeInWarps.x]; + Vec vCachePages[grpLoadV ? gemm1NbWarpGrps : ctaShapeInWarps.x]; +#endif +#endif + + using Barrier = CtaBarrier; + + Barrier qBarrier[ctaShapeInWarps.y]; + // Beside X buffers, also protects warpRowMax and warpRowSum. For CTA_ROW_MAX_BACKWARD_METHOD==1 or 2, also + // ctaRowMax. + CtaBarrierPair xBarriers[ctaShapeInWarps.y][ctaShapeInWarps.x]; +#if CTA_ROW_MAX_BACKWARD_METHOD == 3 + Barrier ctaRowMaxBwdBarriers[ctaShapeInWarps.y] + [ctaShapeInWarps.x]; // xFwdBarriers+ctaRowMaxBwdBarriers protects ctaRowMax +#endif + +#if GRP_LOAD_V + static constexpr uint32_t nbOtherBarriers = nbVBuffers * gemm1NbWarpGrps + gemm1NbWarpGrps; + Barrier otherBarriers[nbOtherBarriers]; +#endif + __device__ inline Barrier* vBarrier(uint32_t warpGrpIdx, uint32_t idxBuf) { +#if GRP_LOAD_V + return &reinterpret_cast(otherBarriers)[warpGrpIdx][idxBuf]; +#else + return nullptr; +#endif + } + + __device__ inline Barrier* warpGrpBar(uint32_t warpGrpIdx) { +#if GRP_LOAD_V + return &otherBarriers[nbVBuffers * gemm1NbWarpGrps + warpGrpIdx]; +#else + return nullptr; +#endif + } +}; + +CUBIN_EXPORT __device__ constexpr uint32_t smemSize = sizeof(SharedMem); +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +static_assert(smemSize < kMAX_SMEM_SIZE); +#endif + +#if 0 +template +__device__ inline void smemRotateInplace(Warp const& Warp, Array2D& data, uint32_t idxPart, uint32_t idxToken) { + static_assert(inputSeqLen == 1); + constexpr uint32_t rowElems = inputElemsPerGrain * cols; + constexpr uint32_t nbParts = exactDiv(headElems, idxPart); + static_assert(nbParts % 2 == 0); + bool const isFirstHalf = (idxPart < nbParts / 2); + static_assert(mha::is_same_v, "not implemented"); + if constexpr (cols <= warp_size) { + static_assert(warp_size % cols == 0); + constexpr uint32_t thrdGrpSize = LdGrain::size * cols; + uint32_t const idxThrdGrp = laneId() / thrdGrpSize; + uint32_t const thrdGrpLane = laneId() % thrdGrpSize; + constexpr uint32_t nbThrdGrps = warp_size / thrdGrpSize; + static_assert(warp_size % thrdGrpSize == 0); + constexpr uint32_t nbElemsPerWord = exactDiv(sizeof(LdGrain::Elem), inputElemSize); + Vec cosAngles; + Vec sinAngles; +#pragma unroll + for (uint32_t i = 0; i < angles.size; i++) { + uint32_t const n = rowElems * (idxPart % (nbParts / 2)) + angles.size * thrdGrpLane + i; + float const angle = powf(1E-4f, n * (2.f / headElems)) * idxToken; + sincosf(angle, &sinAngles[i], &cosAngles[i]); + } + + constexpr uint32_t nbIters = exactDiv(rows, nbThrdGrps); +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) { + auto const word = data.template at(nbThrdGrps * i + idxThrdGrp, thrdGrpLane / LdGrain::size)[thrdGrpLane % LdGrain::size]; + float2 const val = __half22float2(reinterpret_cast(word)); + Vec result; +#pragma unroll + for (uint32_t j = 0; j < nbElemsPerWord; j++) { + if (isFirstHalf) { + result[j] = cosAngles[j] * ; + } + } + } + } + else { + static_assert(cols <= warp_size, "not implemented"); + } +} +#endif + +using WarpAcc = WarpAccT; + +#if SPEC_DEC +#define MMAS_N_PER_MASK 2 + +__device__ inline void applyMaskFromInput(Warp const& warp, WarpAcc& acc, MaskType const* mask, uint32_t rowOffset, + uint32_t nbValidCols, uint32_t qSeqLen, uint32_t actualQSeqLen, uint32_t headGrpSize +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + , + int32_t tok0WinBeg, uint32_t seqIter, uint32_t const cacheSeqLen, uint32_t const warpTileTokenBeg +#endif +) { + uint32_t const idxInQuad = laneId() % 4; + uint32_t const idxQuad = laneId() / 4; + // Packed mask is aligned with 32 bits (2 uint16_t). + uint32_t const nbPackedMasksPerRow = divUp(qSeqLen, 32u) * 2u; + uint16_t const* uint16Mask = reinterpret_cast(mask); + constexpr uint64_t fullMask = ~uint64_t{0}; +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + Range const tileRange = {warpTileTokenBeg, warpTileTokenBeg + warpTile.x}; + Range const maxMaskOutRange = {0, mha::max(0, tok0WinBeg) + (nbValidRows / MMAS_N_PER_MASK - 1)}; + bool const ctaNeedBegMask = tileRange.beg < maxMaskOutRange.end; + assert(ctaNeedBegMask == overlap(tileRange, maxMaskOutRange)); + int32_t const tok0NbMaskOut = int32_t(tok0WinBeg) - int32_t(warpTileTokenBeg); + uint32_t const nbSeqItersWithoutSpecDecMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x; + bool const ctaNeedSpecDecMask = (seqIter >= nbSeqItersWithoutSpecDecMask); +#else + constexpr bool ctaNeedBegMask = false; + bool const ctaNeedSpecDecMask = true; + int32_t const tok0NbMaskOut = -2147483648; +#endif + bool const needMask = ctaNeedBegMask || ctaNeedSpecDecMask; + + if (!needMask) { + return; + } +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + uint32_t const idxQTokInCta = (rowOffset + instM * m + idxQuad + i * 8) / headGrpSize; + uint32_t const tokenRow = min(idxQTokInCta, actualQSeqLen - 1); +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + int32_t const begNbMaskOut = tok0NbMaskOut + int32_t(idxQTokInCta); + uint64_t const begMask = (begNbMaskOut > 0 ? fullMask << begNbMaskOut : fullMask); +#else + uint64_t const begMask = fullMask; +#endif + +#pragma unroll + for (uint32_t mask_n = 0; mask_n < acc.cols / MMAS_N_PER_MASK; mask_n++) { + uint32_t const firstCol = instN * mask_n * MMAS_N_PER_MASK + InstAcc::cols * idxInQuad; + uint32_t const lastCol = firstCol + instN * (MMAS_N_PER_MASK - 1) + InstAcc::cols - 1; + uint32_t const maskPos0 = firstCol + actualQSeqLen < nbValidCols + ? 0u + : min(firstCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1); + uint32_t const maskPos1 = lastCol + actualQSeqLen < nbValidCols + ? 0u + : min(lastCol + actualQSeqLen - nbValidCols, actualQSeqLen - 1); + uint32_t const maskPosStart = (maskPos0 / 16) * 16; + uint32_t packedMask = ~uint32_t{0}; + if (ctaNeedSpecDecMask) { + reinterpret_cast(&packedMask)[0] = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos0 / 16)]; + reinterpret_cast(&packedMask)[1] = uint16Mask[tokenRow * nbPackedMasksPerRow + (maskPos1 / 16)]; + } +#pragma unroll + for (uint32_t nj = 0; nj < MMAS_N_PER_MASK; nj++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + uint32_t const n = (mask_n * MMAS_N_PER_MASK + nj); + uint32_t const col = instN * n + InstAcc::cols * idxInQuad + j; + // bool const maskFlag = col + qSeqLen < nbValidCols ? true : mask[tokenRow * qSeqLen + (col + + // qSeqLen - nbValidCols)]; + bool const maskFlag = col + actualQSeqLen < nbValidCols + ? true + : packedMask & (1u << ((col + actualQSeqLen - nbValidCols) - maskPosStart)); + + bool const begMaskFlag = ctaNeedBegMask ? (begMask & (1ULL << col)) : true; + + acc(m, n)(i, j) = maskFlag && begMaskFlag && col < nbValidCols ? acc(m, n)(i, j) : safeInitRowMax; + } + } + } + } + } +} +#endif + +__device__ inline QuadRegRowMax warpTileOnlineSoftmax(Warp const& warp, QuadRegRowMax const& rowMaxHint, WarpAcc& acc) { + QuadRegRowMax rowMax = rowMaxHint; +// compute per-thread row max +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + rowMax[m * InstAcc::rows + i] = fmaxf(rowMax[m * InstAcc::rows + i], acc(m, n)(i, j)); + } + } + } + } +// compute warp row max +#pragma unroll + for (uint32_t xorMask = 2; xorMask != 0; xorMask /= 2) { +#pragma unroll + for (uint32_t i = 0; i < rowMax.size; i++) { + rowMax[i] = fmaxf(rowMax[i], __shfl_xor_sync(~0U, rowMax[i], xorMask)); + } + } +// update acc and rowMax +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + float const maxVal = rowMax[m * InstAcc::rows + i]; + float const bias = maxVal * log2e; +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + float& elem = acc(m, n)(i, j); + assert(maxVal >= elem); + elem = exp2f(elem * log2e - bias); + } + } + } + } + return rowMax; +} + +using GemmOutRegTile = Array2D; + +__device__ inline GemmOutRegTile toFp16(WarpAcc const& acc) { + GemmOutRegTile dst; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j += 2) { +#if INPUT_FP16 + dst(m * InstAcc::rows + i, (n * InstAcc::cols + j) / 2) = __floats2half2_rn(acc(m, n)(i, j), acc(m, n)(i, j + 1)); +#else + dst(m * InstAcc::rows + i, (n * InstAcc::cols + j) / 2) = __floats2bfloat162_rn(acc(m, n)(i, j), acc(m, n)(i, j + 1)); +#endif + } + } + } + } + return dst; +} + +__device__ inline WarpAcc toWarpAcc(GemmOutRegTile const& outTile) { + WarpAcc acc; +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j += 2) { +#if INPUT_FP16 + float2 const fp32Vals = __half22float2(outTile(m * InstAcc::rows + i, (n * InstAcc::cols + j) / 2)); +#else + float2 const fp32Vals = __bfloat1622float2(outTile(m * InstAcc::rows + i, (n * InstAcc::cols + j) / 2)); +#endif + acc(m, n)(i, j) = fp32Vals.x; + acc(m, n)(i, j + 1) = fp32Vals.y; + } + } + } + } + return acc; +} + +__device__ inline QuadRegRowMax computeRowSum(Warp const& warp, GemmOutRegTile const& src) { + Vec acc{}; +#if INPUT_FP16 + InputElem2 const b[2][1] = {__floats2half2_rn(1, 1), __floats2half2_rn(1, 1)}; +#else + InputElem2 const b[2][1] = {__floats2bfloat162_rn(1, 1), __floats2bfloat162_rn(1, 1)}; +#endif +#pragma unroll + for (uint32_t n = 0; n < exactDiv(GemmOutRegTile::cols, 2); n++) { +#pragma unroll + for (uint32_t m = 0; m < exactDiv(GemmOutRegTile::rows, 2); m++) { + InputElem2 const a[2 /*kEx*/][2 /*mEx*/] = {src(m * 2, n * 2), src(m * 2 + 1, n * 2), src(m * 2, n * 2 + 1), src(m * 2 + 1, n * 2 + 1)}; + mma(acc[m].data, reinterpret_cast(a), + reinterpret_cast(b)); + } + } + QuadRegRowMax rowSum; +#pragma unroll + for (uint32_t i = 0; i < acc.size; i++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::rows; j++) { + rowSum[i * InstAcc::rows + j] = acc[i](j, 0); +#pragma unroll + for (uint32_t k = 0; k < InstAcc::cols; k++) { + assert(acc[i](j, k) == acc[i](j, 0)); + } + } + rowSum[i * 2] = acc[i](0, 0); + rowSum[i * 2 + 1] = acc[i](1, 0); + } +// Sometimes there are errors in sum and they mismatch inside a quad. Force broadcast from lane 0 of each quad to +// eliminate mismatch. This has no visible impact on final result and can be removed. +#pragma unroll + for (uint32_t i = 0; i < QuadRegRowMax::size; i++) { + auto const lane0Val = __shfl_sync(0xFU << (laneId() / 4 * 4), rowSum[i], 0, 4); + // Disable the assert, sometimes it triggers because of different orders of accumulation. + // assert(fabs(rowSum[i] - lane0Val) < 1E-4f); + rowSum[i] = lane0Val; + } + return rowSum; +} + +__device__ inline void storeOrderedGemmOutTile(Warp const& warp, SharedMem::XSmemBuffer& dst, GemmOutRegTile const& src) { + static_assert(sizeof(dst) == sizeof(src) * warp_size); + uint32_t const lane = laneId(); +#if __CUDA_ARCH__ >= 900 + constexpr uint2 storeUnits = {4, 1}; // in 8x8 b16 matrices. + static_assert(storeUnits.x * storeUnits.y == 4); +#pragma unroll + for (uint32_t m = 0; m < exactDiv(dst.rows, 8 * storeUnits.y); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(dst.cols * grainBytes / inputElemSize, 8 * storeUnits.x); n++) { + uint32_t const idxRowLocal = lane % 8; + uint32_t const flatIdxMatLocal = lane / 8; + uint2 const idxMatLocal = {flatIdxMatLocal % storeUnits.x, flatIdxMatLocal / storeUnits.x}; + LdGrain* const p = &dst.template at( + 8 * (storeUnits.y * m + idxMatLocal.y) + idxRowLocal, storeUnits.x * n + idxMatLocal.x); + + LdGrain data; +#pragma unroll + for (uint32_t i = 0; i < storeUnits.y; i++) { +#pragma unroll + for (uint32_t j = 0; j < storeUnits.x; j++) { + data[i * storeUnits.x + j] = reinterpret_cast(src(m * storeUnits.y + i, n * storeUnits.x + j)); + } + } + stmatrix_4x(warp, p, data); + } + } +#else +#pragma unroll + for (uint32_t m = 0; m < exactDiv(dst.rows, 8); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(dst.cols * grainBytes / inputElemSize, 8); n++) { + uint32_t const idxRowLocal = laneId() / 4; + uint32_t const idxWordLocal = laneId() % 4; + dst.template at(8 * m + idxRowLocal, n)[idxWordLocal] = reinterpret_cast(src(m, n)); + } + } +#endif +} + +// Reorder to compensate the reorder caused by V cache load+conversion. +__device__ inline void reorderAndStoreGemmOutTile( + Warp const& warp, SharedMem::XSmemBuffer& dst, GemmOutRegTile const& src) { + static_assert(sizeof(dst) == sizeof(src) * warp_size); + uint32_t const lane = laneId(); +#pragma unroll + for (uint32_t m = 0; m < exactDiv(dst.rows, 8); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(dst.cols * grainBytes / inputElemSize, 8 * 2); n++) { + uint32_t const idxRowLocal = laneId() / 4; + uint32_t const idxSegLocal = laneId() % 4; + Vec seg; +#pragma unroll + for (uint32_t e = 0; e < cvtExpansion; e++) { + seg[e] = src(m, n * cvtExpansion + e); + } + // reorder + // Ideally compiler should be able to fuse this into toFp16() and just reorder input registers of F2FP + // instructions. + Vec reorderedSeg; +#pragma unroll + for (uint32_t e = 0; e < cvtExpansion; e++) { + reorderedSeg[e] = seg[e].x; + reorderedSeg[cvtExpansion + e] = seg[e].y; + } + static_assert(cvtExpansion <= LdGrain::size); + constexpr uint32_t nbSegPerGrain = exactDiv(grainBytes, sizeof(seg)); + reinterpret_cast&>(dst.template at(8 * m + idxRowLocal, + n * cvtExpansion + idxSegLocal / nbSegPerGrain)[idxSegLocal % nbSegPerGrain * cvtExpansion]) = reinterpret_cast&>(reorderedSeg); + } + } +} + +__device__ inline void storeGemmOutTile( + Warp const& warp, SharedMem::XSmemBuffer& dst, GemmOutRegTile const& src, bool reorder) { + if (reorder) { + reorderAndStoreGemmOutTile(warp, dst, src); + } else { + storeOrderedGemmOutTile(warp, dst, src); + } +} + +__device__ inline GemmOutRegTile loadGemmOutTile(Warp const& warp, SharedMem::XSmemBuffer const& src) { + uint32_t const lane = laneId(); + GemmOutRegTile dst; + static_assert(sizeof(src) == sizeof(dst) * warp_size); +#if __CUDA_ARCH__ >= 900 + constexpr uint2 storeUnits = {4, 1}; // in 8x8 b16 matrices. + static_assert(storeUnits.x * storeUnits.y == 4); +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::XSmemBuffer::rows, 8 * storeUnits.y); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::XSmemBuffer::cols * grainBytes / inputElemSize, 8 * storeUnits.x); + n++) { + uint32_t const idxRowLocal = lane % 8; + uint32_t const flatIdxMatLocal = lane / 8; + uint2 const idxMatLocal = {flatIdxMatLocal % storeUnits.x, flatIdxMatLocal / storeUnits.x}; + LdGrain const* const p = &src.template at( + 8 * (storeUnits.y * m + idxMatLocal.y) + idxRowLocal, storeUnits.x * n + idxMatLocal.x); + + LdGrain data = ldmatrix_4x(warp, p); +#pragma unroll + for (uint32_t i = 0; i < storeUnits.y; i++) { +#pragma unroll + for (uint32_t j = 0; j < storeUnits.x; j++) { + reinterpret_cast(dst(m * storeUnits.y + i, n * storeUnits.x + j)) = data[i * storeUnits.x + j]; + } + } + } + } +#else +#pragma unroll + for (uint32_t m = 0; m < exactDiv(SharedMem::XSmemBuffer::rows, 8); m++) { +#pragma unroll + for (uint32_t n = 0; n < exactDiv(SharedMem::XSmemBuffer::cols * grainBytes / inputElemSize, 8); n++) { + uint32_t const idxRowLocal = laneId() / 4; + uint32_t const idxWordLocal = laneId() % 4; + reinterpret_cast(dst(m, n)) = src.template at(8 * m + idxRowLocal, n)[idxWordLocal]; + } + } +#endif + return dst; +} +// only the first nbValidRows rows are copied, to allow padding. +__device__ inline void copyOutputToGlobalMem(Warp const& warp, OutputHead* dst, uint32_t nbQHeads, +#if SPEC_DEC + uint32_t headGrpSize, uint32_t idxHeadGrpOffset, uint32_t nbValidHeadTokens, +#else + uint32_t idxHeadGrp, +#endif + uint2 dstOffset, SharedMem::XSmemBuffer const& src) { + static_assert(sizeof(PaddedInputHead) == grainBytes * SharedMem::XSmemBuffer::cols * gemm1WarpsPerGrp); +#if SPEC_DEC + static_assert(warpTile.y <= SharedMem::XSmemBuffer::rows); +#else + static_assert(nbValidRows <= SharedMem::XSmemBuffer::rows); +#endif + constexpr uint32_t nbIters = divUp(nbValidRows * SharedMem::XSmemBuffer::cols, warp_size); +#pragma unroll + for (uint32_t i = 0; i < nbIters; i++) { + uint32_t const flatIdx = warp_size * i + laneId(); + uint32_t const r = flatIdx / SharedMem::XSmemBuffer::cols; + uint32_t const c = flatIdx % SharedMem::XSmemBuffer::cols; + assert(r < SharedMem::XSmemBuffer::rows); + LdGrain const data = src.template at(r, c); + + uint32_t const m = dstOffset.y + r; + uint32_t const n = exactDiv(dstOffset.x, grainBytes / inputElemSize) + c; +#if SPEC_DEC + if (r >= nbValidHeadTokens) { +#else + if (nbValidRows * SharedMem::XSmemBuffer::cols % warp_size != 0 && m >= nbValidRows) { +#endif + break; + } + assert(m < nbValidRows); +#if SPEC_DEC + uint32_t const idxBeam = 0; + uint32_t const idxInGrp = m; + uint32_t const tokenIdx = idxInGrp / headGrpSize; + uint32_t const headIdx = idxInGrp % headGrpSize; + assert(idxBeam < beamWidth); + uint32_t const idxHead = idxHeadGrpOffset + tokenIdx * nbQHeads + headIdx; + assert(idxHead < nbValidHeadTokens * nbQHeads); +#else + uint32_t const idxBeam = m / headGrpSize; + uint32_t const idxInGrp = m % headGrpSize; + assert(idxBeam < beamWidth); + uint32_t const idxHead = headGrpSize * idxHeadGrp + idxInGrp; + assert(idxHead < nbQHeads); +#endif + assert(n < paddedInputHeadBytes / grainBytes); + if (!isHeadPadded || n < ioHeadBytes / grainBytes) { + auto const outVec = convert(reinterpret_cast const&>(data)); + reinterpret_cast, exactDiv(ioHeadBytes, grainBytes)>&>( + dst[nbQHeads * idxBeam + idxHead])[n] = outVec; + } + } +} + +// MMA instruction expansion in GEMM k-dim and m/n-dim, with b16 8x8 as baseline +template +struct InstInMat { + static constexpr uint32_t kEx = kEx_; + static constexpr uint32_t mnEx = mnEx_; + uint32_t data[kEx][mnEx]; +}; + +template +using InstInMatWTrans = InstInMat; + +//@fixme: for B-mat, use InstInMat<2, 1>[2] instead. + +// kEx is for srcCol and mnEx is for srcRow, before transpose. +// rowBeg/colBeg are in src indices +// note that grainBytes-byte swizzling per 128-byte or per row(>=128byte) is applied when loading to avoid bank +// conflict. transOuter: transpose InstInMat with 8x8 b16 matrices as elements unchanged. transInner: transpose the +// elements, i.e. the 8x8 b16 matrices. transOuter=true and transInner=false is for B matrix of 16816. It actually loads +// two 8x16 B matrices for two instructions. transOuter=false and transInner=false is for A matrix of 16816. +template +__device__ inline InstInMatWTrans loadInstInMat( + Warp const& warp, Array2D const& src, uint32_t rowOffset, uint32_t colOffset) { + static_assert(kEx * mnEx == 4, "implemented only for ldmatrix.x4 for now"); + using Dst = InstInMatWTrans; + assert(rowOffset % (8 * mnEx) == 0 && colOffset % kEx == 0); + uint32_t const idx = laneId() / 8; + uint32_t const idxKEx = idx / Dst::mnEx; + uint32_t const idxMNEx = idx % Dst::mnEx; + uint32_t const srcIdxKEx = (transOuter ? idxMNEx : idxKEx); + uint32_t const srcIdxMNEx = (transOuter ? idxKEx : idxMNEx); + + LdGrain const* const ptr = &src.template at(rowOffset + 8 * srcIdxMNEx + laneId() % 8, colOffset + srcIdxKEx); + + Vec const data = ldmatrix_4x(warp, ptr); + static_assert(sizeof(Dst) == sizeof(data)); + Dst dst; +#pragma unroll + for (int i = 0; i < data.size; i++) { + (&dst.data[0][0])[i] = data[i]; + } + return dst; +} + +template +using Array2DWTrans = Array2D; + +// src rows/cols are in src indices +// dst rows/cols are in InstInMatWTrans +// row is contiguous and gemm-K dim. +// kEx combines with dstCols and mnEx combines with dstRows. +template +__device__ inline Array2DWTrans, dstRows, dstCols, transArr2D> +loadMatrix(Warp const& warp, Array2D const& src, uint32_t rowBeg, uint32_t colBeg) { + assert(rowBeg % (8 * mnEx * dstRows) == 0 && colBeg % (kEx * dstCols) == 0); + Array2DWTrans, dstRows, dstCols, transArr2D> dst; +#pragma unroll + for (uint32_t i = 0; i < dstRows; i++) { +#pragma unroll + for (uint32_t j = 0; j < dstCols; j++) { + (transArr2D ? dst(j, i) : dst(i, j)) = loadInstInMat( + warp, src, rowBeg + (mnEx * 8) * i, colBeg + kEx * j); + } + } + return dst; +} + +// acc is used as both input and output +// qColBeg is in the unit of LdGrain +// using KElemType = int8_t; +template +__device__ inline void smemQKPartGemm( + Warp const& warp, WarpAcc& acc, SharedMem::QSmemBuffer const& q, uint32_t qColBeg, SharedMem::KSmemBuffer const& k) { + assert(qColBeg % (SharedMem::KSmemBuffer::cols) == 0); + constexpr uint32_t kEx = 2; + constexpr uint32_t mnEx = 2; + static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); + static_assert((mha::is_same_v || mha::is_same_v || mha::is_same_v || mha::is_same_v), + "not implemented"); + constexpr uint32_t nbInstInMatPerSliceInGemmKDim = 1; + constexpr uint32_t kElemSize = sizeof(KElemType); + constexpr uint32_t elemsPerKHeadPart = exactDiv(kHeadPartBytes, kElemSize); + constexpr uint32_t gemmKSplit = exactDiv(elemsPerKHeadPart, 8 * kEx * nbInstInMatPerSliceInGemmKDim); + + // @fixme: check if compiler mixes LDS+HMMA and does prefetch properly. We are not doing prefetch explicitly. But we + // do fully unroll and expect compiler to do that for us. + constexpr uint32_t nbUnroll = cacheElemSize == 2 ? gemmKSplit : 2; +#pragma unroll(nbUnroll) + for (uint32_t s = 0; s < gemmKSplit; s++) { + // load q + constexpr uint32_t qSliceRows = exactDiv(warpTile.y, 8 * mnEx); // in InstInMat + constexpr uint32_t qSliceCols = nbInstInMatPerSliceInGemmKDim; + Array2D, qSliceRows, qSliceCols> const qSlice = loadMatrix( + warp, q, 0, qColBeg + kEx * qSliceCols * s); + // load k + constexpr uint32_t cvtExp = exactDiv(inputElemSize, kElemSize); + constexpr uint32_t mnExK = mnEx * cvtExp; + constexpr uint32_t kExK = exactDiv(kEx, cvtExp); + constexpr uint32_t kSliceRows = exactDiv(warpTile.x, 8 * mnExK); // in InstInMat + constexpr uint32_t kSliceCols = nbInstInMatPerSliceInGemmKDim; + Array2D, kSliceRows, kSliceCols> const kSliceOrig = loadMatrix(warp, k, 0, kExK * kSliceCols * s); + auto const kSlice = [&]() -> Array2D, kSliceRows, kSliceCols> { + if constexpr (mha::is_same_v) { + return kSliceOrig; + } else if constexpr ((mha::is_same_v || mha::is_same_v)) { + Array2D, kSliceRows, kSliceCols> ret; +#pragma unroll + for (uint32_t m = 0; m < kSliceRows; m++) { +#pragma unroll + for (uint32_t n = 0; n < kSliceCols; n++) { +#pragma unroll + for (uint32_t i = 0; i < mnExK; i++) { +#pragma unroll + for (uint32_t j = 0; j < kExK; j++) { + auto const data = convertKCacheWordToF16(kSliceOrig(m, n).data[i][j]); + ret(m, n).data[i][j * cvtExp] = data[0]; + ret(m, n).data[i][j * cvtExp + 1] = data[1]; + } + } + } + } + return ret; + } else { + assert(!"not implemented"); + trap(); + } + }(); +// compute +#pragma unroll + for (uint32_t i = 0; i < qSliceRows; i++) { +#pragma unroll + for (uint32_t j = 0; j < kSliceRows; j++) { + InstInMat const matrixA = qSlice(i, 0); + InstInMat const matrixB = kSlice(j, 0); +#pragma unroll + for (uint32_t n = 0; n < mnExK; n++) { + uint32_t const b[2][1] = {matrixB.data[n][0], matrixB.data[n][1]}; + mma(acc(i, j * mnExK + n).data, matrixA.data, b); + } + } + } + } +} + +// acc is used as both input and output +// v needs transpose +template +__device__ inline void smemXVPartGemm(Warp const& warp, WarpAcc& acc, bool skipXRowRescale, + UniformRescaleMask xRowNeedRescaleMask, ThrdRegRowMax xRowScales, SharedMem::XSmemBuffer const& x, + uint32_t idxVTilePerXTile, SharedMem::VSmemBuffer const& vt, uint32_t idxNSplit) { + static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); + static_assert((mha::is_same_v || mha::is_same_v || mha::is_same_v || mha::is_same_v), + "not implemented"); + constexpr uint32_t kEx = 2; + constexpr uint32_t mnEx = 2; + constexpr uint32_t nbInstInMatPerSliceInGemmKDim = 1; + static_assert(SharedMem::XSmemBuffer::rows == 8 * InstAcc::rows * WarpAcc::rows); + static_assert( + grpLoadV || sizeof(SharedMem::VSmemBuffer::Elem) / cacheElemSize * SharedMem::VSmemBuffer::cols == warpTile.x); + static_assert( + !grpLoadV || sizeof(SharedMem::VSmemBuffer::Elem) / cacheElemSize * SharedMem::VSmemBuffer::cols == headElems); + if (grpLoadV) { + assert(idxNSplit < gemm1WarpsPerGrp); + } else { + assert(idxNSplit == 0); + } + constexpr uint32_t gemmKSplit = exactDiv(SharedMem::VSmemBuffer::rows, 8 * kEx * nbInstInMatPerSliceInGemmKDim); + + Vec xRowScalesQuad; + if (!enableMicroFastPath || !skipXRowRescale) { + assertWarpConverged(); +#if INPUT_FP16 + Vec const xRowScalesF16 = __float2half2_rn(xRowScales); +#else + Vec const xRowScalesF16 = __float2bfloat162_rn(xRowScales); +#endif + static_assert(sizeof(xRowScalesF16) == sizeof(ThrdRegRowMax)); + reinterpret_cast(xRowScalesQuad) = replicateForQuad(warp, reinterpret_cast(xRowScalesF16)); + } + +// @fixme: check if compiler mixes LDS+HMMA and does prefetch properly. We are not doing prefetch explicitly. But we do +// fully unroll and expect compiler to do that for us. +#pragma unroll + for (uint32_t s = 0; s < gemmKSplit; s++) { + // load x + constexpr uint32_t xSliceRows = exactDiv(warpTile.y, 8 * mnEx); // in InstInMat + constexpr uint32_t xSliceCols = nbInstInMatPerSliceInGemmKDim; + uint32_t const colBeg = SharedMem::XSmemBuffer::cols / nbCacheVTilesPerXTile * idxVTilePerXTile + exactDiv(inputElemSize * 8 * kEx * nbInstInMatPerSliceInGemmKDim, grainBytes) * s; + Array2D, xSliceRows, xSliceCols> xSlice = loadMatrix(warp, x, 0u, colBeg); + if (!enableMicroFastPath || !skipXRowRescale) { +#pragma unroll + for (uint32_t m = 0; m < xSliceRows; m++) { +#pragma unroll + for (uint32_t i = 0; i < mnEx; i++) { + uint32_t const r = m * mnEx + i; +#pragma unroll + for (uint32_t n = 0; n < xSliceCols; n++) { +#pragma unroll + for (uint32_t j = 0; j < kEx; j++) { + InputElem2& elem = reinterpret_cast(xSlice(m, n).data[j][i]); + elem = skipXRowRescale ? elem : elem * xRowScalesQuad[r]; + } + } + } + } + } + // load v slice. rows and cols here are before transpose + constexpr uint32_t mnExV = mnEx * cvtExpansion; + constexpr uint32_t vSliceCols = exactDiv(warpTile.x, 8 * mnExV); // in InstInMat + constexpr uint32_t vSliceRows = nbInstInMatPerSliceInGemmKDim; + uint32_t const rowBeg = 8 * kEx * nbInstInMatPerSliceInGemmKDim * s; + Array2D, vSliceCols, vSliceRows> const vSliceOrig = loadMatrix( + warp, vt, rowBeg, mnEx * vSliceCols * idxNSplit); + Array2D, vSliceCols, vSliceRows> const vSlice = [&]() { + if constexpr (mha::is_same_v) { + return vSliceOrig; + } else if constexpr ((mha::is_same_v || mha::is_same_v)) { + Array2D, vSliceCols, vSliceRows> ret; +#pragma unroll + for (uint32_t m = 0; m < ret.rows; m++) { +#pragma unroll + for (uint32_t n = 0; n < ret.cols; n++) { + auto const& src = vSliceOrig(m, n); + auto& dst = ret(m, n); +#pragma unroll + for (uint32_t i = 0; i < mnEx; i++) { +#pragma unroll + for (uint32_t j = 0; j < kEx; j++) { + auto const data = convertVCacheWordToF16(src.data[i][j]); +#pragma unroll + for (uint32_t e = 0; e < cvtExpansion; e++) { + dst.data[i * cvtExpansion + e][j] = data[e]; + } + } + } + } + } + return ret; + } else { + assert(!"not implemented"); + trap(); + } + }(); +// compute +#pragma unroll + for (uint32_t i = 0; i < xSliceRows; i++) { +#pragma unroll + for (uint32_t j = 0; j < vSliceCols; j++) { + auto const& vInMat = vSlice(j, 0); +#pragma unroll + for (uint32_t n = 0; n < mnExV; n++) { + mma(acc(i, j * mnExV + n).data, xSlice(i, 0).data, + reinterpret_cast(vInMat.data[n])); + } + } + } + } +} + +__device__ inline void pickAccRowsForBeamSearch(Warp const& warp, WarpAcc& dst, WarpAcc const& src, bool isCtxTile, + uint32_t idxBeam, void (*func)(float& d, float s)) { + uint32_t const idxQuad = laneId() / 4; + constexpr uint32_t nbQuads = warp_size / 4; +#pragma unroll + for (uint32_t m = 0; m < WarpAcc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < WarpAcc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + uint32_t const idxRow = instM * m + nbQuads * i + idxQuad; + if (isCtxTile || (idxRow >= headGrpSize * idxBeam && idxRow < headGrpSize * idxBeam + headGrpSize)) { + func(dst(m, n)(i, j), src(m, n)(i, j)); + } + } + } + } + } +} + +__device__ inline void rescaleAcc( + Warp const& warp, WarpAcc& acc, UniformRescaleMask const& rescaleMask, ThrdRegRowMax const& rowScales) { + static_assert(WarpAcc::rows * InstAcc::rows * 8 <= ThrdRegRowMax::size * warp_size); +// QuadRegRowMax const quadRowScales = replicateForQuad(warp, rowScales); +#pragma unroll + for (uint32_t m = 0; m < WarpAcc::rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { + uint32_t const r = m * InstAcc::rows + i; // in 8-row unit. + bool const skip = enableMicroFastPath && ((rescaleMask[r / 4] & (0xFFU << 8 * r)) == 0); + if (skip) { // @fixme: do we need this? + continue; + } + // float const scale = quadRowScales[r]; // @fixme: see if this is faster than the line below. + float const scale = replicateValForQuad(warp, rowScales, r); +#pragma unroll + for (uint32_t n = 0; n < WarpAcc::cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + acc(m, n)(i, j) *= scale; + } + } + } + } +} + +__device__ inline void rescaleAcc(Warp const& warp, WarpAcc& acc, float scale) { +#pragma unroll + for (uint32_t m = 0; m < acc.rows; m++) { +#pragma unroll + for (uint32_t i = 0; i < InstAcc::rows; i++) { +#pragma unroll + for (uint32_t n = 0; n < acc.cols; n++) { +#pragma unroll + for (uint32_t j = 0; j < InstAcc::cols; j++) { + acc(m, n)(i, j) *= scale; + } + } + } + } +} + +template +__device__ inline void smemFp16ArraySum( + uint32_t idxWarp, Array2D& dst, Array2D const tiles[nbTiles]) { + constexpr uint32_t nbThrds = warp_size * nbWarps; + uint32_t const tid = warp_size * idxWarp + laneId(); + constexpr uint32_t nbGrains = SharedMem::XSmemBuffer::rows * SharedMem::XSmemBuffer::cols; + constexpr uint32_t nbGrainsPerThrd = exactDiv(nbGrains, nbThrds); + using AccType = mha::conditional_t; + +#pragma unroll + for (uint32_t i = 0; i < nbGrainsPerThrd; i++) { + Vec result; + result.fill(AccType{0, 0}); + uint32_t const idx = nbThrds * i + tid; +#pragma unroll + for (uint32_t j = 0; j < nbTiles; j++) { + auto const data = reinterpret_cast const(&)[nbGrains]>(tiles[j])[idx]; + if constexpr (useFp32Acc) { +#if INPUT_FP16 + result = addFloat2(result, __half22float2(data)); +#else + result = addFloat2(result, __bfloat1622float2(data)); +#endif + } else { + result = __hadd2_rn(result, data); + } + } + auto& dstGrain = reinterpret_cast(&)[nbGrains]>(dst)[idx]; + if constexpr (useFp32Acc) { +#if INPUT_FP16 + dstGrain = __float22half2_rn(result); +#else + PRAGMA_UNROLL_FP16_ONLY + for (uint32_t k = 0; k < LdGrain::size; ++k) { + dstGrain[k] = __floats2bfloat162_rn(result[k].x, result[k].y); + } +#endif + } else { + dstGrain = result; + } + } +} + +template +__device__ inline ThrdRegRowMax mergeRowMax( + Warp const& warp, TinyPtr const rowMaxBuffers, uint32_t nbSubSeqPerSeq) { + ThrdRegRowMax regBuffers[nbBuffers]; + auto load = [&](uint32_t n) { + assert(n < nbSubSeqPerSeq); + regBuffers[n % nbBuffers] = rowMaxBuffers[n].loadToReg(warp); + }; +#pragma unroll + for (uint32_t i = 0; i < nbBuffers; i++) { + if (i >= nbSubSeqPerSeq) { + break; + } + load(i); + } + ThrdRegRowMax mergedRowMax = regBuffers[0]; + for (uint32_t n = 0; n < divUp(nbSubSeqPerSeq, nbBuffers); n++) { +#pragma unroll + for (uint32_t i = 0; i < nbBuffers; i++) { + uint32_t const idx = nbBuffers * n + i; + if (idx >= nbSubSeqPerSeq) { + break; + } + mergedRowMax = fmaxf(mergedRowMax, regBuffers[i]); + uint32_t const idxNext = idx + nbBuffers; + if (idxNext < nbSubSeqPerSeq) { + load(idxNext); + } + } + } + return mergedRowMax; +} + +__device__ inline void addAttentionSinks( + ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks) { + for (uint32_t i = 0; i < globalRowSum.size; i++) { + uint32_t srcOffset = warp_size * i + laneId(); + if (srcOffset < headGrpSize) { + globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]); + } + } +} + +#ifdef NDEBUG +__device__ __forceinline__ +#else +CUBIN_EXPORT __global__ +#endif + void + kernel_mha_impl( +#if SPEC_DEC + uint32_t const qSeqLen, uint32_t const nbKHeads, uint32_t const headGrpSize, + SeqLenDataType const* __restrict__ qCuSeqLens, // [nbReq + 1] +#else + uint32_t const nbKHeads, +#endif +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + // NOTE: the input is actually Q buffer when integrated to TRT-LLM. + IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], +#if SPEC_DEC + MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)]. +#endif + float const* attentionSinks, // [headGrpSize] +#ifdef NDEBUG + KVCacheList const& cacheList, +#if BEAM_WIDTH > 1 + BeamSearchParams const& beamSearchParams, +#endif +#else + KVCacheList const cacheList, +#if BEAM_WIDTH > 1 + BeamSearchParams const beamSearchParams, +#endif +#endif + uint32_t const batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for + // int8/fp8 KV cache. + uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { + assert(allowMultiBlockMode || gridDim.x == 1); + bool const isMultiBlock = allowMultiBlockMode && (gridDim.x != 1); + uint32_t const nbSubSeqPerSeq = allowMultiBlockMode ? gridDim.x : 1; + uint32_t const idxSubSeqInSeq = allowMultiBlockMode ? blockIdx.x : 0; + assert(!isMultiBlock || (semaphores != nullptr && scratch != nullptr)); + + // gridDim: x - K/V sequence-dim split; y - number of K or V heads per token; z - number of requests + assert(gridDim.z == batchSize && gridDim.y == nbKHeads); + extern __shared__ char smemByteBuf[]; + SharedMem& smem = *reinterpret_cast(&smemByteBuf[0]); + + uint32_t const idxReq = blockIdx.z; +#if SPEC_DEC + // Variable query sequence length support. + bool const variableQSeqLen = qCuSeqLens != nullptr; + uint32_t const actualQSeqLen = variableQSeqLen ? uint32_t(qCuSeqLens[idxReq + 1] - qCuSeqLens[idxReq]) : qSeqLen; + // Same as idxReq * qSeqLen if all sequences all the same. + // Take different beams as different requests/sequences currently. + uint32_t const reqSeqOffset = variableQSeqLen ? uint32_t(qCuSeqLens[idxReq]) : (qSeqLen * idxReq); + + uint32_t const nbVHeads = nbKHeads; + uint32_t const nbQHeads = nbKHeads * headGrpSize; + uint32_t const nbQHeadTokens = nbQHeads * actualQSeqLen; + uint32_t const nbQKVHeads = nbQHeads + nbKHeads + nbVHeads; + + uint32_t const nbTokenBlocksPerGrp = gridDim.y / nbKHeads; + uint32_t const idxHeadGrp = blockIdx.y / nbTokenBlocksPerGrp; // inside one request + uint32_t const idxHeadTokenInGrp = (blockIdx.y % nbTokenBlocksPerGrp) * warpTile.y; + uint32_t const totalNbHeadTokensInGrp = actualQSeqLen * headGrpSize; + uint32_t const nbValidHeadTokens = idxHeadTokenInGrp > totalNbHeadTokensInGrp + ? 0u + : mha::min(totalNbHeadTokensInGrp - idxHeadTokenInGrp, rowsPerBlock); + // Shift the mask ptr by batch_idx. + mask += reqSeqOffset * divUp(qSeqLen, 32u); +#else + uint32_t const nbQHeads = nbKHeads * headGrpSize; + + uint32_t const idxHeadGrp = blockIdx.y; // inside one request +#endif + + auto const ctaThrdId = threadIdx.x + warp_size * ctaShapeInWarps.x * (threadIdx.y + ctaShapeInWarps.y * threadIdx.z); + assert(blockDim.x == ctaShapeInWarps.x * warp_size && blockDim.y == ctaShapeInWarps.y && blockDim.z == ctaShapeInWarps.z); + auto const warp = this_warp(); + uint3 const warpIdx = getWarpIdx(warp); // @fixme: use BoundedVal + assert(warpIdx.x < ctaShapeInWarps.x && warpIdx.y < ctaShapeInWarps.y && warpIdx.z < ctaShapeInWarps.z); + uint32_t const flatWarpIdPerRow = warpIdx.z * ctaShapeInWarps.x + warpIdx.x; // per ctaShapeInWarps.y value + unused(flatWarpIdPerRow); + + // initialize shared memory + static_assert(persistentQ && ctaShapeInWarps.y == 1); + if (ctaThrdId < ctaShapeInWarps.y) { + init(&smem.qBarrier[ctaThrdId], warp_size * ctaShapeInWarps.x); // be sure to use .noinc + } + constexpr uint32_t cacheVTileSeqStride = cacheVTileSeqLen * gemm1NbWarpGrps; + constexpr uint32_t nbXTilesPerXIter = cacheVTileSeqStride < warpTile.x ? 1 : exactDiv(cacheVTileSeqStride, warpTile.x); + constexpr uint32_t nbXItersPerCtaTile = exactDiv(ctaShapeInWarps.x, nbXTilesPerXIter); + constexpr uint32_t nbVItersPerXIter = exactDiv(warpTile.x * nbXTilesPerXIter, cacheVTileSeqStride); + constexpr uint32_t nbWarpGrpsPerXTile = mha::min(nbCacheVTilesPerXTile, gemm1NbWarpGrps); + unused(nbWarpGrpsPerXTile); + static_assert(warpTile.x >= cacheVTileSeqLen, "not implemented yet"); + static_assert(ctaSize >= uint32_t(sizeof(smem.xBarriers) / sizeof(CtaBarrierPair))); + if (ctaThrdId < uint32_t(sizeof(smem.xBarriers) / sizeof(CtaBarrierPair))) { + (&smem.xBarriers[0][0])[ctaThrdId].initialize(warp_size, warp_size * gemm1WarpsPerGrp * nbWarpGrpsPerXTile); + } +#if CTA_ROW_MAX_BACKWARD_METHOD == 3 + static_assert(ctaSize >= sizeof(smem.ctaRowMaxBwdBarriers) / sizeof(SharedMem::Barrier)); + if (ctaThrdId < sizeof(smem.ctaRowMaxBwdBarriers) / sizeof(SharedMem::Barrier)) { + init(&smem.ctaRowMaxBwdBarriers[0][0] + ctaThrdId, warp_size); + } +#endif +#if CTA_ROW_MAX_BACKWARD_METHOD != 0 + static_assert(ctaSize >= sizeof(smem.ctaRowMax) / sizeof(float)); + if (ctaThrdId < sizeof(smem.ctaRowMax) / sizeof(float)) { + reinterpret_cast(&smem.ctaRowMax[0])[ctaThrdId] = SAFE_INIT_ROW_MAX; + } +#endif +#if GRP_LOAD_V + static_assert(ctaSize >= gemm1NbWarpGrps * nbVBuffers); + if (ctaThrdId < gemm1NbWarpGrps * nbVBuffers) { + init(smem.vBarrier(0, 0) + ctaThrdId, warp_size * gemm1WarpsPerGrp); + } + if (ctaThrdId < gemm1NbWarpGrps) { + init(smem.warpGrpBar(ctaThrdId), warp_size * gemm1WarpsPerGrp); + } +#endif + __syncthreads(); + +#if ENABLE_PDL + preExit(); + acqBulk(); +#endif + + constexpr bool qkSwizzle = true; + // load whole Q heads into shared memory +#if SPEC_DEC + if (warpIdx.z == 0) { + // map from idxQHead to idxHead in q input. + auto const localQHeadTokenIdxMap = [nbQHeads, headGrpSize, reqSeqOffset, idxReq, idxHeadTokenInGrp](uint32_t idxHeadTokenLocal) -> uint32_t { + assert(idxHeadTokenLocal < warpTile.y); // may be larger than nbValidRows, then the output does not matter. + if constexpr (beamWidth == 1) { + idxHeadTokenLocal += idxHeadTokenInGrp; + uint32_t const tokenIdx = (idxHeadTokenLocal / headGrpSize); + uint32_t const headIdx = idxHeadTokenLocal % headGrpSize; + return tokenIdx * nbQHeads + headIdx; + } + }; + static_assert(nbValidRows <= warpTile.y); + auto const srcBase = q; + uint32_t const idxHeadTokenBeg = nbQHeads * reqSeqOffset + (idxHeadGrp * headGrpSize); + TinyPtr const src{srcBase, idxHeadTokenBeg}; + + bool const isFullTile = (nbValidHeadTokens == warpTile.y); + static_assert(nbQBuffers == 1); + if (isFullTile) { + copyHeadsAsync( + warpIdx.x, smem.q[warpIdx.y][0], src, nbValidHeadTokens, localQHeadTokenIdxMap); + } else { + copyHeadsAsync( + warpIdx.x, smem.q[warpIdx.y][0], src, nbValidHeadTokens, localQHeadTokenIdxMap); + } + + ldgsts::barArrive(smem.qBarrier[warpIdx.y], true); + } +#else + if (warpIdx.z == 0) { + // map from idxQHead to idxHead in q input. + auto const localQHeadIdxMap = [nbQHeads, idxReq, idxHeadGrp](uint32_t idxHeadLocal) -> uint32_t { + assert(idxHeadLocal < warpTile.y); // may be larger than nbValidRows, then the output does not matter. + if constexpr (beamWidth == 1) { + return idxHeadLocal; + } + uint32_t const idxBeam = idxHeadLocal / headGrpSize; + uint32_t const result = idxHeadLocal + idxBeam * (nbQHeads - headGrpSize); + uint32_t const idxQHeadInGrp = idxHeadLocal % headGrpSize; + uint32_t const ref = nbQHeads * idxBeam + idxQHeadInGrp; + assert(result == ref); + unused(ref); + return result; + }; + static_assert(nbValidRows <= warpTile.y); + auto const srcBase = q; + // NOTE: read from Q buffer directly. + uint32_t const idxHeadBeg = nbQHeads * beamWidth * idxReq + headGrpSize * idxHeadGrp; + TinyPtr const src{srcBase, idxHeadBeg}; + + constexpr bool isFullTile = (nbValidRows == warpTile.y); + static_assert(nbQBuffers == 1); + copyHeadsAsync( + warpIdx.x, smem.q[warpIdx.y][0], src, nbValidRows, localQHeadIdxMap); + ldgsts::barArrive(smem.qBarrier[warpIdx.y], true); + } +#endif + + uint32_t const cacheSeqLen = getCacheSeqLen(cacheList, idxReq); +#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE + uint32_t const tok0SeqLen = cacheSeqLen - actualQSeqLen + 1 + idxHeadTokenInGrp; // ctaTokOffset; + int32_t const tok0WinBeg = int32_t(tok0SeqLen) - int32_t(slidingWinSize); + uint32_t const nbTotalSkipTokens = mha::max(0, tok0WinBeg); + +#elif SLIDING_WINDOW + bool const rtIsReallySliding = (cacheSeqLen > slidingWinSize); + assert(!SPEC_DEC || !rtIsReallySliding); + uint32_t const nbTotalSkipTokens = rtIsReallySliding ? cacheSeqLen - slidingWinSize : 0; +#else + constexpr bool rtIsReallySliding = false; + constexpr uint32_t nbTotalSkipTokens = 0; +#endif + uint32_t const nbSkipLeadingTiles = nbTotalSkipTokens / ctaTile.x; + uint32_t const tile0NbSkipTokens = nbTotalSkipTokens % ctaTile.x; + unused(tile0NbSkipTokens); +#if USE_PAGED_KV_CACHE + uint32_t const nbPages = divUp(cacheSeqLen, tokensPerPage); + constexpr uint32_t nbPagesPerCtaTile = exactDiv(ctaTile.x, tokensPerPage); +#endif + + uint32_t const nbSeqIters = useKVCache ? divUp(cacheSeqLen, ctaTile.x) : 0; +#if SLIDING_WINDOW && SPEC_DEC && !IS_SPEC_DEC_TREE + uint32_t const nbSeqItersWithoutMask = nbSkipLeadingTiles; +#elif SPEC_DEC + uint32_t const nbSeqItersWithoutMask = (cacheSeqLen - actualQSeqLen) / ctaTile.x; +#endif + + uint32_t const seqStrideIters = nbSubSeqPerSeq; + constexpr bool isKVCacheQuantized = (cacheElemSize < 2); + uint32_t const seqIterInit = nbSkipLeadingTiles + idxSubSeqInSeq; +#if BEAM_WIDTH > 1 + uint32_t const nbCtxCtaTiles = beamSearchParams.ctxLenList[idxReq * beamWidth] / ctaTile.x; +#endif + auto isConvergedTile = [&](uint32_t seqIter) { +#if BEAM_WIDTH == 1 + return true; +#else + return seqIter < nbCtxCtaTiles; +#endif + }; + if (warpIdx.z == 0) { + float const qkScale = qScale * (isKVCacheQuantized ? kvCacheScale[0] : 1.f); // qkScale is applied onto Q*K.T before softmax. + CircIdx idxCurrSMemKBuf{nbKBuffers - 1}; + auto const getSMemKTile = [&](uint32_t idx) -> SharedMem::KSmemBuffer& { return smem.k[warpIdx.x][idx]; }; +#if BEAM_WIDTH > 1 + auto loadCacheIndir = [&](uint32_t seqIter, uint32_t idxBeam) mutable { + auto& dst = smem.gemm0CacheIndir[warpIdx.x]; + uint32_t const offset = ctaTile.x * seqIter + warpTile.x * warpIdx.x; + loadIndicesForBeamSearchAsync<1, warpTile.x>( + 0, dst, beamSearchParams, idxReq, idxBeam, offset, cacheSeqLen); + }; + loadCacheIndir(seqIterInit, 0U); +#endif +#if USE_PAGED_KV_CACHE +#if BEAM_WIDTH == 1 + KCachePageIndices pageIdx = KCachePageIndices::filled(kBAD_PAGE_INDEX); +#endif + auto loadPages = [&](uint32_t idxPage) mutable { +#if BEAM_WIDTH == 1 + uint32_t const idxBeam = 0; + pageIdx = getPage(cacheList, true, idxReq, idxBeam, idxPage, nbPages); +#else + auto& dst = smem.kCachePages[warpIdx.x]; + loadPagesForBeamSearchAsync<1>(0U, dst, cacheList, true, idxReq, idxPage, nbPages); +#endif + }; + uint32_t idxPageBeg = nbPagesPerCtaTile * seqIterInit + warpIdx.x * warpTile.x / tokensPerPage; + loadPages(idxPageBeg); +#else + constexpr uint32_t idxBeamBase = 0U; + uint32_t const cacheKBaseBatch = cacheList.capacity * nbKHeads * (idxBeamBase + beamWidth * idxReq); + uint32_t const cacheKSeqBaseOffset = cacheList.isBSNH + ? (cacheKBaseBatch + idxHeadGrp) + : (cacheKBaseBatch + cacheList.capacity * idxHeadGrp); +#endif + auto loadKTilePart = [&](uint32_t seqIter, uint32_t idxBeam, uint32_t idxPart) mutable { + assert(idxBeam < beamWidth); + assert(seqIter % nbSubSeqPerSeq == seqIterInit % nbSubSeqPerSeq); + auto const idxNextSMemKBuf = idxCurrSMemKBuf.next(); + auto& dst = getSMemKTile(idxNextSMemKBuf); + uint32_t const dstHeadOffset = 0; + uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * warpIdx.x; +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp; + +#else + uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage; +#endif +#if BEAM_WIDTH == 1 +#if PAGED_KV_CACHE_LAYOUT == 1 + HeadPtr const src{ + cacheList.kCacheVLLM, pageIdx, nbKHeads, idxHeadBeg}; +#else + HeadPtr const src{ + cacheList.pool, pageIdx, nbKHeads, idxHeadBeg}; +#endif +#else + IndexedHeadPtr const src{ + /*indices=*/smem.gemm0CacheIndir[warpIdx.x].data, +#if PAGED_KV_CACHE_LAYOUT == 1 + /*pool=*/cacheList.kCacheVLLM, +#else + /*pool=*/cacheList.pool, +#endif + /*pageIndices=*/smem.kCachePages[warpIdx.x].data, + /*nbKHeads=*/nbKHeads, + /*offset=*/idxHeadBeg}; +#endif +#else + uint32_t const idxHeadBeg = cacheList.isBSNH + ? (cacheKSeqBaseOffset + seqOffset * nbKHeads) + : (cacheKSeqBaseOffset + seqOffset); +#if BEAM_WIDTH == 1 + TinyPtr const src{cacheList.kData, idxHeadBeg}; +#else + IndexedHeadPtr const src{/*indices=*/smem.gemm0CacheIndir[warpIdx.x].data, + /*pointer=*/cacheList.data, + /*offset=*/idxHeadBeg, + /*beamStride=*/cacheList.capacity * nbKHeads * 2}; + // trap(); + // assert("not implemented"); +#endif +#endif + // if (threadIdx.x == dbgPrintTid) { + // printf("K: seqIter=%u, idxBeam=%u, idxPart=%u: pointers={%p, %p}, indices={", seqIter, idxBeam, + // idxPart, src.pointers[0], src.pointers[1]); uint32_t const nbHeadsAvail = mha::min((seqOffset < + // cacheSeqLen ? cacheSeqLen - seqOffset : 0U), warpTile.x); for (int i = 0; i < nbHeadsAvail; i++) { + // printf("%u, ", src.indices[i]); + // } + // printf("}\n"); + // } + bool const isFullTile = (seqIter + 1 < nbSeqIters); + if (isFullTile) { + copyPartialHeadsAsync( + warp, dst, dstHeadOffset, src, idxPart); + } else { + uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset + : 0U); // may also be full but it can be handled correctly anyway + copyPartialHeadsAsync( + warp, dst, dstHeadOffset, src, idxPart, nbHeadsAvail); + } +#if BEAM_WIDTH > 1 + // to make sure all threads has finished usage of cache indir and pages + __syncwarp(); +#endif + if (idxPart + 1 == nbPartsPerCacheKHead) { +#if USE_PAGED_KV_CACHE + bool const isForNextSeqIter = isConvergedTile(seqIter) || idxBeam == beamWidth - 1; + if (isForNextSeqIter) { + idxPageBeg += nbPagesPerCtaTile * nbSubSeqPerSeq; + loadPages(idxPageBeg); + } +#endif +#if BEAM_WIDTH > 1 + uint32_t idxBeamNext, seqIterDelta; + mha::tie(idxBeamNext, seqIterDelta) = isConvergedTile(seqIter) + ? mha::tuple(0U, 1U) + : carryLE(idxBeam + 1, 0); // optimize for context cache + loadCacheIndir(seqIter + seqStrideIters * seqIterDelta, idxBeamNext); +#endif + } + }; + +#if BEAM_WIDTH > 1 + ldgsts::commitGroup(); + ldgsts::waitGroup<0>(); + __syncwarp(); +#endif + loadKTilePart(seqIterInit, 0, 0); + ldgsts::commitGroup(); // @fixme: do prefetch for next iter tile if last part + idxCurrSMemKBuf++; + + auto& xBar = smem.xBarriers[warpIdx.y][warpIdx.x]; + bool xBarConsumedParityNext = false; + + bool qBarParityNext = false; + auto& qBar = smem.qBarrier[warpIdx.y]; + qBar.wait_parity(qBarParityNext); + qBarParityNext = !qBarParityNext; + constexpr bool reorderForKCache = (useKVCache && inputElemSize == 2 && cacheElemSize == 1); + if constexpr (reorderForKCache) { + reorder16bQHeadsToMatch8bKCache(warpIdx.x, smem.q[warpIdx.y][0]); + unused(qBar.arrive()); + qBar.wait_parity(qBarParityNext); + qBarParityNext = !qBarParityNext; + assertWarpConverged(); + } +#if CTA_ROW_MAX_BACKWARD_METHOD == 2 + ThrdRegRowMax initRowMax; + initRowMax.fill(safeInitRowMax); +#endif + for (uint32_t seqIter = seqIterInit; seqIter < nbSeqIters; seqIter += seqStrideIters) { +#if SHORT_SEQ_OPT + if (ctaTile.x * seqIter + warpTile.x * warpIdx.x >= cacheSeqLen) { + break; + } +#endif + auto runGemm0 = [&](auto elemK, uint32_t idxBeam) { + assert(idxBeam < (isConvergedTile(seqIter) ? 1U : beamWidth)); + using KElemType = mha::decay_t; + constexpr uint32_t elemsPerKHeadPart = exactDiv(kHeadPartBytes, sizeof(KElemType)); + constexpr uint32_t nbPartsPerKHead = exactDiv(headElems, elemsPerKHeadPart); + // the accumulator + WarpAcc acc{}; + constexpr uint32_t nbUnroll = (cacheElemSize == 2 ? nbPartsPerKHead : 1); +#pragma unroll(nbUnroll) + for (uint32_t p = 0; p < nbPartsPerKHead; p++) { + constexpr bool syncKTileEarly = (beamWidth > 1); // alternative is to use double buffer for cacheIndir and pages + if constexpr (syncKTileEarly) { + // synchronize gemm0CacheIndir for the next loadKTilePart. the last loaded K tile is also + // sync'ed at the same time. + ldgsts::waitGroup<0>(); + __syncwarp(); + } + // prefetch next part into shared memory + uint32_t idxPartNext, idxBeamNext, nNextBias; + mha::tie(idxPartNext, idxBeamNext, nNextBias) = isConvergedTile(seqIter) + ? carryLE(p + 1, idxBeam, 0U) + : carryLE(p + 1, idxBeam, 0U); + + loadKTilePart(seqIter + seqStrideIters * nNextBias, idxBeamNext, idxPartNext); + ldgsts::commitGroup(); + // @fixme: do L2 cache prefetch for next iter tile if last part + + // q is already synchronized + if constexpr (!syncKTileEarly) { + // synchronize k + ldgsts::waitGroup<1>(); + } + SharedMem::QSmemBuffer const& smemQ = smem.q[warpIdx.y][0]; + constexpr uint32_t qOffsetPerPart = exactDiv(elemsPerKHeadPart, inputElemsPerGrain); + uint32_t const smemQOffset = qOffsetPerPart * p; + SharedMem::KSmemBuffer const& smemKPart = getSMemKTile(idxCurrSMemKBuf); + // #ifndef NDEGBUG + // for (uint32_t i = 0; i < exactDiv(smemKPart.rows * smemKPart.cols, + // warp_size); i++) { + // uint32_t const idx = warp_size * i + laneId(); + // uint32_t const r = idx / smemKPart.cols; + // uint32_t const c = idx % smemKPart.cols; + + // assert(smemKPart(r, c) == ); + // } + // #endif + // do computation. + smemQKPartGemm(warp, acc, smemQ, smemQOffset, smemKPart); + idxCurrSMemKBuf++; + } + return acc; + }; + WarpAcc acc; + //@fixme: alternative is to use separate inner loop, which results in larger but maybe faster code. + for (uint32_t idxBeam = 0; idxBeam < (isConvergedTile(seqIter) ? 1U : beamWidth); idxBeam++) { + WarpAcc tmp; + if constexpr (mha::is_same_v) { + tmp = runGemm0(CacheElem{}, idxBeam); + } else { + tmp = runGemm0(CacheElem{}, idxBeam); + } + pickAccRowsForBeamSearch( + warp, acc, tmp, isConvergedTile(seqIter), idxBeam, [](float& d, float s) { d = s; }); + } + // apply qkScale + rescaleAcc(warp, acc, qkScale); +#if CTA_ROW_MAX_BACKWARD_METHOD == 0 + QuadRegRowMax initRowMaxQuad; + initRowMaxQuad.fill(safeInitRowMax); +#elif CTA_ROW_MAX_BACKWARD_METHOD == 1 + // load hint + xBar.consumed.wait_parity(getAndFlip(xBarConsumedParityNext)); + QuadRegRowMax initRowMaxQuad = smem.ctaRowMax[warpIdx.y][warpIdx.x].loadToRegForQuad(warp); +#elif CTA_ROW_MAX_BACKWARD_METHOD == 2 + QuadRegRowMax initRowMaxQuad = replicateForQuad(warp, initRowMax); +#elif CTA_ROW_MAX_BACKWARD_METHOD == 3 + // load hint + smem.ctaRowMaxBwdBarriers[warpIdx.y][warpIdx.x].wait_parity(xBarConsumedParityNext); + QuadRegRowMax initRowMaxQuad = smem.ctaRowMax[warpIdx.y][warpIdx.x].loadToRegForQuad(warp); +#elif CTA_ROW_MAX_BACKWARD_METHOD == 4 + // load hint + QuadRegRowMax initRowMaxQuad = smem.ctaRowMax[warpIdx.y].loadToRegForQuad(warp); +#endif + // masking + uint32_t const warpTileTokenBeg = ctaTile.x * seqIter + warpTile.x * warpIdx.x; +#if SPEC_DEC + if (seqIter >= nbSeqItersWithoutMask) { + uint32_t const nbValidCols = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U); + applyMaskFromInput(warp, acc, mask, idxHeadTokenInGrp, nbValidCols, qSeqLen, actualQSeqLen, headGrpSize +#if SLIDING_WINDOW && !IS_SPEC_DEC_TREE + , + tok0WinBeg, seqIter, cacheSeqLen, warpTileTokenBeg +#endif + ); + } +#else + bool const isFirstIter = (seqIter == nbSkipLeadingTiles); + bool const needMaskLeading = (rtIsReallySliding && isFirstIter); + bool const isLastIter = (seqIter + 1 == nbSeqIters); + bool const needMaskTrailing = isLastIter && cacheSeqLen % ctaTile.x != 0; + if (needMaskLeading || needMaskTrailing) { + uint32_t const validTokenBeg = (!needMaskLeading || nbTotalSkipTokens < warpTileTokenBeg) + ? 0 + : nbTotalSkipTokens - warpTileTokenBeg; + uint32_t const validTokenEnd = (warpTileTokenBeg < cacheSeqLen ? cacheSeqLen - warpTileTokenBeg : 0U); + if (validTokenBeg > 0 || validTokenEnd < warpTile.x) { + applyMask(warp, acc, validTokenBeg, validTokenEnd); + } + } +#endif + + // find max and update acc into exp(acc-max). + QuadRegRowMax const regRowMax = warpTileOnlineSoftmax(warp, initRowMaxQuad, acc); + + // store result and max to shared memory. + GemmOutRegTile const fp16Acc = toFp16(acc); + QuadRegRowMax const regRowSum = computeRowSum(warp, fp16Acc); +#if CTA_ROW_MAX_BACKWARD_METHOD != 1 + xBar.consumed.wait_parity(getAndFlip(xBarConsumedParityNext)); +#if CTA_ROW_MAX_BACKWARD_METHOD == 2 + initRowMax = smem.ctaRowMax[warpIdx.y][warpIdx.x].loadToReg(warp); +#endif +#endif + storeOrderedGemmOutTile(warp, smem.x[warpIdx.y][warpIdx.x], fp16Acc); + smem.warpRowMax[warpIdx.y][warpIdx.x].storeFromReg(warp, regRowMax); + smem.warpRowSum[warpIdx.y][warpIdx.x].storeFromReg(warp, regRowSum); + unused(xBar.produced.arrive()); + } + } else { + assert(warpIdx.z == 1); +#if CTA_ROW_MAX_BACKWARD_METHOD == 3 + unused(smem.ctaRowMaxBwdBarriers[warpIdx.y][warpIdx.x].arrive()); +#endif + uint32_t const warpIdxInGrp = gemm1WarpIdxInGrp(warpIdx.x); // @fixme: use BoundedVal + uint32_t const warpGrpIdx = gemm1WarpGrpIdx(warpIdx.x); // @fixme: use BoundedVal + auto* const pWarpGrpBar = smem.warpGrpBar(warpGrpIdx); + ParityOrNone warpGrpBarParityNext{}; +#if BEAM_WIDTH > 1 + auto loadCacheIndir = [&](uint32_t seqIter, uint32_t xIter, uint32_t vIter, uint32_t idxBeam) mutable { + uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter + cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx; + auto& dst = smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x]; + loadIndicesForBeamSearchAsync( + grpLoadV ? warpIdxInGrp : 0U, dst, beamSearchParams, idxReq, idxBeam, seqOffset, cacheSeqLen); + }; + loadCacheIndir(seqIterInit, 0, 0, 0); +#endif + unused(smem.xBarriers[warpIdx.y][warpIdx.x].consumed.arrive(gemm1WarpsPerGrp * nbWarpGrpsPerXTile)); + CircIdx idxCurrSMemVBuf{nbVBuffers - 1}; + auto const getSmemVTile = [&](uint32_t idx) -> SharedMem::VSmemBuffer& { return smem.v[warpGrpIdx][grpLoadV ? 0 : warpIdxInGrp][idx]; }; + auto const getSmemVBar = [&](uint32_t idx) -> SharedMem::Barrier* { return smem.vBarrier(warpGrpIdx, idx); }; +#if USE_PAGED_KV_CACHE +#if BEAM_WIDTH == 1 + VCachePageIndices pageIdx = VCachePageIndices::filled(kBAD_PAGE_INDEX); +#endif + auto loadPages = [&](uint32_t idxPageBeg) mutable { +#if BEAM_WIDTH == 1 + uint32_t const idxBeam = 0; + pageIdx = getPage(cacheList, false, idxReq, idxBeam, idxPageBeg, nbPages); +#else + auto& dst = smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x]; + loadPagesForBeamSearchAsync( + grpLoadV ? warpIdxInGrp : 0U, dst, cacheList, false, idxReq, idxPageBeg, nbPages); +#endif + }; + uint32_t idxPageBeg = nbPagesPerCtaTile * seqIterInit + cacheVTileSeqLen * warpGrpIdx / tokensPerPage; + loadPages(idxPageBeg); +#else + uint32_t const idxBeamBase = 0; + uint32_t const cacheVBaseBatch = cacheList.capacity * nbKHeads * (idxBeamBase + beamWidth * idxReq); + uint32_t const cacheVSeqBaseOffset = cacheList.isBSNH + ? (cacheVBaseBatch + idxHeadGrp) + : (cacheVBaseBatch + cacheList.capacity * idxHeadGrp); +#endif + auto nextStep = [&](uint32_t seqIter, uint32_t xIter, uint32_t vIter, uint32_t idxBeam) { + uint32_t vIterNext, isNextBeam; + mha::tie(vIterNext, isNextBeam) = carryLE(vIter + 1, 0); + + uint32_t idxBeamNext, xIterNext, nNextBias; + mha::tie(idxBeamNext, xIterNext, nNextBias) = isConvergedTile(seqIter) + ? carryLE<1, nbXItersPerCtaTile>(idxBeam + isNextBeam, xIter, 0) + : carryLE(idxBeam + isNextBeam, xIter, 0); + + uint32_t const seqIterNext = seqIter + seqStrideIters * nNextBias; + return mha::tuple(seqIterNext, xIterNext, vIterNext, idxBeamNext); + }; + auto loadVTilePart = [&](uint32_t seqIter, uint32_t xIter, uint32_t vIter, + uint32_t idxBeam) mutable { // @fixme: merge three iteration parameters into idxVTileGlb. + assert(idxBeam < beamWidth); + assert(seqIter % nbSubSeqPerSeq == seqIterInit % nbSubSeqPerSeq); + auto const idxNextSMemVBuf = idxCurrSMemVBuf.next(); + auto& dst = getSmemVTile(idxNextSMemVBuf); + uint32_t const dstHeadOffset = 0; + constexpr bool vSwizzle = true; + + uint32_t const seqOffset = ctaTile.x * seqIter + warpTile.x * nbXTilesPerXIter * xIter + cacheVTileSeqStride * vIter + cacheVTileSeqLen * warpGrpIdx; +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + uint32_t const idxHeadBeg = (seqOffset % tokensPerPage) * nbKHeads + idxHeadGrp; + +#else + uint32_t const idxHeadBeg = tokensPerPage * idxHeadGrp + seqOffset % tokensPerPage; +#endif +#if BEAM_WIDTH == 1 +#if PAGED_KV_CACHE_LAYOUT == 1 + HeadPtr const src{ + cacheList.vCacheVLLM, pageIdx, nbKHeads, idxHeadBeg}; +#else + HeadPtr const src{ + cacheList.pool, pageIdx, nbKHeads, idxHeadBeg}; +#endif +#else + IndexedHeadPtr const src{ + /*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data, +#if PAGED_KV_CACHE_LAYOUT == 1 + /*pool=*/cacheList.vCacheVLLM, +#else + /*pool=*/cacheList.pool, +#endif + /*pageIndices=*/smem.vCachePages[grpLoadV ? warpGrpIdx : warpIdx.x].data, + /*nbKHeads=*/nbKHeads, + /*offset=*/idxHeadBeg}; +#endif +#else + uint32_t const idxHeadBeg = cacheList.isBSNH + ? (cacheVSeqBaseOffset + seqOffset * nbKHeads) + : (cacheVSeqBaseOffset + seqOffset); +#if BEAM_WIDTH == 1 + TinyPtr const src{cacheList.vData, idxHeadBeg}; +#else + IndexedHeadPtr const src{ + /*indices=*/smem.gemm1CacheIndir[grpLoadV ? warpGrpIdx : warpIdx.x].data, + /*pointer=*/cacheList.data, + /*offset=*/idxHeadBeg, + /*beamStride=*/cacheList.capacity * nbKHeads * 2}; +#endif +#endif + // if (threadIdx.x == dbgPrintTid) { + // printf("V: seqIter=%u, xIter=%u, idxBeam=%u, vIter=%u: pointers={%p, %p}, indices={", seqIter, xIter, + // idxBeam, vIter, src.pointers[0], src.pointers[1]); uint32_t const nbHeadsAvail = mha::min((seqOffset + // < cacheSeqLen ? cacheSeqLen - seqOffset : 0U), cacheVTileSeqLen); for (int i = 0; i < nbHeadsAvail; + // i++) { + // printf("%u, ", src.indices[i]); + // } + // printf("}\n"); + // } + +#if GRP_LOAD_V + uint32_t const nbHeadsAvail = (seqIter + 1 < nbSeqIters) + ? cacheVTileSeqLen + : (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset + : 0U); // may also be full but it can be handled correctly anyway + copyHeadsAsync( + warpIdxInGrp, dst, src, nbHeadsAvail); +#else + uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset + : 0U); // may also be full but it can be handled correctly anyway + unused(nbHeadsAvail); + bool const isFullTile = (seqIter + 1 < nbSeqIters); + if (isFullTile) { + copyPartialHeadsAsync( + warp, dst, dstHeadOffset, src, warpIdxInGrp); + } else { + uint32_t const nbHeadsAvail = (seqOffset < cacheSeqLen ? cacheSeqLen - seqOffset + : 0U); // may also be full but it can be handled correctly anyway + copyPartialHeadsAsync( + warp, dst, dstHeadOffset, src, warpIdxInGrp, mha::min(nbHeadsAvail, cacheVTileSeqLen)); + } +#endif + +#if BEAM_WIDTH > 1 + // to make sure all threads has finished usage of cache indir and pages + unused(arrive(pWarpGrpBar)); + wait_parity(pWarpGrpBar, getAndFlip(warpGrpBarParityNext)); +#endif +#if USE_PAGED_KV_CACHE + constexpr uint32_t xIterSeqStride = cacheVTileSeqStride * nbVItersPerXIter; + if constexpr (xIterSeqStride <= tokensPerPage) { + uint32_t const nbXItersPerPage = exactDiv(tokensPerPage, xIterSeqStride); + assert(nbXItersPerPage <= nbXItersPerCtaTile); + if (xIter % nbXItersPerPage == nbXItersPerPage - 1 && vIter == nbVItersPerXIter - 1 && (idxBeam == beamWidth - 1 || isConvergedTile(seqIter))) { + auto const step = 1; // cacheVTileSeqLen * gemm1NbWarpGrps / tokensPerPage; + idxPageBeg += (idxPageBeg % nbPagesPerCtaTile == nbPagesPerCtaTile - 1 + ? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step + : step); + assert(beamWidth == 1 || cacheVTileSeqStride <= tokensPerPage && "todo: need to substrate from idxPageBeg for beam switching"); + loadPages(idxPageBeg); + } + } else { + assert(nbVItersPerXIter == 1); + if ((idxBeam == beamWidth - 1 || isConvergedTile(seqIter)) && vIter == nbVItersPerXIter - 1) { + auto const step = exactDiv(xIterSeqStride, tokensPerPage); + idxPageBeg += (idxPageBeg % nbPagesPerCtaTile + step >= nbPagesPerCtaTile + ? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step + : step); + loadPages(idxPageBeg); + } + } +#endif +#if BEAM_WIDTH > 1 + uint32_t seqIterNext, xIterNext, vIterNext, idxBeamNext; + mha::tie(seqIterNext, xIterNext, vIterNext, idxBeamNext) = nextStep(seqIter, xIter, vIter, idxBeam); + loadCacheIndir(seqIterNext, xIterNext, vIterNext, idxBeamNext); +#endif + }; + auto commitVTileLoad = [&](uint32_t idxVBar) { +#if GRP_LOAD_V + auto& bar = *getSmemVBar(idxVBar); + ldgsts::barArrive(bar, true); +#else + ldgsts::commitGroup(); +#endif + }; + auto syncVTileLoad = [&](uint32_t idxVBar, ParityOrNone parity, bool alreadyComplete) { +#if GRP_LOAD_V + if (alreadyComplete) { + return; + } + SharedMem::Barrier& bar = *getSmemVBar(idxVBar); + bar.wait_parity(parity); +#else + assert(!alreadyComplete); + ldgsts::waitGroup(); +#endif + }; + auto testVTileLoad = [&](uint32_t idxVBar, ParityOrNone parity) { return test_wait_parity(getSmemVBar(idxVBar), parity); }; + +#if BEAM_WIDTH > 1 + // synchronize first page/cacheIndir loading to shared memory + ldgsts::commitGroup(); + ldgsts::waitGroup<0>(); + unused(arrive(pWarpGrpBar)); + wait_parity(pWarpGrpBar, getAndFlip(warpGrpBarParityNext)); +#endif + + loadVTilePart(seqIterInit, 0, 0, 0); + commitVTileLoad(idxCurrSMemVBuf.next()); + idxCurrSMemVBuf++; + ParityOrNone vBarParity{}; + // @fixme: do prefetch for next iter tile if last part + + ThrdRegRowMax globalRowMax; + globalRowMax.fill(SAFE_INIT_ROW_MAX); + ThrdRegRowMax globalRowSum; + globalRowSum.fill(0); + // the accumulator + WarpAcc acc{}; + if (grpLoadV) { + unused(pWarpGrpBar->arrive()); + } + bool xBarProducedParityNext = false; + for (uint32_t seqIter = seqIterInit; seqIter < nbSeqIters; seqIter += seqStrideIters) { +#pragma unroll + for (uint32_t xIter = 0; xIter < nbXItersPerCtaTile; xIter++) { + uint32_t const idxXTile = xIter * nbXTilesPerXIter + warpGrpIdx / nbCacheVTilesPerXTile; + assert(idxXTile < ctaShapeInWarps.x); +#if SHORT_SEQ_OPT + if (ctaTile.x * seqIter + warpTile.x * idxXTile >= cacheSeqLen) { + break; + } +#endif + auto const& smemXTile = smem.x[warpIdx.y][idxXTile]; + auto& xBar = smem.xBarriers[warpIdx.y][idxXTile]; + ThrdRegRowMax xRowScales; + UniformRescaleMask xRowNeedRescaleMask; // expect storage in UR + bool skipXRowRescale; + for (uint32_t idxBeam = 0; idxBeam < (isConvergedTile(seqIter) ? 1U : beamWidth); idxBeam++) { +#pragma unroll + for (uint32_t vIter = 0; vIter < nbVItersPerXIter; vIter++) { + bool const vTestConsumed = test_wait_parity(pWarpGrpBar, warpGrpBarParityNext); + constexpr bool syncVTileEarly = (beamWidth > 1); // alternative is to use double buffer for cacheIndir and pages + bool vTestProduced = syncVTileEarly && testVTileLoad(idxCurrSMemVBuf, vBarParity); + auto isLastVBuf = [&] { return (idxCurrSMemVBuf == idxCurrSMemVBuf.nbBuffers - 1); }; + unused(isLastVBuf); + uint32_t const idxVTileInsideXIter = gemm1NbWarpGrps * vIter + warpGrpIdx; + uint32_t const idxVTile = idxVTileInsideXIter % nbCacheVTilesPerXTile; // inside XTile. + assert(idxVTile < nbCacheVTilesPerXTile); + uint32_t nNext, xIterNext, vIterNext, idxBeamNext; + mha::tie(nNext, xIterNext, vIterNext, idxBeamNext) = nextStep(seqIter, xIter, vIter, idxBeam); + if constexpr (syncVTileEarly) { + // sync early to make sure that cacheIndir and pages has been loaded. The last loaded V tile + // is also sync'ed at the same time. + syncVTileLoad(idxCurrSMemVBuf, vBarParity, vTestProduced); + if (idxCurrSMemVBuf == idxCurrSMemVBuf.nbBuffers - 1) { + flip(vBarParity); + } + } + if (!vTestConsumed) { + wait_parity(pWarpGrpBar, warpGrpBarParityNext); + } + flip(warpGrpBarParityNext); + loadVTilePart(nNext, xIterNext, vIterNext, idxBeamNext); + commitVTileLoad(idxCurrSMemVBuf.next()); + // @fixme: do L2 cache prefetch for next iter tile + + if constexpr (!syncVTileEarly) { + vTestProduced = testVTileLoad(idxCurrSMemVBuf, vBarParity); + } + + if (idxBeam == 0 && vIter == 0) { + xBar.produced.wait_parity(xBarProducedParityNext); + auto const& smemRowMax = smem.warpRowMax[warpIdx.y][idxXTile]; + auto const& smemRowSum = smem.warpRowSum[warpIdx.y][idxXTile]; + // update globalRowMax + ThrdRegRowMax xTileRowMax; + ThrdRegRowMax xTileRowSum; + UniformRescaleMask needRescaleMask; +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + xTileRowMax[i] = smemRowMax[warp_size * i + laneId()]; + xTileRowSum[i] = smemRowSum[warp_size * i + laneId()]; + assert(__ballot_sync(~0U, laneId() == 0) == 1U); + assert(__ballot_sync(~0U, laneId() == 0) == 1U); + needRescaleMask[i] = __ballot_sync(~0U, xTileRowMax[i] != globalRowMax[i]); + } + bool const skipAllRescale = !any(needRescaleMask); + if (skipAllRescale) { + skipXRowRescale = true; +#if CTA_ROW_MAX_BACKWARD_METHOD == 3 + if (idxXTile == warpIdx.x) { + unused(smem.ctaRowMaxBwdBarriers[warpIdx.y][warpIdx.x].arrive()); + } +#endif + } else { + ThrdRegRowMax const globalRowMaxOld = globalRowMax; + UniformRescaleMask accRowNeedRescaleMask; +#pragma unroll + for (uint32_t i = 0; i < ThrdRegRowMax::size; i++) { + accRowNeedRescaleMask[i] = __ballot_sync(~0U, xTileRowMax[i] > globalRowMaxOld[i]); + xRowNeedRescaleMask[i] = (needRescaleMask[i] & ~accRowNeedRescaleMask[i]); + assert(xRowNeedRescaleMask[i] == __ballot_sync(~0U, xTileRowMax[i] < globalRowMaxOld[i])); + globalRowMax[i] = fmaxf(globalRowMaxOld[i], xTileRowMax[i]); + } + skipXRowRescale = !any(xRowNeedRescaleMask); + +#if CTA_ROW_MAX_BACKWARD_METHOD == 1 || CTA_ROW_MAX_BACKWARD_METHOD == 2 || CTA_ROW_MAX_BACKWARD_METHOD == 3 + // update smem.ctaRowMax. + if (idxXTile == warpIdx.x) { + smem.ctaRowMax[warpIdx.y][warpIdx.x].storeFromReg(warp, globalRowMax); +#if CTA_ROW_MAX_BACKWARD_METHOD == 3 + unused(smem.ctaRowMaxBwdBarriers[warpIdx.y][warpIdx.x].arrive()); +#endif + } +#elif CTA_ROW_MAX_BACKWARD_METHOD == 4 + // update smem.ctaRowMax. + // smem.ctaRowMax[warpIdx.y].storeFromReg(warp, globalRowMax); + smem.ctaRowMax[warpIdx.y].atomicMaxUpdate(warp, globalRowMax); +#endif + // update row sum and acc + if (!enableMicroFastPath || any(accRowNeedRescaleMask)) { + ThrdRegRowMax const accRowScales = expf(globalRowMaxOld - globalRowMax); + globalRowSum = globalRowSum * accRowScales; + // @fixme: when tmpAcc is used, this can be delayed. + rescaleAcc(warp, acc, accRowNeedRescaleMask, accRowScales); + } + if (!enableMicroFastPath || !skipXRowRescale) { + xRowScales = skipXRowRescale ? xRowScales : expf(xTileRowMax - globalRowMax); + xTileRowSum = skipXRowRescale ? xTileRowSum : xTileRowSum * xRowScales; + } + } + globalRowSum = globalRowSum + xTileRowSum; + } + if constexpr (!syncVTileEarly) { + syncVTileLoad(idxCurrSMemVBuf, vBarParity, vTestProduced); + if (idxCurrSMemVBuf == idxCurrSMemVBuf.nbBuffers - 1) { + flip(vBarParity); + } + } + auto const& smemVTile = getSmemVTile(idxCurrSMemVBuf); + // do computation from shared memory X and V tiles +#if BEAM_WIDTH == 1 + smemXVPartGemm(warp, acc, skipXRowRescale, xRowNeedRescaleMask, xRowScales, + smemXTile, idxVTile, smemVTile, grpLoadV ? warpIdxInGrp : 0); +#else + WarpAcc tmpAcc{}; + smemXVPartGemm(warp, tmpAcc, skipXRowRescale, xRowNeedRescaleMask, xRowScales, + smemXTile, idxVTile, smemVTile, grpLoadV ? warpIdxInGrp : 0); + pickAccRowsForBeamSearch( + warp, acc, tmpAcc, isConvergedTile(seqIter), idxBeam, [](float& d, float s) { d += s; }); +#endif + if (grpLoadV) { + unused(pWarpGrpBar->arrive()); + } + idxCurrSMemVBuf++; + } + } // idxBeam + xBar.consumed.arrive(); + } // xIter + flip(xBarProducedParityNext); + } // seqIter + + auto const fullRescaleMask = UniformRescaleMask::filled(~0U); + + constexpr bool needMergeGlobal = (gemm1NbWarpGrps > 1 && nbXTilesPerXIter > 1); + if constexpr (needMergeGlobal) { + assert(gemm1NbWarpGrps != 1); + __syncthreads(); + smem.warpRowMax[warpIdx.y][warpIdx.x].template storeFromReg(warp, globalRowMax); + smem.warpRowSum[warpIdx.y][warpIdx.x].template storeFromReg(warp, globalRowSum); + __syncthreads(); + for (uint32_t i = 1; i < nbXTilesPerXIter; i++) { // i = 0 is for self and we can skip + static_assert(nbXTilesPerXIter * nbWarpGrpsPerXTile == gemm1NbWarpGrps); + uint32_t const otherWarpGrpIdx = (warpGrpIdx + nbWarpGrpsPerXTile * i) % gemm1NbWarpGrps; + uint32_t const otherWarpIdx = warpIdxInGrp + gemm1WarpsPerGrp * otherWarpGrpIdx; +#ifndef NDEBUG + { + auto const v1 = smem.warpRowMax[warpIdx.y][otherWarpIdx].template loadToReg(warp); + auto const v2 = smem.warpRowMax[warpIdx.y][otherWarpIdx - warpIdxInGrp].template loadToReg(warp); +#pragma unroll + for (uint32_t k = 0; k < ThrdRegRowMax::size; k++) { + assert(__float_as_int(v1[k]) == __float_as_int(v2[k])); + } + } +#endif + auto const otherRowMax = smem.warpRowMax[warpIdx.y][otherWarpIdx].template loadToReg(warp); + auto const otherRowSum = smem.warpRowSum[warpIdx.y][otherWarpIdx].template loadToReg(warp); + auto const globalRowMaxNew = fmaxf(globalRowMax, otherRowMax); + auto const scaleForThis = expf(globalRowMax - globalRowMaxNew); + auto const scaleForOther = expf(otherRowMax - globalRowMaxNew); + rescaleAcc(warp, acc, fullRescaleMask, scaleForThis); + globalRowSum = globalRowSum * scaleForThis + otherRowSum * scaleForOther; + globalRowMax = globalRowMaxNew; + } + } + + float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F); + if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. + // The attention sinks are moved to the multi-block reduction part if the multi-block is enabled. + if (!isMultiBlock && attentionSinks != nullptr) { + // Attention sinks are per head. + addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } + ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); +#if LOW_PREC_OUTPUT + voScale *= rcpOutScale[0]; +#endif + rescaleAcc(warp, acc, fullRescaleMask, rcpRowSum * ThrdRegRowMax::filled(voScale)); + } + GemmOutRegTile const outTile = toFp16(acc); + + auto mergeAndSaveOutTile = [&](GemmOutRegTile const& tile, bool reorder) { + if constexpr (gemm1NbWarpGrps == 1) { + // swizzle in shared memory and write output global memory + auto& outSwizzleBuffer = smem.x[warpIdx.y][warpIdx.x]; + __syncthreads(); + storeGemmOutTile(warp, outSwizzleBuffer, tile, reorder); + __syncwarp(); + return &outSwizzleBuffer; + } else { + __syncthreads(); + // store to shared memory, then merge groups. + using PostProcSMem = SharedMem::XSmemBuffer[ctaShapeInWarps.y][gemm1WarpsPerGrp][gemm1NbWarpGrps]; + static_assert(sizeof(PostProcSMem) <= smemSize); + SharedMem::XSmemBuffer(&postSMem)[gemm1NbWarpGrps] = reinterpret_cast(smem)[warpIdx.y][warpIdxInGrp]; + storeGemmOutTile(warp, postSMem[warpGrpIdx], tile, reorder); + __syncthreads(); + smemFp16ArraySum(warpGrpIdx, postSMem[0], postSMem); + __syncthreads(); + return &postSMem[0]; + } + }; + + // merge results from different warp groups + SharedMem::XSmemBuffer* smemOutTile = mergeAndSaveOutTile(outTile, inputElemSize == 2 && cacheElemSize == 1); + if (isMultiBlock) { + static_assert(ctaShapeInWarps.y == 1, "not implemented"); +#if SPEC_DEC + // Includes both kHeads and qTokens. + uint32_t const nbIndepHeadTokens = gridDim.y; + uint32_t const indepHeadTokenIdx = blockIdx.y; + uint32_t const nbSeq = nbIndepHeadTokens * batchSize; +#else + uint32_t const nbSeq = nbKHeads * batchSize; +#endif + uint32_t const nbSubSeq = nbSubSeqPerSeq * nbSeq; + MemSegmenter segmenter{scratch}; + +#if SPEC_DEC + uint32_t const idxSeq = nbIndepHeadTokens * idxReq + indepHeadTokenIdx; +#else + uint32_t const idxSeq = nbKHeads * idxReq + idxHeadGrp; +#endif + uint32_t const idxBufBase = nbSubSeqPerSeq * idxSeq; + uint32_t const idxBuf = idxBufBase + idxSubSeqInSeq; + // copy row max/sum + TinyPtr const rowMaxBuffers = segmenter.newSeg(nbSubSeq); + TinyPtr const rowSumBuffers = segmenter.newSeg(nbSubSeq); + if (warpGrpIdx == 0 && warpIdxInGrp == 0) { + rowMaxBuffers[idxBuf].storeFromReg(warp, globalRowMax); + rowSumBuffers[idxBuf].storeFromReg(warp, globalRowSum); + } + using ScratchBuf = Array2D; + TinyPtr> const scratchBuffers = segmenter.newSeg>(nbSubSeq); + // copy output to scratch + copyGrains( + warpGrpIdx, &scratchBuffers[idxBuf][warpIdxInGrp](0, 0), &(*smemOutTile)(0, 0)); + __syncthreads(); + constexpr uint32_t nbTileBuffers = 2; + + struct MultiBlockSMem { + bool isLastCta; + + struct MBBuf { + SMemWarpRowMax rowMax; + SMemWarpRowMax rowSum; + SharedMem::XSmemBuffer tiles[gemm1NbWarpGrps][gemm1WarpsPerGrp][nbTileBuffers]; + SMemWarpRowMax tileRowMax[gemm1NbWarpGrps][gemm1WarpsPerGrp][nbTileBuffers]; + SMemWarpRowMax tileRowSums[gemm1NbWarpGrps][gemm1WarpsPerGrp][nbTileBuffers]; + SMemWarpRowMax mergedRowSum[gemm1NbWarpGrps]; + }; + + MBBuf storage[ctaShapeInWarps.y]; + }; + + static_assert(sizeof(MultiBlockSMem) <= smemSize); + MultiBlockSMem& mbsmem = reinterpret_cast(smem); + // increase the semaphore by 1 + if (warpIdx.y == 0 && warpGrpIdx == 0 && warpIdxInGrp == 0 && laneId() == 0) { + uint32_t old; + uint32_t const lastOld = nbSubSeqPerSeq - 1; + asm volatile("atom.acq_rel.gpu.global.inc.u32 %0, [%1], %2;\n" + : "=r"(old) + : "l"(&semaphores[idxSeq]), "r"(lastOld)); + assert(old < nbSubSeqPerSeq); + mbsmem.isLastCta = (old == lastOld); + } + __syncthreads(); + + // merge if we are the last CTA. + bool const isLastCta = mbsmem.isLastCta; + if (isLastCta) { + MultiBlockSMem::MBBuf& mbbuf = mbsmem.storage[warpIdx.y]; + SMemWarpRowMax& smemRowMax = reinterpret_cast(smem); + // get row max. + if (warpIdx.x == 0) { + ThrdRegRowMax const mergedRowMax = mergeRowMax<8>(warp, rowMaxBuffers + idxBufBase, nbSubSeqPerSeq); + smemRowMax.storeFromReg(warp, mergedRowMax); + } + __syncthreads(); + ThrdRegRowMax const mergedRowMax = smemRowMax.loadToReg(warp); + + // rescale and accumulate + auto getTileBuf = [&](auto& buffers, uint32_t d) -> decltype(buffers[0][0][0])& { return buffers[warpGrpIdx][warpIdxInGrp][d]; }; + auto loadBufAsync = [&](uint32_t n) { + uint32_t const d = n / gemm1NbWarpGrps % nbTileBuffers; + SharedMem::XSmemBuffer& dstTile = getTileBuf(mbbuf.tiles, d); + SMemWarpRowMax& dstRowSum = getTileBuf(mbbuf.tileRowSums, d); + SMemWarpRowMax& dstRowMax = getTileBuf(mbbuf.tileRowMax, d); + copyGrains( + 0, &dstTile(0, 0), &scratchBuffers[idxBufBase + n][warpIdxInGrp](0, 0)); + constexpr uint32_t nbGrainsPerRowMaxBuf = exactDiv(sizeof(SMemWarpRowMax), grainBytes); + copyGrains(0, + reinterpret_cast(&dstRowSum), + reinterpret_cast(&rowSumBuffers[idxBufBase + n]), nbGrainsPerRowMaxBuf); + copyGrains(0, + reinterpret_cast(&dstRowMax), + reinterpret_cast(&rowMaxBuffers[idxBufBase + n]), nbGrainsPerRowMaxBuf); + }; + loadBufAsync(warpGrpIdx); + ldgsts::commitGroup(); + WarpAcc sumAcc{}; + ThrdRegRowMax partialMergedRowSum{}; + for (uint32_t n = warpGrpIdx; n < nbSubSeqPerSeq; n += gemm1NbWarpGrps) { + if (n + gemm1NbWarpGrps < nbSubSeqPerSeq) { + loadBufAsync(n + gemm1NbWarpGrps); + } + ldgsts::commitGroup(); + ldgsts::waitGroup<1>(); + uint32_t const d = n / gemm1NbWarpGrps % nbTileBuffers; + WarpAcc tile = toWarpAcc(loadGemmOutTile(warp, mbbuf.tiles[warpGrpIdx][warpIdxInGrp][d])); + ThrdRegRowMax const tileRowMax = getTileBuf(mbbuf.tileRowMax, d).loadToReg(warp); + ThrdRegRowMax const tileRowSum = getTileBuf(mbbuf.tileRowSums, d).loadToReg(warp); + ThrdRegRowMax const tileRowScales = expf(tileRowMax - mergedRowMax); + ThrdRegRowMax const scaledTileRowSum = tileRowSum * tileRowScales; + partialMergedRowSum = partialMergedRowSum + scaledTileRowSum; + assert(isfinite(partialMergedRowSum[0])); + rescaleAcc(warp, tile, fullRescaleMask, scaledTileRowSum); + sumAcc = sumAcc + tile; + } + + ThrdRegRowMax mergedRowSum{}; + if (gemm1NbWarpGrps == 1) { + mergedRowSum = partialMergedRowSum; + } else { + if (warpIdxInGrp == 0) { + mbbuf.mergedRowSum[warpGrpIdx].storeFromReg(warp, partialMergedRowSum); + } + __syncthreads(); +#ifndef NDEBUG +#pragma unroll + for (uint32_t k = 0; k < ThrdRegRowMax::size; k++) { + assert(__float_as_int(mbbuf.mergedRowSum[warpGrpIdx].loadToReg(warp)[k]) == __float_as_int(partialMergedRowSum[k])); + } + __syncthreads(); +#endif +#pragma unroll + for (uint32_t i = 0; i < gemm1NbWarpGrps; i++) { + mergedRowSum = mergedRowSum + mbbuf.mergedRowSum[i].loadToReg(warp); + assert(isfinite(mergedRowSum[0])); + } + } + if (attentionSinks != nullptr) { + // Attention sinks are per head. + addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } + __syncthreads(); + rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum)); + GemmOutRegTile const mergedOutTile = toFp16(sumAcc); + smemOutTile = mergeAndSaveOutTile(mergedOutTile, false); + } + } + if (warpGrpIdx == 0) { +#if SPEC_DEC + copyOutputToGlobalMem(warp, &output[reqSeqOffset * nbQHeads], nbQHeads, headGrpSize, + (idxHeadGrp * headGrpSize), nbValidHeadTokens, + uint2{warpTile.x * warpIdxInGrp, nbValidRows * warpIdx.y + idxHeadTokenInGrp}, *smemOutTile); +#else + copyOutputToGlobalMem(warp, &output[nbQHeads * beamWidth * idxReq], nbQHeads, idxHeadGrp, + uint2{warpTile.x * warpIdxInGrp, nbValidRows * warpIdx.y}, *smemOutTile); +#endif + } + } +} + +#if SPEC_DEC +#if __CUDA_ARCH__ == 900 && M_TILESIZE == 16 +constexpr uint32_t nbCtaPerSM = 2; +#else +constexpr uint32_t nbCtaPerSM = 1; +#endif +#else +#if __CUDA_ARCH__ == 900 +constexpr uint32_t nbCtaPerSM = 2; +#else +constexpr uint32_t nbCtaPerSM = 1; +#endif +#endif + +[[maybe_unused]] CUBIN_EXPORT __device__ constexpr XQAKernelType kernelType = XQAKernelType::kAMPERE_WARP_SPECIALIZED; + +#ifdef NDEBUG +CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( +#if SPEC_DEC + uint32_t const qSeqLen, uint32_t const nbKHeads, uint32_t const headGrpSize, SeqLenDataType const* qCuSeqLens, +#else + uint32_t const nbKHeads, +#endif +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, + OutputHead* __restrict__ const output, // [nbReq][beamWidth][nbQHeads] +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif + IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], +#if SPEC_DEC + MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col + // position). +#endif + float const* attentionSinks, // [headGrpSize] + KVCacheList const cacheList, +#if BEAM_WIDTH > 1 + BeamSearchParams const beamSearchParams, +#endif + uint32_t const batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for + // int8/fp8 KV cache. + uint32_t* __restrict__ semaphores = nullptr, void* __restrict__ scratch = nullptr) { +#if SPEC_DEC + kernel_mha_impl(qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, +#else + kernel_mha_impl(nbKHeads, +#endif +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, +#if SPEC_DEC + mask, +#endif + attentionSinks, cacheList, +#if BEAM_WIDTH > 1 + beamSearchParams, +#endif + batchSize, kvCacheScale, semaphores, scratch); +} +#else +static constexpr auto kernel_mha = kernel_mha_impl; +#endif + +#ifndef GENERATE_CUBIN +uint32_t computeNbSubSeqPerSeqMHA(cudaDeviceProp const& prop, uint32_t batchSize, uint32_t nbKHeads, uint32_t maxSeqLen) { + if (!allowMultiBlockMode) { + return 1; + } + auto const env = std::getenv("XQA_NB_SUB_SEQ"); + if (env != nullptr) { + int32_t const val = std::stoi(env); + if (val > 0) { + return val; + } + } + return std::min( + std::max(1U, prop.multiProcessorCount / (batchSize * nbKHeads)), divUp(maxSeqLen, ctaTile.x)); +} + +void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, +#if SLIDING_WINDOW + uint32_t slidingWinSize, +#endif + float qScale, OutputHead* output, +#if LOW_PREC_OUTPUT + float const* rcpOutScale, +#endif +#if USE_INPUT_KV + InputHead const* qkv, +#if ROPE_STYLE != 0 + Vec const* ropeCosSin, +#endif +#else + InputHead const* q, +#endif + float const* attentionSinks, // [headGrpSize] +#if USE_PAGED_KV_CACHE +#if PAGED_KV_CACHE_LAYOUT == 1 + GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, +#else + GMemCacheHead* pool, // global pool of pages +#endif + KVCachePageIndex const* + kvCachePageList, // device pointer. shape: KVCachePageIndex[batchSize][beamWidth][2][maxNbPagesPerSeq]. +#else + GMemKVCacheHead* kCacheData, + GMemKVCacheHead* vCacheData, + bool isBSNH, +#endif + uint32_t maxSeqLen, uint32_t const* seqLen, +#if BEAM_WIDTH > 1 + BeamSearchParams const& beamSearchParams, +#endif + uint32_t batchSize, + float const* __restrict__ kvCacheScale, // Device memory scalar. Same scale for K and V cache. Used only for + // int8/fp8 KV cache. +#if SPEC_DEC + SpecDecParams const& specDecParams, +#endif +#if SKIP_SOFTMAX_ATTN + float const skipSoftmaxThresholdScaleFactor, // for compatibility with mha_sm90.cu only +#if SKIP_SOFTMAX_ATTN_BLOCK_STATS + uint32_t* __restrict__ skippedBlockCount, // for compatibility with mha_sm90.cu only + uint32_t* __restrict__ totalBlockCount, // for compatibility with mha_sm90.cu only +#endif +#endif + uint32_t* semaphores, void* scratch, cudaStream_t stream) { +#if SPEC_DEC + auto const qSeqLen = specDecParams.qSeqLen; + auto const qCuSeqLens = specDecParams.qCuSeqLens; + auto const mask = specDecParams.mask; +#endif +#if USE_INPUT_KV + throw std::runtime_error("not implemented"); +#else + static uint32_t const hostSmemSize = [&]() { + uint32_t size; + checkCuda(cudaMemcpyFromSymbol(&size, smemSize, sizeof(smemSize))); + checkCuda(cudaFuncSetAttribute(kernel_mha, cudaFuncAttributeMaxDynamicSharedMemorySize, size)); + return size; + }(); + uint32_t const nbVHeads = nbKHeads; + unused(nbVHeads); + uint32_t const nbQHeads = nbKHeads * headGrpSize; + unused(nbQHeads); + + // const uint32_t nbSubSeqPerSeq = allowMultiBlockMode ? DBG_NB_CTAS_PER_SEQ : 1; + uint32_t const nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA(prop, batchSize, nbKHeads, maxSeqLen); + // printf("DEBUG: launchMHA: batch=%u, nbKHeads=%u, maxSeq=%u, nbSubSeqPerSeq=%u\n", batchSize, nbKHeads, maxSeqLen, nbSubSeqPerSeq); + // gridDim.z == batchSize && gridDim.y == nbKHeads && gridDim.x == nbSubSeqPerSeq +#if SPEC_DEC + const uint32_t nbTokenBlocksPerGrp = divUp(qSeqLen * headGrpSize, rowsPerBlock); + dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads * nbTokenBlocksPerGrp, batchSize}; +#else + dim3 const dimGrid{nbSubSeqPerSeq, nbKHeads, batchSize}; +#endif + dim3 const dimCta{warp_size * ctaShapeInWarps.x, ctaShapeInWarps.y, ctaShapeInWarps.z}; +#if defined(NDEBUG) || USE_PAGED_KV_CACHE + auto const launchCfg = makeLaunchConfig(dimGrid, dimCta, hostSmemSize, stream, ENABLE_PDL != 0); +#endif +#if USE_PAGED_KV_CACHE + uint32_t const maxNbPagesPerSeq = exactDiv(maxSeqLen, tokensPerPage); +#if PAGED_KV_CACHE_LAYOUT == 1 + KVCacheList const cacheList{kCacheVLLM, vCacheVLLM, kvCachePageList, seqLen, maxNbPagesPerSeq}; +#else + KVCacheList const cacheList{pool, kvCachePageList, seqLen, maxNbPagesPerSeq}; +#endif + cudaLaunchKernelEx(&launchCfg, kernel_mha, +#if SPEC_DEC + qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, +#else + nbKHeads, +#endif +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, +#if SPEC_DEC + mask, +#endif + attentionSinks, cacheList, +#if BEAM_WIDTH > 1 + beamSearchParams, +#endif + batchSize, kvCacheScale, semaphores, scratch); +#else + KVCacheList const cacheList{kCacheData, vCacheData, seqLen, maxSeqLen, isBSNH, 1}; +#ifndef NDEBUG + kernel_mha<<>>( +#else + cudaLaunchKernelEx(&launchCfg, kernel_mha, +#endif +#if SPEC_DEC + qSeqLen, nbKHeads, headGrpSize, qCuSeqLens, +#else + nbKHeads, +#endif +#if SLIDING_WINDOW + slidingWinSize, +#endif + qScale, output, +#if LOW_PREC_OUTPUT + rcpOutScale, +#endif + q, +#if SPEC_DEC + mask, +#endif + attentionSinks, cacheList, +#if BEAM_WIDTH > 1 + beamSearchParams, +#endif + batchSize, kvCacheScale, semaphores, scratch); +#endif + checkCuda(cudaPeekAtLastError()); +#endif // USE_INPUT_KV +} +#endif +#endif + +__device__ __host__ inline size_t GetScratchSize(uint32_t nbSeq, uint32_t nbSubSeqPerSeq) { + uint32_t const nbSubSeq = nbSubSeqPerSeq * nbSeq; + size_t offset = 0; + + // 1. rowMax + offset = roundUp(offset, sizeof(SMemWarpRowMax)); + offset += sizeof(SMemWarpRowMax) * nbSubSeq; + + // 2. rowSum + offset = roundUp(offset, sizeof(SMemWarpRowMax)); + offset += sizeof(SMemWarpRowMax) * nbSubSeq; + + // 3. scratchBuffers + using ScratchBuf = Array2D; + using VecT = Vec; + + // size_t sem_size = roundUp(nbSeq * sizeof(uint32_t), 128); + // if (nbSubSeqPerSeq > 1) { + // printf("[MHA_IMPL] GetScratchSize: nbSeq=%u, nbSubSeqPerSeq=%u, sizeof(SMemWarpRowMax)=%zu, sizeof(VecT)=%zu, nbValidRows=%u, XS_cols=%u\n", + // nbSeq, nbSubSeqPerSeq, (size_t)sizeof(SMemWarpRowMax), (size_t)sizeof(VecT), (uint32_t)nbValidRows, (uint32_t)SharedMem::XSmemBuffer::cols); + // } + + offset = roundUp(offset, sizeof(VecT)); + offset += sizeof(VecT) * nbSubSeq; + + return offset; +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mha_stdheaders.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_stdheaders.cuh new file mode 100644 index 0000000000000..7b4fc3bdb6b39 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mha_stdheaders.cuh @@ -0,0 +1,1105 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef GENERATE_CUBIN +#include +#include +#include +#include +#include +#include +#include +#endif + +#ifndef __CUDACC__ +#include +#endif + +#define HOST_DEVICE_FUNC __host__ __device__ +#define DEVICE_FUNC __device__ + +namespace mha { + +#ifndef GENERATE_CUBIN +template +using numeric_limits = std::numeric_limits; +using std::max; +using std::min; +#else + +using uint8_t = unsigned char; +using int8_t = signed char; +using uint16_t = unsigned short; +using uint32_t = unsigned int; +using int32_t = int; +using uint64_t = unsigned long long; +using uintptr_t = uint64_t; +static_assert(sizeof(uint8_t) == 1); +static_assert(sizeof(int8_t) == 1); +static_assert(sizeof(uint16_t) == 2); +static_assert(sizeof(uint32_t) == 4); +static_assert(sizeof(int32_t) == 4); +static_assert(sizeof(uint64_t) == 8); + +template +class numeric_limits; + +template <> +class numeric_limits { + public: + static constexpr int32_t max() noexcept { + return 0x7FFFFFFF; + } +}; + +template <> +class numeric_limits { + public: + static constexpr float lowest() noexcept { + return -3.40282347E+38F; + } +}; + +template +DEVICE_FUNC constexpr T const& max(T const& a, T const& b) { + return a > b ? a : b; +} + +template +DEVICE_FUNC constexpr T const& min(T const& a, T const& b) { + return a < b ? a : b; +} + +#endif + +#ifndef GENERATE_CUBIN +template +using conditional_t = std::conditional_t; + +template +using enable_if_t = typename std::enable_if::type; +#else + +// https://en.cppreference.com/w/cpp/types/conditional +template +struct conditional { + using type = T; +}; + +template +struct conditional { + using type = F; +}; + +template +using conditional_t = typename conditional::type; + +template +struct enable_if { +}; + +template +struct enable_if { + typedef T type; +}; + +template +using enable_if_t = typename enable_if::type; +#endif + +#ifndef GENERATE_CUBIN +using byte = std::byte; +#else +// https://en.cppreference.com/w/cpp/types/byte +enum class byte : unsigned char { +}; +#endif + +#ifndef GENERATE_CUBIN +using std::declval; +#else + +// https://en.cppreference.com/w/cpp/types/add_reference +namespace detail { +template +struct type_identity { + using type = T; +}; // or use std::type_identity (since C++20) + +template // Note that `cv void&` is a substitution failure +DEVICE_FUNC auto try_add_lvalue_reference(int) -> type_identity; +template // Handle T = cv void case +DEVICE_FUNC auto try_add_lvalue_reference(...) -> type_identity; + +template +DEVICE_FUNC auto try_add_rvalue_reference(int) -> type_identity; +template +DEVICE_FUNC auto try_add_rvalue_reference(...) -> type_identity; +} // namespace detail + +template +struct add_lvalue_reference : decltype(detail::try_add_lvalue_reference(0)) { +}; + +template +struct add_rvalue_reference : decltype(detail::try_add_rvalue_reference(0)) { +}; + +// https://en.cppreference.com/w/cpp/utility/declval +template +DEVICE_FUNC typename add_rvalue_reference::type declval() noexcept { + static_assert(false, "declval not allowed in an evaluated context"); +} +#endif + +#ifndef GENERATE_CUBIN +template +using array = std::array; +#else +// https://en.cppreference.com/w/cpp/container/array +template +struct array; +#endif + +#ifndef GENERATE_CUBIN +template +using is_same = std::is_same; +using std::is_same_v; +#else + +// https://en.cppreference.com/w/cpp/types/integral_constant +template +struct integral_constant { + static constexpr T value = v; + using value_type = T; + using type = integral_constant; // using injected-class-name + + DEVICE_FUNC constexpr operator value_type() const noexcept { + return value; + } + + DEVICE_FUNC constexpr value_type operator()() const noexcept { + return value; + } // since c++14 +}; + +using false_type = integral_constant; +using true_type = integral_constant; + +// https://en.cppreference.com/w/cpp/types/is_same +template +struct is_same : false_type { +}; + +template +struct is_same : true_type { +}; + +template +inline constexpr bool is_same_v = is_same::value; + +#endif + +#ifndef GENERATE_CUBIN + +using std::forward; +using std::is_empty; +using std::move; + +#else + +// /usr/include/c++/11/type_traits +template +struct is_empty : public integral_constant { +}; + +template +struct remove_reference { + typedef T type; +}; + +template +struct remove_reference { + typedef T type; +}; + +template +struct remove_reference { + typedef T type; +}; + +template +constexpr typename remove_reference::type&& move(T&& arg) { + return static_cast::type&&>(arg); +} + +template +constexpr T&& forward(typename remove_reference::type& param) { + return static_cast(param); +} + +#endif + +// https://gcc.gnu.org/onlinedocs/libstdc++/libstdc++-api-4.5/a01066_source.html +namespace libstdcpp { +// Adds a const reference to a non-reference type. +template +struct __add_c_ref { + typedef _Tp const& type; +}; + +template +struct __add_c_ref<_Tp&> { + typedef _Tp& type; +}; + +// Adds a reference to a non-reference type. +template +struct __add_ref { + typedef _Tp& type; +}; + +template +struct __add_ref<_Tp&> { + typedef _Tp& type; +}; + +template +struct _Head_base; + +template +struct _Head_base<_Idx, _Head, true> : public _Head { + DEVICE_FUNC _Head_base() + : _Head() { + } + + DEVICE_FUNC _Head_base(_Head const& __h) + : _Head(__h) { + } + + template + DEVICE_FUNC _Head_base(_UHead&& __h) + : _Head(forward<_UHead>(__h)) { + } + + DEVICE_FUNC _Head& _M_head() { + return *this; + } + + DEVICE_FUNC _Head const& _M_head() const { + return *this; + } + + DEVICE_FUNC void _M_swap_impl(_Head&) { /* no-op */ + } +}; + +template +struct _Head_base<_Idx, _Head, false> { + DEVICE_FUNC _Head_base() + : _M_head_impl() { + } + + DEVICE_FUNC _Head_base(_Head const& __h) + : _M_head_impl(__h) { + } + + template + DEVICE_FUNC _Head_base(_UHead&& __h) + : _M_head_impl(forward<_UHead>(__h)) { + } + + DEVICE_FUNC _Head& _M_head() { + return _M_head_impl; + } + + DEVICE_FUNC _Head const& _M_head() const { + return _M_head_impl; + } + + DEVICE_FUNC void _M_swap_impl(_Head& __h) { + using std::swap; + swap(__h, _M_head_impl); + } + + _Head _M_head_impl; +}; + +/** + * Contains the actual implementation of the @c tuple template, stored + * as a recursive inheritance hierarchy from the first element (most + * derived class) to the last (least derived class). The @c Idx + * parameter gives the 0-based index of the element stored at this + * point in the hierarchy; we use it to implement a constant-time + * get() operation. + */ +template +struct _Tuple_impl; + +/** + * Zero-element tuple implementation. This is the basis case for the + * inheritance recursion. + */ +template +struct _Tuple_impl<_Idx> { + protected: + DEVICE_FUNC void _M_swap_impl(_Tuple_impl&) { /* no-op */ + } +}; + +/** + * Recursive tuple implementation. Here we store the @c Head element + * and derive from a @c Tuple_impl containing the remaining elements + * (which contains the @c Tail). + */ +template +struct _Tuple_impl<_Idx, _Head, _Tail...> : public _Tuple_impl<_Idx + 1, _Tail...>, + private _Head_base<_Idx, _Head, is_empty<_Head>::value> { + typedef _Tuple_impl<_Idx + 1, _Tail...> _Inherited; + typedef _Head_base<_Idx, _Head, is_empty<_Head>::value> _Base; + + DEVICE_FUNC _Head& _M_head() { + return _Base::_M_head(); + } + + DEVICE_FUNC _Head const& _M_head() const { + return _Base::_M_head(); + } + + DEVICE_FUNC _Inherited& _M_tail() { + return *this; + } + + DEVICE_FUNC _Inherited const& _M_tail() const { + return *this; + } + + DEVICE_FUNC _Tuple_impl() + : _Inherited(), _Base() { + } + + explicit DEVICE_FUNC _Tuple_impl(_Head const& __head, _Tail const&... __tail) + : _Inherited(__tail...), _Base(__head) { + } + + template + explicit DEVICE_FUNC _Tuple_impl(_UHead&& __head, _UTail&&... __tail) + : _Inherited(forward<_UTail>(__tail)...), _Base(forward<_UHead>(__head)) { + } + + DEVICE_FUNC _Tuple_impl(_Tuple_impl const& __arg) + : _Inherited(__arg._M_tail()), _Base(__arg._M_head()) { + } + + DEVICE_FUNC _Tuple_impl(_Tuple_impl&& __arg) + : _Inherited(move(__arg._M_tail())), _Base(forward<_Head>(__arg._M_head())) { + } + + template + DEVICE_FUNC _Tuple_impl(_Tuple_impl<_Idx, _UElements...> const& __arg) + : _Inherited(__arg._M_tail()), _Base(__arg._M_head()) { + } + + template + DEVICE_FUNC _Tuple_impl(_Tuple_impl<_Idx, _UElements...>&& __arg) + : _Inherited(move(__arg._M_tail())), _Base(move(__arg._M_head())) { + } + + DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl const& __arg) { + _M_head() = __arg._M_head(); + _M_tail() = __arg._M_tail(); + return *this; + } + + DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl&& __arg) { + _M_head() = move(__arg._M_head()); + _M_tail() = move(__arg._M_tail()); + return *this; + } + + template + DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl<_Idx, _UElements...> const& __arg) { + _M_head() = __arg._M_head(); + _M_tail() = __arg._M_tail(); + return *this; + } + + template + DEVICE_FUNC _Tuple_impl& operator=(_Tuple_impl<_Idx, _UElements...>&& __arg) { + _M_head() = move(__arg._M_head()); + _M_tail() = move(__arg._M_tail()); + return *this; + } + + protected: + DEVICE_FUNC void _M_swap_impl(_Tuple_impl& __arg) { + _Base::_M_swap_impl(__arg._M_head()); + _Inherited::_M_swap_impl(__arg._M_tail()); + } +}; + +/// tuple +template +class tuple : public _Tuple_impl<0, _Elements...> { + typedef _Tuple_impl<0, _Elements...> _Inherited; + + public: + DEVICE_FUNC tuple() + : _Inherited() { + } + + explicit DEVICE_FUNC tuple(_Elements const&... __elements) + : _Inherited(__elements...) { + } + + template + explicit DEVICE_FUNC tuple(_UElements&&... __elements) + : _Inherited(forward<_UElements>(__elements)...) { + } + + DEVICE_FUNC tuple(tuple const& __arg) + : _Inherited(static_cast<_Inherited const&>(__arg)) { + } + + DEVICE_FUNC tuple(tuple&& __arg) + : _Inherited(static_cast<_Inherited&&>(__arg)) { + } + + template + DEVICE_FUNC tuple(tuple<_UElements...> const& __arg) + : _Inherited(static_cast<_Tuple_impl<0, _UElements...> const&>(__arg)) { + } + + template + DEVICE_FUNC tuple(tuple<_UElements...>&& __arg) + : _Inherited(static_cast<_Tuple_impl<0, _UElements...>&&>(__arg)) { + } + + // XXX http://gcc.gnu.org/ml/libstdc++/2008-02/msg00047.html + template + DEVICE_FUNC tuple(tuple<_UElements...>& __arg) + : _Inherited(static_cast<_Tuple_impl<0, _UElements...> const&>(__arg)) { + } + + DEVICE_FUNC tuple& operator=(tuple const& __arg) { + static_cast<_Inherited&>(*this) = __arg; + return *this; + } + + DEVICE_FUNC tuple& operator=(tuple&& __arg) { + static_cast<_Inherited&>(*this) = move(__arg); + return *this; + } + + template + DEVICE_FUNC tuple& operator=(tuple<_UElements...> const& __arg) { + static_cast<_Inherited&>(*this) = __arg; + return *this; + } + + template + DEVICE_FUNC tuple& operator=(tuple<_UElements...>&& __arg) { + static_cast<_Inherited&>(*this) = move(__arg); + return *this; + } + + void DEVICE_FUNC swap(tuple& __arg) { + _Inherited::_M_swap_impl(__arg); + } +}; + +template <> +class tuple<> { + public: + DEVICE_FUNC void swap(tuple&) { /* no-op */ + } +}; + +/// Gives the type of the ith element of a given tuple type. +template +struct tuple_element; + +/** + * Recursive case for tuple_element: strip off the first element in + * the tuple and retrieve the (i-1)th element of the remaining tuple. + */ +template +struct tuple_element<__i, tuple<_Head, _Tail...>> : tuple_element<__i - 1, tuple<_Tail...>> { +}; + +/** + * Basis case for tuple_element: The first element is the one we're seeking. + */ +template +struct tuple_element<0, tuple<_Head, _Tail...>> { + typedef _Head type; +}; + +/// Finds the size of a given tuple type. +template +struct tuple_size; + +/// class tuple_size +template +struct tuple_size> { + static const size_t value = sizeof...(_Elements); +}; + +template +const size_t tuple_size>::value; + +template +DEVICE_FUNC inline typename __add_ref<_Head>::type __get_helper(_Tuple_impl<__i, _Head, _Tail...>& __t) { + return __t._M_head(); +} + +template +DEVICE_FUNC inline typename __add_c_ref<_Head>::type __get_helper(_Tuple_impl<__i, _Head, _Tail...> const& __t) { + return __t._M_head(); +} + +// Return a reference (const reference) to the ith element of a tuple. +// Any const or non-const ref elements are returned with their original type. +template +DEVICE_FUNC inline typename __add_ref>::type>::type get( + tuple<_Elements...>& __t) { + return __get_helper<__i>(__t); +} + +template +DEVICE_FUNC inline typename __add_c_ref>::type>::type get( + tuple<_Elements...> const& __t) { + return __get_helper<__i>(__t); +} + +// This class helps construct the various comparison operations on tuples +template +struct __tuple_compare; + +template +struct __tuple_compare<0, __i, __j, _Tp, _Up> { + DEVICE_FUNC static bool __eq(_Tp const& __t, _Up const& __u) { + return (get<__i>(__t) == get<__i>(__u) && __tuple_compare<0, __i + 1, __j, _Tp, _Up>::__eq(__t, __u)); + } + + DEVICE_FUNC static bool __less(_Tp const& __t, _Up const& __u) { + return ((get<__i>(__t) < get<__i>(__u)) || !(get<__i>(__u) < get<__i>(__t)) && __tuple_compare<0, __i + 1, __j, _Tp, _Up>::__less(__t, __u)); + } +}; + +template +struct __tuple_compare<0, __i, __i, _Tp, _Up> { + static bool __eq(_Tp const&, _Up const&) { + return true; + } + + static bool __less(_Tp const&, _Up const&) { + return false; + } +}; + +template +DEVICE_FUNC bool operator==(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + typedef tuple<_TElements...> _Tp; + typedef tuple<_UElements...> _Up; + return (__tuple_compare::value - tuple_size<_Up>::value, 0, tuple_size<_Tp>::value, _Tp, _Up>::__eq( + __t, __u)); +} + +template +DEVICE_FUNC bool operator<(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + typedef tuple<_TElements...> _Tp; + typedef tuple<_UElements...> _Up; + return ( + __tuple_compare::value - tuple_size<_Up>::value, 0, tuple_size<_Tp>::value, _Tp, _Up>::__less( + __t, __u)); +} + +template +DEVICE_FUNC inline bool operator!=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + return !(__t == __u); +} + +template +DEVICE_FUNC inline bool operator>(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + return __u < __t; +} + +template +DEVICE_FUNC inline bool operator<=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + return !(__u < __t); +} + +template +DEVICE_FUNC inline bool operator>=(tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + return !(__t < __u); +} + +template +struct __index_holder { +}; + +template +struct __index_holder_impl; + +template +struct __index_holder_impl<__i, __index_holder<_Indexes...>, _IdxHolder, _Elements...> { + typedef typename __index_holder_impl<__i + 1, __index_holder<_Indexes..., __i>, _Elements...>::type type; +}; + +template +struct __index_holder_impl<__i, __index_holder<_Indexes...>> { + typedef __index_holder<_Indexes...> type; +}; + +template +struct __make_index_holder : __index_holder_impl<0, __index_holder<>, _Elements...> { +}; + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...> const& __t, + __index_holder<_TIdx...> const&, tuple<_UElements...> const& __u, __index_holder<_UIdx...> const&) { + return tuple<_TElements..., _UElements...>(get<_TIdx>(__t)..., get<_UIdx>(__u)...); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...>&& __t, + __index_holder<_TIdx...> const&, tuple<_UElements...> const& __u, __index_holder<_UIdx...> const&) { + return tuple<_TElements..., _UElements...>(move(get<_TIdx>(__t))..., get<_UIdx>(__u)...); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...> const& __t, + __index_holder<_TIdx...> const&, tuple<_UElements...>&& __u, __index_holder<_UIdx...> const&) { + return tuple<_TElements..., _UElements...>(get<_TIdx>(__t)..., move(get<_UIdx>(__u))...); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> __tuple_cat_helper(tuple<_TElements...>&& __t, + __index_holder<_TIdx...> const&, tuple<_UElements...>&& __u, __index_holder<_UIdx...> const&) { + return tuple<_TElements..., _UElements...>(move(get<_TIdx>(__t))..., move(get<_UIdx>(__u))...); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( + tuple<_TElements...> const& __t, tuple<_UElements...> const& __u) { + return __tuple_cat_helper(__t, typename __make_index_holder<_TElements...>::type(), __u, + typename __make_index_holder<_UElements...>::type()); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( + tuple<_TElements...>&& __t, tuple<_UElements...> const& __u) { + return __tuple_cat_helper(move(__t), typename __make_index_holder<_TElements...>::type(), __u, + typename __make_index_holder<_UElements...>::type()); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat( + tuple<_TElements...> const& __t, tuple<_UElements...>&& __u) { + return __tuple_cat_helper(__t, typename __make_index_holder<_TElements...>::type(), move(__u), + typename __make_index_holder<_UElements...>::type()); +} + +template +DEVICE_FUNC inline tuple<_TElements..., _UElements...> tuple_cat(tuple<_TElements...>&& __t, tuple<_UElements...>&& __u) { + return __tuple_cat_helper(move(__t), typename __make_index_holder<_TElements...>::type(), move(__u), + typename __make_index_holder<_UElements...>::type()); +} + +template +DEVICE_FUNC inline tuple<_Elements&...> tie(_Elements&... __args) { + return tuple<_Elements&...>(__args...); +} + +template +DEVICE_FUNC inline void swap(tuple<_Elements...>& __x, tuple<_Elements...>& __y) { + __x.swap(__y); +} + +// A class (and instance) which can be used in 'tie' when an element +// of a tuple is not required +struct _Swallow_assign { + template + DEVICE_FUNC _Swallow_assign& operator=(_Tp const&) { + return *this; + } +}; + +// TODO: Put this in some kind of shared file. +namespace { +_Swallow_assign ignore; +}; // anonymous namespace +} // namespace libstdcpp + +template +using tuple = libstdcpp::tuple; + +using libstdcpp::tie; +using libstdcpp::tuple_cat; + +#ifndef GENERATE_CUBIN +template +using remove_cv = std::remove_cv; +template +using remove_cv_t = typename std::remove_cv::type; +template +using decay = std::decay; +template +using decay_t = std::decay_t; +#else + +// https://en.cppreference.com/w/cpp/types/is_array +template +struct is_array : false_type { +}; + +template +struct is_array : true_type { +}; + +template +struct is_array : true_type { +}; + +// https://en.cppreference.com/w/cpp/types/remove_extent +template +struct remove_extent { + using type = T; +}; + +template +struct remove_extent { + using type = T; +}; + +template +struct remove_extent { + using type = T; +}; + +// https://en.cppreference.com/w/cpp/types/is_function +template +struct is_function : false_type { +}; + +// specialization for regular functions +template +struct is_function : true_type { +}; + +// specialization for variadic functions such as printf +template +struct is_function : true_type { +}; + +// specialization for function types that have cv-qualifiers +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +// specialization for function types that have ref-qualifiers +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +// specializations for noexcept versions of all the above (C++17 and later) +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +template +struct is_function : true_type { +}; + +// https://en.cppreference.com/w/cpp/types/remove_cv +template +struct remove_cv { + typedef T type; +}; + +template +struct remove_cv { + typedef T type; +}; + +template +struct remove_cv { + typedef T type; +}; + +template +struct remove_cv { + typedef T type; +}; + +template +struct remove_const { + typedef T type; +}; + +template +struct remove_const { + typedef T type; +}; + +template +struct remove_volatile { + typedef T type; +}; + +template +struct remove_volatile { + typedef T type; +}; + +template +using remove_cv_t = typename remove_cv::type; + +// https://en.cppreference.com/w/cpp/types/add_pointer +namespace detail { +template +auto try_add_pointer(int) -> type_identity::type*>; // usual case + +template +auto try_add_pointer(...) -> type_identity; // unusual case (cannot form std::remove_reference::type*) +} // namespace detail + +template +struct add_pointer : decltype(detail::try_add_pointer(0)) { +}; + +// https://en.cppreference.com/w/cpp/types/decay +template +struct decay { + private: + typedef typename remove_reference::type U; + + public: + typedef typename conditional::value, typename add_pointer::type>::type, + typename conditional::value, typename add_pointer::type, + typename remove_cv::type>::type>::type type; +}; + +template +using decay_t = typename decay::type; +#endif + +#ifndef GENERATE_CUBIN +template +using is_void = std::is_void; +template +inline constexpr bool is_void_v = std::is_void_v; +#else +template +using is_void = is_same, void>; +template +inline constexpr bool is_void_v = is_void::value; +#endif + +#ifndef GENERATE_CUBIN +template +using pair = std::pair; +#else +template +struct pair { + T1 first; + T2 second; +}; +#endif + +} // namespace mha + +#if GENERATE_CUBIN +using uint8_t = mha::uint8_t; +using int8_t = mha::int8_t; +using uint16_t = mha::uint16_t; +using int32_t = mha::int32_t; +using uint32_t = mha::uint32_t; +using uint64_t = mha::uint64_t; +using uintptr_t = mha::uintptr_t; +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/mma.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/mma.cuh new file mode 100644 index 0000000000000..f947c519cfaa8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/mma.cuh @@ -0,0 +1,84 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "cuda_hint.cuh" +#include "mha_stdheaders.cuh" +#ifndef __CUDACC__ +#include +#endif +#include +#include + +// for both a and b, outer-dim is gemm-K and inner-dim is gemm-M or gemm-N +// acc is used as both input and output. +template +__device__ inline void mma(float (&acc)[2][2], uint32_t const (&a)[2][2], uint32_t const (&b)[2][1]) { + static_assert(mha::is_same_v || mha::is_same_v || mha::is_same_v, + "not implemented"); + if constexpr (mha::is_same_v) { + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) + : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); + } else if constexpr (mha::is_same_v) { + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) + : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); + } else if constexpr (mha::is_same_v) { + asm("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5, %6, %7}, \n" + " {%8, %9}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) + : "r"(a[0][0]), "r"(a[0][1]), "r"(a[1][0]), "r"(a[1][1]), "r"(b[0][0]), "r"(b[1][0])); + } else { + asm volatile("trap;"); + } +} + +__device__ inline void mmaF8_k16(float (&acc)[2][2], uint32_t const (&a)[2], uint32_t const b) { + asm("mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%0, %1, %2, %3}; \n" + : "+f"(acc[0][0]), "+f"(acc[0][1]), "+f"(acc[1][0]), "+f"(acc[1][1]) + : "r"(a[0]), "r"(a[1]), "r"(b)); +} + +__device__ inline void mmaF8_k32_2inst(float (&acc)[2][2], uint32_t const (&a)[2][2], uint32_t const (&b)[2][1]) { + for (uint32_t i = 0; i < 2; i++) { + mmaF8_k16(acc, a[i], b[i][0]); + } +} + +struct mmaShape { + uint32_t m; + uint32_t n; + uint32_t k; +}; + +inline constexpr mmaShape qmmaShape = {16, 8, 32}; diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/platform.h b/onnxruntime/contrib_ops/cuda/bert/xqa/platform.h new file mode 100644 index 0000000000000..9d40bcc704c28 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/platform.h @@ -0,0 +1,25 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// for IDE parser +#if defined(Q_CREATOR_RUN) || defined(__CLION_IDE__) || defined(__INTELLISENSE__) || defined(IN_KDEVELOP_PARSER) || defined(__JETBRAINS_IDE__) || defined(__CLANGD__) +#define IS_IN_IDE_PARSER 1 +#else +#define IS_IN_IDE_PARSER 0 +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/specDec.h b/onnxruntime/contrib_ops/cuda/bert/xqa/specDec.h new file mode 100644 index 0000000000000..d1942b5b46ba7 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/specDec.h @@ -0,0 +1,29 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "defines.h" +#if SPEC_DEC + +struct SpecDecParams { + uint32_t qSeqLen; + uint32_t const* qCuSeqLens; // [nbReq + 1] + MaskType const* mask; // [nbReq][qSeqLen][divUp(qSeqLen, 32)] or [qCuSeqLen[nbReq]][divUp(qSeqLen, 32)] +}; + +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/tma.h b/onnxruntime/contrib_ops/cuda/bert/xqa/tma.h new file mode 100644 index 0000000000000..2a4297281acc3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/tma.h @@ -0,0 +1,269 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_hint.cuh" +#include "utils.h" +#ifndef GENERATE_CUBIN +#include +#include +#include +#endif +#include "barriers.cuh" + +enum class StateSpace { + kCONSTANT, + kPARAMETER, + kGENERIC +}; + +#ifdef GENERATE_CUBIN +#define CU_TENSOR_MAP_NUM_QWORDS 16 + +typedef struct CUtensorMap_st { +#if defined(__cplusplus) && (__cplusplus >= 201103L) + alignas(64) +#elif __STDC_VERSION__ >= 201112L + _Alignas(64) +#endif + uint64_t opaque[CU_TENSOR_MAP_NUM_QWORDS]; +} CUtensorMap; +#endif + +namespace tma { + +__device__ inline void loadLinearAsync(void* dst, void const* src, uint32_t nbBytes, CtaBarrier& bar) { + asm volatile("cp.async.bulk.shared::cta.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +__device__ inline void prefetchLinear(void const* src, uint32_t nbBytes) { + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" ::"l"(reinterpret_cast(src)), "r"(nbBytes) + : "memory"); +} + +// dsr and &bar must be remote address generated by mapa and src must be local address +__device__ inline void sm2smCopyAsync(void* dst, void const* src, uint32_t nbBytes, CgaBarrier& bar) { + asm volatile("cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(src)), "r"(nbBytes), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); +} + +template +__device__ inline void loadAsync(void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2}], [%3];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3}], [%4];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4}], " + "[%5];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, " + "%5}], [%6];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile [%0], [%1, {%2, %3, %4, %5, " + "%6}], [%7];\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +template +__device__ inline void loadAsync( + void* dst, CUtensorMap const& tensorMap, DimsLE offset, CtaBarrier& bar, uint64_t cacheHint) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2}], [%3], %4;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3}], [%4], %5;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4}], [%5], %6;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4, %5}], [%6], %7;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "l"(__cvta_generic_to_shared(&bar)), "l"(cacheHint) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.tile.L2::cache_hint [%0], [%1, " + "{%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "l"(__cvta_generic_to_shared(dst)), "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), + "r"(offset[1]), "r"(offset[2]), "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(&bar)), + "l"(cacheHint) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +// shared::cta -> global +__device__ inline void store1DAsync(void* dst, void const* src, uint32_t nbBytes) { + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(reinterpret_cast(dst)), "l"(__cvta_generic_to_shared(src)), "r"(nbBytes)); +} + +template +__device__ inline void storeAsync(CUtensorMap const& tensorMap, DimsLE const& offset, void* src) { + if constexpr (nbDims == 1) { + // nbDims==1 does not need tensormap and should just use cp.async.bulk + asm volatile("cp.async.bulk.tensor.1d.global.shared::cta.bulk_group.tile [%0, {%1}], [%2];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 2) { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2}], [%3];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 3) { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3}], [%4];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 4) { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4}], [%5];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "r"(offset[3]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else if constexpr (nbDims == 5) { + asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group.tile [%0, {%1, %2, %3, %4, %5}], [%6];\n" + : + : "l"(reinterpret_cast(&tensorMap)), "r"(offset[0]), "r"(offset[1]), "r"(offset[2]), + "r"(offset[3]), "r"(offset[4]), "l"(__cvta_generic_to_shared(src)) + : "memory"); + } else { + static_assert(nbDims >= 1 && nbDims <= 5); + } +} + +__device__ inline void setTensorMapGlbAddr(CUtensorMap& tensorMap, void* ptr) { + asm volatile("tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;\n" ::"l"(&tensorMap), "l"(ptr) + : "memory"); +} + +__device__ inline void commitGroup() { + asm volatile("cp.async.bulk.commit_group;\n" : : : "memory"); +} + +// wait until only targetNbInFlightGroups groups are still in-flight. +template +__device__ inline void waitGroup() { + asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(targetNbInFlightGroups) : "memory"); +} + +__device__ inline void prefetchTensorMap(CUtensorMap const& tensorMap, StateSpace loc = StateSpace::kGENERIC) { + assert(reinterpret_cast(&tensorMap) % alignof(CUtensorMap) == 0); + switch (loc) { + case StateSpace::kCONSTANT: + asm volatile("prefetch.const.tensormap [%0];\n" ::"l"(__cvta_generic_to_constant(&tensorMap)) : "memory"); + break; + case StateSpace::kPARAMETER: + asm volatile("prefetch.param.tensormap [%0];\n" ::"l"(__cvta_generic_to_grid_constant(&tensorMap)) : "memory"); + break; + case StateSpace::kGENERIC: + asm volatile("prefetch.tensormap [%0];\n" ::"l"(reinterpret_cast(&tensorMap)) : "memory"); + break; + default: + asm volatile("trap;\n"); + } +} + +template +__device__ inline void storeAsync(void* dst, T const& src, CgaBarrier& bar) { + constexpr uint32_t nbWords = exactDiv(sizeof(T), sizeof(uint32_t)); + Vec const& srcVec = reinterpret_cast const&>(src); + if constexpr (nbWords == 1) { + asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 2) { + asm volatile("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v2.u32 [%0], {%1, %2}, [%3];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else if constexpr (nbWords == 4) { + asm volatile( + "st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.v4.u32 [%0], {%1, %2, %3, %4}, [%5];\n" ::"l"( + __cvta_generic_to_shared(dst)), + "r"(srcVec[0]), "r"(srcVec[1]), "r"(srcVec[2]), "r"(srcVec[3]), "l"(__cvta_generic_to_shared(&bar)) + : "memory"); + } else { + static_assert(nbWords == 1 || nbWords == 2 || nbWords == 4, "src size must be 4, 8 or 16 bytes"); + } +} + +} // namespace tma diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/utils.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/utils.cuh new file mode 100644 index 0000000000000..3335707511599 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/utils.cuh @@ -0,0 +1,907 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_hint.cuh" + +#ifdef __CUDA_ARCH__ +#define XQA_UNROLL _Pragma("unroll") +#else +#define XQA_UNROLL +#endif +#include "utils.h" + +#ifndef GENERATE_CUBIN +#include +#else +#include "mha_stdheaders.cuh" +#endif + +#ifndef __CUDACC__ +#include +#endif +#include "barriers.cuh" +#include +#include +#include + +inline constexpr float log2e = 1.4426950408889634f; // std::log2(M_E) +// we used an optimization where exp(x-rowMax) is computed as: +/* bias = rowMax * log2e // shared for the whole row + exp(x-rowMax) = exp2f(x * log2e - bias) +*/ +// this reason, don't set safeInitRowMax with a huge absolute value. +// #define SAFE_INIT_ROW_MAX (-1e+5F) // moved to defines.h +inline constexpr int32_t kBAD_PAGE_INDEX = -1; +__constant__ constexpr float kE4M3_MAX = 448.F; + +#ifdef __CUDA_ARCH__ +#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 890 || __CUDA_ARCH__ == 1200 +constexpr uint32_t kMAX_SMEM_SIZE = (99u << 10); +#elif __CUDA_ARCH__ == 800 || __CUDA_ARCH__ == 870 +constexpr uint32_t kMAX_SMEM_SIZE = (163u << 10); +#elif __CUDA_ARCH__ == 900 +constexpr uint32_t kMAX_SMEM_SIZE = (227u << 10); +#else +constexpr uint32_t kMAX_SMEM_SIZE = (48u << 10); // Default for older architectures +#endif +#endif + +__device__ inline void assertWarpConverged() { + // assert(__activemask() == ~0U); +} + +#define DEFINE_VEC_BINARY_FUNC(func) \ + template \ + __device__ __host__ inline Vec(), mha::declval())), size> func( \ + Vec const& a, Vec const& b) { \ + Vec(), mha::declval())), size> result; \ + XQA_UNROLL for (uint32_t i = 0; i < size; i++) { \ + result[i] = func(a[i], b[i]); \ + } \ + return result; \ + } +DEFINE_VEC_BINARY_FUNC(max) +DEFINE_VEC_BINARY_FUNC(fmaxf) +DEFINE_VEC_BINARY_FUNC(__hadd2_rn) + +__device__ __host__ inline float2 addFloat2(float2 a, float2 b) { + return float2{a.x + b.x, a.y + b.y}; +} +DEFINE_VEC_BINARY_FUNC(addFloat2) +#undef DEFINE_VEC_BINARY_FUNC +#define DEFINE_VEC_BINARY_OP(op) \ + template \ + __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ + Vec const& a, Vec const& b) { \ + Vec() op mha::declval()), size> result; \ + XQA_UNROLL for (uint32_t i = 0; i < size; i++) { \ + result[i] = a[i] op b[i]; \ + } \ + return result; \ + } \ + template \ + __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ + Vec const& a, Scalar const& b) { \ + Vec() op mha::declval()), size> result; \ + XQA_UNROLL for (uint32_t i = 0; i < size; i++) { \ + result[i] = a[i] op b; \ + } \ + return result; \ + } \ + template \ + __device__ __host__ inline Vec() op mha::declval()), size> operator op( \ + Scalar const& a, Vec const& b) { \ + Vec() op mha::declval()), size> result; \ + XQA_UNROLL for (uint32_t i = 0; i < size; i++) { \ + result[i] = a op b[i]; \ + } \ + return result; \ + } +// Don't use DEFINE_VEC_BINARY_FUNC(operator+), as operator+(float, float) is undefined, +// and float will be converted into half to perform the operation, which results in much +// lower precision. It's a defect of C++ that operator+(1.F, 2.F) does not work! +DEFINE_VEC_BINARY_OP(+) +DEFINE_VEC_BINARY_OP(-) +DEFINE_VEC_BINARY_OP(*) +DEFINE_VEC_BINARY_OP(/) +DEFINE_VEC_BINARY_OP(==) +DEFINE_VEC_BINARY_OP(!=) +DEFINE_VEC_BINARY_OP(>) +DEFINE_VEC_BINARY_OP(<) +DEFINE_VEC_BINARY_OP(>=) +DEFINE_VEC_BINARY_OP(<=) +#undef DEFINE_VEC_BINARY_OP + +template +HOST_DEVICE_FUNC inline bool all(Vec const& src) { + bool ret = true; + XQA_UNROLL + for (uint32_t i = 0; i < size; i++) { + ret = ret && src[i]; + } + return ret; +} + +template +HOST_DEVICE_FUNC inline bool any(Vec const& src) { + bool ret = false; + XQA_UNROLL + for (uint32_t i = 0; i < size; i++) { + ret = ret || src[i]; + } + return ret; +} + +#define DEFINE_VEC_UNARY_OP(op) \ + template \ + __device__ __host__ inline Vec())), size> op(Vec const& a) { \ + Vec())), size> result; \ + XQA_UNROLL for (uint32_t i = 0; i < size; i++) { \ + result[i] = op(a[i]); \ + } \ + return result; \ + } +DEFINE_VEC_UNARY_OP(expf) +DEFINE_VEC_UNARY_OP(exp2f) +DEFINE_VEC_UNARY_OP(__float2bfloat162_rn) +DEFINE_VEC_UNARY_OP(__float2half2_rn) +DEFINE_VEC_UNARY_OP(__float22half2_rn) +DEFINE_VEC_UNARY_OP(__bfloat1622float2) +DEFINE_VEC_UNARY_OP(__half22float2) +DEFINE_VEC_UNARY_OP(__frcp_rn) +#undef DEFINE_VEC_UNARY_OP + +template +__device__ __host__ inline Vec convert(Vec const& src) { + if constexpr (mha::is_same_v, mha::decay_t>) { + return src; + } + Vec dst; + if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast(dst[i]) = __half22float2(reinterpret_cast(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast(dst[i]) = __float22half2_rn(reinterpret_cast(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } + if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast(dst[i]) = __bfloat1622float2(reinterpret_cast<__nv_bfloat162 const&>(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast<__nv_bfloat162&>(dst[i]) = __float22bfloat162_rn(reinterpret_cast(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast(dst[i]) = float2(reinterpret_cast<__nv_fp8x2_e4m3 const&>(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{float2{src[i], src[i + 1]}}; + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast(dst[i]) = half2(reinterpret_cast<__nv_fp8x2_e4m3 const&>(src[i])); + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{reinterpret_cast(src[i])}; + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } + // else if constexpr (mha::is_same_v && mha::is_same_v) { + // static_assert("not implemented"); + // } + else if constexpr (mha::is_same_v && mha::is_same_v) { + for (uint32_t i = 0; i < size - 1; i += 2) { + reinterpret_cast<__nv_fp8x2_e4m3&>(dst[i]) = __nv_fp8x2_e4m3{reinterpret_cast<__nv_bfloat162 const&>(src[i])}; + } + if constexpr (size % 2 != 0) { + dst[size - 1] = Dst{src[size - 1]}; + } + } else { + for (uint32_t i = 0; i < size; i++) { + dst[i] = Dst{src[i]}; + } + } + return dst; +} + +__device__ inline uint32_t laneId() { + uint32_t id; + asm("mov.u32 %0, %%laneid;\n" : "=r"(id)); + return id; +} + +__device__ inline uint32_t dynamicSmemSize() { + uint32_t size; + asm("mov.u32 %0, %%dynamic_smem_size;\n" : "=r"(size)); + return size; +} + +__device__ inline void trap() { + asm volatile("trap;\n"); +} + +inline constexpr uint32_t warp_size = 32; + +struct Warp { +}; + +__device__ inline Warp this_warp() { + return {}; +} + +// @fixme: check asm code to make sure UR is used and SHFL is not generated. +template +__device__ inline T makeWarpUniform(Warp const& warp, T const& val) { + T const val0 = __shfl_sync(~0U, val, 0); + assert(val == val0); + return val0; +} + +__device__ inline uint3 getWarpIdx(uint3 ctaShapeInWarps, Warp const& warp = this_warp()) { + assert(ctaShapeInWarps.x % 128 == 0); + return uint3{ctaShapeInWarps.x == 1 ? 0 : makeWarpUniform(warp, threadIdx.x / warp_size), + ctaShapeInWarps.y == 1 ? 0 : makeWarpUniform(warp, threadIdx.y), + ctaShapeInWarps.z == 1 ? 0 : makeWarpUniform(warp, threadIdx.z)}; +} + +constexpr uint32_t cacheLineSize = 128; + +template +__device__ __host__ inline void assertIsPowerOf2() { + static_assert((x & (x - 1)) == 0); +} + +template +__device__ inline bool hasBankConflict(T* p) { + static_assert(sizeof(T) % 4 == 0 && sizeof(T) <= 16 && alignof(T) == sizeof(T)); + constexpr uint32_t grpSize = 128 / sizeof(T); + const uint32_t grpMask = static_cast(((1ULL << grpSize) - 1ULL) << (laneId() / grpSize * grpSize)); + uint32_t const x = reinterpret_cast(p) / sizeof(T) % grpSize; + auto const match = __match_any_sync(grpMask, x); + bool const conflict = __popc(match) > 1; + if (grpSize <= 8 && conflict) { + char str[grpSize * 2 + 1] = {}; + for (uint32_t i = 0; i < grpSize; i++) { + str[i * 2] = __shfl_sync(grpMask, x, i, grpSize) + '0'; + str[i * 2 + 1] = ' '; + } + if (laneId() % grpSize == 0) { + printf("bank conflict (%u): %s\n", match, str); + } + } + return conflict; +} + +__device__ inline float atomicMax(float* addr, float value) { + float old; + old = (value >= 0) ? __int_as_float(atomicMax(reinterpret_cast(addr), __float_as_int(value))) + : __uint_as_float(atomicMin(reinterpret_cast(addr), __float_as_uint(value))); + return old; +} + +__device__ inline bool isInInt32Range(uint32_t x) { + return x <= static_cast(mha::numeric_limits::max()); +} + +// struct of arrays instead of array of structs for compact storage +template +struct CompactRangeList { + mha::array pointerList; + mha::array sizeList; + + struct Range { + Pointer const& data; + uint32_t const& size; + }; + + __device__ inline Range operator[](uint32_t i) const { + return Range{pointerList[i], sizeList[i]}; + } +}; + +// alignedForSwizzle is for case when you need to mix TMA+LDS/LDSM, or LDGSTS/STS/STSM+GMMA +template +struct alignas(mha::min(maxArrayAlign(rows_* cols_), cacheLineSize)) Array2D { + using Elem = T; + static constexpr uint32_t rows = rows_; + static constexpr uint32_t cols = cols_; + static constexpr uint32_t size = rows * cols; + static constexpr uint32_t rowBytes = sizeof(T) * cols; + + template + __device__ inline T const& at(uint32_t r, uint32_t c) const { + assert(r < rows && c < cols); + // two different swizzle styles +#if 1 + uint32_t const c_swizzled = [&] { + if constexpr (swizzle) { + static_assert(rowBytes % cacheLineSize == 0 || cacheLineSize % rowBytes == 0); + static constexpr uint32_t rowsPerSliding = exactDiv(cacheLineSize, rowBytes % cacheLineSize == 0 ? cacheLineSize : rowBytes % cacheLineSize); + constexpr uint32_t swizzleRowsRepeat = exactDiv(cacheLineSize, sizeof(Elem)); + auto const runtimeBaseOffset = static_cast(__cvta_generic_to_shared(this->data)) / rowBytes % rows; + uint32_t const baseOffset = alignedForSwizzle + ? 0 + : runtimeBaseOffset; // To match TMA when array is not aligned to pattern boundary + uint32_t const xorMask = alignedForSwizzle + ? BoundedVal{r} + .template divBy() + .template mod() + .get() + : (r + baseOffset) / rowsPerSliding % exactDiv(swizzleRowsRepeat, rowsPerSliding); + return c ^ xorMask; + } + return c; + }(); +#else + uint32_t const c_swizzled = swizzle ? (c + r / rowsPerSliding) % cols : c; +#endif + T const& ret = (&data[0][0])[r * cols + c_swizzled]; + assert(&data[r][c_swizzled] == &ret); + return ret; + } + + template + __device__ inline T& at(uint32_t r, uint32_t c) { + return const_cast(static_cast(this)->at(r, c)); + } + + __device__ inline T const& operator()(uint32_t r, uint32_t c) const { + return at(r, c); + } + + __device__ inline T& operator()(uint32_t r, uint32_t c) { + return at(r, c); + } + + template + __device__ inline Array2D& as() { + return reinterpret_cast&>(*this); + } + + __device__ inline void fill(T val) { + XQA_UNROLL + for (uint32_t i = 0; i < rows * cols; i++) { + (&data[0][0])[i] = val; + } + } + + __device__ inline static Array2D filled(T val) { + Array2D ret; + ret.fill(val); + return ret; + } + + T data[rows][cols]; +}; + +#define DEFINE_ARRAY2D_BINARY_OP(op) \ + template \ + __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ + Array2D const& a, Array2D const& b) { \ + Array2D() op mha::declval()), rows, cols> result; \ + XQA_UNROLL for (uint32_t i = 0; i < rows; i++) { \ + for (uint32_t j = 0; j < cols; j++) { \ + result(i, j) = a(i, j) op b(i, j); \ + } \ + } \ + return result; \ + } \ + template \ + __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ + Array2D const& a, Scalar const& b) { \ + Array2D() op mha::declval()), rows, cols> result; \ + XQA_UNROLL for (uint32_t i = 0; i < rows; i++) { \ + for (uint32_t j = 0; j < cols; j++) { \ + result(i, j) = a(i, j) op b; \ + } \ + } \ + return result; \ + } \ + template \ + __device__ __host__ inline Array2D() op mha::declval()), rows, cols> operator op( \ + Scalar const& a, Array2D const& b) { \ + Array2D() op mha::declval()), rows, cols> result; \ + XQA_UNROLL for (uint32_t i = 0; i < rows; i++) { \ + for (uint32_t j = 0; j < cols; j++) { \ + result(i, j) = a op b(i, j); \ + } \ + } \ + return result; \ + } +// Don't use DEFINE_VEC_BINARY_FUNC(operator+), as operator+(float, float) is undefined, +// and float will be converted into half to perform the operation, which results in much +// lower precision. It's a defect of C++ that operator+(1.F, 2.F) does not work! +DEFINE_ARRAY2D_BINARY_OP(+) +DEFINE_ARRAY2D_BINARY_OP(-) +DEFINE_ARRAY2D_BINARY_OP(*) + +using LdGrain = Vec; +constexpr uint32_t grainBytes = sizeof(LdGrain); + +// wrapper for PTX ldmatrix +template +__device__ inline Vec ldmatrix(LdGrain const* row) { + assertWarpConverged(); + uint32_t a, b, c, d; + if constexpr (nbMat == 4) { + if (transpose) { + asm("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(a), "=r"(b), "=r"(c), "=r"(d) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } else { + asm("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(a), "=r"(b), "=r"(c), "=r"(d) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } +#if 0 + auto checkMat = [&](uint32_t val, uint32_t idxMat) -> Vec const& { + auto const v = (Vec const&)val; + uint32_t const lane = laneId(); + auto getRow = [&](uint32_t r) { + assert(r<8); + auto const ret = __shfl_sync(~0U, reinterpret_cast(row), 8*idxMat+r); + return *reinterpret_cast const*>(ret); + }; + auto checkEq = [](uint16_t x, uint16_t y) { + if (!(x==y)) { + printf("x=%u, y= %u\n", (unsigned)x, (unsigned)y); + } + }; + if (transpose) { + checkEq(v[0], getRow(lane % 4 * 2)[lane / 4]); + checkEq(v[1], getRow(lane % 4 * 2 + 1)[lane / 4]); + } + else { + checkEq(v[0], getRow(lane / 4)[lane % 4 * 2]); + checkEq(v[1], getRow(lane / 4)[lane % 4 * 2 + 1]); + } + }; + checkMat(a, 0); + checkMat(b, 1); + checkMat(c, 2); + checkMat(d, 3); +#endif + return Vec{a, b, c, d}; + } else if constexpr (nbMat == 2) { + if (transpose) { + asm("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n" + : "=r"(a), "=r"(b) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } else { + asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(a), "=r"(b) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } + return Vec{a, b}; + } else if constexpr (nbMat == 1) { + if (transpose) { + asm("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 %0, [%1];\n" + : "=r"(a) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } else { + asm("ldmatrix.sync.aligned.m8n8.x2.shared.b16 %0, [%1];\n" + : "=r"(a) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + } + return Vec{a}; + } else { + static_assert(nbMat == 1 || nbMat == 2 || nbMat == 4); + } +} + +template +__device__ inline Vec ldmatrix_4x(Warp const& warp, LdGrain const* row) { + return ldmatrix(row); +} + +template +__device__ inline Vec ldmatrix_16x16_trans(LdGrain const* row) { + uint32_t a, b, c, d; + if constexpr (nbMat == 1) { + asm("ldmatrix.sync.aligned.m16n16.x1.trans.shared::cta.b8 {%0, %1}, [%2];\n" + : "=r"(a), "=r"(b) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b}; + } else if constexpr (nbMat == 2) { + asm("ldmatrix.sync.aligned.m16n16.x2.trans.shared::cta.b8 {%0, %1, %2, %3}, [%4];\n" + : "=r"(a), "=r"(b), "=r"(c), "=r"(d) + : "l"(__cvta_generic_to_shared(row)) + : "memory"); + return Vec{a, b, c, d}; + } else { + static_assert(nbMat == 1 || nbMat == 2); + } +} + +template +__device__ inline void stmatrix(LdGrain* row, Vec const& data) { +#if __CUDA_ARCH__ >= 900 + assertWarpConverged(); + if constexpr (nbMat == 4) { + if constexpr (transpose) { + asm("stmatrix.sync.aligned.m8n8.x4.trans.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"l"( + __cvta_generic_to_shared(row)), + "r"(data[0]), "r"(data[1]), "r"(data[2]), "r"(data[3]) + : "memory"); + } else { + asm("stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" ::"l"( + __cvta_generic_to_shared(row)), + "r"(data[0]), "r"(data[1]), "r"(data[2]), "r"(data[3]) + : "memory"); + } + } else if constexpr (nbMat == 2) { + if constexpr (transpose) { + asm("stmatrix.sync.aligned.m8n8.x2.trans.shared.b16 [%0], {%1, %2};\n" ::"l"(__cvta_generic_to_shared(row)), + "r"(data[0]), "r"(data[1]) + : "memory"); + } else { + asm("stmatrix.sync.aligned.m8n8.x2.shared.b16 [%0], {%1, %2};\n" ::"l"(__cvta_generic_to_shared(row)), + "r"(data[0]), "r"(data[1]) + : "memory"); + } + } else if constexpr (nbMat == 1) { + if constexpr (transpose) { + asm("stmatrix.sync.aligned.m8n8.x1.trans.shared.b16 [%0], {%1};\n" ::"l"(__cvta_generic_to_shared(row)), + "r"(data[0]) + : "memory"); + } else { + asm("stmatrix.sync.aligned.m8n8.x1.shared.b16 [%0], {%1};\n" ::"l"(__cvta_generic_to_shared(row)), + "r"(data[0]) + : "memory"); + } + } else { + static_assert(nbMat == 1 || nbMat == 2 || nbMat == 4); + } +#else + trap(); +#endif +} + +template +__device__ inline void stmatrix_4x(Warp const& warp, LdGrain* row, Vec const& data) { + stmatrix(row, data); +} + +struct None { +}; + +template +using RealTypeOrNone = mha::conditional_t; + +template +struct MBarrierPair { + MBarrier produced; + MBarrier consumed; + + __device__ inline void initialize(uint32_t producedCount, uint32_t consumedCount) { + init(&produced, producedCount); + init(&consumed, consumedCount); + } +}; + +using CtaBarrierPair = MBarrierPair; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +template +__device__ inline auto arrive_tx(MBarrier& bar, uint32_t txCount, uint32_t arriveCount = 1) { +#if USE_CUSTOM_BARRIER + return bar.arrive_tx(txCount, arriveCount); +#else + return cuda::device::barrier_arrive_tx(bar, arriveCount, txCount); +#endif +} + +template +__device__ inline void arrive_tx_and_wait(MBarrier& bar, uint32_t txCount, uint32_t arriveCount = 1) { + bar.wait(arrive_tx(bar, txCount, arriveCount)); +} +#endif + +template +__device__ inline mha::tuple carryLE(uint32_t i0, uint32_t iLast) { + return mha::tuple{i0 % bound0, iLast + i0 / bound0}; +} + +template +__device__ inline mha::tuple carryLE( + uint32_t i0, uint32_t i1, decltype(bounds)... i, uint32_t iLast) { + return mha::tuple_cat(mha::tuple(i0 % bound0), carryLE(i1 + i0 / bound0, i..., iLast)); +} + +__device__ __host__ inline void assertClose([[maybe_unused]] float a, [[maybe_unused]] float b, [[maybe_unused]] float threshold = 0.01f) { + assert(abs(a - b) < threshold); +} + +__device__ __host__ inline void assertClose([[maybe_unused]] half a, [[maybe_unused]] half b, [[maybe_unused]] float threshold = 0.01f) { + assertClose(__half2float(a), __half2float(b), threshold); +} + +template +__device__ inline Vec convertKCacheWordToF16(uint32_t i8data) { + static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); + static_assert(sizeof(CacheElem) == 1); + Vec ret; +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + if constexpr (mha::is_same_v && mha::is_same_v) { + uint16_t (&src)[2] = reinterpret_cast(i8data); + uint32_t (&dst)[2] = reinterpret_cast(ret); + asm("{\n" + "cvt.rn.f16x2.e4m3x2 %0, %2;\n" + "cvt.rn.f16x2.e4m3x2 %1, %3;\n" + "}" + : "=r"(dst[0]), "=r"(dst[1]) + : "h"(src[0]), "h"(src[1])); + return ret; + } +#endif + CacheElem const(&src)[4] = reinterpret_cast(i8data); + InputElem(&dst)[4] = reinterpret_cast(ret); + XQA_UNROLL + for (uint32_t i = 0; i < 4; i++) { + dst[i] = InputElem(src[i]); + } + return ret; +} + +template +__device__ inline Vec convertVCacheWordToF16(uint32_t i8data) { + static_assert(mha::is_same_v || mha::is_same_v, "not implemented"); + static_assert(sizeof(CacheElem) == 1); + Vec ret; +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + if constexpr (mha::is_same_v && mha::is_same_v) { + uint32_t (&dst)[2] = reinterpret_cast(ret); + asm("{\n" + ".reg .b32 dst0;\n" + ".reg .b32 dst1;\n" + ".reg .b32 src;\n" + ".reg .b16 src0;\n" + ".reg .b16 src1;\n" + "prmt.b32 src, %2, 0x0, 0x3120;\n" + "mov.b32 {src0, src1}, src;\n" + "cvt.rn.f16x2.e4m3x2 %0, src0;\n" + "cvt.rn.f16x2.e4m3x2 %1, src1;\n" + "}" + : "=r"(dst[0]), "=r"(dst[1]) + : "r"(i8data)); + return ret; + } +#endif + CacheElem const(&src)[2][2] = reinterpret_cast(i8data); + InputElem(&dst)[2][2] = reinterpret_cast(ret); + XQA_UNROLL + for (uint32_t i = 0; i < 2; i++) { + XQA_UNROLL + for (uint32_t j = 0; j < 2; j++) { + dst[i][j] = InputElem(src[j][i]); + } + } + + return ret; +} + +struct PermuteOrder { + uint16_t x0 : 4; + uint16_t x1 : 4; + uint16_t x2 : 4; + uint16_t x3 : 4; +}; + +static_assert(sizeof(PermuteOrder) == 2); + +__device__ inline uint32_t prmt(uint32_t a, uint32_t b, PermuteOrder order) { + uint32_t d; + uint32_t const c = reinterpret_cast(order); + asm("prmt.b32 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +__device__ inline uint32_t movmatrix(uint32_t src) { + uint32_t dst; + asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" : "=r"(dst) : "r"(src)); + return dst; +} + +__device__ inline bool warpElectSync() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t pred = 0; + asm volatile( + "{\n" + " .reg .b32 d;\n" + " .reg .pred p;\n" + " elect.sync d|p, 0xFFFFFFFF;\n" + " selp.b32 %0, 1, 0, p;\n" + "}\n" + : "=r"(pred)); + return pred != 0; +#else + assert("not available"); + return false; +#endif +} + +__device__ inline void preExit() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("griddepcontrol.launch_dependents;\n"); +#endif +} + +__device__ inline void acqBulk() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("griddepcontrol.wait;\n"); +#endif +} + +__device__ inline uint3 nbClusters() { + uint3 id; + asm("mov.v4.u32 {%0, %1, %2, _}, %%nclusterid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); + return id; +} + +__device__ inline uint3 clusterId() { + uint3 id; + asm("mov.v4.u32 {%0, %1, %2, _}, %%clusterid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); + return id; +} + +__device__ inline uint32_t clusterCtaRank() { +#if __CUDA_ARCH__ >= 900 + uint32_t rank; + asm("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank)); + return rank; +#else + return 0; +#endif +} + +__device__ inline uint3 clusterCtaId() { + uint3 id; + asm("mov.v4.u32 {%0, %1, %2, _}, %%cluster_ctaid;\n" : "=r"(id.x), "=r"(id.y), "=r"(id.z)); + return id; +} + +// src and return are both generic address +template +__device__ inline T* mapa(T* src, uint32_t clusterCtaRank) { + uint64_t dst; + asm volatile("mapa.u64 %0, %1, %2;\n" : "=l"(dst) : "l"(reinterpret_cast(src)), "r"(clusterCtaRank)); + return reinterpret_cast(dst); +} + +template +__device__ inline T& mapa(T& src, uint32_t clusterCtaRank) { + return *mapa(&src, clusterCtaRank); +} + +__device__ inline void clusterBarArrive() { + asm volatile("barrier.cluster.arrive.release.aligned;\n"); +} + +__device__ inline void clusterBarWait() { + asm volatile("barrier.cluster.wait.acquire.aligned;\n"); +} + +__device__ inline uint32_t clock32() { + uint32_t ret; + asm volatile("mov.u32 %0, %%clock;\n" : "=r"(ret)::"memory"); + return ret; +} + +template +struct BarWaiter { + MBarrierPair (*bars)[nbBufs]; + uint32_t idx; + uint32_t idxBuf; + bool skipBarWait = false; + + __device__ inline BarWaiter(MBarrierPair (&bars)[nbBufs], uint32_t idx) + : bars{&bars}, idx{idx}, idxBuf{idx % nbBufs} { + } + + __device__ inline bool testWait() { + bool const parity = toParity(idx); + skipBarWait = bar().produced.test_wait_parity(parity); + return skipBarWait; + } + + __device__ inline BarWaiter next(uint32_t step = 1) { + return BarWaiter{*bars, idx + step}; + } + + __device__ inline void wait() { + if (!skipBarWait) { + bar().produced.wait_parity(toParity(idx)); + } + } + + __device__ inline MBarrierPair& bar() { + return (*bars)[idxBuf]; + } + + __device__ inline void consumed() { + bar().consumed.arrive(); + } +}; + +class Timer { + public: + __device__ inline Timer() { + reset(); + } + + __device__ inline void print(char const* name = "unnamed", bool reset = false) { + auto const toc = clock32(); + printf("%s: %u (block={%u, %u, %u})\n", name, toc - mTic, blockIdx.x, blockIdx.y, blockIdx.z); + if (reset) { + this->reset(); + } + } + + __device__ inline void reset() { + mTic = clock32(); + } + + private: + uint32_t mTic; +}; + +// [beg, end) +struct Range { + uint32_t beg, end; +}; + +constexpr bool overlap(Range a, Range b) { + return a.beg < b.end && b.beg < a.end; +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/utils.h b/onnxruntime/contrib_ops/cuda/bert/xqa/utils.h new file mode 100644 index 0000000000000..bd618f171654b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/utils.h @@ -0,0 +1,323 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifndef GENERATE_CUBIN +#include +#include +#include +#include +#include +#include +#endif +#include "mha_stdheaders.cuh" + +#ifdef __CUDA_ARCH__ +#define XQA_UNROLL _Pragma("unroll") +#else +#define XQA_UNROLL +#endif + +template +HOST_DEVICE_FUNC constexpr inline void unused(T&& x) { + static_cast(x); +} + +#ifndef GENERATE_CUBIN +inline void checkCuda(cudaError_t err) { + if (err != cudaSuccess) { + printf("%s\n", cudaGetErrorName(err)); + throw std::runtime_error(cudaGetErrorName(err)); + } +} + +inline void checkCu(CUresult err) { + if (err != CUDA_SUCCESS) { + char const* str = nullptr; + if (cuGetErrorName(err, &str) != CUDA_SUCCESS) { + str = "A cuda driver API error happened, but we failed to query the error name\n"; + } + printf("%s\n", str); + throw std::runtime_error(str); + } +} +#endif + +HOST_DEVICE_FUNC constexpr inline uint32_t greatestPowerOf2Divisor(uint32_t x) { + return x & ~(x - 1); +} + +template +HOST_DEVICE_FUNC constexpr uint32_t maxArrayAlign(uint32_t size) { + return sizeof(T) * greatestPowerOf2Divisor(size); +} + +HOST_DEVICE_FUNC constexpr inline uint32_t exactDiv(uint32_t a, uint32_t b) { + assert(a % b == 0); + return a / b; +} + +template +HOST_DEVICE_FUNC constexpr inline T divUp(T a, T b) { + return (a + b - 1) / b; +} + +template +HOST_DEVICE_FUNC constexpr inline T roundUp(T a, T b) { + return divUp(a, b) * b; +} + +// upperBound is exclusive, i.e. range is [0, upperBound) +template +struct BoundedVal { + template + HOST_DEVICE_FUNC inline BoundedVal divBy() const { + assert(value < upperBound); + return {upperBound <= divisor ? 0 : value / divisor}; + } + + template + HOST_DEVICE_FUNC inline BoundedVal mod() const { + assert(value < upperBound); + return {upperBound <= divisor ? value : value % divisor}; + } + + HOST_DEVICE_FUNC inline bool operator<=(uint32_t rhs) const { + assert(value < upperBound); + return upperBound <= rhs || value <= rhs; + } + + HOST_DEVICE_FUNC inline uint32_t get() const { + assert(value < upperBound); + return upperBound == 1 ? 0 : value; + } + + uint32_t value; +}; + +template +struct alignas(mha::max(alignof(T), mha::min(maxArrayAlign(size_), 16))) Vec { + using Elem = T; + static constexpr uint32_t size = size_; + Elem data[size]; + + HOST_DEVICE_FUNC inline void fill(T const& val) { + XQA_UNROLL + for (uint32_t i = 0; i < size; i++) { + data[i] = val; + } + } + + static HOST_DEVICE_FUNC inline Vec filled(T const& val) { + Vec ret; + ret.fill(val); + return ret; + } + + HOST_DEVICE_FUNC inline Elem const& operator[](uint32_t i) const { + assert(i < size); + return data[BoundedVal{i}.get()]; + } + + HOST_DEVICE_FUNC inline Elem& operator[](uint32_t i) { + assert(i < size); + return data[BoundedVal{i}.get()]; + } +}; + +template +struct CircIdx { + public: + static constexpr uint32_t nbBuffers = nbBuffers_; + static_assert(nbBuffers >= 1); + + __device__ inline CircIdx(uint32_t init) + : mIndex{init % nbBuffers} { + } + + __device__ inline operator uint32_t() const { + return mIndex; + } + + __device__ inline CircIdx operator+(uint32_t i) const { + return CircIdx{(mIndex + i) % nbBuffers}; + } + + __device__ inline CircIdx operator-(uint32_t i) const { + return CircIdx{(mIndex + (nbBuffers - 1) * i) % nbBuffers}; + } + + __device__ inline CircIdx next() const { + return *this + 1u; + } + + __device__ inline CircIdx& operator++() { + mIndex = next(); + return *this; + } + + __device__ inline CircIdx operator++(int) { + CircIdx old = *this; + operator++(); + return old; + } + + __device__ inline CircIdx prev() const { + return *this - 1u; + } + + __device__ inline CircIdx& operator--() { + mIndex = prev(); + return *this; + } + + __device__ inline CircIdx operator--(int) { + CircIdx old = *this; + operator--(); + return old; + } + + private: + uint32_t mIndex; +}; + +// base is usually in constant memory, so usually only require 1 register to store the offset. +template +struct TinyPtr { + T* base; // typically in constant memory or uniform registers + uint32_t offset; // may be non-uniform + + template + __device__ __host__ inline TinyPtr cast() const { + D* const p = reinterpret_cast(base); + assert(reinterpret_cast(p) % alignof(D) == 0); + if constexpr (mha::is_void_v) { + assert(offset == 0); + return TinyPtr{p, 0}; + } else if constexpr (sizeof(T) < sizeof(D)) { + return TinyPtr{p, exactDiv(offset, exactDiv(sizeof(D), sizeof(T)))}; + } else { + return TinyPtr{p, offset * exactDiv(sizeof(T), sizeof(D))}; + } + } + + __device__ __host__ inline T& operator*() const { + return base[offset]; + } + + __device__ __host__ inline TinyPtr operator+(uint32_t i) const { + return TinyPtr{base, offset + i}; + } + + __device__ __host__ inline T& operator[](uint32_t i) const { + return *(*this + i); + } + + __device__ __host__ inline operator T*() const { + return base + offset; + } +}; + +template +class Segmenter { + public: + HOST_DEVICE_FUNC Segmenter(uint32_t offset = 0) + : mNextOffset{offset} { + } + + // offset is in bytes + template + HOST_DEVICE_FUNC OffsetInt newSeg(uint32_t count = 1, uint32_t alignment = alignof(T)) { + mMaxAlignment = mha::max(mMaxAlignment, alignment); + OffsetInt const offset = roundUp(mNextOffset, alignment); + mNextOffset = offset + sizeof(T) * count; + return offset; + } + + HOST_DEVICE_FUNC OffsetInt getEndOffset() const { + return mNextOffset; + } + + HOST_DEVICE_FUNC uint32_t getMaxAlignment() const { + return mMaxAlignment; + } + + private: + OffsetInt mNextOffset; + uint32_t mMaxAlignment = 1; +}; + +template +using AddConst = mha::conditional_t; + +template +class MemSegmenter { + public: + HOST_DEVICE_FUNC MemSegmenter(AddConst* base, uint32_t offset = 0) + : mBase{static_cast*>(base)}, mSegmenter{offset} { + } + + // to use TinyPtr, alignment must be sizeof(T) + template + HOST_DEVICE_FUNC TinyPtr> newSeg(uint32_t count = 1, uint32_t alignment = sizeof(T)) { + assert(reinterpret_cast(mBase) % alignof(T) == 0); + OffsetInt const offset = mSegmenter.template newSeg(count, alignment); + return TinyPtr>{mBase, offset}.template cast>(); + } + + HOST_DEVICE_FUNC OffsetInt getEndOffset() const { + return mSegmenter.getEndOffset(); + } + + HOST_DEVICE_FUNC uint32_t getMaxAlignment() const { + return mSegmenter.getMaxAlignment(); + } + + private: + AddConst* mBase; + Segmenter mSegmenter; +}; + +// dims in little endian +template +struct DimsLE { + static constexpr uint32_t nbDims = nbDims_; + + __device__ __host__ inline uint32_t& operator[](uint32_t i) { + return d[i]; + } + + __device__ __host__ inline uint32_t const& operator[](uint32_t i) const { + return d[i]; + } + + uint32_t d[nbDims]; +}; + +// check if val is in range [lb, ub) +template +constexpr bool inRange(T val, T lb, T ub) { + return val >= lb && val < ub; +} + +// val is an optimized / pre-computed value, ref is the original value +template +HOST_DEVICE_FUNC constexpr inline T checkedVal(T val, T ref) { + assert(val == ref); + return val; +} diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh new file mode 100644 index 0000000000000..744739980759d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_impl_gen.cuh @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Template for XQA Kernel Implementation +// Expected macros: +// NAMESPACE_NAME: Name of the namespace (e.g., grp8) +// GRP_SIZE: Integer value for HEAD_GRP_SIZE + +namespace NAMESPACE_NAME { +// Undefine dependent guard to allow header re-processing +#undef MHA_H_DEPENDENT + +// Define macro for mha_impl.cuh (which includes mha.h) +// We assume mha.h's dependent part relies on this macro +#define HEAD_GRP_SIZE GRP_SIZE + +// Include implementation (re-compiles kernel for this group size) +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +#include "mha_impl.cuh" +#endif + +#undef HEAD_GRP_SIZE + +template +inline Status Launch( + [[maybe_unused]] const cudaDeviceProp& device_prop, + [[maybe_unused]] cudaStream_t stream, + [[maybe_unused]] const void* query, + [[maybe_unused]] const void* key_cache, + [[maybe_unused]] const void* value_cache, + [[maybe_unused]] void* output, + [[maybe_unused]] const int batch_size, + [[maybe_unused]] const int num_heads, + [[maybe_unused]] const int kv_num_heads, + [[maybe_unused]] const int head_size, + [[maybe_unused]] const int max_seq_len, + [[maybe_unused]] const float scale, + [[maybe_unused]] const bool is_bsnh, + [[maybe_unused]] const int* past_seq_lens, + [[maybe_unused]] const float* kv_cache_scale, + [[maybe_unused]] void* workspace, + [[maybe_unused]] size_t workspace_size) { +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + const InputHead* q_ptr = reinterpret_cast(query); + GMemKVCacheHead* k_ptr = reinterpret_cast(const_cast(key_cache)); + GMemKVCacheHead* v_ptr = reinterpret_cast(const_cast(value_cache)); + OutputHead* out_ptr = reinterpret_cast(output); + + uint32_t* semaphores = nullptr; + void* scratch = nullptr; + + if (workspace != nullptr) { + uint32_t nbSeq = static_cast(batch_size * kv_num_heads); + size_t semaphore_size = nbSeq * sizeof(uint32_t); + size_t padded_sem_size = roundUp(semaphore_size, 128); + + uint32_t nbSubSeqPerSeq = computeNbSubSeqPerSeqMHA( + device_prop, + static_cast(batch_size), + static_cast(kv_num_heads), + static_cast(max_seq_len)); + size_t required_scratch_size = NAMESPACE_NAME::GetScratchSize(nbSeq, nbSubSeqPerSeq); + size_t total_required = padded_sem_size + required_scratch_size; + + if (workspace_size < total_required) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA workspace size is too small. Expected at least ", total_required, ", but got ", workspace_size); + } + semaphores = reinterpret_cast(workspace); + scratch = reinterpret_cast(workspace) + padded_sem_size; + + // Initialize semaphores to 0 + cudaMemsetAsync(semaphores, 0, semaphore_size, stream); + } + + launchMHA( + device_prop, + static_cast(kv_num_heads), + scale, + out_ptr, + q_ptr, + nullptr, // attentionSinks + k_ptr, + v_ptr, + is_bsnh, + static_cast(max_seq_len), + reinterpret_cast(past_seq_lens), + static_cast(batch_size), + kv_cache_scale, // Pass kv_cache_scale for INT8 dequantization + semaphores, // semaphores + scratch, // scratch + stream); + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA is only supported on Ampere (SM80) or newer GPUs."); +#endif +} +} // namespace NAMESPACE_NAME diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h new file mode 100644 index 0000000000000..8439c19687097 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Quantization type for XQA +enum class XqaQuantType { + kNone = 0, // no quantization, use FP16/BF16 + kInt8 = 1, + kFp8 = 2 +}; + +// Wrapper for XQA MHA launch +// Only supports decoding (S=1) for now. +template +Status LaunchXQAKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, // [B, NumHeads, HeadSize] + const void* key_cache, // [B, MaxSeqLen, NumKVHeads, HeadSize] (or BNSH, but XQA usually expects contiguous or paged) + const void* value_cache, // [B, MaxSeqLen, NumKVHeads, HeadSize] + void* output, // [B, NumHeads, HeadSize] + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, // Max sequence length of cache + const float scale, + const bool is_bsnh, // Layout of KV cache + const int* past_seq_lens, // Past sequence lengths [BatchSize] + const float* kv_cache_scale, // KV cache dequant scale (nullptr for FP16/BF16, per-tensor float for INT8) + const XqaQuantType kv_quant_type, + void* workspace = nullptr, // Scratch memory + size_t workspace_size = 0 // Size of scratch memory +); + +size_t GetXQAScratchSize( + const cudaDeviceProp& device_prop, + int batch_size, + int num_heads, + int kv_num_heads, + int head_size, + int max_seq_len, + XqaQuantType kv_quant_type, + bool is_bf16 = false); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu new file mode 100644 index 0000000000000..4c6731b10fe77 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16.cu @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "xqa_loader.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Forward declarations of instantiated kernels from H64, H128, and H256 namespaces +namespace H64 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); +} // namespace H64 + +namespace H128 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); +} // namespace H128 + +namespace H256 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); +} // namespace H256 + +// Forward declaration for INT8 BF16 dispatcher +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +// ============================================================================ +// Specialization for BFloat16 +// ============================================================================ + +template <> +Status LaunchXQAKernel<__nv_bfloat16>( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size) { + // Dispatch to INT8 path if requested + if (kv_quant_type == XqaQuantType::kInt8) { + return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } + + // Dispatch based on head_size + if (head_size == 256) { + return H256::LaunchXQAKernelImpl<__nv_bfloat16>( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else if (head_size == 128) { + return H128::LaunchXQAKernelImpl<__nv_bfloat16>( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else if (head_size == 64) { + return H64::LaunchXQAKernelImpl<__nv_bfloat16>( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu new file mode 100644 index 0000000000000..7572986d14632 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_128.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#include "xqa_loader_bf16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation for BFloat16 +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu new file mode 100644 index 0000000000000..2706a9de32b14 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_256.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#include "xqa_loader_bf16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation for BFloat16 +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu new file mode 100644 index 0000000000000..7bd8897fdfd93 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_64.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#include "xqa_loader_bf16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation for BFloat16 +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl<__nv_bfloat16>( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh new file mode 100644 index 0000000000000..644dec2c67bbd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_impl.cuh @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_bf16_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_bf16_impl.cuh" +#endif + +// Define global constants +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 0 // Set to 0 for BFloat16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// BF16 KV Cache Instantiations +// ============================================================================ + +#define NAMESPACE_NAME grp1_bf16 +#define GRP_SIZE 1 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp2_bf16 +#define GRP_SIZE 2 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp4_bf16 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_bf16 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_bf16 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_bf16 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +// Extern declarations for INT8 kernels with BF16 query (implemented in xqa_loader_bf16_int8.cu) +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +// ============================================================================ +// Specialization for BFloat16 +// ============================================================================ + +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +template <> +Status LaunchXQAKernelImpl<__nv_bfloat16>( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size) { + // Head size check in global dispatcher + + // Dispatch to INT8 path if requested + if (kv_quant_type == XqaQuantType::kInt8) { + return LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, + batch_size, num_heads, kv_num_heads, head_size, max_seq_len, + scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, + workspace_size); + } + + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 1: + return grp1_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 2: + return grp2_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 4: + return grp4_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_bf16::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8.cu new file mode 100644 index 0000000000000..cbca83f27cb87 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8.cu @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "xqa_loader.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Forward declarations of LaunchXQAIn8KernelBF16 from H64, H128, H256 namespaces +namespace H64 { +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +} // namespace H64 + +namespace H128 { +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +} // namespace H128 + +namespace H256 { +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); +} // namespace H256 + +// Dispatcher for INT8 BF16 query kernel based on head_size +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + if (head_size == 256) { + return H256::LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else if (head_size == 128) { + return H128::LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else if (head_size == 64) { + return H64::LaunchXQAInt8KernelBF16(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 BF16 only supports head_size=64, 128, or 256. Input has ", head_size); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_128.cu new file mode 100644 index 0000000000000..afecf2136c95f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_128.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#include "xqa_loader_bf16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_256.cu new file mode 100644 index 0000000000000..d7af7744dbf42 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_256.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#include "xqa_loader_bf16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_64.cu new file mode 100644 index 0000000000000..120e0339b3a0d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_64.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#include "xqa_loader_bf16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh new file mode 100644 index 0000000000000..acec9aeed9973 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_bf16_int8_impl.cuh @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_bf16_int8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_bf16_int8_impl.cuh" +#endif + +// Define global constants for INT8 KV Cache with BF16 Query +#define CACHE_ELEM_ENUM 1 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 0 // Q is BF16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// INT8 KV Cache Instantiations for BF16 Query +// ============================================================================ + +#define NAMESPACE_NAME grp4_bf16_int8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_bf16_int8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_bf16_int8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_bf16_int8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAInt8KernelBF16( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_bf16_int8::Launch<__nv_bfloat16>(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu new file mode 100644 index 0000000000000..37b974a8a3e60 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16.cu @@ -0,0 +1,196 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "xqa_loader.h" +#include "utils.h" +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Forward declarations of instantiated kernels from H128 and H64 namespaces +namespace H128 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace H128 + +namespace H64 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace H64 + +namespace H256 { +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace H256 + +// Dispatcher Implementation + +template +Status LaunchXQAKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size) { + if (device_prop.major < 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA is only supported on Ampere (SM80) or newer GPUs."); + } + + if (head_size == 256) { + return H256::LaunchXQAKernelImpl( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else if (head_size == 128) { + return H128::LaunchXQAKernelImpl( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else if (head_size == 64) { + return H64::LaunchXQAKernelImpl( + device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, + max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, kv_quant_type, workspace, workspace_size); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA only supports head_size=64, 128, or 256. Input has ", head_size); + } +} + +size_t GetXQAScratchSize( + const cudaDeviceProp& device_prop, + int batch_size, + int num_heads, + int kv_num_heads, + int head_size, + int max_seq_len, + [[maybe_unused]] XqaQuantType kv_quant_type, + [[maybe_unused]] bool is_bf16) { + if (device_prop.major < 8) { + return 0; + } + + uint32_t nbSeq = static_cast(batch_size * kv_num_heads); + // nbSubSeqPerSeq calculation matches computeNbSubSeqPerSeqMHA in mha_impl.cuh + // ctaTile.x is 256 for all current configurations + uint32_t nbSubSeqPerSeq = std::min( + std::max(1U, static_cast(device_prop.multiProcessorCount) / nbSeq), + (static_cast(max_seq_len) + 255) / 256); + uint32_t nbSubSeq = nbSeq * nbSubSeqPerSeq; + + int group_size = num_heads / kv_num_heads; + // M_TILESIZE: 8 for group_size <= 8, 16 for group_size <= 16, 32 for group_size <= 32 + int m_tilesize = (group_size <= 8) ? 8 : (group_size <= 16 ? 16 : 32); + + // sizeof(SMemWarpRowMax) is 128 (4 * 8 * 4) for all group sizes <= 32 + // sizeof(VecT) is head_size * m_tilesize * 2 (2 bytes per element for fp16/bf16 intermediate results) + size_t vec_size = static_cast(head_size) * m_tilesize * 2; + + size_t scratch_size = 0; + // 1. rowMax + scratch_size = roundUp(scratch_size, 128); + scratch_size += 128 * nbSubSeq; + // 2. rowSum + scratch_size = roundUp(scratch_size, 128); + scratch_size += 128 * nbSubSeq; + // 3. scratchBuffers + scratch_size = roundUp(scratch_size, vec_size); + scratch_size += vec_size * nbSubSeq; + + size_t semaphore_size = nbSeq * sizeof(uint32_t); + return roundUp(semaphore_size, 128) + scratch_size; +} + +// Instantiate template for half +template Status LaunchXQAKernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu new file mode 100644 index 0000000000000..87304cfd1adc2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_128.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#include "xqa_loader_fp16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu new file mode 100644 index 0000000000000..3d070a87f87a8 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_256.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#include "xqa_loader_fp16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu new file mode 100644 index 0000000000000..1664122dbc6d3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_64.cu @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#include "xqa_loader_fp16_impl.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Explicit instantiation +template Status HEAD_DIM_NAMESPACE::LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh new file mode 100644 index 0000000000000..8ba0fe3b1ee0d --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_impl.cuh @@ -0,0 +1,179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_fp16_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_fp16_impl.cuh" +#endif + +// Define global constants based on macros +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 1 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// FP16 KV Cache Instantiations +// ============================================================================ + +#define NAMESPACE_NAME grp1_fp16 +#define GRP_SIZE 1 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp2_fp16 +#define GRP_SIZE 2 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp4_fp16 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_fp16 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_fp16 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_fp16 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +// Extern declarations for INT8 kernels (implemented in xqa_loader_fp16_int8_impl.cuh via instantiation) +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +// ============================================================================ +// Dispatcher Implementation +// ============================================================================ + +template +Status LaunchXQAKernelImpl( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + const XqaQuantType kv_quant_type, + void* workspace, + size_t workspace_size) { + // Head size check is done in global dispatcher + + // Dispatch to INT8 path if requested + if (kv_quant_type == XqaQuantType::kInt8) { + if constexpr (std::is_same::value) { + return LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else { + // BF16 case is handled in xqa_loader_bf16.cu via specialization + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 path mismatch."); + } + } + + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 1: + return grp1_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 2: + return grp2_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 4: + return grp4_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_fp16::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA supports group_size 1, 2, 4, 8, 16, 32. Input has ", group_size); + } +} + +// Instantiate template for half + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8.cu new file mode 100644 index 0000000000000..4855571c32a57 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8.cu @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "xqa_loader.h" +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Forward declarations of LaunchXQAInt8Kernel from H64, H128, H256 namespaces +namespace H64 { +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +} // namespace H64 + +namespace H128 { +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +} // namespace H128 + +namespace H256 { +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size); + +} // namespace H256 + +// Dispatcher for INT8 FP16 query kernel based on head_size +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + if (head_size == 256) { + return H256::LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else if (head_size == 128) { + return H128::LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else if (head_size == 64) { + return H64::LaunchXQAInt8Kernel(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports head_size=64, 128, or 256. Input has ", head_size); + } +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_128.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_128.cu new file mode 100644 index 0000000000000..eaca0a3bb2060 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_128.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 128 +#define HEAD_DIM_NAMESPACE H128 + +#include "xqa_loader_fp16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_256.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_256.cu new file mode 100644 index 0000000000000..5fdeb61f2e58a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_256.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 256 +#define HEAD_DIM_NAMESPACE H256 + +#include "xqa_loader_fp16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_64.cu b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_64.cu new file mode 100644 index 0000000000000..b648f917e0675 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_64.cu @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define HEAD_ELEMS 64 +#define HEAD_DIM_NAMESPACE H64 + +#include "xqa_loader_fp16_int8_impl.cuh" diff --git a/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh new file mode 100644 index 0000000000000..f3a1fcd8a8e63 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/xqa/xqa_loader_fp16_int8_impl.cuh @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "xqa_loader.h" +#include + +// HEAD_ELEMS must be defined by the including file +#ifndef HEAD_ELEMS +#error "HEAD_ELEMS must be defined before including xqa_loader_fp16_int8_impl.cuh" +#endif + +// HEAD_DIM_NAMESPACE must be defined by the including file +#ifndef HEAD_DIM_NAMESPACE +#error "HEAD_DIM_NAMESPACE must be defined before including xqa_loader_fp16_int8_impl.cuh" +#endif + +// Define global constants for INT8 KV Cache +#define CACHE_ELEM_ENUM 1 +#define USE_PAGED_KV_CACHE 0 +#define TOKENS_PER_PAGE 0 +#define INPUT_FP16 1 // Q is FP16 +#define ALLOW_MULTI_BLOCK_MODE 1 + +#pragma nv_diag_suppress 177 +#pragma nv_diag_suppress 20012 + +// Include common headers once +#include "cuda_hint.cuh" +#include "mha.h" +// Include all helpers globally to ensure visibility +#include "ldgsts.cuh" +#include "mhaUtils.cuh" +#include "mha_components.cuh" +#include "mma.cuh" +#include "utils.cuh" +#include "hostUtils.h" + +// Undefine HEAD_GRP_SIZE and M_TILESIZE to allow re-definition in impl gen +#undef HEAD_GRP_SIZE +#undef M_TILESIZE + +namespace onnxruntime { +namespace contrib { +namespace cuda { +namespace HEAD_DIM_NAMESPACE { + +// ============================================================================ +// INT8 KV Cache Instantiations for FP16 Query +// ============================================================================ +#define NAMESPACE_NAME grp4_int8 +#define GRP_SIZE 4 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp8_int8 +#define GRP_SIZE 8 +#define M_TILESIZE 8 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp16_int8 +#define GRP_SIZE 16 +#define M_TILESIZE 16 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +#define NAMESPACE_NAME grp32_int8 +#define GRP_SIZE 32 +#define M_TILESIZE 32 +#include "xqa_impl_gen.cuh" +#undef NAMESPACE_NAME +#undef GRP_SIZE +#undef M_TILESIZE + +Status LaunchXQAInt8Kernel( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + const void* query, + const void* key_cache, + const void* value_cache, + void* output, + const int batch_size, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int max_seq_len, + const float scale, + const bool is_bsnh, + const int* past_seq_lens, + const float* kv_cache_scale, + void* workspace, + size_t workspace_size) { + int group_size = num_heads / kv_num_heads; + switch (group_size) { + case 4: + return grp4_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 8: + return grp8_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 16: + return grp16_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + case 32: + return grp32_int8::Launch(device_prop, stream, query, key_cache, value_cache, output, batch_size, num_heads, kv_num_heads, head_size, max_seq_len, scale, is_bsnh, past_seq_lens, kv_cache_scale, workspace, workspace_size); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "XQA INT8 only supports group_size 4, 8, 16, 32. Input has ", group_size); + } +} + +} // namespace HEAD_DIM_NAMESPACE +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 0864e30831092..ab692e0549d6c 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -107,8 +107,14 @@ class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_BFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_BFloat16, MultiHeadAttention); -class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, GroupQueryAttention); -class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_MLFloat16, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_BFloat16, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_int8_t, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_int8_t, GroupQueryAttention); +#ifdef USE_INT4_KV_CACHE +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_uint8_t, GroupQueryAttention); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16_uint8_t, GroupQueryAttention); +#endif class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, PagedAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, DecoderAttention); @@ -348,8 +354,14 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#ifdef USE_INT4_KV_CACHE + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 4fc0e7826b49c..fb237d8cb9e9a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -103,7 +103,6 @@ struct WebgpuAttentionParameters { int num_splits_ = 0; // number of splits for splitkv int rotary_dim_ = 0; // rotary embedding dimension int local_window_size_ = 0; - bool kv_share_buffer_ = false; bool is_packed_qkv_ = false; bool is_subsequent_prompt_ = false; // indicates whether we have past context and seqlen > 1 bool is_first_prompt_ = false; // indicates whether this is first decoding step diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 2133bfed760e4..e6b0189b6ee53 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -207,7 +207,9 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& seqlen_k, total_seqlen_tensor, scale_, - softcap_)); + softcap_, + 0, + context.DeviceLimits().maxComputeInvocationsPerWorkgroup)); params.use_smooth_softmax = use_smooth_softmax_; params.rotary_interleaved = rotary_interleaved_; diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index deaa77a204a7f..6b314ed8714a5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -235,7 +235,22 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte int past_key_index = -1, int use_max_past_present_buffer = -1, int output_qk_index = -1) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Type inference for outputs + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); // output + + if (ctx.getNumOutputs() >= 3) { // has present output + const auto* past_key_type = ctx.getInputType(past_key_index); + if (past_key_type != nullptr) { + // present_key and present_value have the same type as past_key/past_value. + // This allows them to be int8 or packed uint8 when quantization is enabled. + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); // present_key + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index + 1, 2); // present_value + } else { + // If no past state, present is the same type as query. + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); + } + } int64_t kv_sequence_length = -1; if (hasInputShape(ctx, 0)) { @@ -280,12 +295,6 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } if (ctx.getNumOutputs() >= 3) { // has present output - // copy the type from query to present key - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 1); - - // copy the type from query to present value - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 2); - int64_t total_sequence_length_value = 0; const auto* total_sequence_length_data = ctx.getInputData(6); if (total_sequence_length_data != nullptr) { @@ -342,7 +351,15 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte } if (output_qk_index >= 0) { - const bool did_supply_qk_buffer = ctx.hasOutput(output_qk_index); + // An output is considered "supplied" only if it's present AND has a meaningful type definition. + // An empty string placeholder for an optional output will not have a tensor type proto. + bool did_supply_qk_buffer = false; + if (ctx.hasOutput(output_qk_index)) { + // The output is considered "supplied" if it is present in the node. + // Note: TypeProto might not be fully populated yet during initial inference. + did_supply_qk_buffer = true; + } + const int64_t qk_output_type = getAttribute(ctx, "qk_output", static_cast(QKOutputType::NO_OUTPUT)); if (qk_output_type == static_cast(QKOutputType::NO_OUTPUT) && did_supply_qk_buffer) { @@ -1121,15 +1138,26 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); constexpr const char* GroupQueryAttention_ver1_doc = R"DOC( -Group Query Self/Cross Attention. +Group Query Self/Cross Attention with KV Cache Quantization Support. + +This operator implements causal grouped-query attention with past state (KV cache) support. +It also supports optional float8, int8 or int4 quantization for the KV cache to reduce memory footprint. + +**Cache Format:** +The past and present KV cache tensors are expected in a BNSH format: `(batch_size, num_heads, cache_sequence_length, head_size)`, where `cache_sequence_length` is the length of the cached key/value sequences, or the maximum sequence length when past and present buffer sharing is used. + +**Quantization:** +When quantization is enabled, `past_key` and `past_value` inputs can be of type `float8e4m3fn`, `uint8` or `int8`. The corresponding `k_scale` and `v_scale` tensors must be provided. +The operator will output `present_key` and `present_value` in same format as the `past_key` and `past_value`. -*Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv. -Supports different number of heads for q and kv for CPU and CUDA. -Only supports causal and local attention. -Supports rotary position embedding for CPU and CUDA. -Supports packed input for CPU and CUDA. -Supports continuous decoding for batch_size == 1 for CPU and CUDA. +For 4-bit quantization, the data type is uint8 where each byte contains two 4-bit values. The bit width of quantized KV cache can be set using `kv_cache_bit_width` attribute. +The shapes of the k_scale, v_scale tensors shall be broadcastable to present_key shape. + +**Quantization Modes (`k_quant_type`, `v_quant_type` attributes):** +- **"NONE"**: No quantization. +- **"PER_TENSOR"**: A single scale for the entire tensor. Scale example shape: `[1]`. +- **"PER_CHANNEL"**: A scale for each channel. Scale example shape: `[1, num_heads_k, 1, head_size]`. )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1166,6 +1194,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Output values of QK matrix multiplication before (1) or after (2) softmax normalization. Default value is 0 (don't output).", AttributeProto::INT, static_cast(QKOutputType::NO_OUTPUT)) + .Attr("k_quant_type", "Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.", AttributeProto::STRING, std::string("NONE")) + .Attr("v_quant_type", "Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.", AttributeProto::STRING, std::string("NONE")) + .Attr("kv_cache_bit_width", "Bit width of quantized KV cache. Supported values are 8 and 4.", AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" @@ -1185,13 +1216,13 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", - "T", + "T_CACHE", OpSchema::Optional) .Input(4, "past_value", "past state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.", - "T", + "T_CACHE", OpSchema::Optional) .Input(5, "seqlens_k", @@ -1228,6 +1259,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.", "T", OpSchema::Optional) + .Input(12, "k_scale", "Scale tensor for past_key.", "T_KV_SCALE", OpSchema::Optional) + .Input(13, "v_scale", "Scale tensor for past_value.", "T_KV_SCALE", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", @@ -1237,22 +1270,28 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "present state key with support for format BNSH. When past_key uses same tensor as present_key" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T") + "T_CACHE") .Output(2, "present_value", "present state value with support for format BNSH. When past_value uses same tensor as present_value" "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", - "T") + "T_CACHE") .Output(3, "output_qk", "Values of QK matrix multiplication, either before or after softmax normalization", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T_CACHE", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)", "tensor(uint8)", "tensor(int8)", "tensor(float8e4m3fn)"}, "Constrain KV cache types.") + .TypeConstraint("T_KV_SCALE", {"tensor(float)"}, "Constrain KV cache scale types.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { - GroupQueryAttentionTypeAndShapeInference(ctx, 3, 3); + // The 'output_qk' is an optional output at index 3. + // Pass its index to the shape inference logic only if the node instance actually has more than 3 outputs. + // Otherwise, pass -1 to signal that the optional output is not present and validation should be skipped. + int qk_output_index = ctx.getNumOutputs() > 3 ? 3 : -1; + GroupQueryAttentionTypeAndShapeInference(ctx, 3, qk_output_index); })); constexpr const char* PagedAttention_ver1_doc = R"DOC( diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 8126c61977db1..c44d6b606d3a2 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -7,11 +7,24 @@ Benchmark performance of GroupQueryAttention. """ +from dataclasses import dataclass + import torch -from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention +from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention + +try: + import triton +except ImportError: + triton = None + +@dataclass +class TestConfig: + test_int4: bool = False + test_int8: bool = False -def get_plot_algos(sm: int, local_window_size: int | None): + +def get_plot_algos(sm: int, local_window_size: int | None, config: TestConfig | None): # GQA with local windows only works in sm=8x if sm >= 80 and local_window_size: line_vals = ["ort_gqa", "ort_gqa_local", "ort_gqa_packed", "ort_gqa_local_packed"] @@ -22,6 +35,20 @@ def get_plot_algos(sm: int, local_window_size: int | None): line_names = ["ORT-GQA-Dense", "ORT-GQA-Dense-PackedQKV"] styles = [("red", "solid"), ("blue", "dashed")] + # Add quantized variants if requested + if sm >= 80 and config: + quant_vals = ["ort_gqa_int4", "ort_gqa_int8"] + quant_names = ["ORT-GQA-INT4", "ORT-GQA-INT8"] + quant_styles = [("purple", "dotted"), ("orange", "dashdot")] + if config.test_int4: + line_vals.extend(quant_vals[:1]) + line_names.extend(quant_names[:1]) + styles.extend(quant_styles[:1]) + if config.test_int8: + line_vals.extend(quant_vals[1:]) + line_names.extend(quant_names[1:]) + styles.extend(quant_styles[1:]) + return { "line_vals": line_vals, "line_names": line_names, @@ -39,15 +66,14 @@ def plot_prompt_performance( max_seq_len: int, local_window_size: int | None = None, use_smooth_softmax: bool = False, + config: TestConfig | None = None, dtype: str = "float16", ): - import triton # noqa: PLC0415 - - algos = get_plot_algos(sm, local_window_size) + algos = get_plot_algos(sm, local_window_size, config) configs = [ triton.testing.Benchmark( x_names=["sequence_length"], - x_vals=[2**i for i in range(4, 17) if 2**i <= max_seq_len], + x_vals=[2**i for i in range(6, 17) if 2**i <= max_seq_len], line_arg="provider", ylabel="ms", **algos, @@ -80,6 +106,17 @@ def benchmark( warmup = 15 repeat = 100 + # Determine quantization settings based on provider + k_quant_type = "NONE" + v_quant_type = "NONE" + kv_cache_type = "float16" if dtype == "float16" else "bfloat16" + if "_int4" in provider: + k_quant_type = v_quant_type = "PER_CHANNEL" + kv_cache_type = "int4" + elif "_int8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "int8" + config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, sequence_length=sequence_length, @@ -93,6 +130,9 @@ def benchmark( device=device, dtype=torch.float16 if dtype == "float16" else torch.bfloat16, is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], + k_quant_type=k_quant_type, + v_quant_type=v_quant_type, + kv_cache_type=kv_cache_type, ) obj = OrtGroupQueryAttention(config) @@ -113,15 +153,14 @@ def plot_token_performance( max_seq_len: int, local_window_size: int | None = None, use_smooth_softmax: bool = False, + config: TestConfig | None = None, dtype: str = "float16", ): - import triton # noqa: PLC0415 - - algos = get_plot_algos(sm, local_window_size) + algos = get_plot_algos(sm, local_window_size, config) configs = [ triton.testing.Benchmark( x_names=["past_sequence_length"], - x_vals=[2**i for i in range(4, 17) if 2**i < max_seq_len] + [max_seq_len - 1], + x_vals=[2**i for i in range(6, 17) if 2**i < max_seq_len] + [max_seq_len - 1], line_arg="provider", ylabel="ms", **algos, @@ -154,6 +193,19 @@ def benchmark( warmup = 15 repeat = 100 + # Determine quantization settings based on provider + k_quant_type = "NONE" + v_quant_type = "NONE" + kv_cache_type = "float16" if dtype == "float16" else "bfloat16" + share_kv_scale = False + if "_int4" in provider: + k_quant_type = v_quant_type = "PER_CHANNEL" + kv_cache_type = "int4" + elif "_int8" in provider: + k_quant_type = v_quant_type = "PER_TENSOR" + kv_cache_type = "int8" + share_kv_scale = True # XQA requires shared scale + config: GroupQueryAttentionConfig = GroupQueryAttentionConfig( batch_size=batch_size, sequence_length=1, @@ -168,6 +220,10 @@ def benchmark( use_smooth_softmax=use_smooth_softmax, device=device, dtype=torch.float16 if dtype == "float16" else torch.bfloat16, + k_quant_type=k_quant_type, + v_quant_type=v_quant_type, + kv_cache_type=kv_cache_type, + share_kv_scale=share_kv_scale, ) obj = OrtGroupQueryAttention(config) @@ -178,7 +234,9 @@ def benchmark( benchmark.run(save_path=".", print_data=True) -def run_performance_test(sm: int, fast: bool = False): +def run_performance_test( + sm: int, fast: bool = False, config: TestConfig | None = None, dtype: str = "float16", is_prompt: bool = True +): """ Run performance tests for prompt and token generation. @@ -202,7 +260,6 @@ def run_performance_test(sm: int, fast: bool = False): configures = configures[:1] batch_sizes = [1] if fast else [1, 4] smooth_softmax_options = [False] if fast else [False, True] - dtypes = ["float16", "bfloat16"] # Reduce max sequence length when GPU memory is not enough. threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 @@ -210,7 +267,7 @@ def run_performance_test(sm: int, fast: bool = False): for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: for batch_size in batch_sizes: for use_smooth_softmax in smooth_softmax_options: - for dtype in dtypes: + if is_prompt: plot_prompt_performance( sm=sm, batch_size=batch_size, @@ -221,8 +278,10 @@ def run_performance_test(sm: int, fast: bool = False): local_window_size=local_window_size, use_smooth_softmax=use_smooth_softmax, model_name=model_name, + config=config, dtype=dtype, ) + else: plot_token_performance( sm=sm, batch_size=batch_size, @@ -233,6 +292,7 @@ def run_performance_test(sm: int, fast: bool = False): local_window_size=local_window_size, use_smooth_softmax=use_smooth_softmax, model_name=model_name, + config=config, dtype=dtype, ) @@ -243,4 +303,8 @@ def run_performance_test(sm: int, fast: bool = False): s = torch.cuda.Stream() with torch.cuda.stream(s), torch.no_grad(): - run_performance_test(sm, fast=True) + config = TestConfig(test_int4=False, test_int8=True) + run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=True) + run_performance_test(sm, fast=True, config=config, dtype="float16", is_prompt=False) + # run_performance_test(sm, fast=True, config=config, dtype="bfloat16", is_prompt=True) + # run_performance_test(sm, fast=True, config=config, dtype="bfloat16", is_prompt=False) diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index 46523672669b4..b92f255301dde 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -3,7 +3,7 @@ import time import torch -from test_sparse_attention import GroupQueryAttentionConfig, OrtGroupQueryAttention +from gqa_test_helper import GroupQueryAttentionConfig, OrtGroupQueryAttention def save_results(results, filename): diff --git a/onnxruntime/test/python/transformers/gqa_test_helper.py b/onnxruntime/test/python/transformers/gqa_test_helper.py new file mode 100644 index 0000000000000..cd34f4f420ad5 --- /dev/null +++ b/onnxruntime/test/python/transformers/gqa_test_helper.py @@ -0,0 +1,562 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import math + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +# --- Quantization Helpers (from test_gqa.py) --- + +ONNX_TENSOR_TYPE_MAP = { + "float32": TensorProto.FLOAT, + "float16": TensorProto.FLOAT16, + "bfloat16": TensorProto.BFLOAT16, + "int32": TensorProto.INT32, + "int8": TensorProto.INT8, + "int4": TensorProto.UINT8, +} + +TORCH_DTYPE_TO_ONNX_MAP = { + torch.float32: TensorProto.FLOAT, + torch.float16: TensorProto.FLOAT16, + torch.bfloat16: TensorProto.BFLOAT16, + torch.int32: TensorProto.INT32, + torch.int8: TensorProto.INT8, +} + +TORCH_DTYPE_MAP = { + "float32": torch.float32, + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "int8": torch.int8, + "int4": torch.uint8, +} + +NUMPY_DTYPE_MAP = { + "float32": numpy.float32, + "float16": numpy.float16, + "bfloat16": numpy.uint16, + "int8": numpy.int8, + "int4": numpy.uint8, +} + + +def get_q_range(q_type_str): + q_type_str = str(q_type_str) + if q_type_str.endswith("int8"): + return -128, 127 + if q_type_str.endswith("int4"): + return -8, 7 + raise ValueError(f"Unsupported quantization type for range: {q_type_str}") + + +def pack_int4(tensor_int8): + assert tensor_int8.shape[-1] % 2 == 0 + t_low = tensor_int8[..., 0::2] + 8 + t_high = tensor_int8[..., 1::2] + 8 + packed = (t_low & 0x0F) | (t_high << 4) + return packed.to(torch.uint8) + + +def unpack_int4(packed_tensor_uint8): + t_low = (packed_tensor_uint8 & 0x0F) - 8 + t_high = (packed_tensor_uint8 >> 4) - 8 + unpacked = torch.empty( + (*packed_tensor_uint8.shape[:-1], packed_tensor_uint8.shape[-1] * 2), + dtype=torch.int8, + device=packed_tensor_uint8.device, + ) + unpacked[..., 0::2] = t_low + unpacked[..., 1::2] = t_high + return unpacked + + +def compute_scale(tensor_float, quant_type, q_type_str): + if quant_type == "NONE": + return None + + qmin, qmax = get_q_range(q_type_str) + + if quant_type == "PER_TENSOR": + t_max = torch.max(torch.abs(tensor_float)) + scale = t_max / qmax if t_max > 1e-6 else torch.tensor(1.0, device=tensor_float.device, dtype=torch.float32) + return scale.unsqueeze(0).to(torch.float32) + + if quant_type == "PER_CHANNEL": + # Per-channel scale is computed independently for each channel across the batch and sequence length dimensions. + t_max = torch.max(torch.abs(tensor_float), dim=2, keepdim=True)[0] + t_max = torch.max(t_max, dim=0, keepdim=True)[0] + scale = t_max / qmax + scale[scale < 1e-6] = 1.0 + return scale.to(torch.float32) + + raise ValueError(f"Unsupported quant_type: {quant_type}") + + +def dequantize_tensor(quantized_tensor, scale, quant_type, q_type_str): + if quant_type == "NONE": + return quantized_tensor + + # Ensure scale is on the same device as quantized_tensor + if isinstance(scale, torch.Tensor): + scale = scale.to(quantized_tensor.device) + + unpacked_tensor = quantized_tensor + q_type_str_s = str(q_type_str) + if q_type_str_s.endswith("int4"): + unpacked_tensor = unpack_int4(quantized_tensor) + + return unpacked_tensor.to(torch.float32) * scale + + +def quantize_tensor_with_scale(tensor_float, scale, quant_type, q_type_str): + """Quantizes a tensor using a provided scale.""" + if quant_type == "NONE": + return tensor_float + + qmin, qmax = get_q_range(q_type_str) + quantized = torch.clamp(torch.round(tensor_float / scale), qmin, qmax) + + q_type_str_s = str(q_type_str) + if q_type_str_s.endswith("int4"): + quantized = pack_int4(quantized.to(torch.int8)) + else: + target_dtype = TORCH_DTYPE_MAP[q_type_str] + quantized = quantized.to(target_dtype) + return quantized + + +# --- Classes moved from test_sparse_attention.py --- + + +class AttentionConfig: + def __init__( + self, + operator: str, + batch_size: int, + sequence_length: int, + max_sequence_length: int, + past_sequence_length: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + softmax_scale: float | None, + do_rotary: bool, + rotary_interleaved: bool, + provider: str = "CUDAExecutionProvider", + device="cuda", + dtype=torch.float16, + share_buffer: bool = True, + is_packed_qkv: bool = False, + max_cache_sequence_length=None, + max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, + ): + self.operator = operator + self.batch_size = batch_size + self.sequence_length = sequence_length + self.max_sequence_length = max_sequence_length + self.max_cache_sequence_length = max_cache_sequence_length or max_sequence_length + self.max_rotary_sequence_length = max_rotary_sequence_length or max_sequence_length + self.past_sequence_length = past_sequence_length + self.num_heads = num_heads + self.kv_num_heads = kv_num_heads + self.head_size = head_size + self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) + + # Derived values + self.total_sequence_length = sequence_length + past_sequence_length + self.past_buffer_length = self.max_cache_sequence_length if share_buffer else past_sequence_length + self.present_buffer_length = ( + self.max_cache_sequence_length if share_buffer else (past_sequence_length + sequence_length) + ) + + self.do_rotary = do_rotary + self.rotary_interleaved = rotary_interleaved + + self.provider = provider + self.device = device + self.dtype = dtype + + self.share_buffer = share_buffer + self.is_packed_qkv = is_packed_qkv + + self.use_smooth_softmax = use_smooth_softmax + + def shape_dict(self): + shapes = { + "query": ( + self.batch_size, + self.sequence_length, + (self.num_heads + 2 * self.kv_num_heads) * self.head_size, + ), + "past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), + "past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), + "total_sequence_length": (1,), + "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), + "present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), + "cos_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + "sin_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), + } + + if not self.is_packed_qkv: + shapes.update( + { + "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), + "key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), + "value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), + } + ) + return shapes + + def get_cos_sin_cache(self, dtype): + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16 + angle = torch.rand(self.max_rotary_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + return cos.to(device=self.device), sin.to(device=self.device) + + def random_inputs(self): + device = self.device + # ORT python I/O binding API supports bf16 via torch tensor. + dtype = self.dtype + + # Always use non-packed qkv to generate same inputs for Torch and ORT. + packed = self.is_packed_qkv # Save the original value. + self.is_packed_qkv = False + shape_dict = self.shape_dict() + self.is_packed_qkv = packed # Restore the original value. + torch.manual_seed(123) + + feeds = { + "query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), + "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32), + } + + if packed: + query = feeds["query"].view(self.batch_size, self.sequence_length, self.num_heads, self.head_size) + key = feeds["key"].view(self.batch_size, self.sequence_length, self.kv_num_heads, self.head_size) + value = feeds["value"].view(self.batch_size, self.sequence_length, self.kv_num_heads, self.head_size) + feeds["query"] = torch.dstack((query, key, value)).reshape(self.batch_size, self.sequence_length, -1) + del feeds["key"] + del feeds["value"] + + if self.do_rotary: + cos_cache, sin_cache = self.get_cos_sin_cache(dtype) + feeds["cos_cache"] = cos_cache + feeds["sin_cache"] = sin_cache + + return feeds + + +class GroupQueryAttentionConfig(AttentionConfig): + def __init__( + self, + batch_size: int, + sequence_length: int, + max_sequence_length: int, + past_sequence_length: int, + num_heads: int, + kv_num_heads: int, + head_size: int, + softmax_scale=None, + do_rotary: bool = False, + rotary_interleaved: bool = False, + provider: str = "CUDAExecutionProvider", + device="cuda", + dtype=torch.float16, + local_window_size: int = -1, + attention_mask=None, + is_packed_qkv=False, + max_cache_sequence_length=None, + max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, + k_quant_type: str = "NONE", + v_quant_type: str = "NONE", + kv_cache_type: str = "float16", + share_kv_scale: bool = False, + ): + super().__init__( + "GroupQueryAttention", + batch_size=batch_size, + sequence_length=sequence_length, + max_sequence_length=max_sequence_length, + past_sequence_length=past_sequence_length, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + softmax_scale=softmax_scale, + do_rotary=do_rotary, + rotary_interleaved=rotary_interleaved, + provider=provider, + device=device, + dtype=dtype, + is_packed_qkv=is_packed_qkv, + max_cache_sequence_length=max_cache_sequence_length, + max_rotary_sequence_length=max_rotary_sequence_length, + use_smooth_softmax=use_smooth_softmax, + ) + # local_window_size is for ORT only, not for Torch implementation. + self.local_window_size = local_window_size + + # attention mask is for Torch implementation only, not for ORT. + self.attention_mask = attention_mask + + # Quantization parameters + self.k_quant_type = k_quant_type + self.v_quant_type = v_quant_type + self.kv_cache_type = kv_cache_type + # Determine bit width from cache type if applicable + self.kv_cache_bit_width = 4 if kv_cache_type == "int4" else (8 if kv_cache_type == "int8" else 0) + self.share_kv_scale = share_kv_scale + + def shape_dict(self): + shapes = super().shape_dict() + shapes.update( + { + "seqlens_k": (self.batch_size,), + } + ) + # Note: We don't adjust shapes for int4 here because the parent's random_inputs + # creates float tensors first, then quantization will pack them + return shapes + + def random_inputs(self): + feeds = super().random_inputs() + k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length + feeds.update( + { + "seqlens_k": k_seqlens - 1, + } + ) + + # Generate quantized cache and scales if quantization is enabled + if self.k_quant_type != "NONE": + # Compute scales from the generated float cache + k_scale = compute_scale(feeds["past_key"], self.k_quant_type, self.kv_cache_type) + if self.share_kv_scale: + v_scale = k_scale + else: + v_scale = compute_scale(feeds["past_value"], self.v_quant_type, self.kv_cache_type) + + # Scale tensors must be float32 (required by GQA operator) + if k_scale is not None: + k_scale = k_scale.to(torch.float32) + feeds["k_scale"] = k_scale + if v_scale is not None: + v_scale = v_scale.to(torch.float32) + feeds["v_scale"] = v_scale + + # Quantize the cache tensors + feeds["past_key"] = quantize_tensor_with_scale( + feeds["past_key"], k_scale, self.k_quant_type, self.kv_cache_type + ) + feeds["past_value"] = quantize_tensor_with_scale( + feeds["past_value"], v_scale, self.v_quant_type, self.kv_cache_type + ) + + return feeds + + +def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): + assert config.dtype in [torch.float16, torch.float32, torch.bfloat16] + + if config.dtype == torch.float16: + float_type = TensorProto.FLOAT16 + elif config.dtype == torch.bfloat16: + float_type = TensorProto.BFLOAT16 + else: + float_type = TensorProto.FLOAT + + # Build input list for the GQA node + node_inputs = [ + "query", + "key" if not config.is_packed_qkv else "", + "value" if not config.is_packed_qkv else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length" if config.share_buffer else "", + "cos_cache" if config.do_rotary else "", + "sin_cache" if config.do_rotary else "", + "", # position_ids (optional, not used in benchmark) + "", # attention_bias (optional, not used in benchmark) + "", # head_sink (optional, not used in benchmark) + "k_scale" if config.k_quant_type != "NONE" else "", + "v_scale" if config.v_quant_type != "NONE" else "", + ] + # Remove trailing empty strings + while node_inputs and node_inputs[-1] == "": + node_inputs.pop() + + # Build attributes dictionary + node_attrs = { + "num_heads": config.num_heads, + "kv_num_heads": config.kv_num_heads, + "scale": config.softmax_scale, + "local_window_size": config.local_window_size, + "do_rotary": 1 if config.do_rotary else 0, + "rotary_interleaved": config.rotary_interleaved, + "smooth_softmax": 1 if config.use_smooth_softmax else 0, + "domain": "com.microsoft", + } + + # Add quantization attributes if enabled + if config.k_quant_type != "NONE": + node_attrs["k_quant_type"] = config.k_quant_type + node_attrs["v_quant_type"] = config.v_quant_type + node_attrs["kv_cache_bit_width"] = config.kv_cache_bit_width + + nodes = [ + helper.make_node( + "GroupQueryAttention", + node_inputs, + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + **node_attrs, + ), + ] + + shape_dict = config.shape_dict() + graph_input = [ + helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])), + ] + + if not config.is_packed_qkv: + graph_input.extend( + [ + helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])), + helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])), + ] + ) + + # Determine cache tensor type based on quantization + # Note: INT8 uses INT8 type, INT4 uses UINT8 (for packing 2x4-bit values per byte) + cache_type = float_type + if config.kv_cache_type == "int4": + cache_type = TensorProto.UINT8 + elif config.kv_cache_type == "int8": + cache_type = TensorProto.INT8 + + # Compute actual cache shapes (packed for INT4) + past_key_shape = list(shape_dict["past_key"]) + past_value_shape = list(shape_dict["past_value"]) + present_key_shape = list(shape_dict["present_key"]) + present_value_shape = list(shape_dict["present_value"]) + + # For INT4, the last dimension is packed (2 values per byte) + if config.kv_cache_type == "int4": + past_key_shape[-1] = past_key_shape[-1] // 2 + past_value_shape[-1] = past_value_shape[-1] // 2 + present_key_shape[-1] = present_key_shape[-1] // 2 + present_value_shape[-1] = present_value_shape[-1] // 2 + + graph_input.extend( + [ + helper.make_tensor_value_info("past_key", cache_type, past_key_shape), + helper.make_tensor_value_info("past_value", cache_type, past_value_shape), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])), + helper.make_tensor_value_info( + "total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"]) + ), + ] + ) + + if config.do_rotary: + graph_input += [ + helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])), + helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), + ] + + # Add scale inputs for quantization + # Shape depends on quantization type: + # - PER_TENSOR: [1] + # - PER_CHANNEL: [1, kv_num_heads, 1, head_size] + # Note: k_scale and v_scale are always float32 regardless of the model's dtype + if config.k_quant_type != "NONE": + if config.k_quant_type == "PER_TENSOR": + k_scale_shape = [1] + else: # PER_CHANNEL + k_scale_shape = [1, config.kv_num_heads, 1, config.head_size] + graph_input.append(helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, k_scale_shape)) + + if config.v_quant_type != "NONE": + if config.v_quant_type == "PER_TENSOR": + v_scale_shape = [1] + else: # PER_CHANNEL + v_scale_shape = [1, config.kv_num_heads, 1, config.head_size] + graph_input.append(helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, v_scale_shape)) + + graph_output = [ + helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])), + helper.make_tensor_value_info("present_key", cache_type, present_key_shape), + helper.make_tensor_value_info("present_value", cache_type, present_value_shape), + ] + + graph = helper.make_graph( + nodes, + "GroupQueryAttention_Graph", + graph_input, + graph_output, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_gqa_ort_session( + config: GroupQueryAttentionConfig, session_options=None, enable_cuda_graph=False +) -> CudaSession: + onnx_model_str = create_group_query_attention_onnx_model(config) + + if config.provider == "CUDAExecutionProvider": + device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index + provider_options = CudaSession.get_cuda_provider_options( + device_id, enable_cuda_graph=enable_cuda_graph, stream=torch.cuda.current_stream().cuda_stream + ) + providers = [(config.provider, provider_options), "CPUExecutionProvider"] + else: + providers = ["CPUExecutionProvider"] + + ort_session = InferenceSession(onnx_model_str, session_options, providers=providers) + # Note that CudaSession could work with both CUDA and CPU providers. + cuda_session = CudaSession(ort_session, config.device, enable_cuda_graph=enable_cuda_graph) + shape_dict = config.shape_dict() + cuda_session.allocate_buffers(shape_dict) + + buffer_sharing = {"past_key": "present_key", "past_value": "present_value"} + for input_name, output_name in buffer_sharing.items(): + cuda_session.set_buffer_sharing(input_name, output_name) + + return cuda_session + + +class OrtGroupQueryAttention: + """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" + + def __init__(self, config: GroupQueryAttentionConfig): + self.session = create_gqa_ort_session(config) + + self.feed_dict = config.random_inputs() + + # ENABLE_DEBUG is not defined in this module, so we assume False or pass it as arg if needed. + # But looking at original code, it was a global. Since this is a helper, we might skip the debug print or make it optional. + # For strict refactoring, I'll remove the debug print block or comment it out unless I import ENABLE_DEBUG. + # I'll check if ENABLE_DEBUG was used in the class. It was. + # I'll skip it for now to avoid dependency on global var. + + def infer(self): + return self.session.infer(self.feed_dict) diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index e800c22f92efb..5cbba989a4dbd 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -9,25 +9,33 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 # ------------------------------------------------------------------------- +import gc import math import os import platform import random import unittest +from copy import deepcopy from dataclasses import dataclass import numpy import torch from einops import rearrange, repeat + +# --- ONNX and Torch/Numpy Dtype Mappings --- +from gqa_test_helper import ( + ONNX_TENSOR_TYPE_MAP, + TORCH_DTYPE_MAP, + compute_scale, + dequantize_tensor, + quantize_tensor_with_scale, +) from onnx import TensorProto, helper +from packaging import version from parameterized import parameterized -from onnxruntime import ( - InferenceSession, - SessionOptions, - get_available_providers, - get_build_info, -) +from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_build_info +from onnxruntime import __version__ as ort_version # Set seed for reproducibility torch.manual_seed(0) @@ -47,51 +55,16 @@ # When quick build is used, flash attention only supports head_size=128 quick_build = ", quick-build=" in get_build_info() -enable_debug_print = quick_build +has_int4_kv_cache = ", int4-kv-cache=" in get_build_info() -enable_deterministic_check = True +enable_debug_print = False -enable_quantized_kv_tests = True +enable_deterministic_check = True # ################################################################################################# # Configuration and Helper Classes # ################################################################################################# -# --- ONNX and Torch/Numpy Dtype Mappings --- -ONNX_TENSOR_TYPE_MAP = { - "float32": TensorProto.FLOAT, - "float16": TensorProto.FLOAT16, - "bfloat16": TensorProto.BFLOAT16, - "int32": TensorProto.INT32, - "int8": TensorProto.INT8, - "int4": TensorProto.UINT8, -} - -TORCH_DTYPE_TO_ONNX_MAP = { - torch.float32: TensorProto.FLOAT, - torch.float16: TensorProto.FLOAT16, - torch.bfloat16: TensorProto.BFLOAT16, - torch.int32: TensorProto.INT32, - torch.int8: TensorProto.INT8, -} - -TORCH_DTYPE_MAP = { - "float32": torch.float32, - "float16": torch.float16, - "bfloat16": torch.bfloat16, - "int8": torch.int8, - "int4": torch.uint8, -} - -NUMPY_DTYPE_MAP = { - "float32": numpy.float32, - "float16": numpy.float16, - "bfloat16": numpy.uint16, - "int8": numpy.int8, - "int4": numpy.uint8, -} - - @dataclass class GQAConfig: batch_size: int @@ -112,10 +85,16 @@ class GQAConfig: has_head_sink: bool = False kv_cache_type: str = "" share_buffer: bool = True + share_kv_scale: bool = False has_position_ids: bool = False has_attention_bias: bool = False + # Quantization parameters + k_quant_type: str = "NONE" + v_quant_type: str = "NONE" + kv_cache_bit_width: int = 0 + # ################################################################################################# # Rotary Embedding Implementations (CPU and CUDA) @@ -238,12 +217,20 @@ def create_gqa_node_and_io( if output_qk > 0: outputs.append("output_qk") + # Ensure kv_cache_bit_width is set correctly based on cache type if not provided + bit_width = config.kv_cache_bit_width + if bit_width == 0: + if config.kv_cache_type == "int4": + bit_width = 4 + elif config.kv_cache_type == "int8": + bit_width = 8 + inputs = [ "query", "key" if not config.packed else "", "value" if not config.packed else "", - "past_key" if is_past or share_buffer else "", - "past_value" if is_past or share_buffer else "", + "past_key" if is_past or share_buffer or config.k_quant_type != "NONE" else "", + "past_value" if is_past or share_buffer or config.k_quant_type != "NONE" else "", "seqlens_k", "total_sequence_length", "cos_cache" if config.rotary else "", @@ -251,12 +238,26 @@ def create_gqa_node_and_io( "position_ids" if config.has_position_ids else "", "attention_bias" if config.has_attention_bias else "", "head_sink" if config.has_head_sink else "", + "k_scale" if config.k_quant_type != "NONE" else "", + "k_scale" + if config.share_kv_scale and config.k_quant_type != "NONE" + else ("v_scale" if config.v_quant_type != "NONE" else ""), ] # Remove trailing empty strings while inputs and inputs[-1] == "": inputs.pop() + quantization_attributes = ( + { + "k_quant_type": config.k_quant_type, + "v_quant_type": config.v_quant_type, + "kv_cache_bit_width": bit_width, + } + if config.k_quant_type != "NONE" + else {} + ) + node = helper.make_node( op_type="GroupQueryAttention", inputs=inputs, @@ -270,6 +271,7 @@ def create_gqa_node_and_io( softcap=config.softcap, smooth_softmax=1 if config.use_smooth_softmax else 0, qk_output=output_qk, + **quantization_attributes, domain="com.microsoft", ) @@ -284,11 +286,7 @@ def create_gqa_node_and_io( helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [config.batch_size]), helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), ] - - if isinstance(config.kv_cache_type, torch.dtype): - cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] - else: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if not config.packed: graph_input.extend( @@ -306,14 +304,22 @@ def create_gqa_node_and_io( ] ) - if is_past or share_buffer: + if is_past or share_buffer or config.k_quant_type != "NONE": k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + if config.kv_cache_type == "int4": + k_shape[-1] //= 2 graph_input.extend( [ helper.make_tensor_value_info("past_key", cache_ort_type, k_shape), helper.make_tensor_value_info("past_value", cache_ort_type, k_shape), ] ) + if config.k_quant_type != "NONE": + # Scales are always float32 + graph_input.append(helper.make_tensor_value_info("k_scale", TensorProto.FLOAT, None)) + if config.v_quant_type != "NONE" and not config.share_kv_scale: + graph_input.append(helper.make_tensor_value_info("v_scale", TensorProto.FLOAT, None)) + if config.rotary: rotary_dim = (math.floor(config.head_size / 16) * 16) // 2 cache_seq_len = config.buffer_sequence_length @@ -342,6 +348,8 @@ def create_gqa_node_and_io( # --- Graph Outputs --- output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + if config.kv_cache_type == "int4": + output_k_shape[-1] //= 2 graph_output = [ helper.make_tensor_value_info( @@ -423,6 +431,8 @@ def gqa_prompt_func( position_ids, attention_bias, head_sink, + k_scale, + v_scale, ep, device, share_buffer=True, @@ -457,15 +467,17 @@ def gqa_prompt_func( bind_tensor(io_binding, "key", new_k, device, ort_type) bind_tensor(io_binding, "value", new_v, device, ort_type) - # 3. Bind 'past_key', 'past_value' - if share_buffer: + # 3. Bind 'past_key', 'past_value' (if share_buffer or quantized) + if share_buffer or config.k_quant_type != "NONE": # cache_ort_type corresponds to config.kv_cache_type - if isinstance(config.kv_cache_type, torch.dtype): - cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] - else: + cache_ort_type = ort_type + if config.kv_cache_type: cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - k_to_bind = k if share_buffer else k[:, :, :0, :] - v_to_bind = v if share_buffer else v[:, :, :0, :] + + # Use full buffer if sharing, otherwise empty tensor for prompt phase + k_to_bind = k if share_buffer else k[:, :, :0, :].contiguous() + v_to_bind = v if share_buffer else v[:, :, :0, :].contiguous() + bind_tensor(io_binding, "past_key", k_to_bind, device, cache_ort_type) bind_tensor(io_binding, "past_value", v_to_bind, device, cache_ort_type) @@ -494,7 +506,22 @@ def gqa_prompt_func( if config.has_head_sink and head_sink is not None: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) - # Bind Outputs + # 6. Quantization scales + if k_scale is not None: + k_scale_ort_type = TensorProto.FLOAT + if k_scale.dtype != torch.float32: + k_scale = k_scale.to(torch.float32) + k_scale = k_scale.contiguous() + bind_tensor(io_binding, "k_scale", k_scale, device, k_scale_ort_type) + if v_scale is not None: + v_scale_ort_type = TensorProto.FLOAT + if v_scale.dtype != torch.float32: + v_scale = v_scale.to(torch.float32) + v_scale = v_scale.contiguous() + if not config.share_kv_scale: + bind_tensor(io_binding, "v_scale", v_scale, device, v_scale_ort_type) + + # 7. Bind Outputs # output shape calculation hidden_size = config.num_heads * config.head_size @@ -517,19 +544,18 @@ def gqa_prompt_func( present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + # Update present shape when kv cache has quantization (int4 packs 2 values) + if config.kv_cache_bit_width == 4: + present_dims[-1] //= 2 + # Determine dtype for cache tensors cache_dtype = out_dtype cache_ort_type = ort_type - if isinstance(config.kv_cache_type, torch.dtype): - is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP - else: - is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP + if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - if is_valid_type: - if isinstance(config.kv_cache_type, torch.dtype): - cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] - else: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if config.kv_cache_type in TORCH_DTYPE_MAP: + cache_dtype = TORCH_DTYPE_MAP[config.kv_cache_type] if share_buffer: # We bind output to the input buffer 'k' / 'v' (in-place update) @@ -544,7 +570,9 @@ def gqa_prompt_func( bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) + io_binding.synchronize_inputs() ort_session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() return out_torch, present_k, present_v @@ -562,6 +590,8 @@ def gqa_past_func( position_ids, attention_bias, head_sink, + k_scale, + v_scale, ep, device, share_buffer=True, @@ -582,6 +612,7 @@ def gqa_past_func( new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1)) sess_options = SessionOptions() + # sess_options.log_severity_level = 0 ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) io_binding = ort_session.io_binding() @@ -600,10 +631,7 @@ def gqa_past_func( # 3. Bind 'past_key', 'past_value' # These are required inputs for past_func # cache_ort_type corresponds to config.kv_cache_type - if isinstance(config.kv_cache_type, torch.dtype): - cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] - else: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # If sharing buffer, we bind 'past_key' to the large buffer 'k' @@ -639,6 +667,25 @@ def gqa_past_func( if config.has_head_sink and head_sink is not None: bind_tensor(io_binding, "head_sink", head_sink, device, ort_type) + # 6. Quantization + if k_scale is not None: + k_scale_ort_type = TensorProto.FLOAT + if k_scale.dtype != torch.float32: + k_scale = k_scale.to(torch.float32) + k_scale = k_scale.contiguous() + bind_tensor(io_binding, "k_scale", k_scale, device, k_scale_ort_type) + if v_scale is not None: + v_scale_ort_type = TensorProto.FLOAT + if v_scale.dtype != torch.float32: + v_scale = v_scale.to(torch.float32) + v_scale = v_scale.contiguous() + # Even if share_kv_scale is True, the node might have two scale inputs named "k_scale" and "v_scale" + # depending on the graph creation logic. We should bind "v_scale" if it's expected by the graph. + # In create_gqa_node_and_io, if share_kv_scale is True, Input 13 is named "k_scale". + # But if it's False, it's named "v_scale". + if not config.share_kv_scale: + bind_tensor(io_binding, "v_scale", v_scale, device, v_scale_ort_type) + # 7. Outputs # output shape calculation hidden_size = config.num_heads * config.head_size @@ -662,19 +709,15 @@ def gqa_past_func( present_seqlen = total_seq_len # For past_func, total seq len is accumulated present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] + if config.kv_cache_bit_width == 4: + present_dims[-1] //= 2 cache_dtype = out_dtype cache_ort_type = ort_type - if isinstance(config.kv_cache_type, torch.dtype): - is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP - else: - is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP - - if is_valid_type: - if isinstance(config.kv_cache_type, torch.dtype): - cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] - else: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if config.kv_cache_type in TORCH_DTYPE_MAP: + cache_dtype = TORCH_DTYPE_MAP[config.kv_cache_type] if share_buffer: # In-place update to k/v buffers @@ -688,7 +731,9 @@ def gqa_past_func( bind_output_tensor(io_binding, "present_key", present_k, device, cache_ort_type) bind_output_tensor(io_binding, "present_value", present_v, device, cache_ort_type) + io_binding.synchronize_inputs() ort_session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() return out_torch, present_k, present_v @@ -795,6 +840,30 @@ def attention_ref( # ################################################################################################# # Parity Check (Core Test Logic) # ################################################################################################# +def get_static_scale(config: GQAConfig, device, torch_type, std): + """Generates calibration data and computes the static quantization scale.""" + calibration_batch_size = 1 + calibration_sequence_length = 1024 + calibration_data_k = ( + torch.randn( + calibration_batch_size, + config.kv_num_heads, + calibration_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + calibration_data_v = torch.randn_like(calibration_data_k) * std + + # TODO: handle config.share_kv_scale here. + k_scale = compute_scale(calibration_data_k, config.k_quant_type, config.kv_cache_type) + if config.share_kv_scale: + v_scale = k_scale + else: + v_scale = compute_scale(calibration_data_v, config.v_quant_type, config.kv_cache_type) + return k_scale, v_scale def parity_check_gqa_prompt( @@ -822,16 +891,17 @@ def parity_check_gqa_prompt( ) # Initialize the KV cache to zeros since no past context in prompt testing. - k = ( - torch.zeros( - config.batch_size, - config.kv_num_heads, - config.buffer_sequence_length, - config.head_size, - device=device, - dtype=torch_type, - ) - * std + cache_dtype = torch_type + if config.kv_cache_type: + cache_dtype = TORCH_DTYPE_MAP[config.kv_cache_type] + + k = torch.zeros( + config.batch_size, + config.kv_num_heads, + config.buffer_sequence_length, + config.head_size if config.kv_cache_bit_width != 4 else config.head_size // 2, + device=device, + dtype=cache_dtype, ) v = torch.zeros_like(k) @@ -848,6 +918,12 @@ def parity_check_gqa_prompt( ) new_v = torch.randn_like(new_k) * std + k_scale, v_scale = get_static_scale(config, device, torch_type, std) + if k_scale is not None: + k_scale = k_scale.to(torch_type) + if v_scale is not None: + v_scale = v_scale.to(torch_type) + head_sink = torch.rand(config.num_heads, dtype=torch_type, device=device) if config.has_head_sink else None window_size = (-1, -1) if config.local_window_size > 0: @@ -856,10 +932,24 @@ def parity_check_gqa_prompt( window_size = (-1, 0) # --- PyTorch Reference Path --- - # Transpose BNSH cache to BSNH format for reference implementation - k_cache_ref = k.clone().transpose(1, 2) - v_cache_ref = v.clone().transpose(1, 2) - + if config.kv_cache_bit_width == 4 or config.kv_cache_type == "int8": + k_ref_dequant = dequantize_tensor(k, k_scale, config.k_quant_type, config.kv_cache_type) + v_ref_dequant = dequantize_tensor(v, v_scale, config.v_quant_type, config.kv_cache_type) + else: + k_ref_dequant = dequantize_tensor( + quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type), + k_scale, + config.k_quant_type, + config.kv_cache_type, + ) + v_ref_dequant = dequantize_tensor( + quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type), + v_scale, + config.v_quant_type, + config.kv_cache_type, + ) + k_cache_ref = k_ref_dequant.clone().transpose(1, 2) + v_cache_ref = v_ref_dequant.clone().transpose(1, 2) cache_seqlens = torch.full((config.batch_size,), config.kv_sequence_length, device=device, dtype=torch.int32) rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long) @@ -894,20 +984,37 @@ def parity_check_gqa_prompt( kv_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - # Explicitly cast the source tensor to the destination's dtype before assignment. - source_k = rearrange(k_ro, "b s ... -> (b s) ...") - k_cache_ref[update_mask] = source_k.to(k_cache_ref.dtype) - - source_v = rearrange(new_v, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = source_v.to(v_cache_ref.dtype) - - key_padding_mask = arange < kv_seqlens_expanded + k_to_cache = k_ro + v_to_cache = new_v + if config.kv_cache_type != "none": + k_scale_bsnh = k_scale + v_scale_bsnh = v_scale + if config.k_quant_type == "PER_CHANNEL" and k_scale is not None: + k_scale_bsnh = k_scale.transpose(1, 2) # (1, H, 1, D) -> (1, 1, H, D) + if config.v_quant_type == "PER_CHANNEL" and v_scale is not None: + v_scale_bsnh = v_scale.transpose(1, 2) # (1, H, 1, D) -> (1, 1, H, D) + + k_to_cache = dequantize_tensor( + quantize_tensor_with_scale(k_ro, k_scale_bsnh, config.k_quant_type, config.kv_cache_type), + k_scale_bsnh, + config.k_quant_type, + config.kv_cache_type, + ).to(torch_type) + v_to_cache = dequantize_tensor( + quantize_tensor_with_scale(new_v, v_scale_bsnh, config.v_quant_type, config.kv_cache_type), + v_scale_bsnh, + config.v_quant_type, + config.kv_cache_type, + ).to(torch_type) + + k_cache_ref[update_mask] = rearrange(k_to_cache, "b s ... -> (b s) ...").to(k_cache_ref.dtype) + v_cache_ref[update_mask] = rearrange(v_to_cache, "b s ... -> (b s) ...").to(v_cache_ref.dtype) out_ref, _ = attention_ref( q=q_ro, - k=k_cache_ref, - v=v_cache_ref, - key_padding_mask=key_padding_mask, + k=k_ro, + v=new_v, + key_padding_mask=None, attention_bias=attention_bias, causal=True, window_size=window_size, @@ -918,63 +1025,32 @@ def parity_check_gqa_prompt( out_ref_np = out_ref.to(torch.float32).detach().cpu().numpy() # --- ONNX Runtime Path --- - q_ort, k_ort, v_ort, new_k_ort, new_v_ort = q, k, v, new_k, new_v + q_ort, new_k_ort, new_v_ort = q, new_k, new_v if config.packed: q_ort = torch.cat([q, new_k, new_v], dim=2) new_k_ort, new_v_ort = None, None - # seqlens_k for GQA op is past_seq_len + seq_len - 1 ort_seqlens = cache_seqlens - 1 - num_runs = 2 if enable_deterministic_check else 1 - for i in range(num_runs): - out, present_k, present_v = gqa_prompt_func( - q=q_ort, - k=k_ort, - v=v_ort, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens, - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) - if i == 0: - first_out = out.clone() - first_present_k = present_k.clone() if present_k is not None else None - first_present_v = present_v.clone() if present_v is not None else None - else: - if present_k is not None: - try: - torch.testing.assert_close( - present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" - ) - except AssertionError as e: - print(e) - raise e - if present_v is not None: - try: - torch.testing.assert_close( - present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" - ) - except AssertionError as e: - print(e) - raise e - try: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") - except AssertionError as e: - max_diff = (out - first_out).abs().max().item() - print(f"Output mismatch max diff: {max_diff}") - with open("/tmp/gqa_diff_info.txt", "w") as f: - f.write(f"Max Diff: {max_diff}\n") - print(e) - raise e + out, present_k, present_v = gqa_prompt_func( + q=q_ort, + k=k, + v=v, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens, + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + k_scale=k_scale, + v_scale=v_scale, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -1001,14 +1077,91 @@ def parity_check_gqa_prompt( k_cache_ref_np = k_cache_ref_np[:, :, : config.kv_sequence_length, :] v_cache_ref_np = v_cache_ref_np[:, :, : config.kv_sequence_length, :] - print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") - numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) - print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") - numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + if config.k_quant_type == "NONE": + numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + # Compare quantized cache with proper masking per batch + if config.k_quant_type != "NONE": + # Convert numpy array to torch tensor with correct dtype + if isinstance(present_k, torch.Tensor): + present_k_torch = present_k.to(device) + # If tensor is int8/uint8, it should be preserved. + else: + if config.kv_cache_type == "int4": + # For int4, present_k is uint8 packed data + present_k_torch = torch.from_numpy(present_k).to(device) + elif config.kv_cache_type == "int8": + # For int8, present_k is int8 data + present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + else: + present_k_torch = torch.from_numpy(present_k).to(device) + + present_k_dequant = ( + dequantize_tensor(present_k_torch, k_scale, config.k_quant_type, config.kv_cache_type) + .detach() + .cpu() + .numpy() + ) + + # Mask the reference cache to only valid regions + k_cache_ref_masked = k_cache_ref.transpose(1, 2).clone() + arange = torch.arange(config.buffer_sequence_length, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(-1) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1).unsqueeze(1).unsqueeze(-1) + mask = arange >= cache_seqlens_expanded + k_cache_ref_masked[mask.expand_as(k_cache_ref_masked)] = 0 + k_cache_ref_dequant = k_cache_ref_masked.cpu().numpy() + + for b in range(config.batch_size): + valid_len = cache_seqlens[b].item() + print_diff_statistics( + torch.tensor(present_k_dequant[b, :, :valid_len, :] - k_cache_ref_dequant[b, :, :valid_len, :]), + f"present_k[{b}]", + ) + numpy.testing.assert_allclose( + present_k_dequant[b, :, :valid_len, :], k_cache_ref_dequant[b, :, :valid_len, :], rtol=rtol, atol=atol + ) + + if config.v_quant_type != "NONE": + # Convert numpy array to torch tensor with correct dtype + if isinstance(present_v, torch.Tensor): + present_v_torch = present_v.to(device) + else: + if config.kv_cache_type == "int4": + present_v_torch = torch.from_numpy(present_v).to(device) + elif config.kv_cache_type == "int8": + present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + else: + present_v_torch = torch.from_numpy(present_v).to(device) + + present_v_dequant = ( + dequantize_tensor(present_v_torch, v_scale, config.v_quant_type, config.kv_cache_type) + .detach() + .cpu() + .numpy() + ) + + # Mask the reference cache to only valid regions + v_cache_ref_masked = v_cache_ref.transpose(1, 2).clone() + arange = torch.arange(config.buffer_sequence_length, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(-1) + cache_seqlens_expanded = cache_seqlens.unsqueeze(1).unsqueeze(1).unsqueeze(-1) + mask = arange >= cache_seqlens_expanded + v_cache_ref_masked[mask.expand_as(v_cache_ref_masked)] = 0 + v_cache_ref_dequant = v_cache_ref_masked.cpu().numpy() + + for b in range(config.batch_size): + valid_len = cache_seqlens[b].item() + print_diff_statistics( + torch.tensor(present_v_dequant[b, :, :valid_len, :] - v_cache_ref_dequant[b, :, :valid_len, :]), + f"present_v[{b}]", + ) + numpy.testing.assert_allclose( + present_v_dequant[b, :, :valid_len, :], v_cache_ref_dequant[b, :, :valid_len, :], rtol=rtol, atol=atol + ) + def parity_check_gqa_past( config: GQAConfig, @@ -1053,13 +1206,16 @@ def parity_check_gqa_past( ) v = torch.randn_like(k) * std - # past cache sequence length is in [1, past_kv_sequence_length] + # Random past sequence lengths. This tests paddings in decoding. + # Use a separate generator to ensure deterministic behavior independent of prior RNG state. + cache_seqlens_gen = torch.Generator(device=device).manual_seed(42) cache_seqlens = torch.randint( 1, config.past_kv_sequence_length + 1, (config.batch_size,), device=device, dtype=torch.long, + generator=cache_seqlens_gen, ) for i in range(config.batch_size): @@ -1086,10 +1242,29 @@ def parity_check_gqa_past( elif causal: window_size = (-1, 0) + k_scale, v_scale = get_static_scale(config, device, torch_type, std) + if k_scale is not None: + k_scale = k_scale.to(torch_type) + if v_scale is not None: + v_scale = v_scale.to(torch_type) + # --- PyTorch Reference Path --- # Transpose BNSH cache to BSNH format for reference implementation - k_cache_ref = k.clone().transpose(1, 2) - v_cache_ref = v.clone().transpose(1, 2) + + k_ref_dequant = dequantize_tensor( + quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type), + k_scale, + config.k_quant_type, + config.kv_cache_type, + ) + v_ref_dequant = dequantize_tensor( + quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type), + v_scale, + config.v_quant_type, + config.kv_cache_type, + ) + k_cache_ref = k_ref_dequant.clone().transpose(1, 2) + v_cache_ref = v_ref_dequant.clone().transpose(1, 2) cos, sin, q_ro, k_ro = None, None, q, new_k if config.rotary: @@ -1117,8 +1292,32 @@ def parity_check_gqa_past( update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(k_cache_ref.dtype) - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(v_cache_ref.dtype) + + k_to_cache = k_ro + v_to_cache = new_v + if config.kv_cache_type != "none": + k_scale_bsnh = k_scale + v_scale_bsnh = v_scale + if config.k_quant_type == "PER_CHANNEL" and k_scale is not None: + k_scale_bsnh = k_scale.transpose(1, 2) # (1, H, 1, D) -> (1, 1, H, D) + if config.v_quant_type == "PER_CHANNEL" and v_scale is not None: + v_scale_bsnh = v_scale.transpose(1, 2) # (1, H, 1, D) -> (1, 1, H, D) + + k_to_cache = dequantize_tensor( + quantize_tensor_with_scale(k_ro, k_scale_bsnh, config.k_quant_type, config.kv_cache_type), + k_scale_bsnh, + config.k_quant_type, + config.kv_cache_type, + ).to(torch_type) + v_to_cache = dequantize_tensor( + quantize_tensor_with_scale(new_v, v_scale_bsnh, config.v_quant_type, config.kv_cache_type), + v_scale_bsnh, + config.v_quant_type, + config.kv_cache_type, + ).to(torch_type) + + k_cache_ref[update_mask] = rearrange(k_to_cache, "b s ... -> (b s) ...").to(k_cache_ref.dtype) + v_cache_ref[update_mask] = rearrange(v_to_cache, "b s ... -> (b s) ...").to(v_cache_ref.dtype) key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length out_ref, _ = attention_ref( @@ -1142,41 +1341,41 @@ def parity_check_gqa_past( q_ort = torch.cat([q, new_k, new_v], dim=2) new_k_ort, new_v_ort = None, None + # Quantize k and v for ORT when using quantized KV cache + # Quantize k and v for ORT when using quantized KV cache + k_ort = k + v_ort = v + if config.kv_cache_type in ["int8", "int4"]: + # NOTE: Quantize returns tensor with kv_cache_type (int8) + k_ort = quantize_tensor_with_scale(k, k_scale, config.k_quant_type, config.kv_cache_type) + v_ort = quantize_tensor_with_scale(v, v_scale, config.v_quant_type, config.kv_cache_type) + + # Ensure they are contiguous for binding + k_ort = k_ort.contiguous() + v_ort = v_ort.contiguous() + ort_seqlens = cache_seqlens + config.q_sequence_length - 1 - num_runs = 2 if enable_deterministic_check else 1 - for i in range(num_runs): - out, present_k, present_v = gqa_past_func( - q=q_ort, - k=k, - v=v, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens.int(), - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) - if i == 0: - first_out = out.clone() - first_present_k = present_k.clone() if present_k is not None else None - first_present_v = present_v.clone() if present_v is not None else None - else: - torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") - if present_k is not None: - torch.testing.assert_close( - present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" - ) - if present_v is not None: - torch.testing.assert_close( - present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" - ) + + out, present_k, present_v = gqa_past_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens.int(), + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + k_scale=k_scale, + v_scale=v_scale, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -1188,26 +1387,106 @@ def parity_check_gqa_past( raise RuntimeError("Output is all zeros") # --- Comparison --- - # Compare KV cache - # Transpose reference back to BNSH to match ORT output - k_cache_ref_np = k_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() - v_cache_ref_np = v_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() - present_k_np = present_k.to(torch.float32).detach().cpu().numpy() - present_v_np = present_v.to(torch.float32).detach().cpu().numpy() - - if not config.share_buffer: - total_len = config.past_kv_sequence_length + config.q_sequence_length - k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] - v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] - - print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") - numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) - print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") - numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + if config.k_quant_type == "NONE" and config.v_quant_type == "NONE": + # Compare KV cache + # Transpose reference back to BNSH to match ORT output + k_cache_ref_np = k_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() + v_cache_ref_np = v_cache_ref.transpose(1, 2).to(torch.float32).detach().cpu().numpy() + present_k_np = present_k.to(torch.float32).detach().cpu().numpy() + present_v_np = present_v.to(torch.float32).detach().cpu().numpy() + + if not config.share_buffer: + total_len = config.past_kv_sequence_length + config.q_sequence_length + k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] + v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + + numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + # Compare quantized cache with proper masking per batch + if config.k_quant_type != "NONE": + if isinstance(present_k, torch.Tensor): + present_k_torch = present_k.to(device) + else: + if config.kv_cache_type == "int4": + present_k_torch = torch.from_numpy(present_k).to(device) + elif config.kv_cache_type == "int8": + present_k_torch = torch.from_numpy(present_k.astype(numpy.int8)).to(device) + else: + present_k_torch = torch.from_numpy(present_k).to(device) + + present_k_dequant = ( + dequantize_tensor(present_k_torch, k_scale, config.k_quant_type, config.kv_cache_type) + .detach() + .cpu() + .numpy() + ) + + # Mask the reference cache to only valid regions + k_cache_ref_masked = k_cache_ref.transpose(1, 2).clone() + total_seqlens = cache_seqlens + config.q_sequence_length + arange = torch.arange(config.buffer_sequence_length, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(-1) + total_seqlens_expanded = total_seqlens.unsqueeze(1).unsqueeze(1).unsqueeze(-1) + mask = arange >= total_seqlens_expanded + k_cache_ref_masked[mask.expand_as(k_cache_ref_masked)] = 0 + k_cache_ref_dequant = k_cache_ref_masked.cpu().numpy() + + for b in range(config.batch_size): + valid_len = (cache_seqlens[b] + config.q_sequence_length).item() + print_diff_statistics( + torch.tensor(present_k_dequant[b, :, :valid_len, :] - k_cache_ref_dequant[b, :, :valid_len, :]), + f"present_k[{b}]", + ) + numpy.testing.assert_allclose( + present_k_dequant[b, :, :valid_len, :], + k_cache_ref_dequant[b, :, :valid_len, :], + rtol=rtol, + atol=atol, + ) + + if config.v_quant_type != "NONE": + if isinstance(present_v, torch.Tensor): + present_v_torch = present_v.to(device) + else: + if config.kv_cache_type == "int4": + present_v_torch = torch.from_numpy(present_v).to(device) + elif config.kv_cache_type == "int8": + present_v_torch = torch.from_numpy(present_v.astype(numpy.int8)).to(device) + else: + present_v_torch = torch.from_numpy(present_v).to(device) + + present_v_dequant = ( + dequantize_tensor(present_v_torch, v_scale, config.v_quant_type, config.kv_cache_type) + .detach() + .cpu() + .numpy() + ) + + # Mask the reference cache to only valid regions + v_cache_ref_masked = v_cache_ref.transpose(1, 2).clone() + total_seqlens = cache_seqlens + config.q_sequence_length + arange = torch.arange(config.buffer_sequence_length, device=device).unsqueeze(0).unsqueeze(0).unsqueeze(-1) + total_seqlens_expanded = total_seqlens.unsqueeze(1).unsqueeze(1).unsqueeze(-1) + mask = arange >= total_seqlens_expanded + v_cache_ref_masked[mask.expand_as(v_cache_ref_masked)] = 0 + v_cache_ref_dequant = v_cache_ref_masked.cpu().numpy() + + for b in range(config.batch_size): + valid_len = (cache_seqlens[b] + config.q_sequence_length).item() + print_diff_statistics( + torch.tensor(present_v_dequant[b, :, :valid_len, :] - v_cache_ref_dequant[b, :, :valid_len, :]), + f"present_v[{b}]", + ) + numpy.testing.assert_allclose( + present_v_dequant[b, :, :valid_len, :], + v_cache_ref_dequant[b, :, :valid_len, :], + rtol=rtol, + atol=atol, + ) + def parity_test_gqa_padding_prompt(): device = "cuda" @@ -1297,6 +1576,8 @@ def parity_test_gqa_padding_prompt(): position_ids=None, attention_bias=None, head_sink=None, + k_scale=None, + v_scale=None, ep="CUDAExecutionProvider", device=device, share_buffer=config.share_buffer, @@ -1449,7 +1730,7 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True, allow_local: bool = b = batches[combo_index % len(batches)] sq, skv = seqs[combo_index % len(seqs)] n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, skv)] if allow_local else [-1] + lws_opts = [-1, max(1, skv // 2)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1527,7 +1808,7 @@ def gqa_cuda_past_test_cases( b = 1 # Force batch=1 for subsequent prompt n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, s2)] if allow_local else [-1] + lws_opts = [-1, max(1, s2 // 2)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1563,6 +1844,40 @@ def gqa_cuda_past_test_cases( yield name, config +def gqa_cuda_quantized_test_cases(is_past: bool): + base_cases = ( + gqa_cuda_past_test_cases(allow_local=True, enforce_share_buffer=True) + if is_past + else gqa_cuda_prompt_test_cases(allow_local=True) + ) + + for name, config in base_cases: + for kv_type in ["int8", "int4"] if has_int4_kv_cache else ["int8"]: + for quant_mode in ["PER_TENSOR", "PER_CHANNEL"]: + share_scales_options = [False] + if quant_mode == "PER_TENSOR" and kv_type == "int8": + share_scales_options = [True] + + for share_scales in share_scales_options: + q_config = deepcopy(config) + q_config.k_quant_type = quant_mode + q_config.v_quant_type = quant_mode + q_config.kv_cache_type = kv_type + q_config.share_kv_scale = share_scales + + if kv_type == "int4": + if q_config.head_size % 2 != 0: + continue + q_config.kv_cache_bit_width = 4 + elif kv_type == "int8": + q_config.kv_cache_bit_width = 8 + + q_name = f"{name}_quant_{kv_type}_{quant_mode}" + if share_scales: + q_name += "_shared" + yield q_name, q_config + + # ################################################################################################# # Unit Test Classes # ################################################################################################# @@ -1579,18 +1894,36 @@ def has_cuda_device(min_capability: int = 80): return major * 10 + minor >= min_capability -def has_flash_attention(): - return has_cuda_device(80) +def has_flash_attention(bf16=False): + if not has_cuda_device(80): + return False + if bf16: + return torch.cuda.is_bf16_supported() + return True + + +rtol = {"fp16": 5e-3, "bf16": 5e-2, "int8_fp16": 5e-2, "int4_fp16": 5e-2, "int8_bf16": 5e-2, "int4_bf16": 5e-2} +atol = {"fp16": 5e-3, "bf16": 1e-2, "int8_fp16": 1e-1, "int4_fp16": 1e-1, "int8_bf16": 2e-1, "int4_bf16": 2e-1} -rtol = {"fp16": 5e-3, "bf16": 5e-2} -atol = {"fp16": 5e-3, "bf16": 1e-2} +def has_quantized_kv_cache(): + return version.parse(ort_version) >= version.parse("1.24.0") @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFlashGQA(unittest.TestCase): + def tearDown(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + @parameterized.expand(gqa_cuda_prompt_test_cases()) def test_gqa_prompt_flash_attention(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( config=config, @@ -1605,6 +1938,10 @@ def test_gqa_prompt_flash_attention(self, name, config): @parameterized.expand(gqa_cuda_past_test_cases()) def test_gqa_past_flash_attention(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( config=config, @@ -1618,13 +1955,23 @@ def test_gqa_past_flash_attention(self, name, config): ) -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +@unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.") class TestFlashGQABF16(unittest.TestCase): + def tearDown(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + @parameterized.expand(gqa_cuda_prompt_test_cases()) def test_gqa_prompt_flash_attention_bf16(self, name, config): if not torch.cuda.is_bf16_supported(): self.skipTest("BFloat16 not supported on this device") + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + config.kv_cache_type = "bfloat16" os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_prompt( @@ -1643,6 +1990,10 @@ def test_gqa_past_flash_attention_bf16(self, name, config): if not torch.cuda.is_bf16_supported(): self.skipTest("BFloat16 not supported on this device") + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + config.kv_cache_type = "bfloat16" os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_check_gqa_past( @@ -1657,10 +2008,74 @@ def test_gqa_past_flash_attention_bf16(self, name, config): ) +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestFlashGQABF16QuantizedKV(unittest.TestCase): + def manual_seed(self): + # Reset random seeds before each test to ensure test isolation + torch.manual_seed(0) + random.seed(69) + numpy.random.seed(42) + + def setUp(self): + self.manual_seed() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + def tearDown(self): + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() + + @parameterized.expand(gqa_cuda_quantized_test_cases(is_past=False)) + def test_gqa_quantized_prompt_bf16(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + self.manual_seed() + + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) + + @parameterized.expand(gqa_cuda_quantized_test_cases(is_past=True)) + def test_gqa_quantized_past_bf16(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + + self.manual_seed() + + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.bfloat16, + ort_type=TensorProto.BFLOAT16, + causal=True, + rtol=rtol[f"{config.kv_cache_type}_bf16"], + atol=atol[f"{config.kv_cache_type}_bf16"], + ) + + @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") class TestMemoryEfficientGQA(unittest.TestCase): @parameterized.expand(gqa_cuda_prompt_test_cases(allow_head_sink=False)) def test_gqa_prompt_memory_efficient(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" parity_check_gqa_prompt( config=config, @@ -1675,6 +2090,10 @@ def test_gqa_prompt_memory_efficient(self, name, config): @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) def test_gqa_past_memory_efficient(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" parity_check_gqa_past( config=config, @@ -1692,6 +2111,10 @@ def test_gqa_past_memory_efficient(self, name, config): class TestBF16MemoryEfficientGQA(unittest.TestCase): @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) def test_gqa_past_memory_efficient_bf16(self, name, config): + if enable_debug_print: + print("-" * 20) + print(f"test_case: {name}\n{config}") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" parity_check_gqa_past( config=config, @@ -1708,6 +2131,10 @@ def test_gqa_past_memory_efficient_bf16(self, name, config): @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFlashGQAPaddingPrompt(unittest.TestCase): def test_gqa_padding_prompt_flash_attention(self): + if enable_debug_print: + print("-" * 20) + print("test_case: test_gqa_padding_prompt_flash_attention") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" parity_test_gqa_padding_prompt() @@ -1715,20 +2142,24 @@ def test_gqa_padding_prompt_flash_attention(self): @unittest.skipIf(not has_cuda_device(53), "Memory Efficient Attention is not available, skipping tests.") class TestMemoryEfficientGQAPaddingPrompt(unittest.TestCase): def test_gqa_padding_prompt_memory_efficient_attention(self): + if enable_debug_print: + print("-" * 20) + print("test_case: test_gqa_padding_prompt_memory_efficient_attention") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" parity_test_gqa_padding_prompt() -@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") -class TestFusedKernelParity(unittest.TestCase): - """Tests that verify fused kernels produce the same results as unfused kernels.""" +# ################################################################################################# +# Fused Kernel Parity Tests (ORT_DISABLE_FUSED_KV and ORT_DISABLE_FLASH_DECODE) +# ################################################################################################# - def test_flash_decode_parity(self): - """Test ORT_DISABLE_FLASH_DECODE: fast decode vs standard path.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - # Decoding config (seq_len=1, share_buffer=True) - config = GQAConfig( +def fused_kernel_test_cases(): + """Test cases specifically for fused vs unfused kernel parity.""" + configs = [ + # Decoding with RoPE and shared buffer + GQAConfig( batch_size=2, q_sequence_length=1, kv_sequence_length=1, @@ -1740,40 +2171,103 @@ def test_flash_decode_parity(self): rotary=True, packed=False, share_buffer=True, - ) + ), + # Packed QKV decoding with RoPE + GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=8, + kv_num_heads=2, + head_size=128, + past_kv_sequence_length=64, + buffer_sequence_length=128, + rotary=True, + packed=True, + share_buffer=True, + ), + # Subsequent prompt with RoPE + GQAConfig( + batch_size=1, + q_sequence_length=4, + kv_sequence_length=4, + num_heads=8, + kv_num_heads=4, + head_size=128, + past_kv_sequence_length=32, + buffer_sequence_length=64, + rotary=True, + packed=False, + share_buffer=True, + ), + ] + for i, config in enumerate(configs): + yield f"fused_config_{i}", config + + +def gqa_xqa_test_cases(): + # Decoding config (seq_len=1, share_buffer=True) + # Testing different group sizes and query types + for torch_type, ort_type in [(torch.float16, TensorProto.FLOAT16), (torch.bfloat16, TensorProto.BFLOAT16)]: + for group_size in [4, 8, 16, 32]: + for past_kv_sequence_length in [1, 4]: + for rotary in [False, True]: + for packed in [False, True]: + for head_size in [256, 128, 64]: + kv_num_heads = 4 + num_heads = kv_num_heads * group_size + config = GQAConfig( + batch_size=2, + q_sequence_length=1, + kv_sequence_length=1, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + past_kv_sequence_length=past_kv_sequence_length, + buffer_sequence_length=past_kv_sequence_length + 128, + rotary=rotary, + packed=packed, + share_buffer=True, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="int8", + share_kv_scale=True, + ) + type_str = "bf16" if torch_type == torch.bfloat16 else "fp16" + rot_str = "rot" if rotary else "norot" + pkd_str = "pkd" if packed else "sep" + name = f"{type_str}_g_{group_size}_h{head_size}_past{past_kv_sequence_length}_{rot_str}_{pkd_str}" + yield name, config, torch_type, ort_type - # Run with flash decode enabled (default) - if "ORT_DISABLE_FLASH_DECODE" in os.environ: - del os.environ["ORT_DISABLE_FLASH_DECODE"] - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestXQAQuantizedParity(unittest.TestCase): + """Tests that verify fused kernels produce the same results as unfused kernels.""" + + def tearDown(self): + """Clear CUDA cache after each test to prevent memory corruption in batch runs.""" + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + gc.collect() - # Run with flash decode disabled - os.environ["ORT_DISABLE_FLASH_DECODE"] = "1" + @parameterized.expand(gqa_xqa_test_cases()) + def test_xqa_quantized_parity(self, name, config, torch_type, ort_type): + """Test XQA per-tensor INT8 quantized parity.""" + os.environ["ORT_ENABLE_XQA"] = "1" parity_check_gqa_past( config=config, ep="CUDAExecutionProvider", device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, + torch_type=torch_type, + ort_type=ort_type, causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], + rtol=rtol["int8_bf16"] if torch_type == torch.bfloat16 else rtol["int8_fp16"], + atol=atol["int8_bf16"] if torch_type == torch.bfloat16 else atol["int8_fp16"], + std=0.1, ) - # Clean up - del os.environ["ORT_DISABLE_FLASH_DECODE"] - @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestGQARegressions(unittest.TestCase): @@ -1819,6 +2313,48 @@ def test_gqa_rope_separate_qkv_bug(self): std=1.0, ) + def test_gqa_int8_large_seq_batch4(self): + """ + Regression test for batch_size=4 + max_seq_len=8192 + int8 KV cache crash. + This reproduces a CUDA illegal memory access due to scratch size under-allocation. + """ + if "CUDAExecutionProvider" not in get_available_providers(): + self.skipTest("CUDA required") + + # Config that triggers the crash: batch=4, large max_seq_len, int8 kv + config = GQAConfig( + batch_size=4, + num_heads=32, + kv_num_heads=8, + head_size=128, + q_sequence_length=1, + kv_sequence_length=1, + past_kv_sequence_length=8191, + buffer_sequence_length=8192, + rotary=True, + rotary_interleaved=False, + k_quant_type="PER_TENSOR", + v_quant_type="PER_TENSOR", + kv_cache_type="int8", + share_buffer=True, + share_kv_scale=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=5e-2, + atol=5e-2, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 2901ef7005926..19968db98edd7 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -199,7 +199,7 @@ def create_group_query_attention_graph_prompt( smooth_softmax=1 if use_smooth_softmax else 0, qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, + # past_present_share_buffer=1 if share_buffer else 0, domain="com.microsoft", ), ] @@ -442,7 +442,7 @@ def create_group_query_attention_graph_past( smooth_softmax=1 if use_smooth_softmax else 0, qk_output=config.qk_output.value, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, + # past_present_share_buffer=1 if share_buffer else 0, domain="com.microsoft", ), ] @@ -916,7 +916,9 @@ def gqa_prompt_func( ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + io_binding.synchronize_inputs() ort_session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() out_qk = None if config.qk_output != QKOutputType.NO_OUTPUT: @@ -1083,7 +1085,9 @@ def gqa_past_func( ort_outputs["output_qk"] = OrtValue.ortvalue_from_numpy(output_qk.detach().cpu().numpy(), "cpu", 0) io_binding.bind_ortvalue_output("output_qk", ort_outputs["output_qk"]) + io_binding.synchronize_inputs() ort_session.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() out_qk = None if config.qk_output != QKOutputType.NO_OUTPUT: diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index e328212e97b97..68b10f9fc4064 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -7,7 +7,6 @@ Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 7.5 or above. """ -import math import os import unittest @@ -20,203 +19,26 @@ from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime.transformers.io_binding_helper import CudaSession -ENABLE_DEBUG = False - - -class AttentionConfig: - def __init__( - self, - operator: str, - batch_size: int, - sequence_length: int, - max_sequence_length: int, - past_sequence_length: int, - num_heads: int, - kv_num_heads: int, - head_size: int, - softmax_scale: float | None, - do_rotary: bool, - rotary_interleaved: bool, - provider: str = "CUDAExecutionProvider", - device="cuda", - dtype=torch.float16, - share_buffer: bool = True, - is_packed_qkv: bool = False, - max_cache_sequence_length=None, - max_rotary_sequence_length=None, - use_smooth_softmax: bool = False, - ): - self.operator = operator - self.batch_size = batch_size - self.sequence_length = sequence_length - self.max_sequence_length = max_sequence_length - self.max_cache_sequence_length = max_cache_sequence_length or max_sequence_length - self.max_rotary_sequence_length = max_rotary_sequence_length or max_sequence_length - self.past_sequence_length = past_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - self.softmax_scale = softmax_scale or (1.0 / (head_size**0.5)) - - # Derived values - self.total_sequence_length = sequence_length + past_sequence_length - self.past_buffer_length = self.max_cache_sequence_length if share_buffer else past_sequence_length - self.present_buffer_length = ( - self.max_cache_sequence_length if share_buffer else (past_sequence_length + sequence_length) - ) - - self.do_rotary = do_rotary - self.rotary_interleaved = rotary_interleaved - - self.provider = provider - self.device = device - self.dtype = dtype - - self.share_buffer = share_buffer - self.is_packed_qkv = is_packed_qkv - - self.use_smooth_softmax = use_smooth_softmax - - def shape_dict(self): - shapes = { - "query": ( - self.batch_size, - self.sequence_length, - (self.num_heads + 2 * self.kv_num_heads) * self.head_size, - ), - "past_key": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), - "past_value": (self.batch_size, self.kv_num_heads, self.past_buffer_length, self.head_size), - "total_sequence_length": (1,), - "output": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "present_key": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), - "present_value": (self.batch_size, self.kv_num_heads, self.present_buffer_length, self.head_size), - "cos_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), - "sin_cache": (self.max_rotary_sequence_length, (math.floor(self.head_size / 16) * 16) // 2), - } - - if not self.is_packed_qkv: - shapes.update( - { - "query": (self.batch_size, self.sequence_length, self.num_heads * self.head_size), - "key": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), - "value": (self.batch_size, self.sequence_length, self.kv_num_heads * self.head_size), - } - ) - return shapes - - def get_cos_sin_cache(self, dtype): - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * self.head_size) / 16) * 16 - angle = torch.rand(self.max_rotary_sequence_length, rotary_dim // 2, device="cpu") * 2 * math.pi - cos = torch.cos(angle).to(dtype=dtype) - sin = torch.sin(angle).to(dtype=dtype) - return cos.to(device=self.device), sin.to(device=self.device) - - def random_inputs(self): - device = self.device - # ORT python I/O binding API supports bf16 via torch tensor. - dtype = self.dtype - - # Always use non-packed qkv to generate same inputs for Torch and ORT. - packed = self.is_packed_qkv # Save the original value. - self.is_packed_qkv = False - shape_dict = self.shape_dict() - self.is_packed_qkv = packed # Restore the original value. - torch.manual_seed(123) - - feeds = { - "query": torch.empty(shape_dict["query"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "key": torch.empty(shape_dict["key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "value": torch.empty(shape_dict["value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "past_key": torch.empty(shape_dict["past_key"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "past_value": torch.empty(shape_dict["past_value"], device=device, dtype=dtype).normal_(mean=0, std=0.1), - "total_sequence_length": torch.tensor([self.total_sequence_length], dtype=torch.int32), - } - - if packed: - query = feeds["query"].view(self.batch_size, self.sequence_length, self.num_heads, self.head_size) - key = feeds["key"].view(self.batch_size, self.sequence_length, self.kv_num_heads, self.head_size) - value = feeds["value"].view(self.batch_size, self.sequence_length, self.kv_num_heads, self.head_size) - feeds["query"] = torch.dstack((query, key, value)).reshape(self.batch_size, self.sequence_length, -1) - del feeds["key"] - del feeds["value"] - - if self.do_rotary: - cos_cache, sin_cache = self.get_cos_sin_cache(dtype) - feeds["cos_cache"] = cos_cache - feeds["sin_cache"] = sin_cache - - return feeds - - -class GroupQueryAttentionConfig(AttentionConfig): - def __init__( - self, - batch_size: int, - sequence_length: int, - max_sequence_length: int, - past_sequence_length: int, - num_heads: int, - kv_num_heads: int, - head_size: int, - softmax_scale=None, - do_rotary: bool = False, - rotary_interleaved: bool = False, - provider: str = "CUDAExecutionProvider", - device="cuda", - dtype=torch.float16, - local_window_size: int = -1, - attention_mask=None, - is_packed_qkv=False, - max_cache_sequence_length=None, - max_rotary_sequence_length=None, - use_smooth_softmax: bool = False, - ): - super().__init__( - "GroupQueryAttention", - batch_size=batch_size, - sequence_length=sequence_length, - max_sequence_length=max_sequence_length, - past_sequence_length=past_sequence_length, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - softmax_scale=softmax_scale, - do_rotary=do_rotary, - rotary_interleaved=rotary_interleaved, - provider=provider, - device=device, - dtype=dtype, - is_packed_qkv=is_packed_qkv, - max_cache_sequence_length=max_cache_sequence_length, - max_rotary_sequence_length=max_rotary_sequence_length, - use_smooth_softmax=use_smooth_softmax, - ) - # local_window_size is for ORT only, not for Torch implementation. - self.local_window_size = local_window_size - - # attention mask is for Torch implementation only, not for ORT. - self.attention_mask = attention_mask +try: + from gqa_test_helper import ( + AttentionConfig, + GroupQueryAttentionConfig, + OrtGroupQueryAttention, + ) +except ImportError: + import sys + + sys.path.insert(0, os.path.dirname(__file__)) + from gqa_test_helper import ( + AttentionConfig, + GroupQueryAttentionConfig, + OrtGroupQueryAttention, + ) - def shape_dict(self): - shapes = super().shape_dict() - shapes.update( - { - "seqlens_k": (self.batch_size,), - } - ) - return shapes +ENABLE_DEBUG = False - def random_inputs(self): - feeds = super().random_inputs() - k_seqlens = torch.ones((self.batch_size,), device=self.device, dtype=torch.int32) * self.total_sequence_length - feeds.update( - { - "seqlens_k": k_seqlens - 1, - } - ) - return feeds +# AttentionConfig and GroupQueryAttentionConfig moved to gqa_test_helper class SparseAttentionConfig(AttentionConfig): @@ -510,108 +332,7 @@ def create_sparse_attention_onnx_model(config: SparseAttentionConfig): return model.SerializeToString() -def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): - assert config.dtype in [torch.float16, torch.float32, torch.bfloat16] - - if config.dtype == torch.float16: - float_type = TensorProto.FLOAT16 - elif config.dtype == torch.bfloat16: - float_type = TensorProto.BFLOAT16 - else: - float_type = TensorProto.FLOAT - - # Build input list for the GQA node - node_inputs = [ - "query", - "key" if not config.is_packed_qkv else "", - "value" if not config.is_packed_qkv else "", - "past_key", - "past_value", - "seqlens_k", - "total_sequence_length" if config.share_buffer else "", - "cos_cache" if config.do_rotary else "", - "sin_cache" if config.do_rotary else "", - "", # position_ids (optional, not used in benchmark) - "", # attention_bias (optional, not used in benchmark) - "", # head_sink (optional, not used in benchmark) - ] - # Remove trailing empty strings - while node_inputs and node_inputs[-1] == "": - node_inputs.pop() - - # Build attributes dictionary - node_attrs = { - "num_heads": config.num_heads, - "kv_num_heads": config.kv_num_heads, - "scale": config.softmax_scale, - "local_window_size": config.local_window_size, - "do_rotary": 1 if config.do_rotary else 0, - "rotary_interleaved": config.rotary_interleaved, - "smooth_softmax": 1 if config.use_smooth_softmax else 0, - "domain": "com.microsoft", - } - - nodes = [ - helper.make_node( - "GroupQueryAttention", - node_inputs, - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - **node_attrs, - ), - ] - - shape_dict = config.shape_dict() - graph_input = [ - helper.make_tensor_value_info("query", float_type, list(shape_dict["query"])), - ] - - if not config.is_packed_qkv: - graph_input.extend( - [ - helper.make_tensor_value_info("key", float_type, list(shape_dict["key"])), - helper.make_tensor_value_info("value", float_type, list(shape_dict["value"])), - ] - ) - - cache_type = float_type - past_key_shape = list(shape_dict["past_key"]) - past_value_shape = list(shape_dict["past_value"]) - present_key_shape = list(shape_dict["present_key"]) - present_value_shape = list(shape_dict["present_value"]) - - graph_input.extend( - [ - helper.make_tensor_value_info("past_key", cache_type, past_key_shape), - helper.make_tensor_value_info("past_value", cache_type, past_value_shape), - helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, list(shape_dict["seqlens_k"])), - helper.make_tensor_value_info( - "total_sequence_length", TensorProto.INT32, list(shape_dict["total_sequence_length"]) - ), - ] - ) - - if config.do_rotary: - graph_input += [ - helper.make_tensor_value_info("cos_cache", float_type, list(shape_dict["cos_cache"])), - helper.make_tensor_value_info("sin_cache", float_type, list(shape_dict["sin_cache"])), - ] - - graph_output = [ - helper.make_tensor_value_info("output", float_type, list(shape_dict["output"])), - helper.make_tensor_value_info("present_key", cache_type, present_key_shape), - helper.make_tensor_value_info("present_value", cache_type, present_value_shape), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() +# create_group_query_attention_onnx_model moved to gqa_test_helper def create_session(onnx_model_str, cuda_provider_options=None) -> InferenceSession: @@ -648,8 +369,7 @@ def group_query_attention_reference( attn = attn.masked_fill((1 - mask).bool(), float("-inf")) if config.use_smooth_softmax: - head_sink = None - attn = smooth_softmax_ref(attn, head_sink) + attn = smooth_softmax_ref(attn, head_sink=None) else: attn = attn.softmax(-1) @@ -731,13 +451,11 @@ def infer(self): ) -def create_ort_session( - config: SparseAttentionConfig | GroupQueryAttentionConfig, session_options=None, enable_cuda_graph=False -) -> CudaSession: +def create_ort_session(config: SparseAttentionConfig, session_options=None, enable_cuda_graph=False) -> CudaSession: if isinstance(config, SparseAttentionConfig): onnx_model_str = create_sparse_attention_onnx_model(config) else: - onnx_model_str = create_group_query_attention_onnx_model(config) + raise ValueError("Only SparseAttentionConfig is supported directly here.") if config.provider == "CUDAExecutionProvider": device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index @@ -761,32 +479,7 @@ def create_ort_session( return cuda_session -class OrtGroupQueryAttention: - """A wrapper of ORT GroupQueryAttention to test relevance and performance.""" - - def __init__(self, config: GroupQueryAttentionConfig): - self.session = create_ort_session(config) - - self.feed_dict = config.random_inputs() - - if ENABLE_DEBUG and not config.is_packed_qkv: - query = self.feed_dict["query"].view( - config.batch_size, config.sequence_length, config.num_heads, config.head_size - ) - key = self.feed_dict["key"].view( - config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size - ) - value = self.feed_dict["value"].view( - config.batch_size, config.sequence_length, config.kv_num_heads, config.head_size - ) - print(vars(config)) - print("query(BSNH, GQA)", query) - print("key(BSNH, GQA)", key) - print("value(BSNH, GQA)", value) - print("seqlens_k (BSNH, GQA)", self.feed_dict["seqlens_k"]) - - def infer(self): - return self.session.infer(self.feed_dict) +# OrtGroupQueryAttention moved to gqa_test_helper class OrtSparseAttention: