Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 34 additions & 314 deletions onnxruntime/contrib_ops/webgpu/bert/attention.cc

Large diffs are not rendered by default.

10 changes: 4 additions & 6 deletions onnxruntime/contrib_ops/webgpu/bert/attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
public:
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false, bool is_unidirectional = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_unidirectional_(is_unidirectional) {
bool has_attention_bias, int tile_size, int components, bool has_seqlen_k = false, bool past_present_share_buffer = false, bool is_unidirectional = false)
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_unidirectional_(is_unidirectional) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand Down Expand Up @@ -78,7 +78,6 @@
int components_;
bool has_seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
bool is_unidirectional_;
};

Expand Down Expand Up @@ -110,8 +109,8 @@

class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
public:
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)

Check warning on line 112 in onnxruntime/contrib_ops/webgpu/bert/attention.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/contrib_ops/webgpu/bert/attention.h:112: Add #include <string> for string [build/include_what_you_use] [4]
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;
Expand All @@ -138,7 +137,6 @@
int tile_size_;
const Tensor* seqlen_k_;
bool past_present_share_buffer_;
bool is_first_prompt_;
};

class Attention final : public WebGpuKernel, public onnxruntime::contrib::AttentionBase {
Expand Down
126 changes: 126 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/attention_probs.wgsl.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param components
#param feed_past_key
#param has_present_key
#param has_attention_bias
#param has_seqlen_k
#param past_present_share_buffer
#param is_unidirectional
#param tile_size_param

var<workgroup> tileQ: array<q_value_t, tile_size_param * tile_size_param>;
var<workgroup> tileK: array<key_value_t, tile_size_param * tile_size_param>;
#if components == 4
alias f32_val_t = vec4<f32>;
#elif components == 2
alias f32_val_t = vec2<f32>;
#else
alias f32_val_t = f32;
#endif

#if has_attention_bias
fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> output_value_t {
// Handle broadcasting: if dimension size is 1, use index 0
let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
// Calculate flat offset with broadcasting applied
// attention_bias shape: [attn_bias_dim0, attn_bias_dim1, sequence_length, total_sequence_length]
let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.M * uniforms.N +
bias_head_idx * uniforms.M * uniforms.N +
q_idx * uniforms.N +
k_idx;
return attention_bias[offset];
}
#endif

$MAIN {
// x holds the N and y holds the M
let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;
let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;
let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));
let batch_idx = batch_head_idx / uniforms.num_heads;
let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;
let sequence_length = uniforms.M;
var total_sequence_length = uniforms.N;
#if has_seqlen_k
total_sequence_length = u32(seqlen_k[batch_idx]) + 1;
var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);
#else
let past_sequence_length = uniforms.past_sequence_length;
#endif
let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;
#if has_present_key
let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;
#endif

var value = f32_val_t(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
}
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
var idx = TILE_SIZE * local_id.y + local_id.x;
#if (feed_past_key && has_present_key) || (past_present_share_buffer)
if (n + local_id.y < past_sequence_length) {
let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;
#if past_present_share_buffer
tileK[idx] = present_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
#else
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
#endif
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
}
#else
if (n + local_id.y < uniforms.kv_sequence_length) {
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
}
#endif

#if has_present_key
#if past_present_share_buffer
if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {
#else
if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {
#endif
present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
}
#endif
}
workgroupBarrier();
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
}
workgroupBarrier();
}

if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {
let headOffset = batch_head_idx * uniforms.M * uniforms.N;
let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;
let head_idx = batch_head_idx % uniforms.num_heads;
#if components == 4
var sum: f32 = value.x + value.y + value.z + value.w;
#elif components == 2
var sum: f32 = value.x + value.y;
#else
var sum: f32 = value;
#endif

#if is_unidirectional
// Apply causal masking for unidirectional attention
let query_pos = m + local_id.y + past_sequence_length;
let key_pos = n + local_id.x;
if (key_pos > query_pos) {
sum = -3.4028234663852886e+38; // Set to very negative value for masking
}
#endif

#if has_attention_bias
output[outputIdx] = output_value_t(sum * uniforms.alpha) + loadAttentionBias(batch_idx, head_idx, m + local_id.y, n + local_id.x);
#else
output[outputIdx] = output_value_t(sum * uniforms.alpha);
#endif
}
} // MAIN
142 changes: 142 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/inplace_softmax.wgsl.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param components
#param work_group_size
#param use_smooth_softmax
#param has_seqlen_k
#param has_head_sink
#param has_sliding_window

#if components == 4
alias f32_val_t = vec4<f32>;
#elif components == 2
alias f32_val_t = vec2<f32>;
#else
alias f32_val_t = f32;
#endif

var<workgroup> thread_max: array<f32, work_group_size>;
var<workgroup> thread_sum: array<f32, work_group_size>;

$MAIN {
let sequence_length = uniforms.sequence_length;
let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;
let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;
var total_sequence_length = uniforms.total_sequence_length_comp * components;
#if has_seqlen_k
total_sequence_length = u32(seqlen_k[batch_idx]) + 1;
var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);
#else
let past_sequence_length = uniforms.past_sequence_length;
#endif
#if has_seqlen_k
let seq_causal_length = past_sequence_length + workgroup_idx % sequence_length + 1;
#else
let seq_causal_length = uniforms.total_sequence_length_comp;
#endif
let local_offset = local_idx * uniforms.elements_per_thread;
let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;

#if has_sliding_window
// Sliding window
let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size;
let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);
let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);
#else
// No sliding window: we keep the code for sliding window in the shader but
// using const for start_offset and should_apply_local_window will make the compiler optimize it out.
const start_offset = 0;
const should_apply_local_window = false;
let effective_seq_length = seq_causal_length;
#endif

var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
let actual_pos = local_offset + i + start_offset;
if (!should_apply_local_window || actual_pos < seq_causal_length) {
thread_max_vector = max(f32_val_t(x[offset + i + start_offset]), thread_max_vector);
}
}
#if components == 4
thread_max[local_idx] = max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w));
#elif components == 2
thread_max[local_idx] = max(thread_max_vector.x, thread_max_vector.y);
#else
thread_max[local_idx] = thread_max_vector;
#endif
workgroupBarrier();

#if has_head_sink
// Handle head sink
let sink_value: f32 = f32(head_sink[head_idx]);
var max_value = sink_value;
#elif use_smooth_softmax
var max_value: f32 = 0.0;
#else
var max_value = f32(-3.4028234663852886e+38f);
#endif

for (var i = 0u; i < work_group_size; i++) {
max_value = max(thread_max[i], max_value);
}
var sum_vector = f32_val_t(0);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
let actual_pos = local_offset + i + start_offset;
if (!should_apply_local_window || actual_pos < seq_causal_length) {
sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);
}
}
#if components == 4
thread_sum[local_idx] = sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w;
#elif components == 2
thread_sum[local_idx] = sum_vector.x + sum_vector.y;
#else
thread_sum[local_idx] = sum_vector;
#endif
workgroupBarrier();
var sum: f32 = 0;
for (var i = 0u; i < work_group_size; i++) {
sum += thread_sum[i]
;}

#if has_head_sink
sum += exp(sink_value - max_value);
#elif use_smooth_softmax
sum += exp(-max_value);
#endif

if (sum == 0) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
let actual_pos = local_offset + i + start_offset;
if (actual_pos < seq_causal_length) {
x[offset + i + start_offset] = x_value_t(x_element_t(1.0)/x_element_t(effective_seq_length));
}
}
} else {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
let actual_pos = local_offset + i + start_offset;
let pos = offset + i + start_offset;
if (!should_apply_local_window || actual_pos < seq_causal_length) {
var f32input = f32_val_t(x[pos]);
x[pos] = x_value_t(exp(f32input - max_value) / sum);
}
}
}

// zero out elements outside the sliding window
if (should_apply_local_window) {
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
let global_pos = i + local_offset;
if (global_pos < start_offset) {
x[offset + i] = x_value_t(x_element_t(0));
}
}
}

#if has_seqlen_k
for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {
x[offset + total_seq_id] = x_value_t(x_element_t(0));
}
#endif
} // MAIN
23 changes: 23 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/split_packed_qkv.wgsl.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#use guardAgainstOutOfBoundsWorkgroupSizes
#use .offsetToIndices .setByIndices .getByOffset

$MAIN {
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.input_size);
let packed_qkv_indices = packed_qkv.offsetToIndices(global_idx);
let batch = packed_qkv_indices[0];
let seq = packed_qkv_indices[1];
let d = packed_qkv_indices[2];
let input_data = packed_qkv.getByOffset(global_idx);
if (d < uniforms.hidden_size) {
query.setByIndices(vec3<u32>(batch, seq, d), input_data);
} else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {
let kd = d - uniforms.hidden_size;
key.setByIndices(vec3<u32>(batch, seq, kd), input_data);
} else {
let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;
val.setByIndices(vec3<u32>(batch, seq, vd), input_data);
}
} // MAIN
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#param has_bias

#use guardAgainstOutOfBoundsWorkgroupSizes
#use .offsetToIndices

$MAIN {
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.data_size);
let output_indices = qkv_output.offsetToIndices(global_idx);
let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *
uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];
#if has_bias
let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;
qkv_output[global_idx] = qkv_input[input_offset_idx] + bias[bias_offset_idx];
#else
qkv_output[global_idx] = qkv_input[input_offset_idx];
#endif
} // MAIN
Loading
Loading