Skip to content

Commit b0f3758

Browse files
committed
feat: add XAttention support for Qwen3 generative recommendation.
1 parent 2ed5d70 commit b0f3758

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2818
-37
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* https://github.com/jd-opensource/xllm/blob/main/LICENSE
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License.
14+
*==============================================================================*/
15+
16+
#pragma once
17+
18+
namespace xllm {
19+
20+
namespace detail {
21+
// Thread-local flag indicating whether current thread is inside ACL graph
22+
// capture region.
23+
inline thread_local bool g_in_acl_graph_capture = false;
24+
} // namespace detail
25+
26+
// Return true if current thread is inside ACL graph capture.
27+
inline bool in_acl_graph_capture() { return detail::g_in_acl_graph_capture; }
28+
29+
// RAII guard that marks current thread as inside ACL graph capture for its
30+
// lifetime.
31+
class AclGraphCaptureGuard {
32+
public:
33+
AclGraphCaptureGuard() { detail::g_in_acl_graph_capture = true; }
34+
~AclGraphCaptureGuard() { detail::g_in_acl_graph_capture = false; }
35+
};
36+
37+
} // namespace xllm

xllm/core/common/global_flags.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ DEFINE_int64(buffer_size_per_seq,
406406
0,
407407
"Buffer size per sequence in bytes, default 0.");
408408

409+
DEFINE_int64(max_token_per_req, 1024, "Max token per request, default 0.");
410+
409411
// --- beam search config ---
410412

411413
DEFINE_bool(enable_beam_search_kernel,
@@ -489,3 +491,10 @@ DEFINE_string(
489491
"ATB",
490492
"NPU kernel backend. Supported options: ATB, TORCH. Default is ATB.");
491493
#endif
494+
DEFINE_int32(beam_width, 1, "Beam width for beam search.");
495+
496+
// --- multi-step decode config ---
497+
DEFINE_int32(max_decode_rounds,
498+
0,
499+
"Maximum number of decode rounds for multi-step decoding. 0 means "
500+
"disabled.");

xllm/core/common/global_flags.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,10 @@ DECLARE_int32(max_global_tpot_ms);
201201

202202
DECLARE_int32(max_requests_per_batch);
203203

204+
DECLARE_int32(max_decode_rounds);
205+
206+
DECLARE_int32(beam_width);
207+
204208
DECLARE_bool(enable_continuous_kvcache);
205209

206210
DECLARE_int64(phy_page_granularity_size);
@@ -240,3 +244,4 @@ DECLARE_bool(enable_constrained_decoding);
240244
#if defined(USE_NPU)
241245
DECLARE_string(npu_kernel_backend);
242246
#endif
247+
DECLARE_int64(max_token_per_req);

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 93 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,9 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
263263
int index_n_head = 1;
264264
index_slot_size = dtype_size * index_n_head * args_.index_head_dim();
265265
}
266+
if (FLAGS_max_decode_rounds > 0) {
267+
slot_size *= FLAGS_max_decode_rounds;
268+
}
266269
kv_cache_cap.slot_size = slot_size;
267270
kv_cache_cap.index_slot_size = index_slot_size;
268271
kv_cache_cap.n_layers = args_.n_layers();
@@ -309,10 +312,23 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
309312
kv_cache_shape.emplace_back(std::vector<int64_t>{
310313
kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()});
311314
} else {
312-
kv_cache_shape.emplace_back(std::vector<int64_t>{
313-
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
314-
kv_cache_shape.emplace_back(std::vector<int64_t>{
315-
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
315+
if (FLAGS_max_decode_rounds > 0) {
316+
kv_cache_shape.emplace_back(std::vector<int64_t>{kv_cache_cap.n_blocks,
317+
block_size,
318+
n_local_kv_heads_,
319+
FLAGS_max_decode_rounds,
320+
head_dim_});
321+
kv_cache_shape.emplace_back(std::vector<int64_t>{kv_cache_cap.n_blocks,
322+
block_size,
323+
n_local_kv_heads_,
324+
FLAGS_max_decode_rounds,
325+
head_dim_});
326+
} else {
327+
kv_cache_shape.emplace_back(std::vector<int64_t>{
328+
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
329+
kv_cache_shape.emplace_back(std::vector<int64_t>{
330+
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
331+
}
316332
}
317333
if (enable_lighting_indexer) {
318334
kv_cache_shape.emplace_back(std::vector<int64_t>{
@@ -340,10 +356,17 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
340356

341357
// initialize block manager
342358
BlockManagerPool::Options options;
359+
// simplify when use max_decode round.
360+
bool enable_prefix_cache =
361+
options_.enable_prefix_cache() && FLAGS_max_decode_rounds == 0;
362+
bool enable_kvcache_store =
363+
options_.enable_kvcache_store() && FLAGS_max_decode_rounds == 0;
364+
auto host_blocks_factor =
365+
FLAGS_max_decode_rounds == 0 ? options_.host_blocks_factor() : 0.0;
343366
options.num_blocks(kv_cache_cap.n_blocks)
344367
.block_size(block_size)
345-
.host_num_blocks(kv_cache_cap.n_blocks * options_.host_blocks_factor())
346-
.enable_prefix_cache(options_.enable_prefix_cache())
368+
.host_num_blocks(kv_cache_cap.n_blocks * host_blocks_factor)
369+
.enable_prefix_cache(enable_prefix_cache)
347370
.enable_disagg_pd(options_.enable_disagg_pd())
348371
.enable_cache_upload(options_.enable_cache_upload())
349372
.enable_kvcache_store(options_.enable_kvcache_store());
@@ -694,11 +717,59 @@ bool LLMEngine::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
694717
return true;
695718
}
696719

720+
ForwardOutput LLMEngine::step_multi_round(std::vector<Batch>& batch) {
721+
Timer timer;
722+
DCHECK(dp_size_ == batch.size())
723+
<< "Split DP batch failed with dp_size as " << dp_size_
724+
<< " and actual batch size as " << batch.size() << ".";
725+
auto batched_raw_forward_inputs = prepare_inputs(batch);
726+
DCHECK(dp_size_ == batched_raw_forward_inputs.size())
727+
<< "The processed raw forward inputs size "
728+
<< batched_raw_forward_inputs.size() << " is not equal to dp size "
729+
<< dp_size_ << ".";
730+
std::vector<folly::SemiFuture<std::optional<RawForwardOutput>>> futures;
731+
futures.reserve(worker_clients_num_);
732+
for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) {
733+
auto dp_rank = worker_rank / dp_local_tp_size_;
734+
futures.emplace_back(worker_clients_[worker_rank]->step_async(
735+
batched_raw_forward_inputs[dp_rank]));
736+
}
737+
auto results = folly::collectAll(futures).get();
738+
size_t dp_rank = 0;
739+
for (auto worker_rank = 0; worker_rank < worker_clients_num_;
740+
worker_rank += dp_local_tp_size_) {
741+
auto result = results[worker_rank].value();
742+
if (result.has_value()) {
743+
if (result.value().outputs.empty() && layer_forward_interrupted_) {
744+
throw ForwardInterruptedException();
745+
}
746+
auto& raw = result.value();
747+
if (!raw.beam_sequence_group.empty()) {
748+
batch[dp_rank].process_beam_sequence_group(raw);
749+
} else {
750+
batch[dp_rank].process_decode_beam_search_output(raw, false);
751+
}
752+
} else {
753+
LOG(FATAL) << "Failed to execute model, result has no value";
754+
}
755+
++dp_rank;
756+
}
757+
COUNTER_ADD(engine_latency_seconds, timer.elapsed_seconds());
758+
// finish all sequences in the batch
759+
for (auto& b : batch) {
760+
b.finish();
761+
}
762+
return {};
763+
}
764+
697765
ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
698766
if (worker_clients_.empty()) {
699767
// empty worker, return
700768
return {};
701769
}
770+
if (FLAGS_max_decode_rounds > 0) {
771+
return step_multi_round(batch);
772+
}
702773
Timer timer;
703774
DCHECK(dp_size_ == batch.size())
704775
<< "Split DP batch failed with dp_size as " << dp_size_
@@ -735,9 +806,14 @@ ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
735806
if (result.value().outputs.empty() && layer_forward_interrupted_) {
736807
throw ForwardInterruptedException();
737808
}
738-
// if src_seq_idxes is not empty, skip sample output processing and
739-
// process beam search output instead
740-
if (result.value().src_seq_idxes.size() == 0) {
809+
// If both src_seq_idxes and out_tokens are populated, this step used
810+
// the beam search kernel; otherwise, fall back to normal sample output
811+
// processing. Note that proto serialization always fills src_seq_idxes
812+
// with a fallback [0..num_seqs) when it is undefined, so we must also
813+
// check out_tokens to distinguish real beam-kernel outputs.
814+
if (result.value().src_seq_idxes.empty() ||
815+
result.value().out_tokens.empty() ||
816+
!FLAGS_enable_beam_search_kernel) {
741817
// set second input param enable_schedule_overlap to false,
742818
// if it's not enabled, process_sample_output will append the real
743819
// token, if it's enabled, this false here will append the fake token in
@@ -868,8 +944,14 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
868944

869945
// build model input for every single micro batch
870946
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
871-
batched_inputs.emplace_back(std::move(
872-
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
947+
if (FLAGS_max_decode_rounds > 0) {
948+
batched_inputs.emplace_back(
949+
std::move(batch[dp_rank].prepare_multi_step_forward_input(
950+
args_, threadpool_.get())));
951+
} else {
952+
batched_inputs.emplace_back(std::move(
953+
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
954+
}
873955
dp_global_token_nums[dp_rank] =
874956
batched_inputs[dp_rank].flatten_tokens_vec.size();
875957
global_empty_kv_cache =

xllm/core/distributed_runtime/llm_engine.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class LLMEngine : public Engine {
5050
virtual ~LLMEngine() = default;
5151

5252
ForwardOutput step(std::vector<Batch>& batch) override;
53+
ForwardOutput step_multi_round(std::vector<Batch>& batch);
5354

5455
const runtime::Options& options() const { return options_; }
5556

xllm/core/distributed_runtime/llm_master.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ LLMMaster::LLMMaster(const Options& options)
9898
.max_global_tpot_ms(options_.max_global_tpot_ms())
9999
.server_idx(options_.server_idx())
100100
.prefetch_timeout(options_.prefetch_timeout());
101-
scheduler_ = create_continuous_scheduler(engine_.get(), scheduler_options);
101+
if (FLAGS_max_decode_rounds > 0) {
102+
// When using multi-round decode, use FixedStepsScheduler to run
103+
// fixed-step scheduling for both prefill and decode.
104+
scheduler_ = create_fixed_steps_scheduler(engine_.get(), scheduler_options);
105+
} else {
106+
scheduler_ = create_continuous_scheduler(engine_.get(), scheduler_options);
107+
}
102108

103109
if (options_.enable_service_routing()) {
104110
auto& instance_info = scheduler_->get_instance_info();
@@ -318,6 +324,9 @@ std::shared_ptr<Request> LLMMaster::generate_request(
318324
}
319325

320326
uint32_t max_tokens = sp.max_tokens;
327+
if (FLAGS_max_decode_rounds > 0) {
328+
max_tokens = FLAGS_max_decode_rounds;
329+
}
321330
if (max_tokens == 0) {
322331
const uint32_t kDefaultMaxTokens = 5120;
323332
max_tokens = kDefaultMaxTokens;
@@ -369,12 +378,12 @@ std::shared_ptr<Request> LLMMaster::generate_request(
369378
stop_sequences.push_back(std::move(tmp_tokens));
370379
}
371380
}
372-
381+
bool ignore_eos = FLAGS_max_decode_rounds > 0;
373382
StoppingChecker stopping_checker(
374383
max_tokens,
375384
max_context_len - options_.num_speculative_tokens(),
376385
model_args_.eos_token_id(),
377-
sp.ignore_eos,
386+
FLAGS_max_decode_rounds > 0 ? ignore_eos : sp.ignore_eos,
378387
std::move(stop_tokens),
379388
std::move(stop_sequences));
380389

xllm/core/distributed_runtime/master.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,13 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
5757
brpc::FLAGS_graceful_quit_on_sigterm = true;
5858
brpc::FLAGS_graceful_quit_on_sighup = true;
5959

60+
auto opt_block_size = options.block_size();
61+
if (FLAGS_max_decode_rounds > 0) {
62+
CHECK(FLAGS_beam_width > 0)
63+
<< "beam_width must be greater than 0 when max_decode_rounds > 0";
64+
options_.block_size(FLAGS_beam_width);
65+
opt_block_size = options_.block_size();
66+
}
6067
#if defined(USE_NPU)
6168
if (options.rank_tablefile().has_value()) {
6269
FLAGS_rank_tablefile = options.rank_tablefile().value();
@@ -112,7 +119,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
112119
eng_options.model_path(options_.model_path())
113120
.devices(devices)
114121
.backend(options.backend())
115-
.block_size(options.block_size())
122+
.block_size(opt_block_size)
116123
.max_cache_size(options.max_cache_size())
117124
.max_memory_utilization(options.max_memory_utilization())
118125
.enable_prefix_cache(options.enable_prefix_cache())

0 commit comments

Comments
 (0)