Skip to content

Commit 16333ec

Browse files
committed
Support f32 dtype ssm_state in gdn kenrel for qwen3.5
Signed-off-by: yangqun <qun.yang@intel.com>
1 parent cd49ea7 commit 16333ec

File tree

3 files changed

+132
-91
lines changed

3 files changed

+132
-91
lines changed

csrc/xpu/gdn_attn/gated_delta_rule.hpp

Lines changed: 55 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
namespace gdn {
77
static constexpr int sub_group_size = 32;
8-
template <typename T, int k_bucket_size>
8+
template <typename T, typename StateT, int k_bucket_size>
99
struct gated_delta_rule_kernel {
1010
public:
1111
static constexpr int group_size = 256;
@@ -23,7 +23,7 @@ struct gated_delta_rule_kernel {
2323
const T* a,
2424
const T* A_log,
2525
const T* dt_bias,
26-
T* ssm_state,
26+
StateT* ssm_state,
2727
const int ssm_state_stride_0,
2828
const int* query_start_loc,
2929
const int* cache_indices,
@@ -102,7 +102,7 @@ struct gated_delta_rule_kernel {
102102
float k_local[k_bucket_size];
103103
float v_local[v_dim_per_sg];
104104

105-
T* ssm_state_ptr =
105+
StateT* ssm_state_ptr =
106106
ssm_state +
107107
static_cast<int64_t>(cache_indices[batch_id]) * ssm_state_stride_0;
108108

@@ -112,10 +112,11 @@ struct gated_delta_rule_kernel {
112112
for (int j = 0; j < v_dim_per_sg; ++j) {
113113
#pragma unroll
114114
for (int i = 0; i < k_bucket_size; ++i) {
115-
state_local[j * k_bucket_size + i] = ssm_state_ptr
116-
[num_v_heads_id * head_k_dim * head_v_dim +
117-
(k_bucket_size * sg_local_id + i) +
118-
(head_v_dim_id + j) * head_k_dim];
115+
state_local[j * k_bucket_size + i] =
116+
static_cast<float>(ssm_state_ptr
117+
[num_v_heads_id * head_k_dim * head_v_dim +
118+
(k_bucket_size * sg_local_id + i) +
119+
(head_v_dim_id + j) * head_k_dim]);
119120
}
120121
}
121122
} else {
@@ -237,7 +238,7 @@ struct gated_delta_rule_kernel {
237238
[num_v_heads_id * head_k_dim * head_v_dim +
238239
(k_bucket_size * sg_local_id + i) +
239240
(head_v_dim_id + j) * head_k_dim] =
240-
state_local[j * k_bucket_size + i];
241+
static_cast<StateT>(state_local[j * k_bucket_size + i]);
241242
}
242243
}
243244
}
@@ -251,7 +252,7 @@ struct gated_delta_rule_kernel {
251252
const T* a;
252253
const T* A_log;
253254
const T* dt_bias;
254-
T* ssm_state;
255+
StateT* ssm_state;
255256
const int ssm_state_stride_0;
256257
const int* query_start_loc;
257258
const int* cache_indices;
@@ -264,7 +265,7 @@ struct gated_delta_rule_kernel {
264265
const int head_v_dim;
265266
};
266267

267-
template <typename T, int k_bucket_size>
268+
template <typename T, typename StateT, int k_bucket_size>
268269
void kernel_launcher(
269270
sycl::queue& queue,
270271
T* core_attn_out,
@@ -275,7 +276,7 @@ void kernel_launcher(
275276
const T* a,
276277
const T* A_log,
277278
const T* dt_bias,
278-
T* ssm_state,
279+
StateT* ssm_state,
279280
const int ssm_state_stride_0,
280281
const int* query_start_loc,
281282
const int* cache_indices,
@@ -286,7 +287,7 @@ void kernel_launcher(
286287
const int head_k_dim,
287288
const int num_v_heads,
288289
const int head_v_dim) {
289-
using KERNEL = gated_delta_rule_kernel<T, k_bucket_size>;
290+
using KERNEL = gated_delta_rule_kernel<T, StateT, k_bucket_size>;
290291
auto range = KERNEL::get_nd_range(batch_size, num_v_heads, head_v_dim);
291292
assert(head_v_dim % KERNEL::v_dim_per_group == 0);
292293
queue.submit([&](sycl::handler& cgh) {
@@ -351,8 +352,8 @@ void gated_delta_rule(
351352
TORCH_CHECK(head_k_dim % sub_group_size == 0);
352353
const int k_bucket_size = head_k_dim / sub_group_size;
353354

354-
#define KERNEL_LAUNCHER(scalar_t, k_bucket_size) \
355-
kernel_launcher<scalar_t, k_bucket_size>( \
355+
#define KERNEL_LAUNCHER(scalar_t, state_scalar_t, k_bucket_size) \
356+
kernel_launcher<scalar_t, state_scalar_t, k_bucket_size>( \
356357
queue, \
357358
reinterpret_cast<scalar_t*>(core_attn_out.data_ptr()), \
358359
reinterpret_cast<scalar_t*>(q.data_ptr()), \
@@ -362,7 +363,7 @@ void gated_delta_rule(
362363
reinterpret_cast<scalar_t*>(a.data_ptr()), \
363364
reinterpret_cast<scalar_t*>(A_log.data_ptr()), \
364365
reinterpret_cast<scalar_t*>(dt_bias.data_ptr()), \
365-
reinterpret_cast<scalar_t*>(ssm_state.data_ptr()), \
366+
reinterpret_cast<state_scalar_t*>(ssm_state.data_ptr()), \
366367
ssm_state_stride_0, \
367368
reinterpret_cast<int*>(query_start_loc.data_ptr()), \
368369
reinterpret_cast<int*>(cache_indices.data_ptr()), \
@@ -376,34 +377,54 @@ void gated_delta_rule(
376377
num_v_heads, \
377378
head_v_dim);
378379

379-
#define BUCKET_DISPATCH(scalar_t, k_bucket_size) \
380-
switch (k_bucket_size) { \
381-
case 1: \
382-
KERNEL_LAUNCHER(scalar_t, 1) \
383-
break; \
384-
case 2: \
385-
KERNEL_LAUNCHER(scalar_t, 2) \
386-
break; \
387-
case 4: \
388-
KERNEL_LAUNCHER(scalar_t, 4) \
389-
break; \
390-
case 8: \
391-
KERNEL_LAUNCHER(scalar_t, 8) \
392-
break; \
393-
default: \
394-
TORCH_CHECK(false); \
380+
#define BUCKET_DISPATCH(scalar_t, state_scalar_t, k_bucket_size) \
381+
switch (k_bucket_size) { \
382+
case 1: \
383+
KERNEL_LAUNCHER(scalar_t, state_scalar_t, 1) \
384+
break; \
385+
case 2: \
386+
KERNEL_LAUNCHER(scalar_t, state_scalar_t, 2) \
387+
break; \
388+
case 4: \
389+
KERNEL_LAUNCHER(scalar_t, state_scalar_t, 4) \
390+
break; \
391+
case 8: \
392+
KERNEL_LAUNCHER(scalar_t, state_scalar_t, 8) \
393+
break; \
394+
default: \
395+
TORCH_CHECK(false); \
395396
}
396397

398+
#define DISPATCH_STATE_DTYPE(scalar_t) \
399+
do { \
400+
if (ssm_state.scalar_type() == at::kFloat) { \
401+
using state_scalar_t = float; \
402+
BUCKET_DISPATCH(scalar_t, state_scalar_t, k_bucket_size) \
403+
} else if (ssm_state.scalar_type() == at::kBFloat16) { \
404+
using state_scalar_t = sycl::ext::oneapi::bfloat16; \
405+
BUCKET_DISPATCH(scalar_t, state_scalar_t, k_bucket_size) \
406+
} else if (ssm_state.scalar_type() == at::kHalf) { \
407+
using state_scalar_t = sycl::half; \
408+
BUCKET_DISPATCH(scalar_t, state_scalar_t, k_bucket_size) \
409+
} else { \
410+
TORCH_CHECK( \
411+
false, \
412+
"ssm_state dtype must be float32/float16/bfloat16, but got ", \
413+
ssm_state.scalar_type()); \
414+
} \
415+
} while (0)
416+
397417
if (core_attn_out.scalar_type() == at::kBFloat16) {
398418
using scalar_t = sycl::ext::oneapi::bfloat16;
399-
BUCKET_DISPATCH(scalar_t, k_bucket_size)
419+
DISPATCH_STATE_DTYPE(scalar_t);
400420
} else if (core_attn_out.scalar_type() == at::kHalf) {
401421
using scalar_t = sycl::half;
402-
BUCKET_DISPATCH(scalar_t, k_bucket_size)
422+
DISPATCH_STATE_DTYPE(scalar_t);
403423
} else {
404424
using scalar_t = float;
405-
BUCKET_DISPATCH(scalar_t, k_bucket_size)
425+
DISPATCH_STATE_DTYPE(scalar_t);
406426
}
427+
#undef DISPATCH_STATE_DTYPE
407428
#undef BUCKET_DISPATCH
408429
#undef KERNEL_LAUNCHER
409430
}

0 commit comments

Comments
 (0)