Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions xllm/core/common/acl_graph_runtime.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://github.com/jd-opensource/xllm/blob/main/LICENSE
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*==============================================================================*/

#pragma once

namespace xllm {

namespace detail {
// Thread-local flag indicating whether current thread is inside ACL graph
// capture region.
inline thread_local bool g_in_acl_graph_capture = false;
} // namespace detail

// Return true if current thread is inside ACL graph capture.
inline bool in_acl_graph_capture() { return detail::g_in_acl_graph_capture; }

// RAII guard that marks current thread as inside ACL graph capture for its
// lifetime.
class AclGraphCaptureGuard {
public:
AclGraphCaptureGuard() { detail::g_in_acl_graph_capture = true; }
~AclGraphCaptureGuard() { detail::g_in_acl_graph_capture = false; }
};

} // namespace xllm
9 changes: 9 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,8 @@ DEFINE_int64(buffer_size_per_seq,
0,
"Buffer size per sequence in bytes, default 0.");

DEFINE_int64(max_token_per_req, 1024, "Max token per request, default 0.");

// --- beam search config ---

DEFINE_bool(enable_beam_search_kernel,
Expand Down Expand Up @@ -489,3 +491,10 @@ DEFINE_string(
"ATB",
"NPU kernel backend. Supported options: ATB, TORCH. Default is ATB.");
#endif
DEFINE_int32(beam_width, 1, "Beam width for beam search.");

// --- multi-step decode config ---
DEFINE_int32(max_decode_rounds,
0,
"Maximum number of decode rounds for multi-step decoding. 0 means "
"disabled.");
5 changes: 5 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,10 @@ DECLARE_int32(max_global_tpot_ms);

DECLARE_int32(max_requests_per_batch);

DECLARE_int32(max_decode_rounds);

DECLARE_int32(beam_width);

DECLARE_bool(enable_continuous_kvcache);

DECLARE_int64(phy_page_granularity_size);
Expand Down Expand Up @@ -240,3 +244,4 @@ DECLARE_bool(enable_constrained_decoding);
#if defined(USE_NPU)
DECLARE_string(npu_kernel_backend);
#endif
DECLARE_int64(max_token_per_req);
106 changes: 94 additions & 12 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
int index_n_head = 1;
index_slot_size = dtype_size * index_n_head * args_.index_head_dim();
}
if (FLAGS_max_decode_rounds > 0) {
slot_size *= FLAGS_max_decode_rounds;
}
kv_cache_cap.slot_size = slot_size;
kv_cache_cap.index_slot_size = index_slot_size;
kv_cache_cap.n_layers = args_.n_layers();
Expand Down Expand Up @@ -309,10 +312,23 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, 1, args_.qk_rope_head_dim()});
} else {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
if (FLAGS_max_decode_rounds > 0) {
kv_cache_shape.emplace_back(std::vector<int64_t>{kv_cache_cap.n_blocks,
block_size,
n_local_kv_heads_,
FLAGS_max_decode_rounds,
head_dim_});
kv_cache_shape.emplace_back(std::vector<int64_t>{kv_cache_cap.n_blocks,
block_size,
n_local_kv_heads_,
FLAGS_max_decode_rounds,
head_dim_});
} else {
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
kv_cache_shape.emplace_back(std::vector<int64_t>{
kv_cache_cap.n_blocks, block_size, n_local_kv_heads_, head_dim_});
}
}
if (enable_lighting_indexer) {
kv_cache_shape.emplace_back(std::vector<int64_t>{
Expand Down Expand Up @@ -340,13 +356,20 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {

// initialize block manager
BlockManagerPool::Options options;
// simplify when use max_decode round.
bool enable_prefix_cache =
options_.enable_prefix_cache() && FLAGS_max_decode_rounds == 0;
bool enable_kvcache_store =
options_.enable_kvcache_store() && FLAGS_max_decode_rounds == 0;
auto host_blocks_factor =
FLAGS_max_decode_rounds == 0 ? options_.host_blocks_factor() : 0.0;
options.num_blocks(kv_cache_cap.n_blocks)
.block_size(block_size)
.host_num_blocks(kv_cache_cap.n_blocks * options_.host_blocks_factor())
.enable_prefix_cache(options_.enable_prefix_cache())
.host_num_blocks(kv_cache_cap.n_blocks * host_blocks_factor)
.enable_prefix_cache(enable_prefix_cache)
.enable_disagg_pd(options_.enable_disagg_pd())
.enable_cache_upload(options_.enable_cache_upload())
.enable_kvcache_store(options_.enable_kvcache_store());
.enable_kvcache_store(enable_kvcache_store);
if (options_.host_blocks_factor() > 1.0 || options_.enable_kvcache_store()) {
kv_cache_manager_ =
std::make_unique<HierarchyBlockManagerPool>(options, this, dp_size_);
Expand Down Expand Up @@ -694,11 +717,59 @@ bool LLMEngine::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
return true;
}

ForwardOutput LLMEngine::step_multi_round(std::vector<Batch>& batch) {
Timer timer;
CHECK(dp_size_ == batch.size())
<< "Split DP batch failed with dp_size as " << dp_size_
<< " and actual batch size as " << batch.size() << ".";
auto batched_raw_forward_inputs = prepare_inputs(batch);
CHECK(dp_size_ == batched_raw_forward_inputs.size())
<< "The processed raw forward inputs size "
<< batched_raw_forward_inputs.size() << " is not equal to dp size "
<< dp_size_ << ".";
std::vector<folly::SemiFuture<std::optional<RawForwardOutput>>> futures;
futures.reserve(worker_clients_num_);
for (auto worker_rank = 0; worker_rank < worker_clients_num_; ++worker_rank) {
auto dp_rank = worker_rank / dp_local_tp_size_;
futures.emplace_back(worker_clients_[worker_rank]->step_async(
batched_raw_forward_inputs[dp_rank]));
}
auto results = folly::collectAll(futures).get();
size_t dp_rank = 0;
for (auto worker_rank = 0; worker_rank < worker_clients_num_;
worker_rank += dp_local_tp_size_) {
auto result = results[worker_rank].value();
if (result.has_value()) {
if (result.value().outputs.empty() && layer_forward_interrupted_) {
throw ForwardInterruptedException();
}
auto& raw = result.value();
if (!raw.beam_sequence_group.empty()) {
batch[dp_rank].process_beam_sequence_group(raw);
} else {
batch[dp_rank].process_decode_beam_search_output(raw, false);
}
} else {
LOG(FATAL) << "Failed to execute model, result has no value";
}
++dp_rank;
}
COUNTER_ADD(engine_latency_seconds, timer.elapsed_seconds());
// finish all sequences in the batch
for (auto& b : batch) {
b.finish();
}
return {};
}

ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
if (worker_clients_.empty()) {
// empty worker, return
return {};
}
if (FLAGS_max_decode_rounds > 0) {
return step_multi_round(batch);
}
Timer timer;
DCHECK(dp_size_ == batch.size())
<< "Split DP batch failed with dp_size as " << dp_size_
Expand Down Expand Up @@ -735,9 +806,14 @@ ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
if (result.value().outputs.empty() && layer_forward_interrupted_) {
throw ForwardInterruptedException();
}
// if src_seq_idxes is not empty, skip sample output processing and
// process beam search output instead
if (result.value().src_seq_idxes.size() == 0) {
// If both src_seq_idxes and out_tokens are populated, this step used
// the beam search kernel; otherwise, fall back to normal sample output
// processing. Note that proto serialization always fills src_seq_idxes
// with a fallback [0..num_seqs) when it is undefined, so we must also
// check out_tokens to distinguish real beam-kernel outputs.
if (result.value().src_seq_idxes.empty() ||
result.value().out_tokens.empty() ||
!FLAGS_enable_beam_search_kernel) {
// set second input param enable_schedule_overlap to false,
// if it's not enabled, process_sample_output will append the real
// token, if it's enabled, this false here will append the fake token in
Expand Down Expand Up @@ -868,8 +944,14 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(

// build model input for every single micro batch
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
batched_inputs.emplace_back(std::move(
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
if (FLAGS_max_decode_rounds > 0) {
batched_inputs.emplace_back(
std::move(batch[dp_rank].prepare_multi_step_forward_input(
args_, threadpool_.get())));
} else {
batched_inputs.emplace_back(std::move(
batch[dp_rank].prepare_forward_input(args_, threadpool_.get())));
}
dp_global_token_nums[dp_rank] =
batched_inputs[dp_rank].flatten_tokens_vec.size();
global_empty_kv_cache =
Expand Down
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/llm_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class LLMEngine : public Engine {
virtual ~LLMEngine() = default;

ForwardOutput step(std::vector<Batch>& batch) override;
ForwardOutput step_multi_round(std::vector<Batch>& batch);

const runtime::Options& options() const { return options_; }

Expand Down
7 changes: 5 additions & 2 deletions xllm/core/distributed_runtime/llm_master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ std::shared_ptr<Request> LLMMaster::generate_request(
}

uint32_t max_tokens = sp.max_tokens;
if (FLAGS_max_decode_rounds > 0) {
max_tokens = FLAGS_max_decode_rounds;
}
if (max_tokens == 0) {
const uint32_t kDefaultMaxTokens = 5120;
max_tokens = kDefaultMaxTokens;
Expand Down Expand Up @@ -369,12 +372,12 @@ std::shared_ptr<Request> LLMMaster::generate_request(
stop_sequences.push_back(std::move(tmp_tokens));
}
}

bool ignore_eos = FLAGS_max_decode_rounds > 0 ? true : sp.ignore_eos;
StoppingChecker stopping_checker(
max_tokens,
max_context_len - options_.num_speculative_tokens(),
model_args_.eos_token_id(),
sp.ignore_eos,
ignore_eos,
std::move(stop_tokens),
std::move(stop_sequences));

Expand Down
9 changes: 8 additions & 1 deletion xllm/core/distributed_runtime/master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
brpc::FLAGS_graceful_quit_on_sigterm = true;
brpc::FLAGS_graceful_quit_on_sighup = true;

auto opt_block_size = options.block_size();
if (FLAGS_max_decode_rounds > 0) {
CHECK(FLAGS_beam_width > 0)
<< "beam_width must be greater than 0 when max_decode_rounds > 0";
options_.block_size(FLAGS_beam_width);
opt_block_size = options_.block_size();
}
#if defined(USE_NPU)
if (options.rank_tablefile().has_value()) {
FLAGS_rank_tablefile = options.rank_tablefile().value();
Expand Down Expand Up @@ -112,7 +119,7 @@ Master::Master(const Options& options, EngineType type) : options_(options) {
eng_options.model_path(options_.model_path())
.devices(devices)
.backend(options.backend())
.block_size(options.block_size())
.block_size(opt_block_size)
.max_cache_size(options.max_cache_size())
.max_memory_utilization(options.max_memory_utilization())
.enable_prefix_cache(options.enable_prefix_cache())
Expand Down
55 changes: 52 additions & 3 deletions xllm/core/distributed_runtime/worker_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ void WorkerService::step(ForwardInput& fwd_input,
int32_t& prepared_layer_id,
torch::Tensor& src_seq_idxes,
torch::Tensor& out_tokens,
torch::Tensor& out_logprobs) {
torch::Tensor& out_logprobs,
std::vector<int32_t>* beam_group_flat,
bool* has_beam_group) {
// execute model
auto future = worker_->step_async(fwd_input);

Expand Down Expand Up @@ -131,6 +133,10 @@ void WorkerService::step(ForwardInput& fwd_input,
true);
}
auto ret = stream_->synchronize();
if (beam_group_flat && has_beam_group) {
FillBeamGroup(
forward_outputs.value(), *beam_group_flat, *has_beam_group);
}
}
}
} else {
Expand Down Expand Up @@ -527,6 +533,23 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
return;
}

void WorkerService::FillBeamGroup(const ForwardOutput& out,
std::vector<int32_t>& beam_group_flat,
bool& has_beam_group) {
if (FLAGS_max_decode_rounds <= 0) {
return;
}
const auto& bsg =
safe_to(out.beam_sequence_group, torch::kCPU, /*pin_memory=*/true);
if (!bsg.defined()) {
return;
}
auto flat = bsg.flatten();
beam_group_flat.assign(flat.data_ptr<int32_t>(),
flat.data_ptr<int32_t>() + flat.numel());
has_beam_group = true;
}

void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
const proto::ForwardInput* pb_forward_input,
proto::ForwardOutput* pb_forward_output,
Expand Down Expand Up @@ -554,6 +577,9 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
torch::Tensor out_tokens;
torch::Tensor out_logprobs;

std::vector<int32_t> beam_group_flat;
bool has_beam_group = false;

step(forward_input,
next_tokens,
logprobs,
Expand All @@ -564,8 +590,10 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs);
// convert to proto output
out_logprobs,
&beam_group_flat,
&has_beam_group);

forward_output_to_proto(next_tokens,
logprobs,
top_tokens,
Expand All @@ -577,6 +605,12 @@ void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
out_tokens,
out_logprobs,
pb_forward_output);

if (has_beam_group && FLAGS_max_decode_rounds > 0) {
ADD_VECTOR_TO_PROTO(pb_forward_output->mutable_beam_sequence_group(),
beam_group_flat);
}

COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
});
}
Expand Down Expand Up @@ -642,6 +676,21 @@ void WorkerService::GetLastStepResult(
out_tokens,
out_logprobs,
pb_forward_output);
// append batch-level beam output
if (FLAGS_max_decode_rounds > 0) {
const auto& bsg =
safe_to(forward_outputs.value().beam_sequence_group,
torch::kCPU,
true);
if (bsg.defined()) {
auto flat = bsg.flatten();
std::vector<int32_t> flat_vec(
flat.data_ptr<int32_t>(),
flat.data_ptr<int32_t>() + flat.numel());
ADD_VECTOR_TO_PROTO(
pb_forward_output->mutable_beam_sequence_group(), flat_vec);
}
}
}
}
});
Expand Down
Loading
Loading