Skip to content

Commit c3aceb6

Browse files
authored
[Models][OP][Optimization] Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM (#6689)
* Support DeepSeek-v3.2 model, integrate DSA & Indexer architecture with FlashMLA/DeepGEMM
1 parent 25c4793 commit c3aceb6

File tree

22 files changed

+8021
-142
lines changed

22 files changed

+8021
-142
lines changed

custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cu

Lines changed: 616 additions & 0 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/append_attn/ds_mla_cache_kernel.cuh

Lines changed: 548 additions & 0 deletions
Large diffs are not rendered by default.

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,47 @@ std::vector<paddle::Tensor> get_attn_mask_q(
11331133
const paddle::optional<paddle::Tensor>& attn_mask_kv,
11341134
const int kv_token_num);
11351135

1136+
void RadixTopkRaggedTransform(
1137+
paddle::Tensor& input,
1138+
paddle::Tensor& output_indices,
1139+
const paddle::Tensor& offsets,
1140+
paddle::Tensor& lengths,
1141+
paddle::optional<paddle::Tensor>& seq_len_decoder,
1142+
paddle::optional<paddle::Tensor>& batch_id_per_token,
1143+
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
1144+
int top_k,
1145+
int q_num_heads = 0);
1146+
1147+
std::vector<paddle::Tensor> DSMLAWriteCacheKernel(
1148+
const paddle::Tensor& kv_nope,
1149+
const paddle::Tensor& kv_pe,
1150+
const paddle::Tensor& kv_cache,
1151+
const paddle::Tensor& slot_mapping,
1152+
const paddle::Tensor& seq_lens,
1153+
const paddle::Tensor& seq_lens_decoder,
1154+
const paddle::Tensor& batch_id_per_token,
1155+
const paddle::Tensor& cu_seqlens_q,
1156+
const paddle::Tensor& block_tables,
1157+
const paddle::optional<paddle::Tensor>& kv_signal_data,
1158+
const paddle::optional<paddle::Tensor>& scale,
1159+
const std::string& cache_quant_type_str,
1160+
const int max_seq_len,
1161+
const bool is_prefill);
1162+
1163+
std::vector<paddle::Tensor> IndexerKQuantAndCacheKernel(
1164+
const paddle::Tensor& k,
1165+
const paddle::Tensor& kv_cache,
1166+
const paddle::Tensor& slot_mapping,
1167+
const int64_t quant_block_size,
1168+
const std::string& scale_fmt);
1169+
1170+
std::vector<paddle::Tensor> CpGatherIndexerKQuantCacheKernel(
1171+
const paddle::Tensor& kv_cache,
1172+
paddle::Tensor& dst_k,
1173+
paddle::Tensor& dst_scale,
1174+
const paddle::Tensor& block_table,
1175+
const paddle::Tensor& cu_seq_lens);
1176+
11361177
PYBIND11_MODULE(fastdeploy_ops, m) {
11371178
m.def("get_expert_token_num",
11381179
&GetExpertTokenNum,
@@ -1736,4 +1777,18 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
17361777
m.def("custom_numpy_to_tensor",
17371778
&CustomNumpyToTensor,
17381779
"custom_numpy_to_tensor function");
1780+
1781+
m.def("radix_topk_ragged_transform",
1782+
&RadixTopkRaggedTransform,
1783+
"radix_topk_ragged_transform function");
1784+
1785+
m.def("dsk_attn_write_cache", &DSMLAWriteCacheKernel, "dsk_attn_write_cache");
1786+
1787+
m.def("indexer_k_quant_and_cache",
1788+
&IndexerKQuantAndCacheKernel,
1789+
"indexer_k_quant_and_cache");
1790+
1791+
m.def("cp_gather_indexer_k_quant_cache",
1792+
&CpGatherIndexerKQuantCacheKernel,
1793+
"cp_gather_indexer_k_quant_cache");
17391794
}

custom_ops/gpu_ops/helper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,8 @@ inline const char *getEnvVar(const char *varName) {
662662

663663
inline bool checkAttentionBackend() {
664664
const char *backend = getEnvVar("FD_ATTENTION_BACKEND");
665-
if (backend && std::strcmp(backend, "MLA_ATTN") == 0) {
665+
if (backend && (std::strcmp(backend, "MLA_ATTN") == 0 ||
666+
std::strcmp(backend, "DSA_ATTN") == 0)) {
666667
return true;
667668
}
668669
return false;
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
/*
2+
* Copyright (c) 2024 by FlashInfer team.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#ifndef FLASHINFER_EXCEPTION_H_
17+
#define FLASHINFER_EXCEPTION_H_
18+
19+
#include <exception>
20+
#include <iostream>
21+
#include <sstream>
22+
23+
#define FLASHINFER_ERROR(message) \
24+
throw flashinfer::Error(__FUNCTION__, __FILE__, __LINE__, message)
25+
26+
// Base case for empty arguments
27+
inline void write_to_stream(std::ostringstream& oss) {
28+
// No-op for empty arguments
29+
}
30+
31+
template <typename T>
32+
void write_to_stream(std::ostringstream& oss, T&& val) {
33+
oss << std::forward<T>(val);
34+
}
35+
36+
template <typename T, typename... Args>
37+
void write_to_stream(std::ostringstream& oss, T&& val, Args&&... args) {
38+
oss << std::forward<T>(val) << " ";
39+
write_to_stream(oss, std::forward<Args>(args)...);
40+
}
41+
42+
// Helper macro to handle empty __VA_ARGS__
43+
#define FLASHINFER_CHECK_IMPL(condition, message) \
44+
if (!(condition)) { \
45+
FLASHINFER_ERROR(message); \
46+
}
47+
48+
// Main macro that handles both cases
49+
#define FLASHINFER_CHECK(condition, ...) \
50+
do { \
51+
if (!(condition)) { \
52+
std::ostringstream oss; \
53+
write_to_stream(oss, ##__VA_ARGS__); \
54+
std::string msg = oss.str(); \
55+
if (msg.empty()) { \
56+
msg = "Check failed: " #condition; \
57+
} \
58+
FLASHINFER_ERROR(msg); \
59+
} \
60+
} while (0)
61+
62+
// Warning macro
63+
#define FLASHINFER_WARN(...) \
64+
do { \
65+
std::ostringstream oss; \
66+
write_to_stream(oss, ##__VA_ARGS__); \
67+
std::string msg = oss.str(); \
68+
if (msg.empty()) { \
69+
msg = "Warning triggered"; \
70+
} \
71+
flashinfer::Warning(__FUNCTION__, __FILE__, __LINE__, msg).emit(); \
72+
} while (0)
73+
74+
namespace flashinfer {
75+
class Error : public std::exception {
76+
private:
77+
std::string message_;
78+
79+
public:
80+
Error(const std::string& func,
81+
const std::string& file,
82+
int line,
83+
const std::string& message) {
84+
std::ostringstream oss;
85+
oss << "Error in function '" << func << "' "
86+
<< "at " << file << ":" << line << ": " << message;
87+
message_ = oss.str();
88+
}
89+
90+
virtual const char* what() const noexcept override {
91+
return message_.c_str();
92+
}
93+
};
94+
95+
class Warning {
96+
private:
97+
std::string message_;
98+
99+
public:
100+
Warning(const std::string& func,
101+
const std::string& file,
102+
int line,
103+
const std::string& message) {
104+
std::ostringstream oss;
105+
oss << "Warning in function '" << func << "' "
106+
<< "at " << file << ":" << line << ": " << message;
107+
message_ = oss.str();
108+
}
109+
110+
void emit() const { std::cerr << message_ << std::endl; }
111+
};
112+
113+
} // namespace flashinfer
114+
115+
#endif // FLASHINFER_EXCEPTION_H_
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
2+
#include "indexer_topk.cuh"
3+
4+
#include <cuda_bf16.h>
5+
6+
#include "paddle/extension.h"
7+
8+
#include "paddle/phi/api/ext/op_meta_info.h"
9+
#include "paddle/utils/optional.h"
10+
11+
#include "append_attn/mem_util.cuh"
12+
#include "append_attn/mma_tensor_op.cuh"
13+
#include "append_attn/utils.cuh"
14+
#include "helper.h"
15+
16+
// using namespace flashinfer;
17+
#ifndef PD_BUILD_STATIC_OP
18+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
19+
#endif
20+
21+
template <paddle::DataType T>
22+
cudaError_t DispatchTopK(paddle::Tensor& input,
23+
paddle::Tensor& output_indices,
24+
const paddle::Tensor& offsets,
25+
paddle::Tensor& lengths,
26+
uint32_t num_rows,
27+
const int32_t* seq_len_decoder,
28+
const int32_t* batch_id_per_token,
29+
uint32_t top_k,
30+
uint32_t q_num_heads,
31+
uint32_t max_len,
32+
flashinfer::sampling::RadixRowState* row_states_ptr,
33+
cudaStream_t stream) {
34+
typedef PDTraits<T> traits_;
35+
typedef typename traits_::DataType DataType_;
36+
typedef typename traits_::data_t data_t;
37+
38+
cudaError_t status;
39+
status =
40+
flashinfer::sampling::TopKRaggedTransformDispatch<DataType_, int32_t>(
41+
reinterpret_cast<DataType_*>(input.data<data_t>()),
42+
static_cast<int32_t*>(output_indices.data<int32_t>()),
43+
static_cast<const int32_t*>(offsets.data<int32_t>()),
44+
static_cast<int32_t*>(lengths.data<int32_t>()),
45+
num_rows,
46+
seq_len_decoder,
47+
batch_id_per_token,
48+
static_cast<uint32_t>(top_k),
49+
static_cast<uint32_t>(q_num_heads),
50+
max_len,
51+
row_states_ptr,
52+
stream);
53+
return status;
54+
}
55+
56+
void RadixTopkRaggedTransform(
57+
paddle::Tensor& input,
58+
paddle::Tensor& output_indices,
59+
const paddle::Tensor& offsets,
60+
paddle::Tensor& lengths,
61+
paddle::optional<paddle::Tensor>& seq_len_decoder,
62+
paddle::optional<paddle::Tensor>& batch_id_per_token,
63+
paddle::optional<paddle::Tensor>& maybe_row_states_buffer,
64+
int top_k,
65+
int q_num_heads = 0) {
66+
// CHECK_INPUT(input);
67+
// CHECK_INPUT(output_indices);
68+
// CHECK_INPUT(offsets);
69+
// CHECK_INPUT(lengths);
70+
// CHECK_DIM(2, input); // input: (num_rows, max_len)
71+
// CHECK_DIM(2, output_indices); // output_indices: (num_rows, top_k)
72+
// CHECK_DIM(1, offsets); // offsets: (num_rows,)
73+
// CHECK_DIM(1, lengths); // lengths: (num_rows,)
74+
75+
unsigned int num_rows = input.dims()[0];
76+
unsigned int max_len = input.dims()[1];
77+
78+
static cudaStream_t stream = input.stream();
79+
cudaError_t status;
80+
auto input_dtype = input.dtype();
81+
82+
// sampling::RadixRowState* row_states_ptr = nullptr;
83+
// if (maybe_row_states_buffer.has_value()) {
84+
// row_states_ptr =
85+
// static_cast<sampling::RadixRowState*>(maybe_row_states_buffer.value().data_ptr());
86+
// }
87+
flashinfer::sampling::RadixRowState* row_states_ptr = nullptr;
88+
if (maybe_row_states_buffer) {
89+
auto& tensor_ptr = maybe_row_states_buffer.get();
90+
row_states_ptr = reinterpret_cast<flashinfer::sampling::RadixRowState*>(
91+
tensor_ptr.data<uint8_t>());
92+
}
93+
94+
const int32_t* seq_len_ptr = nullptr;
95+
if (seq_len_decoder) {
96+
auto& tensor_ptr = seq_len_decoder.get();
97+
seq_len_ptr = static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
98+
}
99+
const int32_t* batch_id_per_token_ptr = nullptr;
100+
if (batch_id_per_token) {
101+
auto& tensor_ptr = batch_id_per_token.get();
102+
batch_id_per_token_ptr =
103+
static_cast<const int32_t*>(tensor_ptr.data<int32_t>());
104+
}
105+
106+
if (input_dtype == paddle::DataType::BFLOAT16) {
107+
status = DispatchTopK<paddle::DataType::BFLOAT16>(input,
108+
output_indices,
109+
offsets,
110+
lengths,
111+
num_rows,
112+
seq_len_ptr,
113+
batch_id_per_token_ptr,
114+
top_k,
115+
q_num_heads,
116+
max_len,
117+
row_states_ptr,
118+
stream);
119+
} else if (input_dtype == paddle::DataType::FLOAT32) {
120+
status = DispatchTopK<paddle::DataType::FLOAT32>(input,
121+
output_indices,
122+
offsets,
123+
lengths,
124+
num_rows,
125+
seq_len_ptr,
126+
batch_id_per_token_ptr,
127+
top_k,
128+
q_num_heads,
129+
max_len,
130+
row_states_ptr,
131+
stream);
132+
}
133+
}
134+
135+
PD_BUILD_STATIC_OP(radix_topk_ragged_transform)
136+
.Inputs({"input",
137+
"output_indices",
138+
"offsets",
139+
"lengths",
140+
paddle::Optional("seq_len_decoder"),
141+
paddle::Optional("batch_id_per_token"),
142+
paddle::Optional("maybe_row_states_buffer")})
143+
.Attrs({"top_k : int", "q_num_heads : int"})
144+
.SetKernelFn(PD_KERNEL(RadixTopkRaggedTransform));

0 commit comments

Comments
 (0)