55
66namespace gdn {
77static constexpr int sub_group_size = 32 ;
8- template <typename T, int k_bucket_size>
8+ template <typename T, typename StateT, int k_bucket_size>
99struct gated_delta_rule_kernel {
1010 public:
1111 static constexpr int group_size = 256 ;
@@ -23,7 +23,7 @@ struct gated_delta_rule_kernel {
2323 const T* a,
2424 const T* A_log,
2525 const T* dt_bias,
26- T * ssm_state,
26+ StateT * ssm_state,
2727 const int ssm_state_stride_0,
2828 const int * query_start_loc,
2929 const int * cache_indices,
@@ -102,7 +102,7 @@ struct gated_delta_rule_kernel {
102102 float k_local[k_bucket_size];
103103 float v_local[v_dim_per_sg];
104104
105- T * ssm_state_ptr =
105+ StateT * ssm_state_ptr =
106106 ssm_state +
107107 static_cast <int64_t >(cache_indices[batch_id]) * ssm_state_stride_0;
108108
@@ -112,10 +112,11 @@ struct gated_delta_rule_kernel {
112112 for (int j = 0 ; j < v_dim_per_sg; ++j) {
113113#pragma unroll
114114 for (int i = 0 ; i < k_bucket_size; ++i) {
115- state_local[j * k_bucket_size + i] = ssm_state_ptr
116- [num_v_heads_id * head_k_dim * head_v_dim +
117- (k_bucket_size * sg_local_id + i) +
118- (head_v_dim_id + j) * head_k_dim];
115+ state_local[j * k_bucket_size + i] =
116+ static_cast <float >(ssm_state_ptr
117+ [num_v_heads_id * head_k_dim * head_v_dim +
118+ (k_bucket_size * sg_local_id + i) +
119+ (head_v_dim_id + j) * head_k_dim]);
119120 }
120121 }
121122 } else {
@@ -237,7 +238,7 @@ struct gated_delta_rule_kernel {
237238 [num_v_heads_id * head_k_dim * head_v_dim +
238239 (k_bucket_size * sg_local_id + i) +
239240 (head_v_dim_id + j) * head_k_dim] =
240- state_local[j * k_bucket_size + i];
241+ static_cast <StateT>( state_local[j * k_bucket_size + i]) ;
241242 }
242243 }
243244 }
@@ -251,7 +252,7 @@ struct gated_delta_rule_kernel {
251252 const T* a;
252253 const T* A_log;
253254 const T* dt_bias;
254- T * ssm_state;
255+ StateT * ssm_state;
255256 const int ssm_state_stride_0;
256257 const int * query_start_loc;
257258 const int * cache_indices;
@@ -264,7 +265,7 @@ struct gated_delta_rule_kernel {
264265 const int head_v_dim;
265266};
266267
267- template <typename T, int k_bucket_size>
268+ template <typename T, typename StateT, int k_bucket_size>
268269void kernel_launcher (
269270 sycl::queue& queue,
270271 T* core_attn_out,
@@ -275,7 +276,7 @@ void kernel_launcher(
275276 const T* a,
276277 const T* A_log,
277278 const T* dt_bias,
278- T * ssm_state,
279+ StateT * ssm_state,
279280 const int ssm_state_stride_0,
280281 const int * query_start_loc,
281282 const int * cache_indices,
@@ -286,7 +287,7 @@ void kernel_launcher(
286287 const int head_k_dim,
287288 const int num_v_heads,
288289 const int head_v_dim) {
289- using KERNEL = gated_delta_rule_kernel<T, k_bucket_size>;
290+ using KERNEL = gated_delta_rule_kernel<T, StateT, k_bucket_size>;
290291 auto range = KERNEL::get_nd_range (batch_size, num_v_heads, head_v_dim);
291292 assert (head_v_dim % KERNEL::v_dim_per_group == 0 );
292293 queue.submit ([&](sycl::handler& cgh) {
@@ -351,8 +352,8 @@ void gated_delta_rule(
351352 TORCH_CHECK (head_k_dim % sub_group_size == 0 );
352353 const int k_bucket_size = head_k_dim / sub_group_size;
353354
354- #define KERNEL_LAUNCHER (scalar_t, k_bucket_size ) \
355- kernel_launcher<scalar_t , k_bucket_size>( \
355+ #define KERNEL_LAUNCHER (scalar_t, state_scalar_t, k_bucket_size ) \
356+ kernel_launcher<scalar_t , state_scalar_t , k_bucket_size>( \
356357 queue, \
357358 reinterpret_cast <scalar_t *>(core_attn_out.data_ptr ()), \
358359 reinterpret_cast <scalar_t *>(q.data_ptr ()), \
@@ -362,7 +363,7 @@ void gated_delta_rule(
362363 reinterpret_cast <scalar_t *>(a.data_ptr ()), \
363364 reinterpret_cast <scalar_t *>(A_log.data_ptr ()), \
364365 reinterpret_cast <scalar_t *>(dt_bias.data_ptr ()), \
365- reinterpret_cast <scalar_t *>(ssm_state.data_ptr ()), \
366+ reinterpret_cast <state_scalar_t *>(ssm_state.data_ptr ()), \
366367 ssm_state_stride_0, \
367368 reinterpret_cast <int *>(query_start_loc.data_ptr ()), \
368369 reinterpret_cast <int *>(cache_indices.data_ptr ()), \
@@ -376,34 +377,54 @@ void gated_delta_rule(
376377 num_v_heads, \
377378 head_v_dim);
378379
379- #define BUCKET_DISPATCH (scalar_t, k_bucket_size ) \
380- switch (k_bucket_size) { \
381- case 1 : \
382- KERNEL_LAUNCHER (scalar_t , 1 ) \
383- break ; \
384- case 2 : \
385- KERNEL_LAUNCHER (scalar_t , 2 ) \
386- break ; \
387- case 4 : \
388- KERNEL_LAUNCHER (scalar_t , 4 ) \
389- break ; \
390- case 8 : \
391- KERNEL_LAUNCHER (scalar_t , 8 ) \
392- break ; \
393- default : \
394- TORCH_CHECK (false ); \
380+ #define BUCKET_DISPATCH (scalar_t, state_scalar_t, k_bucket_size ) \
381+ switch (k_bucket_size) { \
382+ case 1 : \
383+ KERNEL_LAUNCHER (scalar_t , state_scalar_t , 1 ) \
384+ break ; \
385+ case 2 : \
386+ KERNEL_LAUNCHER (scalar_t , state_scalar_t , 2 ) \
387+ break ; \
388+ case 4 : \
389+ KERNEL_LAUNCHER (scalar_t , state_scalar_t , 4 ) \
390+ break ; \
391+ case 8 : \
392+ KERNEL_LAUNCHER (scalar_t , state_scalar_t , 8 ) \
393+ break ; \
394+ default : \
395+ TORCH_CHECK (false ); \
395396 }
396397
398+ #define DISPATCH_STATE_DTYPE (scalar_t ) \
399+ do { \
400+ if (ssm_state.scalar_type () == at::kFloat ) { \
401+ using state_scalar_t = float ; \
402+ BUCKET_DISPATCH (scalar_t , state_scalar_t , k_bucket_size) \
403+ } else if (ssm_state.scalar_type () == at::kBFloat16 ) { \
404+ using state_scalar_t = sycl::ext::oneapi::bfloat16; \
405+ BUCKET_DISPATCH (scalar_t , state_scalar_t , k_bucket_size) \
406+ } else if (ssm_state.scalar_type () == at::kHalf ) { \
407+ using state_scalar_t = sycl::half; \
408+ BUCKET_DISPATCH (scalar_t , state_scalar_t , k_bucket_size) \
409+ } else { \
410+ TORCH_CHECK ( \
411+ false , \
412+ " ssm_state dtype must be float32/float16/bfloat16, but got " , \
413+ ssm_state.scalar_type ()); \
414+ } \
415+ } while (0 )
416+
397417 if (core_attn_out.scalar_type () == at::kBFloat16 ) {
398418 using scalar_t = sycl::ext::oneapi::bfloat16;
399- BUCKET_DISPATCH (scalar_t , k_bucket_size)
419+ DISPATCH_STATE_DTYPE (scalar_t );
400420 } else if (core_attn_out.scalar_type () == at::kHalf ) {
401421 using scalar_t = sycl::half;
402- BUCKET_DISPATCH (scalar_t , k_bucket_size)
422+ DISPATCH_STATE_DTYPE (scalar_t );
403423 } else {
404424 using scalar_t = float ;
405- BUCKET_DISPATCH (scalar_t , k_bucket_size)
425+ DISPATCH_STATE_DTYPE (scalar_t );
406426 }
427+ #undef DISPATCH_STATE_DTYPE
407428#undef BUCKET_DISPATCH
408429#undef KERNEL_LAUNCHER
409430}
0 commit comments