Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention ke
option(onnxruntime_USE_LEAN_ATTENTION "Build lean attention kernel for scaled dot product attention" OFF)
cmake_dependent_option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_FPA_INTB_GEMM "Build FpA IntB gemm cuda kernels" OFF)
option(onnxruntime_USE_INT4_KV_CACHE "Build cuda kernels for int4 kv cache" OFF)
option(onnxruntime_QUICK_BUILD "Speed up build by skipping some kernels for faster development" OFF)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
Expand All @@ -125,6 +126,7 @@ option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF)

option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF)
option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF)
option(onnxruntime_DUMP_TENSOR "Dump tensor inside kernel." OFF)
cmake_dependent_option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS_ENABLE_DUMP_TO_SQLDB "Build dump debug information about node inputs and outputs with support for sql database." OFF "onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS" OFF)

# When loading a delay loaded DLL, Windows searches the main EXE's folder first.
Expand Down Expand Up @@ -627,7 +629,6 @@ else()
check_cxx_compiler_flag(-Wparentheses HAS_PARENTHESES)
check_cxx_compiler_flag(-Wshorten-64-to-32 HAS_SHORTEN_64_TO_32)
check_cxx_compiler_flag(-Wstrict-aliasing HAS_STRICT_ALIASING)
check_nvcc_compiler_flag(-Wstrict-aliasing NVCC_HAS_STRICT_ALIASING)
check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW)
check_cxx_compiler_flag(-Wtautological-pointer-compare HAS_TAUTOLOGICAL_POINTER_COMPARE)
check_cxx_compiler_flag(-Wundefined-var-template HAS_UNDEFINED_VAR_TEMPLATE)
Expand Down Expand Up @@ -774,8 +775,13 @@ if (onnxruntime_USE_CUDA)
endif()

if (onnxruntime_QUICK_BUILD)
message( STATUS "Quick build mode: Flash attention limited to fp16 only")
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
message( STATUS "Quick build mode: Flash attention limited to head dimension 128 only")
list(APPEND ORT_PROVIDER_FLAGS -DORT_QUICK_BUILD=1)
endif()

if (onnxruntime_USE_INT4_KV_CACHE)
message( STATUS "Enable int4 kv cache for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_INT4_KV_CACHE=1)
endif()
endif()

Expand Down Expand Up @@ -1433,6 +1439,9 @@ if (Git_FOUND)
if (onnxruntime_QUICK_BUILD)
string(APPEND ORT_BUILD_INFO "quick-build=1, ")
endif()
if (onnxruntime_USE_INT4_KV_CACHE)
string(APPEND ORT_BUILD_INFO "int4-kv-cache=1, ")
endif()
endif()
string(APPEND ORT_BUILD_INFO "build type=${CMAKE_BUILD_TYPE}")
configure_file(onnxruntime_config.h.in ${CMAKE_CURRENT_BINARY_DIR}/onnxruntime_config.h)
Expand All @@ -1446,6 +1455,8 @@ if (onnxruntime_USE_CUDA)
find_package(CUDAToolkit REQUIRED)

if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.8)
add_definitions("-DENABLE_BF16")
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_BF16 flag")
add_definitions("-DENABLE_FP8")
message(STATUS "CUDA Toolkit version is greater or equal than 11.8, enable -DENABLE_FP8 flag")
endif()
Expand Down Expand Up @@ -1779,6 +1790,10 @@ if (onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS)
add_compile_definitions(DEBUG_NODE_INPUTS_OUTPUTS)
endif()

if (onnxruntime_DUMP_TENSOR)
add_compile_definitions(DUMP_TENSOR_LEVEL=1)
endif()

if (onnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS)
if (NOT CMAKE_SYSTEM_NAME STREQUAL "Linux")
message(FATAL_ERROR "External custom operator schemas feature is only supported on Linux")
Expand Down
2 changes: 0 additions & 2 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ file(GLOB_RECURSE onnxruntime_cuda_contrib_ops_cu_srcs CONFIGURE_DEPENDS
# Quick build mode: Filter flash attention kernels for faster development iteration.
# - We keep only hdim128 fp16 flash attention kernels in quick build mode.
# - All other listed head dimensions are excluded (e.g., 32, 64, 96, 192, 256).
# - This regex matches both `flash_fwd_hdim*` and `flash_fwd_split_hdim*` kernels.
# If new head dimensions are added or removed, update this list to match the supported set.
if(onnxruntime_QUICK_BUILD)
message(STATUS "Quick build mode enabled: Only building hdim128 fp16 flash attention kernels")
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*hdim(32|64|96|192|256)")
list(FILTER onnxruntime_cuda_contrib_ops_cu_srcs EXCLUDE REGEX "flash_fwd.*_bf16")
endif()

file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
Expand Down
49 changes: 37 additions & 12 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2520,15 +2520,26 @@ This version of the operator has been available since version 1 of the 'com.micr

### <a name="com.microsoft.GroupQueryAttention"></a><a name="com.microsoft.groupqueryattention">**com.microsoft.GroupQueryAttention**</a>

Group Query Self/Cross Attention.
Group Query Self/Cross Attention with KV Cache Quantization Support.

*Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv.
Supports different number of heads for q and kv for CPU and CUDA.
Only supports causal and local attention.
Supports rotary position embedding for CPU and CUDA.
Supports packed input for CPU and CUDA.
Supports continuous decoding for batch_size == 1 for CPU and CUDA.
This operator implements causal grouped-query attention with past state (KV cache) support.
It also supports optional float8, int8 or int4 quantization for the KV cache to reduce memory footprint.

**Cache Format:**
The past and present KV cache tensors are expected in a BNSH format: `(batch_size, num_heads, cache_sequence_length, head_size)`, where `cache_sequence_length` is the length of the cached key/value sequences, or the maximum sequence length when past and present buffer sharing is used.

**Quantization:**
When quantization is enabled, `past_key` and `past_value` inputs can be of type `float8e4m3fn`, `uint8` or `int8`. The corresponding `k_scale` and `v_scale` tensors must be provided.
The operator will output `present_key` and `present_value` in same format as the `past_key` and `past_value`.

For 4-bit quantization, the data type is uint8 where each byte contains two 4-bit values. The bit width of quantized KV cache can be set using `kv_cache_bit_width` attribute.

The shapes of the k_scale, v_scale tensors shall be broadcastable to present_key shape.

**Quantization Modes (`k_quant_type`, `v_quant_type` attributes):**
- **"NONE"**: No quantization.
- **"PER_TENSOR"**: A single scale for the entire tensor. Scale example shape: `[1]`.
- **"PER_CHANNEL"**: A scale for each channel. Scale example shape: `[1, num_heads_k, 1, head_size]`.

#### Version

Expand All @@ -2539,6 +2550,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>do_rotary</tt> : int</dt>
<dd>Whether to use rotary position embedding. Default value is 0.</dd>
<dt><tt>k_quant_type</tt> : string</dt>
<dd>Quantization type for K cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.</dd>
<dt><tt>kv_cache_bit_width</tt> : int</dt>
<dd>Bit width of quantized KV cache. Supported values are 8 and 4.</dd>
<dt><tt>kv_num_heads</tt> : int (required)</dt>
<dd>Number of attention heads for k and v</dd>
<dt><tt>local_window_size</tt> : int</dt>
Expand All @@ -2555,9 +2570,11 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Use a smooth factor in softmax.</dd>
<dt><tt>softcap</tt> : float</dt>
<dd>Softcap value for attention weights. Default value is 0.</dd>
<dt><tt>v_quant_type</tt> : string</dt>
<dd>Quantization type for V cache. One of 'NONE', 'PER_TENSOR', 'PER_CHANNEL'.</dd>
</dl>

#### Inputs (7 - 12)
#### Inputs (7 - 14)

<dl>
<dt><tt>query</tt> : T</dt>
Expand All @@ -2566,9 +2583,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Key with shape (batch_size, kv_sequence_length, kv_hidden_size) </dd>
<dt><tt>value</tt> (optional) : T</dt>
<dd>Value with shape (batch_size, kv_sequence_length, kv_hidden_size)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>
<dt><tt>past_key</tt> (optional) : T_CACHE</dt>
<dd>past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>past_value</tt> (optional) : T</dt>
<dt><tt>past_value</tt> (optional) : T_CACHE</dt>
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
<dt><tt>seqlens_k</tt> : M</dt>
<dd>1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).</dd>
Expand All @@ -2584,16 +2601,20 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)</dd>
<dt><tt>head_sink</tt> (optional) : T</dt>
<dd>1D tensor with shape (num_heads). Each head has a smooth factor adding to the denominator of softmax.</dd>
<dt><tt>k_scale</tt> (optional) : T_KV_SCALE</dt>
<dd>Scale tensor for past_key.</dd>
<dt><tt>v_scale</tt> (optional) : T_KV_SCALE</dt>
<dd>Scale tensor for past_value.</dd>
</dl>

#### Outputs (3 - 4)

<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, hidden_size)</dd>
<dt><tt>present_key</tt> : T</dt>
<dt><tt>present_key</tt> : T_CACHE</dt>
<dd>present state key with support for format BNSH. When past_key uses same tensor as present_key(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>present_value</tt> : T</dt>
<dt><tt>present_value</tt> : T_CACHE</dt>
<dd>present state value with support for format BNSH. When past_value uses same tensor as present_value(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +kv_sequence_length.</dd>
<dt><tt>output_qk</tt> (optional) : T</dt>
<dd>Values of QK matrix multiplication, either before or after softmax normalization</dd>
Expand All @@ -2604,6 +2625,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>T</tt> : tensor(float16), tensor(bfloat16), tensor(float)</dt>
<dd>Constrain input and output to float tensors.</dd>
<dt><tt>T_CACHE</tt> : tensor(float), tensor(float16), tensor(bfloat16), tensor(uint8), tensor(int8), tensor(float8e4m3fn)</dt>
<dd>Constrain KV cache types.</dd>
<dt><tt>T_KV_SCALE</tt> : tensor(float)</dt>
<dd>Constrain KV cache scale types.</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to int tensor.</dd>
</dl>
Expand Down
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
Expand Down Expand Up @@ -1003,7 +1003,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(bfloat16), tensor(float16)<br/> **T_CACHE** = tensor(bfloat16), tensor(float16), tensor(int8)<br/> **T_KV_SCALE** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down Expand Up @@ -1484,7 +1484,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T_CACHE**<br> *in* past_value:**T_CACHE**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *in* position_ids:**tensor(int64)**<br> *in* attention_bias:**T**<br> *in* head_sink:**T**<br> *in* k_scale:**T_KV_SCALE**<br> *in* v_scale:**T_KV_SCALE**<br> *out* output:**T**<br> *out* present_key:**T_CACHE**<br> *out* present_value:**T_CACHE**<br> *out* output_qk:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T3**<br> *in* g_idx:**T4**<br> *in* bias:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* attention_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**<br> *out* qk:**QK**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ enum class QKOutputType : int {
AFTER_SOFTMAX = 2
};

// Enum to define quantization granularity.
enum class KVQuantizationType : int {
NONE = 0,
PER_TENSOR = 1,
PER_CHANNEL = 2,
};

constexpr bool LAYOUT_BSNH = false;
constexpr bool LAYOUT_BNSH = true;

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ struct GroupQueryAttentionParameters : AttentionParameters {
AttentionQkvFormat past_kv_format;
int zeros_count;
int* zero_ptr;

// Quantization parameters for KV cache
KVQuantizationType k_quant_type = KVQuantizationType::NONE;
KVQuantizationType v_quant_type = KVQuantizationType::NONE;
int kv_cache_bit_width = 0;
};

// Parameters deduced from node attributes and inputs/outputs.
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
seqlens_k,
total_seqlen_tensor,
scale_,
softcap_));
softcap_,
0));

ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids,
attention_bias,
Expand Down
Loading
Loading