Skip to content

Commit c21d2df

Browse files
authored
Merge branch 'main' into moe-alltoall-topk-22
2 parents 7b1262c + 77a179f commit c21d2df

File tree

4 files changed

+413
-144
lines changed

4 files changed

+413
-144
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ High-Performance GPU Kernels for Inference
6969
| Ada Lovelace | SM 8.9 | L4, L40, RTX 40 series |
7070
| Hopper | SM 9.0 | H100, H200 |
7171
| Blackwell | SM 10.0, 10.3 | B200, B300 |
72-
| Blackwell | SM 12.0, 12.1 | RTX 50 series, DGX Spark, Jetson Thor |
72+
| Blackwell | SM 11.0 | Jetson Thor |
73+
| Blackwell | SM 12.0, 12.1 | RTX 50 series, DGX Spark |
7374

7475
> **Note:** Not all features are supported across all compute capabilities.
7576

flashinfer/page.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,12 +178,13 @@ def get_batch_indices_positions(
178178
dtype = torch.int32
179179

180180
if batch_indices is None:
181-
batch_indices = torch.empty((nnz,), device=device, dtype=dtype)
181+
batch_indices = torch.full((nnz,), -1, device=device, dtype=dtype)
182182
else:
183183
check_shape_dtype_device(batch_indices, (nnz,), dtype, device, "batch_indices")
184+
batch_indices.fill_(-1)
184185

185186
if positions is None:
186-
positions = torch.empty((nnz,), device=device, dtype=dtype)
187+
positions = torch.zeros((nnz,), device=device, dtype=dtype)
187188
else:
188189
check_shape_dtype_device(positions, (nnz,), dtype, device, "positions")
189190

include/flashinfer/pos_enc.cuh

Lines changed: 147 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -859,169 +859,175 @@ __global__ void RopeQuantizeAppendPagedKVCacheKernel(
859859
const uint32_t idx = bx * bdy + ty;
860860
const RoPEIdType pos = pos_ids[idx];
861861

862-
// Compute page location for this token
863-
uint32_t page_iter, entry_idx;
864-
paged_kv_like.page_size.divmod(
865-
paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
866-
page_iter, entry_idx);
867-
868-
const int half_rope_dim = rope_dim / 2;
869-
// Load cos/sin for RoPE processing blocks only
870-
if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
871-
int sin_offset = rope_dim / 2;
872-
int vec_idx;
873-
if constexpr (interleave) {
874-
vec_idx = (tx * vec_size) / 2; // Force integer division
875-
} else {
876-
vec_idx = (tx * vec_size) % half_rope_dim;
862+
// skip padding tokens with batch_indices < 0
863+
if (batch_indices[idx] >= 0) {
864+
// Compute page location for this token
865+
uint32_t page_iter, entry_idx;
866+
paged_kv_like.page_size.divmod(
867+
paged_kv_like.indptr[batch_indices[idx]] * paged_kv_like.page_size + positions[idx],
868+
page_iter, entry_idx);
869+
870+
const int half_rope_dim = rope_dim / 2;
871+
// Load cos/sin for RoPE processing blocks only
872+
if ((tx * vec_size < rope_dim) && (by < k_rope_end)) {
873+
int sin_offset = rope_dim / 2;
874+
int vec_idx;
875+
if constexpr (interleave) {
876+
vec_idx = (tx * vec_size) / 2; // Force integer division
877+
} else {
878+
vec_idx = (tx * vec_size) % half_rope_dim;
879+
}
880+
cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx);
881+
sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
877882
}
878-
cos.load(cos_sin_cache + (pos * rope_dim) + vec_idx);
879-
sin.load(cos_sin_cache + (pos * rope_dim) + (sin_offset + vec_idx));
880-
}
881883

882-
if (by < q_rope_end) {
883-
// ============ Q RoPE processing ============
884-
uint32_t q_head_idx = by / rope_chunks;
885-
uint32_t rope_chunk_idx = by % rope_chunks;
886-
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
884+
if (by < q_rope_end) {
885+
// ============ Q RoPE processing ============
886+
uint32_t q_head_idx = by / rope_chunks;
887+
uint32_t rope_chunk_idx = by % rope_chunks;
888+
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
887889

888-
DType* q_rope_in_ptr =
889-
q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n,
890-
q_rope_in_stride_h);
891-
QuantType* q_rope_out_ptr =
892-
q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n,
893-
q_rope_out_stride_h);
890+
DType* q_rope_in_ptr =
891+
q_rope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_in_stride_n,
892+
q_rope_in_stride_h);
893+
QuantType* q_rope_out_ptr =
894+
q_rope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_rope_out_stride_n,
895+
q_rope_out_stride_h);
894896

895-
vec_t<float, vec_size> q_rope_vec;
896-
if constexpr (interleave) {
897-
q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
898-
q_rope_in_ptr, cos, sin, rope_dim);
899-
} else {
900-
q_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
901-
}
897+
vec_t<float, vec_size> q_rope_vec;
898+
if constexpr (interleave) {
899+
q_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
900+
q_rope_in_ptr, cos, sin, rope_dim);
901+
} else {
902+
q_rope_vec =
903+
vec_apply_llama_rope_cos_sin<vec_size, bdx>(q_rope_in_ptr, cos, sin, rope_dim);
904+
}
902905
#pragma unroll
903-
for (uint32_t i = 0; i < vec_size; ++i) {
904-
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
905-
}
906-
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);
907-
908-
} else if (by < k_rope_end) {
909-
// ============ K RoPE processing & Cache Append ============
910-
uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
911-
uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
912-
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
913-
914-
DType* k_rope_in_ptr;
915-
if constexpr (IS_MLA) {
916-
// MLA: 2D K
917-
k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
918-
} else {
919-
// GQA/MHA: 3D K
920-
k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
921-
k_rope_in_stride, k_rope_in_stride_h);
922-
}
906+
for (uint32_t i = 0; i < vec_size; ++i) {
907+
q_rope_vec[i] = q_rope_vec[i] * quant_scale_q;
908+
}
909+
q_rope_vec.cast_store(q_rope_out_ptr + tx * vec_size);
910+
911+
} else if (by < k_rope_end) {
912+
// ============ K RoPE processing & Cache Append ============
913+
uint32_t k_head_idx = (by - q_rope_end) / rope_chunks;
914+
uint32_t rope_chunk_idx = (by - q_rope_end) % rope_chunks;
915+
uint32_t elem_offset = rope_chunk_idx * rope_chunk_size;
916+
917+
DType* k_rope_in_ptr;
918+
if constexpr (IS_MLA) {
919+
// MLA: 2D K
920+
k_rope_in_ptr = k_rope_in + idx * k_rope_in_stride + elem_offset;
921+
} else {
922+
// GQA/MHA: 3D K
923+
k_rope_in_ptr = k_rope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
924+
k_rope_in_stride, k_rope_in_stride_h);
925+
}
923926

924-
vec_t<float, vec_size> k_rope_vec;
925-
if constexpr (interleave) {
926-
k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
927-
k_rope_in_ptr, cos, sin, rope_dim);
928-
} else {
929-
k_rope_vec = vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
930-
}
927+
vec_t<float, vec_size> k_rope_vec;
928+
if constexpr (interleave) {
929+
k_rope_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half<vec_size, bdx>(
930+
k_rope_in_ptr, cos, sin, rope_dim);
931+
} else {
932+
k_rope_vec =
933+
vec_apply_llama_rope_cos_sin<vec_size, bdx>(k_rope_in_ptr, cos, sin, rope_dim);
934+
}
931935
#pragma unroll
932-
for (uint32_t i = 0; i < vec_size; ++i) {
933-
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
934-
}
936+
for (uint32_t i = 0; i < vec_size; ++i) {
937+
k_rope_vec[i] = k_rope_vec[i] * quant_scale_kv;
938+
}
935939

936-
if constexpr (IS_MLA) {
937-
QuantType* kpe_ptr =
938-
paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
939-
k_rope_vec.cast_store(kpe_ptr);
940-
} else {
941-
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size);
942-
k_rope_vec.cast_store(k_ptr);
943-
}
940+
if constexpr (IS_MLA) {
941+
QuantType* kpe_ptr =
942+
paged_kv_like.get_kpe_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
943+
k_rope_vec.cast_store(kpe_ptr);
944+
} else {
945+
QuantType* k_ptr =
946+
paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx, tx * vec_size);
947+
k_rope_vec.cast_store(k_ptr);
948+
}
944949

945-
} else if (by < k_nope_end) {
946-
// ============ K Non-RoPE processing & Cache Append ============
947-
uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
948-
uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
949-
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
950+
} else if (by < k_nope_end) {
951+
// ============ K Non-RoPE processing & Cache Append ============
952+
uint32_t k_head_idx = (by - k_rope_end) / no_rope_chunks;
953+
uint32_t nope_chunk_idx = (by - k_rope_end) % no_rope_chunks;
954+
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
950955

951-
DType* k_nope_in_ptr;
952-
if constexpr (IS_MLA) {
953-
k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
954-
} else {
955-
k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
956-
k_nope_in_stride, k_nope_in_stride_h);
957-
}
956+
DType* k_nope_in_ptr;
957+
if constexpr (IS_MLA) {
958+
k_nope_in_ptr = k_nope_in + idx * k_nope_in_stride + elem_offset;
959+
} else {
960+
k_nope_in_ptr = k_nope_in + get_elem_offset_impl(idx, k_head_idx, elem_offset,
961+
k_nope_in_stride, k_nope_in_stride_h);
962+
}
958963

959-
vec_t<float, vec_size> k_nope_vec;
960-
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
964+
vec_t<float, vec_size> k_nope_vec;
965+
k_nope_vec.cast_load(k_nope_in_ptr + tx * vec_size);
961966
#pragma unroll
962-
for (uint32_t i = 0; i < vec_size; ++i) {
963-
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
964-
}
967+
for (uint32_t i = 0; i < vec_size; ++i) {
968+
k_nope_vec[i] = k_nope_vec[i] * quant_scale_kv;
969+
}
965970

966-
if constexpr (IS_MLA) {
967-
QuantType* ckv_ptr =
968-
paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
969-
k_nope_vec.cast_store(ckv_ptr);
970-
} else {
971-
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
972-
rope_dim + elem_offset + tx * vec_size);
973-
k_nope_vec.cast_store(k_ptr);
974-
}
971+
if constexpr (IS_MLA) {
972+
QuantType* ckv_ptr =
973+
paged_kv_like.get_ckv_ptr(page_iter, entry_idx, elem_offset + tx * vec_size);
974+
k_nope_vec.cast_store(ckv_ptr);
975+
} else {
976+
QuantType* k_ptr = paged_kv_like.get_k_ptr(page_iter, k_head_idx, entry_idx,
977+
rope_dim + elem_offset + tx * vec_size);
978+
k_nope_vec.cast_store(k_ptr);
979+
}
975980

976-
} else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
977-
// ============ V processing & Cache Append (GQA/MHA only) ============
978-
if constexpr (!IS_MLA) {
979-
uint32_t kv_head_idx = by - k_nope_end;
980-
DType* v_in_ptr =
981-
v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h);
982-
// Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
983-
uint32_t head_dim_total = rope_dim + no_rope_dim;
984-
uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size;
981+
} else if (by < k_nope_end + (IS_MLA ? 0u : num_kv_heads)) {
982+
// ============ V processing & Cache Append (GQA/MHA only) ============
983+
if constexpr (!IS_MLA) {
984+
uint32_t kv_head_idx = by - k_nope_end;
985+
DType* v_in_ptr =
986+
v_in + get_elem_offset_impl(idx, kv_head_idx, 0, v_in_stride, v_in_stride_h);
987+
// Cover the full head dimension (rope_dim + no_rope_dim) in chunks of rope_chunk_size
988+
uint32_t head_dim_total = rope_dim + no_rope_dim;
989+
uint32_t v_chunks = (head_dim_total + rope_chunk_size - 1) / rope_chunk_size;
985990
#pragma unroll 1
986-
for (uint32_t j = 0; j < v_chunks; ++j) {
987-
uint32_t v_elem_offset = j * rope_chunk_size;
988-
if (v_elem_offset + tx * vec_size < head_dim_total) {
989-
vec_t<float, vec_size> v_vec;
990-
v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
991+
for (uint32_t j = 0; j < v_chunks; ++j) {
992+
uint32_t v_elem_offset = j * rope_chunk_size;
993+
if (v_elem_offset + tx * vec_size < head_dim_total) {
994+
vec_t<float, vec_size> v_vec;
995+
v_vec.cast_load(v_in_ptr + v_elem_offset + tx * vec_size);
991996
#pragma unroll
992-
for (uint32_t i = 0; i < vec_size; ++i) {
993-
v_vec[i] = v_vec[i] * quant_scale_kv;
997+
for (uint32_t i = 0; i < vec_size; ++i) {
998+
v_vec[i] = v_vec[i] * quant_scale_kv;
999+
}
1000+
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
1001+
v_elem_offset + tx * vec_size);
1002+
v_vec.cast_store(v_ptr);
9941003
}
995-
QuantType* v_ptr = paged_kv_like.get_v_ptr(page_iter, kv_head_idx, entry_idx,
996-
v_elem_offset + tx * vec_size);
997-
v_vec.cast_store(v_ptr);
9981004
}
9991005
}
1000-
}
10011006

1002-
} else {
1003-
// ============ Q Non-RoPE processing ============
1004-
// MLA has no V section, so Q-nope starts immediately after K-nope.
1005-
// GQA/MHA has a V section of length num_kv_heads blocks.
1006-
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
1007-
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
1008-
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
1009-
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
1010-
1011-
DType* q_nope_in_ptr =
1012-
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
1013-
q_nope_in_stride_h);
1014-
QuantType* q_nope_out_ptr =
1015-
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
1016-
q_nope_out_stride_h);
1017-
1018-
vec_t<float, vec_size> q_nope_vec;
1019-
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
1007+
} else {
1008+
// ============ Q Non-RoPE processing ============
1009+
// MLA has no V section, so Q-nope starts immediately after K-nope.
1010+
// GQA/MHA has a V section of length num_kv_heads blocks.
1011+
uint32_t q_nope_start = k_nope_end + (IS_MLA ? 0u : num_kv_heads);
1012+
uint32_t q_head_idx = (by - q_nope_start) / no_rope_chunks;
1013+
uint32_t nope_chunk_idx = (by - q_nope_start) % no_rope_chunks;
1014+
uint32_t elem_offset = nope_chunk_idx * rope_chunk_size;
1015+
1016+
DType* q_nope_in_ptr =
1017+
q_nope_in + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_in_stride_n,
1018+
q_nope_in_stride_h);
1019+
QuantType* q_nope_out_ptr =
1020+
q_nope_out + get_elem_offset_impl(idx, q_head_idx, elem_offset, q_nope_out_stride_n,
1021+
q_nope_out_stride_h);
1022+
1023+
vec_t<float, vec_size> q_nope_vec;
1024+
q_nope_vec.cast_load(q_nope_in_ptr + tx * vec_size);
10201025
#pragma unroll
1021-
for (uint32_t i = 0; i < vec_size; ++i) {
1022-
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
1026+
for (uint32_t i = 0; i < vec_size; ++i) {
1027+
q_nope_vec[i] = q_nope_vec[i] * quant_scale_q;
1028+
}
1029+
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
10231030
}
1024-
q_nope_vec.cast_store(q_nope_out_ptr + tx * vec_size);
10251031
}
10261032
}
10271033
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

0 commit comments

Comments
 (0)