77
88namespace gdn {
99
10- template <typename T, int Width>
10+ template <typename T, int Width, bool ReorderInput >
1111struct causal_conv1d_kernel {
1212 public:
1313 static constexpr int sub_group_size = 32 ;
@@ -105,16 +105,29 @@ struct causal_conv1d_kernel {
105105 int qkvz_dim_id = qkvz_elems_id % qkvz_dim;
106106
107107 // reorder b,a
108- if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
109- int step =
110- token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
111- const int ba_elems_per_item =
112- sycl::min (elems_per_item, num_v_heads / num_k_heads);
108+ if constexpr (ReorderInput) {
109+ if (qkvz_elems_id < num_v_heads) {
110+ int step = token_id * num_v_heads;
113111#pragma unroll
114- for (int e = 0 ; e < ba_elems_per_item; ++e) {
115- b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
116- a_out[step + qkvz_dim_id + e] =
117- mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
112+ for (int e = 0 ; e < elems_per_item; ++e) {
113+ b_out[step + qkvz_elems_id + e] =
114+ mixed_ba[step * 2 + qkvz_elems_id + e];
115+ a_out[step + qkvz_elems_id + e] =
116+ mixed_ba[step * 2 + num_v_heads + qkvz_dim_id + e];
117+ }
118+ }
119+ } else {
120+ if (qkvz_dim_id < (num_v_heads / num_k_heads)) {
121+ int step =
122+ token_id * num_v_heads + k_heads_id * num_v_heads / num_k_heads;
123+ const int ba_elems_per_item =
124+ sycl::min (elems_per_item, num_v_heads / num_k_heads);
125+ #pragma unroll
126+ for (int e = 0 ; e < ba_elems_per_item; ++e) {
127+ b_out[step + qkvz_dim_id + e] = mixed_ba[step * 2 + qkvz_dim_id + e];
128+ a_out[step + qkvz_dim_id + e] =
129+ mixed_ba[step * 2 + num_v_heads / num_k_heads + qkvz_dim_id + e];
130+ }
118131 }
119132 }
120133
@@ -138,19 +151,37 @@ struct causal_conv1d_kernel {
138151 return ;
139152 }
140153
154+ int mixed_qkvz_id = qkvz_elems_id;
155+
141156 bool is_q = false ;
142157 bool is_k = false ;
143158 bool is_v = false ;
144159 bool is_z = false ;
145160
146161 if (qkvz_dim_id < q_dim) {
147162 is_q = true ;
163+ if constexpr (ReorderInput) {
164+ mixed_qkvz_id = k_heads_id * k_dim + qkvz_dim_id;
165+ }
148166 } else if (qkvz_dim_id < q_dim + k_dim) {
149167 is_k = true ;
168+ if constexpr (ReorderInput) {
169+ mixed_qkvz_id = num_k_heads * head_k_dim + k_heads_id * k_dim +
170+ qkvz_dim_id - (q_dim);
171+ }
150172 } else if (qkvz_dim_id < q_dim + k_dim + v_dim) {
151173 is_v = true ;
174+ if constexpr (ReorderInput) {
175+ mixed_qkvz_id = 2 * num_k_heads * head_k_dim + k_heads_id * v_dim +
176+ qkvz_dim_id - (q_dim + k_dim);
177+ }
152178 } else {
153179 is_z = true ;
180+ if constexpr (ReorderInput) {
181+ mixed_qkvz_id = 2 * num_k_heads * head_k_dim +
182+ num_v_heads * head_v_dim + k_heads_id * z_dim +
183+ qkvz_dim_id - (q_dim + k_dim + v_dim);
184+ }
154185 }
155186
156187 // reorder z
@@ -160,7 +191,7 @@ struct causal_conv1d_kernel {
160191#pragma unroll
161192 for (int e = 0 ; e < elems_per_item; ++e) {
162193 z_out[token_id * num_k_heads * z_dim + z_elems_id + e] =
163- mixed_qkvz[token_id * qkvz_elems + qkvz_elems_id + e];
194+ mixed_qkvz[token_id * qkvz_elems + mixed_qkvz_id + e];
164195 }
165196 return ;
166197 }
@@ -224,7 +255,7 @@ struct causal_conv1d_kernel {
224255#pragma unroll
225256 for (int e = 0 ; e < elems_per_item; ++e) {
226257 local_input[Width * e + states_load_len + i] = mixed_qkvz
227- [(token_id - input_load_len + 1 + i) * qkvz_elems + qkvz_elems_id +
258+ [(token_id - input_load_len + 1 + i) * qkvz_elems + mixed_qkvz_id +
228259 e];
229260 }
230261 }
@@ -416,7 +447,7 @@ struct update_states_kernel {
416447 const int batch_size;
417448};
418449
419- template <typename T, int Width>
450+ template <typename T, int Width, bool ReorderInput >
420451void kernel_launcher (
421452 sycl::queue& queue,
422453 T* q_out,
@@ -447,9 +478,10 @@ void kernel_launcher(
447478 const int & conv_elems,
448479 const int & num_prefills,
449480 const int & num_decodes) {
450- using KERNEL_MAIN = causal_conv1d_kernel<T, Width>;
481+ using KERNEL_MAIN = causal_conv1d_kernel<T, Width, ReorderInput >;
451482 auto range_main = KERNEL_MAIN::get_nd_range (num_actual_tokens, qkvz_elems);
452483 assert (head_k_dim % KERNEL_MAIN::elems_per_item == 0 );
484+ assert (num_v_heads % KERNEL_MAIN::elems_per_item == 0 );
453485 queue.submit ([&](sycl::handler& cgh) {
454486 KERNEL_MAIN task (
455487 q_out,
@@ -528,7 +560,8 @@ void causal_conv1d(
528560 const ActMode& act_mode, // silu or swish
529561 const int & pad_slot_id, // -1
530562 const int num_prefills,
531- const int num_decodes) {
563+ const int num_decodes,
564+ const bool reorder_input) {
532565 if (num_prefills == 0 && num_decodes == 0 ) {
533566 return ;
534567 }
@@ -550,8 +583,8 @@ void causal_conv1d(
550583 {batch_size, width - 1 , conv_elems},
551584 torch::dtype (dtype).device (device).requires_grad (false ));
552585
553- #define KERNEL_LAUNCHER (scalar_t, width ) \
554- kernel_launcher<scalar_t , width>( \
586+ #define KERNEL_LAUNCHER (scalar_t, width, reorder_input ) \
587+ kernel_launcher<scalar_t , width, reorder_input>( \
555588 queue, \
556589 reinterpret_cast <scalar_t *>(q_out.data_ptr ()), \
557590 reinterpret_cast <scalar_t *>(k_out.data_ptr ()), \
@@ -586,37 +619,45 @@ void causal_conv1d(
586619 num_prefills, \
587620 num_decodes);
588621
589- #define WIDTH_DISPATCH (scalar_t, width ) \
590- switch (width) { \
591- case 1 : \
592- KERNEL_LAUNCHER (scalar_t , 1 ) \
593- break ; \
594- case 2 : \
595- KERNEL_LAUNCHER (scalar_t , 2 ) \
596- break ; \
597- case 3 : \
598- KERNEL_LAUNCHER (scalar_t , 3 ) \
599- break ; \
600- case 4 : \
601- KERNEL_LAUNCHER (scalar_t , 4 ) \
602- break ; \
603- case 5 : \
604- KERNEL_LAUNCHER (scalar_t , 5 ) \
605- break ; \
606- default : \
607- break ; \
622+ #define WIDTH_DISPATCH (scalar_t, width, reorder_input ) \
623+ switch (width) { \
624+ case 1 : \
625+ KERNEL_LAUNCHER (scalar_t , 1 , reorder_input) \
626+ break ; \
627+ case 2 : \
628+ KERNEL_LAUNCHER (scalar_t , 2 , reorder_input) \
629+ break ; \
630+ case 3 : \
631+ KERNEL_LAUNCHER (scalar_t , 3 , reorder_input) \
632+ break ; \
633+ case 4 : \
634+ KERNEL_LAUNCHER (scalar_t , 4 , reorder_input) \
635+ break ; \
636+ case 5 : \
637+ KERNEL_LAUNCHER (scalar_t , 5 , reorder_input) \
638+ break ; \
639+ default : \
640+ break ; \
641+ }
642+
643+ #define SPLIT_DISPATCH (scalar_t, width, reorder_input ) \
644+ if (reorder_input) { \
645+ WIDTH_DISPATCH (scalar_t , width, true ) \
646+ } else { \
647+ WIDTH_DISPATCH (scalar_t , width, false ) \
608648 }
609649
610650 if (mixed_qkvz.scalar_type () == at::kBFloat16 ) {
611651 using scalar_t = sycl::ext::oneapi::bfloat16;
612- WIDTH_DISPATCH (scalar_t , width)
652+ SPLIT_DISPATCH (scalar_t , width, reorder_input )
613653 } else if (mixed_qkvz.scalar_type () == at::kHalf ) {
614654 using scalar_t = sycl::half;
615- WIDTH_DISPATCH (scalar_t , width)
655+ SPLIT_DISPATCH (scalar_t , width, reorder_input )
616656 } else {
617657 using scalar_t = float ;
618- WIDTH_DISPATCH (scalar_t , width)
658+ SPLIT_DISPATCH (scalar_t , width, reorder_input )
619659 }
660+ #undef SPLIT_DISPATCH
620661#undef WIDTH_DISPATCH
621662#undef KERNEL_LAUNCHER
622663}
0 commit comments