Skip to content

Commit 4c745d9

Browse files
committed
move gqa shader assembly to .wsgl for readability
1 parent 06fe9a4 commit 4c745d9

File tree

7 files changed

+429
-320
lines changed

7 files changed

+429
-320
lines changed

onnxruntime/contrib_ops/webgpu/bert/attention.cc

Lines changed: 34 additions & 314 deletions
Large diffs are not rendered by default.

onnxruntime/contrib_ops/webgpu/bert/attention.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ class SplitPackedQKVProgram final : public Program<SplitPackedQKVProgram> {
4646
class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
4747
public:
4848
AttentionProbsProgram(const std::string& kernel_name, bool feed_past_key, bool has_present_key,
49-
bool has_attention_bias, int tile_size, int components, bool is_first_prompt, bool has_seqlen_k = false, bool past_present_share_buffer = false, bool is_unidirectional = false)
50-
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt), is_unidirectional_(is_unidirectional) {
49+
bool has_attention_bias, int tile_size, int components, bool has_seqlen_k = false, bool past_present_share_buffer = false, bool is_unidirectional = false)
50+
: Program{kernel_name}, feed_past_key_(feed_past_key), has_present_key_(has_present_key), has_attention_bias_(has_attention_bias), tile_size_(tile_size), components_(components), has_seqlen_k_(has_seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_unidirectional_(is_unidirectional) {
5151
}
5252

5353
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -78,7 +78,6 @@ class AttentionProbsProgram final : public Program<AttentionProbsProgram> {
7878
int components_;
7979
bool has_seqlen_k_;
8080
bool past_present_share_buffer_;
81-
bool is_first_prompt_;
8281
bool is_unidirectional_;
8382
};
8483

@@ -110,8 +109,8 @@ class InPlaceSoftmaxProgram final : public Program<InPlaceSoftmaxProgram> {
110109

111110
class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
112111
public:
113-
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, bool is_first_prompt, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
114-
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer), is_first_prompt_(is_first_prompt) {
112+
VxAttentionScoreProgram(const std::string& kernel_name, bool feed_past_value, bool has_present_value, int tile_size, const Tensor* seqlen_k = nullptr, bool past_present_share_buffer = false)
113+
: Program{kernel_name}, feed_past_value_(feed_past_value), has_present_value_(has_present_value), tile_size_(tile_size), seqlen_k_(seqlen_k), past_present_share_buffer_(past_present_share_buffer) {
115114
}
116115

117116
Status GenerateShaderCode(ShaderHelper& sh) const override;
@@ -138,7 +137,6 @@ class VxAttentionScoreProgram final : public Program<VxAttentionScoreProgram> {
138137
int tile_size_;
139138
const Tensor* seqlen_k_;
140139
bool past_present_share_buffer_;
141-
bool is_first_prompt_;
142140
};
143141

144142
class Attention final : public WebGpuKernel, public onnxruntime::contrib::AttentionBase {
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param components
5+
#param feed_past_key
6+
#param has_present_key
7+
#param has_attention_bias
8+
#param has_seqlen_k
9+
#param past_present_share_buffer
10+
#param is_unidirectional
11+
#param tile_size_param
12+
13+
var<workgroup> tileQ: array<q_value_t, tile_size_param * tile_size_param>;
14+
var<workgroup> tileK: array<key_value_t, tile_size_param * tile_size_param>;
15+
#if components == 4
16+
alias f32_val_t = vec4<f32>;
17+
#elif components == 2
18+
alias f32_val_t = vec2<f32>;
19+
#else
20+
alias f32_val_t = f32;
21+
#endif
22+
23+
#if has_attention_bias
24+
fn loadAttentionBias(batch_idx: u32, head_idx: u32, q_idx: u32, k_idx: u32) -> output_value_t {
25+
// Handle broadcasting: if dimension size is 1, use index 0
26+
let bias_batch_idx = select(batch_idx, 0u, batch_idx >= uniforms.attn_bias_dim0);
27+
let bias_head_idx = select(head_idx, 0u, head_idx >= uniforms.attn_bias_dim1);
28+
// Calculate flat offset with broadcasting applied
29+
// attention_bias shape: [attn_bias_dim0, attn_bias_dim1, sequence_length, total_sequence_length]
30+
let offset = bias_batch_idx * uniforms.attn_bias_dim1 * uniforms.M * uniforms.N +
31+
bias_head_idx * uniforms.M * uniforms.N +
32+
q_idx * uniforms.N +
33+
k_idx;
34+
return attention_bias[offset];
35+
}
36+
#endif
37+
38+
$MAIN {
39+
// x holds the N and y holds the M
40+
let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;
41+
let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;
42+
let batch_head_idx = u32(workgroup_idx / (uniforms.num_total_seq_length_tile * uniforms.num_seq_length_tile));
43+
let batch_idx = batch_head_idx / uniforms.num_heads;
44+
let qOffset = batch_head_idx * uniforms.M * uniforms.K + m * uniforms.K;
45+
let sequence_length = uniforms.M;
46+
var total_sequence_length = uniforms.N;
47+
#if has_seqlen_k
48+
total_sequence_length = u32(seqlen_k[batch_idx]) + 1;
49+
var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);
50+
#else
51+
let past_sequence_length = uniforms.past_sequence_length;
52+
#endif
53+
let kOffset = (batch_head_idx / uniforms.n_reps) * uniforms.kv_sequence_length * uniforms.K;
54+
#if has_present_key
55+
let presentKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.present_sequence_length * uniforms.K;
56+
#endif
57+
58+
var value = f32_val_t(0);
59+
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
60+
if (m + local_id.y < uniforms.M && w + local_id.x < uniforms.K) {
61+
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * uniforms.K + w + local_id.x];
62+
}
63+
if (n + local_id.y < uniforms.N && w + local_id.x < uniforms.K) {
64+
var idx = TILE_SIZE * local_id.y + local_id.x;
65+
#if (feed_past_key && has_present_key) || (past_present_share_buffer)
66+
if (n + local_id.y < past_sequence_length) {
67+
let pastKeyOffset = (batch_head_idx / uniforms.n_reps) * uniforms.past_sequence_length * uniforms.K;
68+
#if past_present_share_buffer
69+
tileK[idx] = present_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
70+
#else
71+
tileK[idx] = past_key[pastKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
72+
#endif
73+
} else if (n + local_id.y - past_sequence_length < uniforms.kv_sequence_length) {
74+
tileK[idx] = key[kOffset + (n + local_id.y - past_sequence_length) * uniforms.K + w + local_id.x];
75+
}
76+
#else
77+
if (n + local_id.y < uniforms.kv_sequence_length) {
78+
tileK[idx] = key[kOffset + (n + local_id.y) * uniforms.K + w + local_id.x];
79+
}
80+
#endif
81+
82+
#if has_present_key
83+
#if past_present_share_buffer
84+
if (n + local_id.y >= past_sequence_length && n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {
85+
#else
86+
if (n + local_id.y < uniforms.kv_sequence_length + past_sequence_length) {
87+
#endif
88+
present_key[presentKeyOffset + (n + local_id.y) * uniforms.K + w + local_id.x] = tileK[idx];
89+
}
90+
#endif
91+
}
92+
workgroupBarrier();
93+
for (var k: u32 = 0u; k < TILE_SIZE && w+k < uniforms.K; k++) {
94+
value += f32_val_t(tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k]);
95+
}
96+
workgroupBarrier();
97+
}
98+
99+
if (m + local_id.y < uniforms.M && n + local_id.x < total_sequence_length) {
100+
let headOffset = batch_head_idx * uniforms.M * uniforms.N;
101+
let outputIdx = headOffset + (m + local_id.y) * uniforms.N + n + local_id.x;
102+
let head_idx = batch_head_idx % uniforms.num_heads;
103+
#if components == 4
104+
var sum: f32 = value.x + value.y + value.z + value.w;
105+
#elif components == 2
106+
var sum: f32 = value.x + value.y;
107+
#else
108+
var sum: f32 = value;
109+
#endif
110+
111+
#if is_unidirectional
112+
// Apply causal masking for unidirectional attention
113+
let query_pos = m + local_id.y + past_sequence_length;
114+
let key_pos = n + local_id.x;
115+
if (key_pos > query_pos) {
116+
sum = -3.4028234663852886e+38; // Set to very negative value for masking
117+
}
118+
#endif
119+
120+
#if has_attention_bias
121+
output[outputIdx] = output_value_t(sum * uniforms.alpha) + loadAttentionBias(batch_idx, head_idx, m + local_id.y, n + local_id.x);
122+
#else
123+
output[outputIdx] = output_value_t(sum * uniforms.alpha);
124+
#endif
125+
}
126+
} // MAIN
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param components
5+
#param work_group_size
6+
#param use_smooth_softmax
7+
#param has_seqlen_k
8+
#param has_head_sink
9+
#param has_sliding_window
10+
11+
#if components == 4
12+
alias f32_val_t = vec4<f32>;
13+
#elif components == 2
14+
alias f32_val_t = vec2<f32>;
15+
#else
16+
alias f32_val_t = f32;
17+
#endif
18+
19+
var<workgroup> thread_max: array<f32, work_group_size>;
20+
var<workgroup> thread_sum: array<f32, work_group_size>;
21+
22+
$MAIN {
23+
let sequence_length = uniforms.sequence_length;
24+
let batch_idx = u32(workgroup_idx / sequence_length) / uniforms.num_heads;
25+
let head_idx = u32(workgroup_idx / sequence_length) % uniforms.num_heads;
26+
var total_sequence_length = uniforms.total_sequence_length_comp * components;
27+
#if has_seqlen_k
28+
total_sequence_length = u32(seqlen_k[batch_idx]) + 1;
29+
var past_sequence_length: u32 = select(total_sequence_length - sequence_length, 0u, uniforms.is_first_prompt > 0);
30+
#else
31+
let past_sequence_length = uniforms.past_sequence_length;
32+
#endif
33+
#if has_seqlen_k
34+
let seq_causal_length = past_sequence_length + workgroup_idx % sequence_length + 1;
35+
#else
36+
let seq_causal_length = uniforms.total_sequence_length_comp;
37+
#endif
38+
let local_offset = local_idx * uniforms.elements_per_thread;
39+
let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;
40+
41+
#if has_sliding_window
42+
// Sliding window
43+
let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size;
44+
let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);
45+
let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);
46+
#else
47+
// No sliding window: we keep the code for sliding window in the shader but
48+
// using const for start_offset and should_apply_local_window will make the compiler optimize it out.
49+
const start_offset = 0;
50+
const should_apply_local_window = false;
51+
let effective_seq_length = seq_causal_length;
52+
#endif
53+
54+
var thread_max_vector = f32_val_t(-3.4028234663852886e+38f);
55+
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
56+
let actual_pos = local_offset + i + start_offset;
57+
if (!should_apply_local_window || actual_pos < seq_causal_length) {
58+
thread_max_vector = max(f32_val_t(x[offset + i + start_offset]), thread_max_vector);
59+
}
60+
}
61+
#if components == 4
62+
thread_max[local_idx] = max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w));
63+
#elif components == 2
64+
thread_max[local_idx] = max(thread_max_vector.x, thread_max_vector.y);
65+
#else
66+
thread_max[local_idx] = thread_max_vector;
67+
#endif
68+
workgroupBarrier();
69+
70+
#if has_head_sink
71+
// Handle head sink
72+
let sink_value: f32 = f32(head_sink[head_idx]);
73+
var max_value = sink_value;
74+
#elif use_smooth_softmax
75+
var max_value: f32 = 0.0;
76+
#else
77+
var max_value = f32(-3.4028234663852886e+38f);
78+
#endif
79+
80+
for (var i = 0u; i < work_group_size; i++) {
81+
max_value = max(thread_max[i], max_value);
82+
}
83+
var sum_vector = f32_val_t(0);
84+
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
85+
let actual_pos = local_offset + i + start_offset;
86+
if (!should_apply_local_window || actual_pos < seq_causal_length) {
87+
sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);
88+
}
89+
}
90+
#if components == 4
91+
thread_sum[local_idx] = sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w;
92+
#elif components == 2
93+
thread_sum[local_idx] = sum_vector.x + sum_vector.y;
94+
#else
95+
thread_sum[local_idx] = sum_vector;
96+
#endif
97+
workgroupBarrier();
98+
var sum: f32 = 0;
99+
for (var i = 0u; i < work_group_size; i++) {
100+
sum += thread_sum[i]
101+
;}
102+
103+
#if has_head_sink
104+
sum += exp(sink_value - max_value);
105+
#elif use_smooth_softmax
106+
sum += exp(-max_value);
107+
#endif
108+
109+
if (sum == 0) {
110+
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
111+
let actual_pos = local_offset + i + start_offset;
112+
if (actual_pos < seq_causal_length) {
113+
x[offset + i + start_offset] = x_value_t(x_element_t(1.0)/x_element_t(effective_seq_length));
114+
}
115+
}
116+
} else {
117+
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {
118+
let actual_pos = local_offset + i + start_offset;
119+
let pos = offset + i + start_offset;
120+
if (!should_apply_local_window || actual_pos < seq_causal_length) {
121+
var f32input = f32_val_t(x[pos]);
122+
x[pos] = x_value_t(exp(f32input - max_value) / sum);
123+
}
124+
}
125+
}
126+
127+
// zero out elements outside the sliding window
128+
if (should_apply_local_window) {
129+
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {
130+
let global_pos = i + local_offset;
131+
if (global_pos < start_offset) {
132+
x[offset + i] = x_value_t(x_element_t(0));
133+
}
134+
}
135+
}
136+
137+
#if has_seqlen_k
138+
for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {
139+
x[offset + total_seq_id] = x_value_t(x_element_t(0));
140+
}
141+
#endif
142+
} // MAIN
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#use guardAgainstOutOfBoundsWorkgroupSizes
5+
#use .offsetToIndices .setByIndices .getByOffset
6+
7+
$MAIN {
8+
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.input_size);
9+
let packed_qkv_indices = packed_qkv.offsetToIndices(global_idx);
10+
let batch = packed_qkv_indices[0];
11+
let seq = packed_qkv_indices[1];
12+
let d = packed_qkv_indices[2];
13+
let input_data = packed_qkv.getByOffset(global_idx);
14+
if (d < uniforms.hidden_size) {
15+
query.setByIndices(vec3<u32>(batch, seq, d), input_data);
16+
} else if (d < (uniforms.hidden_size + uniforms.kv_hidden_size)) {
17+
let kd = d - uniforms.hidden_size;
18+
key.setByIndices(vec3<u32>(batch, seq, kd), input_data);
19+
} else {
20+
let vd = d - uniforms.hidden_size - uniforms.kv_hidden_size;
21+
val.setByIndices(vec3<u32>(batch, seq, vd), input_data);
22+
}
23+
} // MAIN
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#param has_bias
5+
6+
#use guardAgainstOutOfBoundsWorkgroupSizes
7+
#use .offsetToIndices
8+
9+
$MAIN {
10+
guardAgainstOutOfBoundsWorkgroupSizes(uniforms.data_size);
11+
let output_indices = qkv_output.offsetToIndices(global_idx);
12+
let input_offset_idx = output_indices[0] * uniforms.batch_offset + output_indices[1] *
13+
uniforms.head_offset + output_indices[2] * uniforms.sequence_offset + output_indices[3];
14+
#if has_bias
15+
let bias_offset_idx = (input_offset_idx % uniforms.sequence_offset) + uniforms.bias_offset;
16+
qkv_output[global_idx] = qkv_input[input_offset_idx] + bias[bias_offset_idx];
17+
#else
18+
qkv_output[global_idx] = qkv_input[input_offset_idx];
19+
#endif
20+
} // MAIN

0 commit comments

Comments
 (0)