Skip to content

Commit 683c5ac

Browse files
authored
spec : disacard last drafted token with low prob (ggml-org#22506)
1 parent b1d5f5b commit 683c5ac

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

common/speculative.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ struct common_speculative_state_draft : public common_speculative_state {
467467

468468
prompt_dft.push_back(id_last);
469469

470-
LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
470+
//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str());
471471

472472
int ret = llama_decode(ctx_dft, batch);
473473
if (ret != 0 && ret != 1) {
@@ -495,14 +495,14 @@ struct common_speculative_state_draft : public common_speculative_state {
495495

496496
common_sampler_accept(smpl, id, true);
497497

498-
result.push_back(id);
499-
500-
if (sparams.n_max <= (int) result.size()) {
498+
// only collect very high-confidence draft tokens
499+
if (cur_p->data[0].p < sparams.p_min) {
501500
break;
502501
}
503502

504-
// only collect very high-confidence draft tokens
505-
if (cur_p->data[0].p < sparams.p_min) {
503+
result.push_back(id);
504+
505+
if (sparams.n_max <= (int) result.size()) {
506506
break;
507507
}
508508

tools/server/server-context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ struct server_slot {
354354

355355
// generate a new draft
356356
spec_draft = common_speculative_draft(spec.get(), params_spec, tokens, sampled);
357+
n_draft_total += spec_draft.size();
357358

358359
if (spec_draft.size() > (size_t) n_draft_max) {
359360
SLT_WRN(*this, "draft size %d exceeds max %d, truncating\n", (int) spec_draft.size(), n_draft_max);
@@ -3019,7 +3020,6 @@ struct server_context_impl {
30193020

30203021
// update how many tokens out of those tested were accepted
30213022
slot.n_draft_accepted += ids.size() - 1;
3022-
slot.n_draft_total += n_draft;
30233023

30243024
// add accepted tokens to the prompt
30253025
slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft);

0 commit comments

Comments
 (0)