Skip to content

Commit 38d3e28

Browse files
committed
refactor: refactor codes based on the review comments.
1 parent 40ea638 commit 38d3e28

File tree

5 files changed

+71
-40
lines changed

5 files changed

+71
-40
lines changed

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -911,9 +911,21 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
911911
std::vector<int32_t> dp_is_decode(dp_size_, 0);
912912
bool global_empty_kv_cache = true;
913913

914-
// flags to detect mixed usage across DP ranks
914+
// Flags to detect mixed forward type usage across data parallel ranks.
915+
// These flags are set during the loop below to track whether different ranks
916+
// have different forward types, which requires setting the global forward
917+
// type to MIXED to ensure consistent processing across all ranks.
918+
919+
// Indicates if at least one DP rank has a DECODE forward type.
915920
bool has_decode = false;
916-
bool has_prefill = false; // Includes PREFILL and CHUNKED_PREFILL
921+
// Indicates if at least one DP rank has a PREFILL or CHUNKED_PREFILL forward
922+
// type (processing multiple tokens in parallel, typically used for initial
923+
// prompt processing or chunked prompt handling).
924+
bool has_prefill = false;
925+
// Indicates if at least one DP rank already has a MIXED forward type
926+
// (contains both decode and prefill operations within the same batch). If
927+
// true, the global forward type must be set to MIXED regardless of other
928+
// flags.
917929
bool has_mixed = false;
918930

919931
// NOTE: when enable dp, we need to check the forward type of each batch
@@ -960,8 +972,9 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
960972
// If not mixed, use the detected uniform type
961973
global_forward_type = representative_type;
962974
} else {
963-
// All empty
964-
global_forward_type = BatchForwardType::EMPTY;
975+
// this should never happen
976+
LOG(FATAL)
977+
<< "All batch forward type are empty, which should never happen.";
965978
}
966979

967980
// eplb related

xllm/core/layers/common/tests/indexer_tests.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -416,16 +416,16 @@ TEST_F(IndexerTest, Bfloat16PrefillVerifyPrecision) {
416416
run_indexer_test(batch_size, max_query_len, is_prefill);
417417

418418
// Verify output shapes
419-
ASSERT_EQ(new_block_tables.sizes().size(), 2)
419+
CHECK_EQ(new_block_tables.sizes().size(), 2)
420420
<< "new_block_tables should be 2D tensor";
421-
ASSERT_EQ(new_context_lens.sizes().size(), 1)
421+
CHECK_EQ(new_context_lens.sizes().size(), 1)
422422
<< "new_context_lens should be 1D tensor";
423-
ASSERT_EQ(new_block_tables.size(0), num_tokens) << "Batch size should match";
424-
ASSERT_EQ(new_block_tables.size(1), index_topk) << "Top-k should match";
423+
CHECK_EQ(new_block_tables.size(0), num_tokens) << "Batch size should match";
424+
CHECK_EQ(new_block_tables.size(1), index_topk) << "Top-k should match";
425425

426426
// Verify that the first value in new_block_tables is 1 (calculated via vLLM
427427
// MLU)
428-
ASSERT_EQ(new_block_tables.index({0, 0}).item<int64_t>(), 1)
428+
EXPECT_EQ(new_block_tables.index({0, 0}).item<int64_t>(), 1)
429429
<< "The first value in new_block_tables should be 1";
430430

431431
// Test bfloat16 mode (non-quantized) - prefill phase
@@ -439,16 +439,16 @@ TEST_F(IndexerTest, Bfloat16PrefillVerifyPrecision) {
439439
run_indexer_test(batch_size, max_query_len, is_prefill);
440440

441441
// Verify output shapes
442-
ASSERT_EQ(new_block_tables.sizes().size(), 2)
442+
CHECK_EQ(new_block_tables.sizes().size(), 2)
443443
<< "new_block_tables should be 2D tensor";
444-
ASSERT_EQ(new_context_lens.sizes().size(), 1)
444+
CHECK_EQ(new_context_lens.sizes().size(), 1)
445445
<< "new_context_lens should be 1D tensor";
446-
ASSERT_EQ(new_block_tables.size(0), num_tokens) << "Batch size should match";
447-
ASSERT_EQ(new_block_tables.size(1), index_topk) << "Top-k should match";
446+
CHECK_EQ(new_block_tables.size(0), num_tokens) << "Batch size should match";
447+
CHECK_EQ(new_block_tables.size(1), index_topk) << "Top-k should match";
448448

449449
// Verify that the first value in new_block_tables is 1 (calculated via vLLM
450450
// MLU)
451-
ASSERT_EQ(new_block_tables.index({0, 0}).item<int64_t>(), 1)
451+
EXPECT_EQ(new_block_tables.index({0, 0}).item<int64_t>(), 1)
452452
<< "The first value in new_block_tables should be 1";
453453
}
454454

@@ -566,9 +566,9 @@ TEST_F(IndexerTest, Bfloat16ChunkedPrefillVerifyPrecision) {
566566

567567
// Validations
568568
// Shape Verification
569-
ASSERT_EQ(new_block_tables.dim(), 2);
570-
ASSERT_EQ(new_block_tables.size(0), num_new_tokens); // [batch * current_len]
571-
ASSERT_EQ(new_block_tables.size(1), index_topk);
569+
CHECK_EQ(new_block_tables.dim(), 2);
570+
CHECK_EQ(new_block_tables.size(0), num_new_tokens); // [batch * current_len]
571+
CHECK_EQ(new_block_tables.size(1), index_topk);
572572

573573
// Value Verification
574574
auto top1_indices = new_block_tables.index({torch::indexing::Slice(), 0})
@@ -582,9 +582,9 @@ TEST_F(IndexerTest, Bfloat16ChunkedPrefillVerifyPrecision) {
582582
// The expected value is calculated via vLLM MLU
583583
int64_t expected_sum = 12288;
584584
int64_t expected_max = 192;
585-
ASSERT_EQ(top1_sum, expected_sum)
585+
EXPECT_EQ(top1_sum, expected_sum)
586586
<< "top-1 block index sum does not match ground truth";
587-
ASSERT_EQ(top1_max, expected_max)
587+
EXPECT_EQ(top1_max, expected_max)
588588
<< "top-1 block index max does not match ground truth";
589589
}
590590

xllm/models/llm/deepseek_v2.h

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,24 +72,6 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
7272
auto model_args = context.get_model_args();
7373
auto parallel_args = context.get_parallel_args();
7474

75-
// Check if prefix cache or chunked prefill is enabled for unsupported
76-
// models
77-
const std::string& model_type = model_args.model_type();
78-
// deepseek_v32 has index_n_heads > 0 (default 64), while deepseek_v3 has 0
79-
bool is_deepseek_v32 =
80-
model_type == "deepseek_v3" && model_args.index_n_heads() > 0;
81-
if (model_type == "deepseek_v2" ||
82-
(model_type == "deepseek_v3" && !is_deepseek_v32)) {
83-
// Note: Only deepseek_v32 supports prefix cache and chunked prefill at
84-
// present.
85-
CHECK(!FLAGS_enable_prefix_cache)
86-
<< "deepseek_v2 and deepseek_v3 have not supported "
87-
"enable_prefix_cache yet. Please disable it.";
88-
CHECK(!FLAGS_enable_chunked_prefill)
89-
<< "deepseek_v2 and deepseek_v3 have not supported "
90-
"enable_chunked_prefill yet. Please disable it.";
91-
}
92-
9375
blocks_ = register_module("layers", torch::nn::ModuleList());
9476
layers_.reserve(model_args.n_layers());
9577

@@ -194,7 +176,16 @@ class DeepseekV2ForCausalLMImpl
194176
: public LlmForCausalLMImplBase<DeepseekV2Model> {
195177
public:
196178
DeepseekV2ForCausalLMImpl(const ModelContext& context)
197-
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
179+
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {
180+
// Check if prefix cache or chunked prefill is enabled for unsupported
181+
// models
182+
CHECK(!FLAGS_enable_prefix_cache)
183+
<< "deepseek_v2 have not supported "
184+
"enable_prefix_cache yet. Please disable it.";
185+
CHECK(!FLAGS_enable_chunked_prefill)
186+
<< "deepseek_v2 have not supported "
187+
"enable_chunked_prefill yet. Please disable it.";
188+
}
198189
};
199190
TORCH_MODULE(DeepseekV2ForCausalLM);
200191

xllm/models/llm/deepseek_v3.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,26 @@ limitations under the License.
1818
#include "deepseek_v2.h"
1919

2020
namespace xllm {
21+
22+
class DeepseekV3ForCausalLMImpl
23+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
24+
public:
25+
DeepseekV3ForCausalLMImpl(const ModelContext& context)
26+
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {
27+
// Check if prefix cache or chunked prefill is enabled for unsupported
28+
// models
29+
CHECK(!FLAGS_enable_prefix_cache)
30+
<< "deepseek_v3 have not supported "
31+
"enable_prefix_cache yet. Please disable it.";
32+
CHECK(!FLAGS_enable_chunked_prefill)
33+
<< "deepseek_v3 have not supported "
34+
"enable_chunked_prefill yet. Please disable it.";
35+
}
36+
};
37+
TORCH_MODULE(DeepseekV3ForCausalLM);
38+
2139
// register the causal model
22-
REGISTER_CAUSAL_MODEL(deepseek_v3, DeepseekV2ForCausalLM);
40+
REGISTER_CAUSAL_MODEL(deepseek_v3, DeepseekV3ForCausalLM);
2341
// register the model args
2442
// example config:
2543
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json

xllm/models/llm/deepseek_v32.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,17 @@ limitations under the License.
1818
#include "deepseek_v2.h"
1919

2020
namespace xllm {
21+
22+
class DeepseekV32ForCausalLMImpl
23+
: public LlmForCausalLMImplBase<DeepseekV2Model> {
24+
public:
25+
DeepseekV32ForCausalLMImpl(const ModelContext& context)
26+
: LlmForCausalLMImplBase<DeepseekV2Model>(context) {}
27+
};
28+
TORCH_MODULE(DeepseekV32ForCausalLM);
29+
2130
// register the causal model
22-
REGISTER_CAUSAL_MODEL(deepseek_v32, DeepseekV2ForCausalLM);
31+
REGISTER_CAUSAL_MODEL(deepseek_v32, DeepseekV32ForCausalLM);
2332
// register the model args
2433
// example config:
2534
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/config.json

0 commit comments

Comments
 (0)