Skip to content

Commit aa55191

Browse files
[NPU][EAGLE3] Fix issue when Draft model generates EOS token (openvinotoolkit#3293)
<!-- Keep your pull requests (PRs) as atomic as possible. That increases the likelihood that an individual PR won't be stuck because of adjacent problems, merge conflicts, or code review. Your merged PR is going to appear in the automatically generated release notes on GitHub. So the clearer the title the better. --> ## Description **Background** Eagle3 uses speculative decoding where a draft model generates candidate tokens that are validated by a target model. During this process, one critical issue existed: Sequence State Management: When the draft model generates an `EOS` token, the `Sampler` immediately marks the sequence as `FINISHED`. This causes `get_running_sequences()` to return empty, leading to assertion failures in subsequent iterations if the target model rejects the `EOS`. **Solution** 1. Restore Running State After Draft `EOS` When draft model generates `EOS`, explicitly restore sequence to `RUNNING` state after sampling. 2. Early Termination on Draft `EOS` Break draft generation loop immediately upon detecting EOS: <!-- Jira ticket number (e.g., 123). Delete if there's no ticket. --> [CVS-180738](https://jira.devtools.intel.com/browse/CVS-180738) ## Checklist: - [x] This PR follows GenAI Contributing guidelines. <!-- Always follow https://github.com/openvinotoolkit/openvino.genai?tab=contributing-ov-file#contributing. If there are deviations, explain what and why. --> - [x] Tests have been updated or added to cover the new code. <!-- Specify exactly which tests were added or updated. If the change isn't maintenance related, update the tests at https://github.com/openvinotoolkit/openvino.genai/tree/master/tests or explain in the description why the tests don't need an update. --> - [x] This PR fully addresses the ticket. <!--- If not, explain clearly what is covered and what is not. If follow-up pull requests are needed, specify in the description. --> - [x] I have made corresponding changes to the documentation. <!-- Run github.com/\<username>/openvino.genai/actions/workflows/deploy_gh_pages.yml on your fork with your branch as a parameter to deploy a test version with the updated content. Replace this comment with the link to the built docs. If the documentation is updated in a separate PR, clearly specify it. -->
1 parent 67e8335 commit aa55191

File tree

3 files changed

+78
-38
lines changed

3 files changed

+78
-38
lines changed

src/cpp/src/speculative_decoding/stateful/eagle3_strategy.cpp

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,8 @@ std::vector<int64_t> Eagle3InferWrapperBase::sample_tokens(const ov::Tensor& log
195195
auto sequence_group = get_sequence_group();
196196
OPENVINO_ASSERT(sequence_group, "SequenceGroup not initialized");
197197

198-
OPENVINO_ASSERT(get_running_sequence_count() == 1,
199-
"Eagle3 currently only supports single sequence, got ",
200-
get_running_sequence_count(),
201-
" sequences");
202-
203198
auto current_seq = get_current_sequence();
204-
OPENVINO_ASSERT(current_seq, "No running sequence at index 0");
199+
OPENVINO_ASSERT(current_seq, "No sequence at index 0");
205200

206201
const size_t prev_generated_len = current_seq->get_generated_len();
207202
const size_t logits_seq_len = shape[1];
@@ -312,9 +307,9 @@ void Eagle3TargetWrapper::initialize_sequence(const ov::Tensor& input_ids, const
312307
TokenIds prompt_ids(ids_data, ids_data + seq_len);
313308
m_sequence_group = std::make_shared<SequenceGroup>(0, prompt_ids, config, 0);
314309

315-
OPENVINO_ASSERT(get_running_sequence_count() == 1,
310+
OPENVINO_ASSERT(m_sequence_group->num_total_seqs() == 1,
316311
"Expected single sequence after initialization, got ",
317-
get_running_sequence_count());
312+
m_sequence_group->num_total_seqs());
318313
}
319314

320315
InferenceOutput Eagle3TargetWrapper::infer(const ov::Tensor& input_ids,
@@ -378,9 +373,9 @@ void Eagle3DraftWrapper::initialize_sequence(const ov::Tensor& input_ids, const
378373
TokenIds draft_prompt_ids(ids_data + 1, ids_data + total_len);
379374
m_sequence_group = std::make_shared<SequenceGroup>(1, draft_prompt_ids, config, 0);
380375

381-
OPENVINO_ASSERT(get_running_sequence_count() == 1,
376+
OPENVINO_ASSERT(m_sequence_group->num_total_seqs() == 1,
382377
"Expected single sequence after initialization, got ",
383-
get_running_sequence_count());
378+
m_sequence_group->num_total_seqs());
384379
}
385380

386381
InferenceOutput Eagle3DraftWrapper::infer(const ov::Tensor& input_ids,
@@ -578,15 +573,18 @@ EncodedResults StatefulEagle3LLMPipeline::generate_tokens(const EncodedInputs& i
578573
while (!eos_reached && generated_tokens < config.max_new_tokens &&
579574
m_target->get_sequence_length() < m_prompt_length + config.max_new_tokens &&
580575
streaming_status == ov::genai::StreamingStatus::RUNNING) {
581-
auto result = run_speculative_iteration(input_token_count, static_cast<int64_t>(config.eos_token_id));
576+
auto result = run_speculative_iteration(input_token_count,
577+
static_cast<int64_t>(config.eos_token_id),
578+
generated_tokens,
579+
config.max_new_tokens);
582580

583581
streaming_status = stream_generated_tokens(streamer_ptr, result.validated_tokens);
584582

585583
// Update statistics
586584
total_draft_generated += m_draft_iterations;
587585
total_draft_accepted += result.accepted_tokens_count;
586+
generated_tokens += result.validated_tokens.size();
588587
eos_reached = result.eos_reached;
589-
generated_tokens++;
590588

591589
// Prepare for next iteration (hidden states are stored in sequence)
592590
input_token_count = result.next_window_size;
@@ -639,11 +637,13 @@ EncodedResults StatefulEagle3LLMPipeline::generate_tokens(const EncodedInputs& i
639637

640638
StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_speculative_iteration(
641639
size_t input_token_count,
642-
int64_t eos_token_id) {
640+
int64_t eos_token_id,
641+
size_t current_generated_tokens,
642+
size_t max_new_tokens) {
643643
SpeculativeResult result;
644644

645-
OPENVINO_ASSERT(m_target->get_running_sequence_count() == 1 && m_draft->get_running_sequence_count() == 1,
646-
"Eagle3 speculative iteration requires single sequence per model");
645+
OPENVINO_ASSERT(m_target->get_sequence_group() && m_draft->get_sequence_group(),
646+
"Eagle3 speculative iteration requires initialized sequence groups");
647647

648648
auto target_hidden_states = m_target->get_current_sequence()->get_hidden_state();
649649
OPENVINO_ASSERT(target_hidden_states && target_hidden_states.get_size() > 0,
@@ -670,8 +670,20 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
670670
// Append first token to target model (draft model already has it from sampler)
671671
m_target->append_tokens({first_draft_token});
672672

673+
// Check if first draft token is EOS - if so, no need to generate more draft tokens
674+
bool draft_eos_reached = (first_draft_token == eos_token_id);
675+
676+
// IMPORTANT: If draft generated EOS, sampler will mark the sequence as FINISHED.
677+
// However, we need to keep the draft sequence in RUNNING state because:
678+
// 1. Target model may reject this EOS during validation
679+
// 2. Next iteration needs draft sequence to be accessible via get_running_sequences()
680+
// Only target model's EOS decision should truly end the generation.
681+
if (draft_eos_reached) {
682+
m_draft->get_current_sequence()->set_status(SequenceStatus::RUNNING);
683+
}
684+
673685
// Step 2: Generate additional draft tokens using internal hidden states
674-
for (size_t i = 1; i < m_draft_iterations; ++i) {
686+
for (size_t i = 1; i < m_draft_iterations && !draft_eos_reached; ++i) {
675687
InferContext more_ctx;
676688
more_ctx.input_token_count = 1;
677689
more_ctx.use_target_hidden = false;
@@ -685,16 +697,24 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
685697
// During validation, target model will retrieve tokens from its own sequence
686698
// so we need to speculatively add draft predictions here
687699
m_target->append_tokens({draft_token});
700+
701+
if (draft_token == eos_token_id) {
702+
draft_eos_reached = true;
703+
// Keep draft sequence in RUNNING state (same reason as above)
704+
m_draft->get_current_sequence()->set_status(SequenceStatus::RUNNING);
705+
}
688706
}
689707

690708
// Step 3: Validate draft tokens with target model
691709

692-
const size_t validation_window_size = m_draft_iterations + 1;
710+
// Validation window is based on actual draft tokens generated (may be less than m_draft_iterations if EOS hit)
711+
const size_t actual_draft_tokens = draft_candidates.size();
712+
const size_t validation_window_size = actual_draft_tokens + 1;
693713

694714
InferContext val_ctx;
695715
val_ctx.input_token_count = validation_window_size;
696716
val_ctx.sample_count = validation_window_size;
697-
val_ctx.num_tokens_to_validate = m_draft_iterations;
717+
val_ctx.num_tokens_to_validate = actual_draft_tokens;
698718
auto val_result = m_target->forward(val_ctx);
699719

700720
// Sampler validates draft tokens and returns accepted + new sampled token
@@ -703,8 +723,34 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
703723
// Result: [accepted_draft_tokens..., new_sampled_token]
704724
const size_t accepted_count = validated_tokens.size() - 1;
705725
const int64_t target_predicted_token = validated_tokens.back();
706-
const size_t tokens_to_remove = m_draft_iterations - accepted_count;
707-
const size_t total_accepted_tokens = validated_tokens.size();
726+
size_t tokens_to_remove = actual_draft_tokens - accepted_count;
727+
size_t total_accepted_tokens = validated_tokens.size();
728+
729+
// Check if accepting all validated tokens would exceed max_new_tokens
730+
size_t tokens_after_accept = current_generated_tokens + validated_tokens.size();
731+
if (tokens_after_accept > max_new_tokens) {
732+
// Truncate to exactly max_new_tokens
733+
size_t excess_tokens = tokens_after_accept - max_new_tokens;
734+
OPENVINO_ASSERT(excess_tokens < validated_tokens.size(),
735+
"excess_tokens (",
736+
excess_tokens,
737+
") must be less than validated_tokens.size() (",
738+
validated_tokens.size(),
739+
")");
740+
size_t tokens_to_keep = validated_tokens.size() - excess_tokens;
741+
742+
validated_tokens.resize(tokens_to_keep);
743+
total_accepted_tokens = tokens_to_keep;
744+
745+
m_target->truncate_sequence(m_prompt_length + max_new_tokens);
746+
747+
// Adjust metrics to reflect actual tokens kept after truncation
748+
auto& target_batch_sizes = m_target->get_raw_perf_metrics().m_batch_sizes;
749+
OPENVINO_ASSERT(!target_batch_sizes.empty(), "batch_sizes should have been recorded by sampler");
750+
target_batch_sizes.back() = tokens_to_keep;
751+
752+
tokens_to_remove = actual_draft_tokens - (tokens_to_keep - 1); // -1 for the new target token
753+
}
708754

709755
// Step 4: Synchronize sequences and KV cache
710756
// Target model's sequence is already updated by Sampler
@@ -732,8 +778,8 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
732778
auto next_hidden = ov::Tensor(current_hidden, start_coord, end_coord);
733779
m_target->get_current_sequence()->update_hidden_state(next_hidden);
734780

735-
result.accepted_tokens_count = accepted_count;
736-
result.next_window_size = accepted_count + 1;
781+
result.accepted_tokens_count = total_accepted_tokens - 1;
782+
result.next_window_size = total_accepted_tokens;
737783
result.validated_tokens = std::move(validated_tokens);
738784
result.eos_reached = (target_predicted_token == eos_token_id);
739785

src/cpp/src/speculative_decoding/stateful/eagle3_strategy.hpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,12 @@ class Eagle3InferWrapperBase {
8686
return m_sequence_group;
8787
}
8888

89-
void set_sequence_group(SequenceGroup::Ptr sequence_group) {
90-
m_sequence_group = sequence_group;
91-
if (m_sequence_group) {
92-
OPENVINO_ASSERT(get_running_sequence_count() == 1,
93-
"Eagle3 only supports single sequence, got ",
94-
get_running_sequence_count());
95-
}
96-
}
97-
98-
/// @brief Returns number of running sequences in the group
99-
size_t get_running_sequence_count() const {
100-
return m_sequence_group ? m_sequence_group->get_running_sequences().size() : 0;
101-
}
102-
10389
/// @brief Returns sequence at given index with bounds checking
10490
/// @param index Sequence index (0 for top-1)
10591
/// @return Sequence pointer or nullptr if index out of bounds
10692
Sequence::Ptr get_sequence(size_t index) const {
10793
if (m_sequence_group) {
108-
auto sequences = m_sequence_group->get_running_sequences();
94+
const auto& sequences = m_sequence_group->get_sequences();
10995
if (index < sequences.size()) {
11096
return sequences[index];
11197
}
@@ -236,7 +222,10 @@ class StatefulEagle3LLMPipeline : public StatefulSpeculativePipelineBase {
236222
std::vector<int64_t> validated_tokens;
237223
};
238224

239-
SpeculativeResult run_speculative_iteration(size_t token_count, int64_t eos_token_id);
225+
SpeculativeResult run_speculative_iteration(size_t token_count,
226+
int64_t eos_token_id,
227+
size_t current_generated_tokens,
228+
size_t max_new_tokens);
240229

241230
std::unique_ptr<Eagle3DraftWrapper> m_draft;
242231
std::unique_ptr<Eagle3TargetWrapper> m_target;

tests/python_tests/test_stateful_speculative_decoding.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,11 @@ def add(a, b):
211211
212212
Question: Can you please add 2 and 3
213213
A:""",
214+
),
215+
(
216+
"Qwen/Qwen3-1.7B",
217+
"AngelSlim/Qwen3-1.7B_eagle3",
218+
"What is the capital of Ireland?/no_think",
214219
)
215220
]
216221

0 commit comments

Comments
 (0)