@@ -263,6 +263,9 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
263263 int index_n_head = 1 ;
264264 index_slot_size = dtype_size * index_n_head * args_.index_head_dim ();
265265 }
266+ if (FLAGS_max_decode_rounds > 0 ) {
267+ slot_size *= FLAGS_max_decode_rounds;
268+ }
266269 kv_cache_cap.slot_size = slot_size;
267270 kv_cache_cap.index_slot_size = index_slot_size;
268271 kv_cache_cap.n_layers = args_.n_layers ();
@@ -309,10 +312,23 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
309312 kv_cache_shape.emplace_back (std::vector<int64_t >{
310313 kv_cache_cap.n_blocks , block_size, 1 , args_.qk_rope_head_dim ()});
311314 } else {
312- kv_cache_shape.emplace_back (std::vector<int64_t >{
313- kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
314- kv_cache_shape.emplace_back (std::vector<int64_t >{
315- kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
315+ if (FLAGS_max_decode_rounds > 0 ) {
316+ kv_cache_shape.emplace_back (std::vector<int64_t >{kv_cache_cap.n_blocks ,
317+ block_size,
318+ n_local_kv_heads_,
319+ FLAGS_max_decode_rounds,
320+ head_dim_});
321+ kv_cache_shape.emplace_back (std::vector<int64_t >{kv_cache_cap.n_blocks ,
322+ block_size,
323+ n_local_kv_heads_,
324+ FLAGS_max_decode_rounds,
325+ head_dim_});
326+ } else {
327+ kv_cache_shape.emplace_back (std::vector<int64_t >{
328+ kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
329+ kv_cache_shape.emplace_back (std::vector<int64_t >{
330+ kv_cache_cap.n_blocks , block_size, n_local_kv_heads_, head_dim_});
331+ }
316332 }
317333 if (enable_lighting_indexer) {
318334 kv_cache_shape.emplace_back (std::vector<int64_t >{
@@ -340,10 +356,17 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
340356
341357 // initialize block manager
342358 BlockManagerPool::Options options;
359+ // simplify when use max_decode round.
360+ bool enable_prefix_cache =
361+ options_.enable_prefix_cache () && FLAGS_max_decode_rounds == 0 ;
362+ bool enable_kvcache_store =
363+ options_.enable_kvcache_store () && FLAGS_max_decode_rounds == 0 ;
364+ auto host_blocks_factor =
365+ FLAGS_max_decode_rounds == 0 ? options_.host_blocks_factor () : 0.0 ;
343366 options.num_blocks (kv_cache_cap.n_blocks )
344367 .block_size (block_size)
345- .host_num_blocks (kv_cache_cap.n_blocks * options_. host_blocks_factor () )
346- .enable_prefix_cache (options_. enable_prefix_cache () )
368+ .host_num_blocks (kv_cache_cap.n_blocks * host_blocks_factor)
369+ .enable_prefix_cache (enable_prefix_cache)
347370 .enable_disagg_pd (options_.enable_disagg_pd ())
348371 .enable_cache_upload (options_.enable_cache_upload ())
349372 .enable_kvcache_store (options_.enable_kvcache_store ());
@@ -694,11 +717,59 @@ bool LLMEngine::unlink_cluster(const std::vector<uint64_t>& cluster_ids,
694717 return true ;
695718}
696719
720+ ForwardOutput LLMEngine::step_multi_round (std::vector<Batch>& batch) {
721+ Timer timer;
722+ DCHECK (dp_size_ == batch.size ())
723+ << " Split DP batch failed with dp_size as " << dp_size_
724+ << " and actual batch size as " << batch.size () << " ." ;
725+ auto batched_raw_forward_inputs = prepare_inputs (batch);
726+ DCHECK (dp_size_ == batched_raw_forward_inputs.size ())
727+ << " The processed raw forward inputs size "
728+ << batched_raw_forward_inputs.size () << " is not equal to dp size "
729+ << dp_size_ << " ." ;
730+ std::vector<folly::SemiFuture<std::optional<RawForwardOutput>>> futures;
731+ futures.reserve (worker_clients_num_);
732+ for (auto worker_rank = 0 ; worker_rank < worker_clients_num_; ++worker_rank) {
733+ auto dp_rank = worker_rank / dp_local_tp_size_;
734+ futures.emplace_back (worker_clients_[worker_rank]->step_async (
735+ batched_raw_forward_inputs[dp_rank]));
736+ }
737+ auto results = folly::collectAll (futures).get ();
738+ size_t dp_rank = 0 ;
739+ for (auto worker_rank = 0 ; worker_rank < worker_clients_num_;
740+ worker_rank += dp_local_tp_size_) {
741+ auto result = results[worker_rank].value ();
742+ if (result.has_value ()) {
743+ if (result.value ().outputs .empty () && layer_forward_interrupted_) {
744+ throw ForwardInterruptedException ();
745+ }
746+ auto & raw = result.value ();
747+ if (!raw.beam_sequence_group .empty ()) {
748+ batch[dp_rank].process_beam_sequence_group (raw);
749+ } else {
750+ batch[dp_rank].process_decode_beam_search_output (raw, false );
751+ }
752+ } else {
753+ LOG (FATAL) << " Failed to execute model, result has no value" ;
754+ }
755+ ++dp_rank;
756+ }
757+ COUNTER_ADD (engine_latency_seconds, timer.elapsed_seconds ());
758+ // finish all sequences in the batch
759+ for (auto & b : batch) {
760+ b.finish ();
761+ }
762+ return {};
763+ }
764+
697765ForwardOutput LLMEngine::step (std::vector<Batch>& batch) {
698766 if (worker_clients_.empty ()) {
699767 // empty worker, return
700768 return {};
701769 }
770+ if (FLAGS_max_decode_rounds > 0 ) {
771+ return step_multi_round (batch);
772+ }
702773 Timer timer;
703774 DCHECK (dp_size_ == batch.size ())
704775 << " Split DP batch failed with dp_size as " << dp_size_
@@ -735,9 +806,14 @@ ForwardOutput LLMEngine::step(std::vector<Batch>& batch) {
735806 if (result.value ().outputs .empty () && layer_forward_interrupted_) {
736807 throw ForwardInterruptedException ();
737808 }
738- // if src_seq_idxes is not empty, skip sample output processing and
739- // process beam search output instead
740- if (result.value ().src_seq_idxes .size () == 0 ) {
809+ // If both src_seq_idxes and out_tokens are populated, this step used
810+ // the beam search kernel; otherwise, fall back to normal sample output
811+ // processing. Note that proto serialization always fills src_seq_idxes
812+ // with a fallback [0..num_seqs) when it is undefined, so we must also
813+ // check out_tokens to distinguish real beam-kernel outputs.
814+ if (result.value ().src_seq_idxes .empty () ||
815+ result.value ().out_tokens .empty () ||
816+ !FLAGS_enable_beam_search_kernel) {
741817 // set second input param enable_schedule_overlap to false,
742818 // if it's not enabled, process_sample_output will append the real
743819 // token, if it's enabled, this false here will append the fake token in
@@ -868,8 +944,14 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
868944
869945 // build model input for every single micro batch
870946 for (auto dp_rank = 0 ; dp_rank < dp_size_; ++dp_rank) {
871- batched_inputs.emplace_back (std::move (
872- batch[dp_rank].prepare_forward_input (args_, threadpool_.get ())));
947+ if (FLAGS_max_decode_rounds > 0 ) {
948+ batched_inputs.emplace_back (
949+ std::move (batch[dp_rank].prepare_multi_step_forward_input (
950+ args_, threadpool_.get ())));
951+ } else {
952+ batched_inputs.emplace_back (std::move (
953+ batch[dp_rank].prepare_forward_input (args_, threadpool_.get ())));
954+ }
873955 dp_global_token_nums[dp_rank] =
874956 batched_inputs[dp_rank].flatten_tokens_vec .size ();
875957 global_empty_kv_cache =
0 commit comments