Skip to content

Commit 1817f4a

Browse files
authored
[Reland] Attention(23) CUDA (#27030)
## Reland reason Reland #26466 The previous PR was reverted because it fails on the test: 1. Windows GPU CUDA CI Pipeline Test Job 2. Windows GPU TensorRT CI Pipeline Test Job This PR includes the correct [fix](cc7a947). --- ## Description This pull request introduces significant improvements and expanded support for multi-head attention kernels in ONNX Runtime, particularly focusing on supporting both 3D (`BSNH`) and 4D (`BNSH`) QKV input formats. The changes enhance flexibility, correctness, and maintainability for attention operations across CPU and CUDA implementations. ### Expanded QKV Input Format Support * Added support for 4D QKV input format (`Q_K_V_BNSH`) in CUDA attention kernels, including proper handling for both cases with and without past/present states, and enforcing that bias is not supported for this format. This includes logic to avoid unnecessary transposes and to write outputs directly when possible. [[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R264-R265) [[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R343-R354) [[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R388-L388) [[4]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R426-R435) [[5]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716) [[6]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11R747-R748) [[7]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791) ### Kernel and Operator Documentation Updates * Updated `OperatorKernels.md` to document the new `Attention` operator inputs and outputs for both 3D and 4D formats, specifying supported tensor types for each input. ### Correctness and Consistency Fixes * Fixed the computation of causal attention indices in CUDA softmax kernels by clarifying and correcting the offset calculation for causal masking. [[1]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL168-R168) [[2]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL244-R244) [[3]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL336-R336) [[4]](diffhunk://#diff-5367f3a93f596de362b09239a92fd1199b3c62fdded9e790810c80526ff9ec9bL442-R442) * Updated workspace allocation logic for QKV preparation to ensure correct workspace usage for new formats. ### Attention Parameter and Helper Refactoring * Added `is_output_bnsh` field to `AttentionParameters` to indicate output format and updated logic to use this for output placement and transposition decisions. [[1]](diffhunk://#diff-e742290164e1e1fa0152840db2a1b83354e153153df19a2762b58655e49b7f9bR37) [[2]](diffhunk://#diff-25a30e78aab7a4cdd1d6ba9f3576fc36b79dd3404225d77ea2ee0018490a83eaL775-R791) * Refactored CPU attention implementation to use the new `attention_helper` namespace for output mode enums and output shape computation, improving code clarity and maintainability. [[1]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7R5) [[2]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L118-R125) [[3]](diffhunk://#diff-e692b5c865c4874e51982867901cd514e68cf38dd435c00fe505f34f93956fe7L143-R149) ### Minor Cleanups * Removed outdated asserts and improved debug output strings for QKV preparation functions to clarify format and state handling. [[1]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L254) [[2]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L363) [[3]](diffhunk://#diff-64c7062a412bd7e91378e5c40574de5a1bf63f42ec4cf7d2d23e812fde5bcd11L673-R716) These changes collectively improve the flexibility, correctness, and maintainability of attention kernel implementations in ONNX Runtime, especially for advanced transformer models and large language model workloads. **NOT supported in this PR** - Boolean mask - GQA - Softcap - Softmax precision - qk_output_mode other than -1 and 0 - **is_causal=True && q_sequence_kength != kv_sequence_length**
1 parent 36017ad commit 1817f4a

File tree

15 files changed

+745
-396
lines changed

15 files changed

+745
-396
lines changed

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ Do not modify directly.*
653653
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
654654
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
655655
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
656+
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
656657
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
657658
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
658659
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|

onnxruntime/contrib_ops/cpu/bert/attention_parameters.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ struct AttentionParameters {
3434
float mask_filter_value;
3535
float scale;
3636
bool use_tf32;
37+
bool is_output_bnsh = false; // whether the output format is BNSH
3738
AttentionMaskType mask_type;
3839
AttentionQkvFormat qkv_format;
3940
};

onnxruntime/contrib_ops/cuda/bert/attention_impl.cu

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -772,20 +772,23 @@ Status UnfusedAttention(
772772

773773
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
774774

775-
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
776-
T* temp_output = data.q;
775+
// compute R*V (as V*R), and store in output or temp workspace depending on whether transpose is needed
776+
// For 4D input (BNSH), write directly to output. For 3D input (BSNH), write to temp then transpose.
777+
T* temp_output = parameters.is_output_bnsh ? data.output : data.q;
777778
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
778779
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
779780
v_head_size, sequence_length, total_sequence_length,
780781
&one, data.v, v_head_size, present_size_per_batch_v,
781782
scratch2, total_sequence_length, sequence_length * total_sequence_length,
782783
&zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32));
783784

784-
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
785-
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
786-
device_prop.maxThreadsPerBlock, false, temp_output, data.output);
785+
if (!parameters.is_output_bnsh) {
786+
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
787+
ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
788+
device_prop.maxThreadsPerBlock, false, temp_output, data.output));
789+
}
787790
DUMP_TENSOR_D("Attention Output", data.output, batch_size, sequence_length, num_heads, v_head_size);
788-
return result;
791+
return Status::OK();
789792
}
790793

791794
template <typename T>
@@ -960,7 +963,7 @@ Status QkvToContext(
960963
auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
961964
const int max_threads_per_block = device_prop.maxThreadsPerBlock;
962965
const int batch_size = parameters.batch_size;
963-
const int sequence_length = parameters.sequence_length;
966+
const int kv_sequence_length = parameters.kv_sequence_length;
964967
const int total_sequence_length = parameters.total_sequence_length;
965968
const int num_heads = parameters.num_heads;
966969
const int qk_head_size = parameters.head_size;
@@ -981,12 +984,12 @@ Status QkvToContext(
981984

982985
if (!parameters.past_present_share_buffer) {
983986
ORT_RETURN_IF_ERROR(ConcatPastToPresent<T>(batch_size, num_heads, qk_head_size, v_head_size,
984-
sequence_length, total_sequence_length,
987+
kv_sequence_length, total_sequence_length,
985988
stream, max_threads_per_block, data));
986989

987990
} else { // past_present_share_buffer
988991
ORT_RETURN_IF_ERROR(PastPresentBufferShare<T>(batch_size, num_heads, qk_head_size, v_head_size,
989-
sequence_length, fused_runner,
992+
kv_sequence_length, fused_runner,
990993
parameters, data, stream, max_threads_per_block));
991994
}
992995

onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu

Lines changed: 144 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
251251
AttentionData<T>& data,
252252
cudaStream_t stream,
253253
int max_threads_per_block) {
254-
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
255254
assert(data.query != nullptr);
256255
assert(data.key != nullptr);
257256
assert(data.value != nullptr);
@@ -262,82 +261,96 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
262261
assert(!parameters.is_unidirectional);
263262
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_NoPast(data));
264263

265-
const int batch_size = parameters.batch_size;
266-
const int sequence_length = parameters.sequence_length;
267-
const int kv_sequence_length = parameters.kv_sequence_length;
268-
const int num_heads = parameters.num_heads;
269-
const int qk_head_size = parameters.head_size;
270-
const int v_head_size = parameters.v_head_size;
271-
272-
if (data.fused_cross_attention_kernel != nullptr) {
273-
assert(qk_head_size == v_head_size);
274-
assert(data.attention_bias == nullptr);
275-
assert(data.mask_index == nullptr);
276-
assert(parameters.hidden_size == parameters.v_hidden_size);
277-
278-
// For fused cross attention, besides adding bias, K and V needed to be packed:
279-
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
280-
LaunchAddBiasTransposeTrt(
281-
stream, max_threads_per_block,
282-
batch_size, sequence_length,
283-
num_heads, qk_head_size,
284-
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
285-
data.v = nullptr;
286-
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
287-
} else if (data.use_memory_efficient_attention ||
288-
data.use_flash_attention ||
289-
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
290-
if (data.bias != nullptr) {
291-
LaunchAddBias(stream, max_threads_per_block,
292-
batch_size, sequence_length, kv_sequence_length,
293-
num_heads, qk_head_size, v_head_size,
294-
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
295-
} else {
296-
data.q = const_cast<T*>(data.query);
297-
data.k = const_cast<T*>(data.key);
298-
data.v = const_cast<T*>(data.value);
299-
}
264+
if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
265+
// 3D inputs in BSNH format (will be transposed)
266+
const int batch_size = parameters.batch_size;
267+
const int sequence_length = parameters.sequence_length;
268+
const int kv_sequence_length = parameters.kv_sequence_length;
269+
const int num_heads = parameters.num_heads;
270+
const int qk_head_size = parameters.head_size;
271+
const int v_head_size = parameters.v_head_size;
272+
273+
if (data.fused_cross_attention_kernel != nullptr) {
274+
assert(qk_head_size == v_head_size);
275+
assert(data.attention_bias == nullptr);
276+
assert(data.mask_index == nullptr);
277+
assert(parameters.hidden_size == parameters.v_hidden_size);
278+
279+
// For fused cross attention, besides adding bias, K and V needed to be packed:
280+
// Key (BxSxNxH), Value (BxSxNxH) => Q (BxSxNxH), K (BxSxNx2xH)
281+
LaunchAddBiasTransposeTrt(
282+
stream, max_threads_per_block,
283+
batch_size, sequence_length,
284+
num_heads, qk_head_size,
285+
data.bias, data.query, data.key, data.value, data.q, true, kv_sequence_length);
286+
data.v = nullptr;
287+
data.qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
288+
} else if (data.use_memory_efficient_attention ||
289+
data.use_flash_attention ||
290+
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
291+
if (data.bias != nullptr) {
292+
LaunchAddBias(stream, max_threads_per_block,
293+
batch_size, sequence_length, kv_sequence_length,
294+
num_heads, qk_head_size, v_head_size,
295+
data.bias, data.query, data.key, data.value, data.q, data.k, data.v);
296+
} else {
297+
data.q = const_cast<T*>(data.query);
298+
data.k = const_cast<T*>(data.key);
299+
data.v = const_cast<T*>(data.value);
300+
}
300301

301-
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
302-
} else if (data.fused_runner != nullptr) {
303-
assert(qk_head_size == v_head_size);
304-
assert(data.attention_bias == nullptr);
302+
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
303+
} else if (data.fused_runner != nullptr) {
304+
assert(qk_head_size == v_head_size);
305+
assert(data.attention_bias == nullptr);
305306

306-
// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
307-
LaunchAddBiasTransposeTrt(
308-
stream, max_threads_per_block,
309-
batch_size, sequence_length,
310-
num_heads, qk_head_size,
311-
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
312-
data.k = nullptr;
313-
data.v = nullptr;
307+
// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
308+
LaunchAddBiasTransposeTrt(
309+
stream, max_threads_per_block,
310+
batch_size, sequence_length,
311+
num_heads, qk_head_size,
312+
data.bias, data.query, data.key, data.value, data.q, false, kv_sequence_length);
313+
data.k = nullptr;
314+
data.v = nullptr;
314315

315-
data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
316-
} else { // unfused kernel
316+
data.qkv_format = AttentionQkvFormat::QKV_BSN3H;
317+
} else { // unfused kernel
318+
assert(data.IsUnfused());
319+
// Query (BxSxNxH) => Q (BxNxSxH)
320+
constexpr int format = 0;
321+
LaunchAddBiasTranspose<T>(
322+
stream, 1, format, max_threads_per_block,
323+
batch_size, sequence_length, num_heads, qk_head_size,
324+
data.query, data.bias, data.q,
325+
true, -1);
326+
327+
// Key (BxLxNxH) => K (BxNxLxH)
328+
LaunchAddBiasTranspose<T>(
329+
stream, 1, format, max_threads_per_block,
330+
batch_size, kv_sequence_length, num_heads, qk_head_size,
331+
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
332+
true, -1);
333+
334+
// Value (BxLxNxH_v) => K (BxNxLxH_v)
335+
LaunchAddBiasTranspose<T>(
336+
stream, 1, format, max_threads_per_block,
337+
batch_size, kv_sequence_length, num_heads, v_head_size,
338+
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
339+
true, -1);
340+
341+
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
342+
}
343+
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
344+
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
317345
assert(data.IsUnfused());
318-
// Query (BxSxNxH) => Q (BxNxSxH)
319-
constexpr int format = 0;
320-
LaunchAddBiasTranspose<T>(
321-
stream, 1, format, max_threads_per_block,
322-
batch_size, sequence_length, num_heads, qk_head_size,
323-
data.query, data.bias, data.q,
324-
true, -1);
325-
326-
// Key (BxLxNxH) => K (BxNxLxH)
327-
LaunchAddBiasTranspose<T>(
328-
stream, 1, format, max_threads_per_block,
329-
batch_size, kv_sequence_length, num_heads, qk_head_size,
330-
data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, data.k,
331-
true, -1);
332-
333-
// Value (BxLxNxH_v) => K (BxNxLxH_v)
334-
LaunchAddBiasTranspose<T>(
335-
stream, 1, format, max_threads_per_block,
336-
batch_size, kv_sequence_length, num_heads, v_head_size,
337-
data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, data.v,
338-
true, -1);
339-
340-
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
346+
// Attention-23 does not support bias with Q_K_V_BNSH format.
347+
assert(data.bias == nullptr);
348+
// No need to transpose since QKV is already in BNSH format.
349+
data.q = const_cast<T*>(data.query);
350+
data.k = const_cast<T*>(data.key);
351+
data.v = const_cast<T*>(data.value);
352+
} else {
353+
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
341354
}
342355

343356
return Status::OK();
@@ -360,7 +373,6 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
360373
AttentionData<T>& data,
361374
cudaStream_t stream,
362375
int max_threads_per_block) {
363-
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
364376
assert(data.query != nullptr);
365377
assert(data.key != nullptr);
366378
assert(data.value != nullptr);
@@ -373,42 +385,53 @@ Status PrepareQkv_MHA_WithPast_NoBias(contrib::AttentionParameters& parameters,
373385
data.past_key != nullptr && data.past_value != nullptr);
374386
assert(data.has_qkv_workspace == !NoQkvWorkspace_MHA_WithPast_NoBias(data));
375387

376-
const int batch_size = parameters.batch_size;
377-
const int sequence_length = parameters.sequence_length;
378-
const int kv_sequence_length = parameters.kv_sequence_length;
379-
const int num_heads = parameters.num_heads;
380-
const int qk_head_size = parameters.head_size;
381-
const int v_head_size = parameters.v_head_size;
382-
383388
// When there is no past state and there is present state, we output K and V directly to present state.
384389
if (data.past_key == nullptr && data.present_key != nullptr) {
385390
data.k = data.present_key;
386391
data.v = data.present_value;
387392
}
393+
if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH) {
394+
// 3D inputs in BSNH format (will be transposed)
395+
const int batch_size = parameters.batch_size;
396+
const int sequence_length = parameters.sequence_length;
397+
const int kv_sequence_length = parameters.kv_sequence_length;
398+
const int num_heads = parameters.num_heads;
399+
const int qk_head_size = parameters.head_size;
400+
const int v_head_size = parameters.v_head_size;
401+
402+
if (data.use_memory_efficient_attention ||
403+
data.use_flash_attention ||
404+
data.use_lean_attention ||
405+
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
406+
// Use oiginal Query (BSNH) since there is no bias.
407+
data.q = const_cast<T*>(data.query);
388408

389-
if (data.use_memory_efficient_attention ||
390-
data.use_flash_attention ||
391-
data.use_lean_attention ||
392-
data.kernel_type == AttentionKernelType::AttentionKernel_CudnnFlashAttention) {
393-
// Use oiginal Query (BSNH) since there is no bias.
394-
data.q = const_cast<T*>(data.query);
395-
396-
// Key (BxLxNxH) => K (BxNxLxH)
397-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
398-
max_threads_per_block, false, data.key, data.k));
399-
// Value (BxLxNxH) => V (BxNxLxH)
400-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
401-
max_threads_per_block, false, data.value, data.v));
402-
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
403-
} else { // unfused kernel
409+
// Key (BxLxNxH) => K (BxNxLxH)
410+
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
411+
max_threads_per_block, false, data.key, data.k));
412+
// Value (BxLxNxH) => V (BxNxLxH)
413+
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
414+
max_threads_per_block, false, data.value, data.v));
415+
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH;
416+
} else { // unfused kernel
417+
assert(data.IsUnfused());
418+
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
419+
max_threads_per_block, false, data.query, data.q));
420+
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
421+
max_threads_per_block, false, data.key, data.k));
422+
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
423+
max_threads_per_block, false, data.value, data.v));
424+
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
425+
}
426+
} else if (parameters.qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
427+
// Currently, 4D inputs are only supported in unfused kernel for Attention-23.
404428
assert(data.IsUnfused());
405-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
406-
max_threads_per_block, false, data.query, data.q));
407-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
408-
max_threads_per_block, false, data.key, data.k));
409-
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
410-
max_threads_per_block, false, data.value, data.v));
411-
data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
429+
// No need to transpose since QKV is already in BNSH format.
430+
data.q = const_cast<T*>(data.query);
431+
data.k = const_cast<T*>(data.key);
432+
data.v = const_cast<T*>(data.value);
433+
} else {
434+
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
412435
}
413436

414437
return Status::OK();
@@ -670,14 +693,27 @@ Status PrepareQkv_MultiHeadAttention(contrib::AttentionParameters& parameters,
670693
case AttentionQkvFormat::Q_K_V_BSNH:
671694
if (data.past_key != nullptr || data.present_key != nullptr) {
672695
if (data.bias == nullptr) {
673-
DUMP_STRING("PrepareQkv_MHA_WithPast_NoBias");
696+
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_NoBias");
674697
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
675698
} else {
676-
DUMP_STRING("PrepareQkv_MHA_WithPast_Bias");
699+
DUMP_STRING("PrepareQkv(3D)_MHA_WithPast_Bias");
677700
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_Bias(parameters, data, stream, max_threads_per_block));
678701
}
679702
} else { // no past state
680-
DUMP_STRING("PrepareQkv_MHA_NoPast");
703+
DUMP_STRING("PrepareQkv(3D)_MHA_NoPast");
704+
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
705+
}
706+
break;
707+
case AttentionQkvFormat::Q_K_V_BNSH:
708+
if (data.past_key != nullptr || data.present_key != nullptr) {
709+
if (data.bias == nullptr) {
710+
DUMP_STRING("PrepareQkv(4D)_MHA_WithPast_NoBias");
711+
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast_NoBias(parameters, data, stream, max_threads_per_block));
712+
} else {
713+
ORT_THROW("Q_K_V_BNSH format with bias is not supported.");
714+
}
715+
} else { // no past state
716+
DUMP_STRING("PrepareQkv(4D)_MHA_NoPast");
681717
ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NoPast(parameters, data, stream, max_threads_per_block));
682718
}
683719
break;
@@ -708,6 +744,8 @@ bool NoQkvWorkspace(contrib::AttentionParameters& parameters, AttentionData<T>&
708744
} else { // no past state
709745
return NoQkvWorkspace_MHA_NoPast(data);
710746
}
747+
case AttentionQkvFormat::Q_K_V_BNSH:
748+
return false; // currently no scenario needs no workspace
711749
default:
712750
ORT_THROW("Unsupported QKV format: ", parameters.qkv_format);
713751
}

0 commit comments

Comments
 (0)