Skip to content

Commit 60e0e85

Browse files
authored
support qwen3.5 input layout (vllm-project#190)
* support qwen3.5 input layout Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * rename pytest param name and add assert to v_heads Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * skip long seqlen for now Signed-off-by: mayuyuace <qiming1.zhang@intel.com> * add pytest.skip Signed-off-by: mayuyuace <qiming1.zhang@intel.com> --------- Signed-off-by: mayuyuace <qiming1.zhang@intel.com>
1 parent 9d3225e commit 60e0e85

6 files changed

Lines changed: 218 additions & 96 deletions

File tree

csrc/xpu/gdn_attn/causal_conv1d.hpp

100644100755
Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
namespace gdn {
99

10-
template <typename T, int Width>
10+
template <typename T, int Width, bool ReorderInput>
1111
struct causal_conv1d_kernel {
1212
public:
1313
static constexpr int sub_group_size = 32;
@@ -105,16 +105,29 @@ struct causal_conv1d_kernel {
105105
int qkvz_dim_id = qkvz_elems_id % qkvz_dim;
106106

107107
// reorder b,a
108-
if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
109-
int step =
110-
token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
111-
const int ba_elems_per_item =
112-
sycl::min(elems_per_item, num_v_heads / num_k_heads);
108+
if constexpr (ReorderInput) {
109+
if (qkvz_elems_id < num_v_heads) {
110+
int step = token_id * num_v_heads;
113111
#pragma unroll
114-
for (int e = 0; e < ba_elems_per_item; ++e) {
115-
b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
116-
a_out[step + qkvz_dim_id + e] =
117-
mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
112+
for (int e = 0; e < elems_per_item; ++e) {
113+
b_out[step + qkvz_elems_id + e] =
114+
mixed_ba[step * 2 + qkvz_elems_id + e];
115+
a_out[step + qkvz_elems_id + e] =
116+
mixed_ba[step * 2 + num_v_heads + qkvz_dim_id + e];
117+
}
118+
}
119+
} else {
120+
if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
121+
int step =
122+
token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
123+
const int ba_elems_per_item =
124+
sycl::min(elems_per_item, num_v_heads / num_k_heads);
125+
#pragma unroll
126+
for (int e = 0; e < ba_elems_per_item; ++e) {
127+
b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
128+
a_out[step + qkvz_dim_id + e] =
129+
mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
130+
}
118131
}
119132
}
120133

@@ -138,19 +151,37 @@ struct causal_conv1d_kernel {
138151
return;
139152
}
140153

154+
int mixed_qkvz_id = qkvz_elems_id;
155+
141156
bool is_q = false;
142157
bool is_k = false;
143158
bool is_v = false;
144159
bool is_z = false;
145160

146161
if (qkvz_dim_id < q_dim) {
147162
is_q = true;
163+
if constexpr (ReorderInput) {
164+
mixed_qkvz_id = k_heads_id * k_dim + qkvz_dim_id;
165+
}
148166
} else if (qkvz_dim_id < q_dim + k_dim) {
149167
is_k = true;
168+
if constexpr (ReorderInput) {
169+
mixed_qkvz_id = num_k_heads * head_k_dim + k_heads_id * k_dim +
170+
qkvz_dim_id - (q_dim);
171+
}
150172
} else if (qkvz_dim_id < q_dim + k_dim + v_dim) {
151173
is_v = true;
174+
if constexpr (ReorderInput) {
175+
mixed_qkvz_id = 2 * num_k_heads * head_k_dim + k_heads_id * v_dim +
176+
qkvz_dim_id - (q_dim + k_dim);
177+
}
152178
} else {
153179
is_z = true;
180+
if constexpr (ReorderInput) {
181+
mixed_qkvz_id = 2 * num_k_heads * head_k_dim +
182+
num_v_heads * head_v_dim + k_heads_id * z_dim +
183+
qkvz_dim_id - (q_dim + k_dim + v_dim);
184+
}
154185
}
155186

156187
// reorder z
@@ -160,7 +191,7 @@ struct causal_conv1d_kernel {
160191
#pragma unroll
161192
for (int e = 0; e < elems_per_item; ++e) {
162193
z_out[token_id * num_k_heads * z_dim + z_elems_id + e] =
163-
mixed_qkvz[token_id * qkvz_elems + qkvz_elems_id + e];
194+
mixed_qkvz[token_id * qkvz_elems + mixed_qkvz_id + e];
164195
}
165196
return;
166197
}
@@ -224,7 +255,7 @@ struct causal_conv1d_kernel {
224255
#pragma unroll
225256
for (int e = 0; e < elems_per_item; ++e) {
226257
local_input[Width * e + states_load_len + i] = mixed_qkvz
227-
[(token_id - input_load_len + 1 + i) * qkvz_elems + qkvz_elems_id +
258+
[(token_id - input_load_len + 1 + i) * qkvz_elems + mixed_qkvz_id +
228259
e];
229260
}
230261
}
@@ -416,7 +447,7 @@ struct update_states_kernel {
416447
const int batch_size;
417448
};
418449

419-
template <typename T, int Width>
450+
template <typename T, int Width, bool ReorderInput>
420451
void kernel_launcher(
421452
sycl::queue& queue,
422453
T* q_out,
@@ -447,9 +478,10 @@ void kernel_launcher(
447478
const int& conv_elems,
448479
const int& num_prefills,
449480
const int& num_decodes) {
450-
using KERNEL_MAIN = causal_conv1d_kernel<T, Width>;
481+
using KERNEL_MAIN = causal_conv1d_kernel<T, Width, ReorderInput>;
451482
auto range_main = KERNEL_MAIN::get_nd_range(num_actual_tokens, qkvz_elems);
452483
assert(head_k_dim % KERNEL_MAIN::elems_per_item == 0);
484+
assert(num_v_heads % KERNEL_MAIN::elems_per_item == 0);
453485
queue.submit([&](sycl::handler& cgh) {
454486
KERNEL_MAIN task(
455487
q_out,
@@ -528,7 +560,8 @@ void causal_conv1d(
528560
const ActMode& act_mode, // silu or swish
529561
const int& pad_slot_id, // -1
530562
const int num_prefills,
531-
const int num_decodes) {
563+
const int num_decodes,
564+
const bool reorder_input) {
532565
if (num_prefills == 0 && num_decodes == 0) {
533566
return;
534567
}
@@ -550,8 +583,8 @@ void causal_conv1d(
550583
{batch_size, width - 1, conv_elems},
551584
torch::dtype(dtype).device(device).requires_grad(false));
552585

553-
#define KERNEL_LAUNCHER(scalar_t, width) \
554-
kernel_launcher<scalar_t, width>( \
586+
#define KERNEL_LAUNCHER(scalar_t, width, reorder_input) \
587+
kernel_launcher<scalar_t, width, reorder_input>( \
555588
queue, \
556589
reinterpret_cast<scalar_t*>(q_out.data_ptr()), \
557590
reinterpret_cast<scalar_t*>(k_out.data_ptr()), \
@@ -586,37 +619,45 @@ void causal_conv1d(
586619
num_prefills, \
587620
num_decodes);
588621

589-
#define WIDTH_DISPATCH(scalar_t, width) \
590-
switch (width) { \
591-
case 1: \
592-
KERNEL_LAUNCHER(scalar_t, 1) \
593-
break; \
594-
case 2: \
595-
KERNEL_LAUNCHER(scalar_t, 2) \
596-
break; \
597-
case 3: \
598-
KERNEL_LAUNCHER(scalar_t, 3) \
599-
break; \
600-
case 4: \
601-
KERNEL_LAUNCHER(scalar_t, 4) \
602-
break; \
603-
case 5: \
604-
KERNEL_LAUNCHER(scalar_t, 5) \
605-
break; \
606-
default: \
607-
break; \
622+
#define WIDTH_DISPATCH(scalar_t, width, reorder_input) \
623+
switch (width) { \
624+
case 1: \
625+
KERNEL_LAUNCHER(scalar_t, 1, reorder_input) \
626+
break; \
627+
case 2: \
628+
KERNEL_LAUNCHER(scalar_t, 2, reorder_input) \
629+
break; \
630+
case 3: \
631+
KERNEL_LAUNCHER(scalar_t, 3, reorder_input) \
632+
break; \
633+
case 4: \
634+
KERNEL_LAUNCHER(scalar_t, 4, reorder_input) \
635+
break; \
636+
case 5: \
637+
KERNEL_LAUNCHER(scalar_t, 5, reorder_input) \
638+
break; \
639+
default: \
640+
break; \
641+
}
642+
643+
#define SPLIT_DISPATCH(scalar_t, width, reorder_input) \
644+
if (reorder_input) { \
645+
WIDTH_DISPATCH(scalar_t, width, true) \
646+
} else { \
647+
WIDTH_DISPATCH(scalar_t, width, false) \
608648
}
609649

610650
if (mixed_qkvz.scalar_type() == at::kBFloat16) {
611651
using scalar_t = sycl::ext::oneapi::bfloat16;
612-
WIDTH_DISPATCH(scalar_t, width)
652+
SPLIT_DISPATCH(scalar_t, width, reorder_input)
613653
} else if (mixed_qkvz.scalar_type() == at::kHalf) {
614654
using scalar_t = sycl::half;
615-
WIDTH_DISPATCH(scalar_t, width)
655+
SPLIT_DISPATCH(scalar_t, width, reorder_input)
616656
} else {
617657
using scalar_t = float;
618-
WIDTH_DISPATCH(scalar_t, width)
658+
SPLIT_DISPATCH(scalar_t, width, reorder_input)
619659
}
660+
#undef SPLIT_DISPATCH
620661
#undef WIDTH_DISPATCH
621662
#undef KERNEL_LAUNCHER
622663
}

csrc/xpu/gdn_attn/gdn_attn_interface.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ void gdn_attention(
4747
const torch::Tensor& non_spec_query_start_loc, // [batch_size + 1]
4848
const torch::Tensor& non_spec_state_indices_tensor, // [batch_size]
4949
const int64_t num_actual_tokens,
50-
const int64_t tp_size) {
50+
const int64_t tp_size,
51+
const bool reorder_input) {
5152
TORCH_CHECK(
5253
core_attn_out.is_contiguous(), "core_attn_out must be contiguous");
5354
TORCH_CHECK(z.is_contiguous(), "z must be contiguous");
@@ -144,7 +145,8 @@ void gdn_attention(
144145
act_mode, \
145146
pad_slot_id, \
146147
num_prefills, \
147-
num_decodes); \
148+
num_decodes, \
149+
reorder_input); \
148150
gdn::gated_delta_rule( \
149151
queue, \
150152
core_attn_out, \
@@ -203,7 +205,8 @@ void gdn_attention(
203205
act_mode,
204206
pad_slot_id,
205207
num_prefills,
206-
num_decodes);
208+
num_decodes,
209+
reorder_input);
207210

208211
chunk_gated_delta_rule_xe2(
209212
queue,

0 commit comments

Comments
 (0)