@@ -859,169 +859,175 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
859859 const uint32_t idx = bx * bdy + ty;
860860 const RoPEIdType pos = pos_ids[idx];
861861
862- // Compute page location for this token
863- uint32_t page_iter, entry_idx;
864- paged_kv_like.page_size .divmod (
865- paged_kv_like.indptr [batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
866- page_iter, entry_idx);
867-
868- const int half_rope_dim = rope_dim / 2 ;
869- // Load cos/sin for RoPE processing blocks only
870- if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
871- int sin_offset = rope_dim / 2 ;
872- int vec_idx;
873- if constexpr (interleave) {
874- vec_idx = (tx * vec_size) / 2 ; // Force integer division
875- } else {
876- vec_idx = (tx * vec_size) % half_rope_dim;
862+ // skip padding tokens with batch_indices < 0
863+ if (batch_indices[idx] >= 0 ) {
864+ // Compute page location for this token
865+ uint32_t page_iter, entry_idx;
866+ paged_kv_like.page_size .divmod (
867+ paged_kv_like.indptr [batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
868+ page_iter, entry_idx);
869+
870+ const int half_rope_dim = rope_dim / 2 ;
871+ // Load cos/sin for RoPE processing blocks only
872+ if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
873+ int sin_offset = rope_dim / 2 ;
874+ int vec_idx;
875+ if constexpr (interleave) {
876+ vec_idx = (tx * vec_size) / 2 ; // Force integer division
877+ } else {
878+ vec_idx = (tx * vec_size) % half_rope_dim;
879+ }
880+ cos.load (cos_sin_cache + (pos * rope_dim) + vec_idx);
881+ sin.load (cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
877882 }
878- cos.load (cos_sin_cache + (pos * rope_dim) + vec_idx);
879- sin.load (cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
880- }
881883
882- if (by < q_rope_end) {
883- // ============ Q RoPE processing ============
884- uint32_t q_head_idx = by / rope_chunks;
885- uint32_t rope_chunk_idx = by % rope_chunks;
886- uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
884+ if (by < q_rope_end) {
885+ // ============ Q RoPE processing ============
886+ uint32_t q_head_idx = by / rope_chunks;
887+ uint32_t rope_chunk_idx = by % rope_chunks;
888+ uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
887889
888- DType* q_rope_in_ptr =
889- q_rope_in + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_rope_in_stride_n,
890- q_rope_in_stride_h);
891- QuantType* q_rope_out_ptr =
892- q_rope_out + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_rope_out_stride_n,
893- q_rope_out_stride_h);
890+ DType* q_rope_in_ptr =
891+ q_rope_in + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_rope_in_stride_n,
892+ q_rope_in_stride_h);
893+ QuantType* q_rope_out_ptr =
894+ q_rope_out + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_rope_out_stride_n,
895+ q_rope_out_stride_h);
894896
895- vec_t <float , vec_size> q_rope_vec;
896- if constexpr (interleave) {
897- q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
898- q_rope_in_ptr, cos, sin, rope_dim);
899- } else {
900- q_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
901- }
897+ vec_t <float , vec_size> q_rope_vec;
898+ if constexpr (interleave) {
899+ q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
900+ q_rope_in_ptr, cos, sin, rope_dim);
901+ } else {
902+ q_rope_vec =
903+ vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
904+ }
902905#pragma unroll
903- for (uint32_t i = 0 ; i < vec_size; ++i) {
904- q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
905- }
906- q_rope_vec.cast_store (q_rope_out_ptr + tx * vec_size);
907-
908- } else if (by < k_rope_end) {
909- // ============ K RoPE processing & Cache Append ============
910- uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
911- uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
912- uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
913-
914- DType* k_rope_in_ptr;
915- if constexpr (IS_MLA) {
916- // MLA: 2D K
917- k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
918- } else {
919- // GQA/MHA: 3D K
920- k_rope_in_ptr = k_rope_in + get_elem_offset_impl (idx, k_head_idx, elem_offset,
921- k_rope_in_stride, k_rope_in_stride_h);
922- }
906+ for (uint32_t i = 0 ; i < vec_size; ++i) {
907+ q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
908+ }
909+ q_rope_vec.cast_store (q_rope_out_ptr + tx * vec_size);
910+
911+ } else if (by < k_rope_end) {
912+ // ============ K RoPE processing & Cache Append ============
913+ uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
914+ uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
915+ uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
916+
917+ DType* k_rope_in_ptr;
918+ if constexpr (IS_MLA) {
919+ // MLA: 2D K
920+ k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
921+ } else {
922+ // GQA/MHA: 3D K
923+ k_rope_in_ptr = k_rope_in + get_elem_offset_impl (idx, k_head_idx, elem_offset,
924+ k_rope_in_stride, k_rope_in_stride_h);
925+ }
923926
924- vec_t <float , vec_size> k_rope_vec;
925- if constexpr (interleave) {
926- k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
927- k_rope_in_ptr, cos, sin, rope_dim);
928- } else {
929- k_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
930- }
927+ vec_t <float , vec_size> k_rope_vec;
928+ if constexpr (interleave) {
929+ k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
930+ k_rope_in_ptr, cos, sin, rope_dim);
931+ } else {
932+ k_rope_vec =
933+ vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
934+ }
931935#pragma unroll
932- for (uint32_t i = 0 ; i < vec_size; ++i) {
933- k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
934- }
936+ for (uint32_t i = 0 ; i < vec_size; ++i) {
937+ k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
938+ }
935939
936- if constexpr (IS_MLA) {
937- QuantType* kpe_ptr =
938- paged_kv_like.get_kpe_ptr (page_iter, entry_idx, elem_offset + tx * vec_size);
939- k_rope_vec.cast_store (kpe_ptr);
940- } else {
941- QuantType* k_ptr = paged_kv_like.get_k_ptr (page_iter, k_head_idx, entry_idx, tx * vec_size);
942- k_rope_vec.cast_store (k_ptr);
943- }
940+ if constexpr (IS_MLA) {
941+ QuantType* kpe_ptr =
942+ paged_kv_like.get_kpe_ptr (page_iter, entry_idx, elem_offset + tx * vec_size);
943+ k_rope_vec.cast_store (kpe_ptr);
944+ } else {
945+ QuantType* k_ptr =
946+ paged_kv_like.get_k_ptr (page_iter, k_head_idx, entry_idx, tx * vec_size);
947+ k_rope_vec.cast_store (k_ptr);
948+ }
944949
945- } else if (by < k_nope_end) {
946- // ============ K Non-RoPE processing & Cache Append ============
947- uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
948- uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
949- uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
950+ } else if (by < k_nope_end) {
951+ // ============ K Non-RoPE processing & Cache Append ============
952+ uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
953+ uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
954+ uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
950955
951- DType* k_nope_in_ptr;
952- if constexpr (IS_MLA) {
953- k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
954- } else {
955- k_nope_in_ptr = k_nope_in + get_elem_offset_impl (idx, k_head_idx, elem_offset,
956- k_nope_in_stride, k_nope_in_stride_h);
957- }
956+ DType* k_nope_in_ptr;
957+ if constexpr (IS_MLA) {
958+ k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
959+ } else {
960+ k_nope_in_ptr = k_nope_in + get_elem_offset_impl (idx, k_head_idx, elem_offset,
961+ k_nope_in_stride, k_nope_in_stride_h);
962+ }
958963
959- vec_t <float , vec_size> k_nope_vec;
960- k_nope_vec.cast_load (k_nope_in_ptr + tx * vec_size);
964+ vec_t <float , vec_size> k_nope_vec;
965+ k_nope_vec.cast_load (k_nope_in_ptr + tx * vec_size);
961966#pragma unroll
962- for (uint32_t i = 0 ; i < vec_size; ++i) {
963- k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
964- }
967+ for (uint32_t i = 0 ; i < vec_size; ++i) {
968+ k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
969+ }
965970
966- if constexpr (IS_MLA) {
967- QuantType* ckv_ptr =
968- paged_kv_like.get_ckv_ptr (page_iter, entry_idx, elem_offset + tx * vec_size);
969- k_nope_vec.cast_store (ckv_ptr);
970- } else {
971- QuantType* k_ptr = paged_kv_like.get_k_ptr (page_iter, k_head_idx, entry_idx,
972- rope_dim + elem_offset + tx * vec_size);
973- k_nope_vec.cast_store (k_ptr);
974- }
971+ if constexpr (IS_MLA) {
972+ QuantType* ckv_ptr =
973+ paged_kv_like.get_ckv_ptr (page_iter, entry_idx, elem_offset + tx * vec_size);
974+ k_nope_vec.cast_store (ckv_ptr);
975+ } else {
976+ QuantType* k_ptr = paged_kv_like.get_k_ptr (page_iter, k_head_idx, entry_idx,
977+ rope_dim + elem_offset + tx * vec_size);
978+ k_nope_vec.cast_store (k_ptr);
979+ }
975980
976- } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
977- // ============ V processing & Cache Append (GQA/MHA only) ============
978- if constexpr (!IS_MLA) {
979- uint32_t kv_head_idx = by - k_nope_end;
980- DType* v_in_ptr =
981- v_in + get_elem_offset_impl (idx, kv_head_idx, 0 , v_in_stride, v_in_stride_h);
982- // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
983- uint32_t head_dim_total = rope_dim + no_rope_dim;
984- uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1 ) / rope_chunk_size;
981+ } else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
982+ // ============ V processing & Cache Append (GQA/MHA only) ============
983+ if constexpr (!IS_MLA) {
984+ uint32_t kv_head_idx = by - k_nope_end;
985+ DType* v_in_ptr =
986+ v_in + get_elem_offset_impl (idx, kv_head_idx, 0 , v_in_stride, v_in_stride_h);
987+ // Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
988+ uint32_t head_dim_total = rope_dim + no_rope_dim;
989+ uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1 ) / rope_chunk_size;
985990#pragma unroll 1
986- for (uint32_t j = 0 ; j < v_chunks; ++j) {
987- uint32_t v_elem_offset = j * rope_chunk_size;
988- if (v_elem_offset + tx * vec_size < head_dim_total) {
989- vec_t <float , vec_size> v_vec;
990- v_vec.cast_load (v_in_ptr + v_elem_offset + tx * vec_size);
991+ for (uint32_t j = 0 ; j < v_chunks; ++j) {
992+ uint32_t v_elem_offset = j * rope_chunk_size;
993+ if (v_elem_offset + tx * vec_size < head_dim_total) {
994+ vec_t <float , vec_size> v_vec;
995+ v_vec.cast_load (v_in_ptr + v_elem_offset + tx * vec_size);
991996#pragma unroll
992- for (uint32_t i = 0 ; i < vec_size; ++i) {
993- v_vec[i] = v_vec[i] * quant_scale_kv;
997+ for (uint32_t i = 0 ; i < vec_size; ++i) {
998+ v_vec[i] = v_vec[i] * quant_scale_kv;
999+ }
1000+ QuantType* v_ptr = paged_kv_like.get_v_ptr (page_iter, kv_head_idx, entry_idx,
1001+ v_elem_offset + tx * vec_size);
1002+ v_vec.cast_store (v_ptr);
9941003 }
995- QuantType* v_ptr = paged_kv_like.get_v_ptr (page_iter, kv_head_idx, entry_idx,
996- v_elem_offset + tx * vec_size);
997- v_vec.cast_store (v_ptr);
9981004 }
9991005 }
1000- }
10011006
1002- } else {
1003- // ============ Q Non-RoPE processing ============
1004- // MLA has no V section, so Q-nope starts immediately after K-nope.
1005- // GQA/MHA has a V section of length num_kv_heads blocks.
1006- uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
1007- uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
1008- uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
1009- uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
1010-
1011- DType* q_nope_in_ptr =
1012- q_nope_in + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_nope_in_stride_n,
1013- q_nope_in_stride_h);
1014- QuantType* q_nope_out_ptr =
1015- q_nope_out + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_nope_out_stride_n,
1016- q_nope_out_stride_h);
1017-
1018- vec_t <float , vec_size> q_nope_vec;
1019- q_nope_vec.cast_load (q_nope_in_ptr + tx * vec_size);
1007+ } else {
1008+ // ============ Q Non-RoPE processing ============
1009+ // MLA has no V section, so Q-nope starts immediately after K-nope.
1010+ // GQA/MHA has a V section of length num_kv_heads blocks.
1011+ uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
1012+ uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
1013+ uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
1014+ uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
1015+
1016+ DType* q_nope_in_ptr =
1017+ q_nope_in + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_nope_in_stride_n,
1018+ q_nope_in_stride_h);
1019+ QuantType* q_nope_out_ptr =
1020+ q_nope_out + get_elem_offset_impl (idx, q_head_idx, elem_offset, q_nope_out_stride_n,
1021+ q_nope_out_stride_h);
1022+
1023+ vec_t <float , vec_size> q_nope_vec;
1024+ q_nope_vec.cast_load (q_nope_in_ptr + tx * vec_size);
10201025#pragma unroll
1021- for (uint32_t i = 0 ; i < vec_size; ++i) {
1022- q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
1026+ for (uint32_t i = 0 ; i < vec_size; ++i) {
1027+ q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
1028+ }
1029+ q_nope_vec.cast_store (q_nope_out_ptr + tx * vec_size);
10231030 }
1024- q_nope_vec.cast_store (q_nope_out_ptr + tx * vec_size);
10251031 }
10261032 }
10271033#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
0 commit comments