Skip to content

Commit fd50e23

Browse files
committed
eagle3: Add deferred boundary checkpoints restore support for hybrid models
1 parent e7177ed commit fd50e23

5 files changed

Lines changed: 72 additions & 13 deletions

File tree

common/common.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2034,7 +2034,7 @@ bool common_prompt_batch_decode(
20342034
}
20352035

20362036
size_t common_prompt_checkpoint::size() const {
2037-
return data_tgt.size() + data_dft.size();
2037+
return data_tgt.size() + data_dft.size() + data_dft_boundary_g_embd.size() * sizeof(float);
20382038
}
20392039

20402040
bool common_prompt_checkpoint::empty() const {
@@ -2049,6 +2049,7 @@ void common_prompt_checkpoint::clear() {
20492049

20502050
data_tgt.clear();
20512051
data_dft.clear();
2052+
data_dft_boundary_g_embd.clear();
20522053
}
20532054

20542055
void common_prompt_checkpoint::update_pos(
@@ -2138,4 +2139,5 @@ void common_prompt_checkpoint::clear_tgt() {
21382139

21392140
void common_prompt_checkpoint::clear_dft() {
21402141
data_dft.clear();
2142+
data_dft_boundary_g_embd.clear();
21412143
}

common/common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ struct common_params_speculative {
363363

364364
uint32_t need_n_rs_seq() const {
365365
bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) {
366-
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP;
366+
return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3;
367367
});
368368

369369
return needs_rs_seq ? draft.n_max : 0u;
@@ -1064,6 +1064,9 @@ struct common_prompt_checkpoint {
10641064
std::vector<uint8_t> data_tgt;
10651065
std::vector<uint8_t> data_dft;
10661066

1067+
// eagle3: deferred-boundary g_embd row stashed with the checkpoint
1068+
std::vector<float> data_dft_boundary_g_embd;
1069+
10671070
size_t size() const;
10681071

10691072
bool empty() const;

common/speculative.cpp

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ struct common_speculative_impl {
161161

162162
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
163163

164+
// eagle3: deferred-boundary g_embd stash for checkpoints (default: none)
165+
virtual bool get_deferred_boundary(llama_seq_id /*seq_id*/, std::vector<float> & /*g_out*/) const { return false; }
166+
virtual void set_deferred_boundary(llama_seq_id /*seq_id*/, llama_pos /*pos*/, const std::vector<float> & /*g*/) {}
167+
164168
// true if this implementation requires the target context to extract post-norm embeddings
165169
virtual bool need_embd() const = 0;
166170

@@ -841,6 +845,35 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
841845
(size_t) n_embd_dec * sizeof(float));
842846
}
843847

848+
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
849+
// their single-position checkpoints drop it on restore
850+
bool need_boundary_stash() const {
851+
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
852+
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
853+
}
854+
855+
bool get_deferred_boundary(llama_seq_id seq_id, std::vector<float> & g_out) const override {
856+
if (!need_boundary_stash()) {
857+
return false;
858+
}
859+
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
860+
return false;
861+
}
862+
g_out = pending_g_last[seq_id];
863+
return true;
864+
}
865+
866+
void set_deferred_boundary(llama_seq_id seq_id, llama_pos pos, const std::vector<float> & g) override {
867+
if (!need_boundary_stash()) {
868+
return;
869+
}
870+
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || (int32_t) g.size() != n_embd_dec) {
871+
return;
872+
}
873+
pending_pos_last[seq_id] = pos;
874+
pending_g_last[seq_id] = g;
875+
}
876+
844877
bool need_embd() const override {
845878
return false;
846879
}
@@ -2118,6 +2151,30 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
21182151
}
21192152
}
21202153

2154+
bool common_speculative_get_deferred_boundary(common_speculative * spec, llama_seq_id seq_id, std::vector<float> & g_out) {
2155+
if (spec == nullptr) {
2156+
return false;
2157+
}
2158+
2159+
for (auto & impl : spec->impls) {
2160+
if (impl->get_deferred_boundary(seq_id, g_out)) {
2161+
return true;
2162+
}
2163+
}
2164+
2165+
return false;
2166+
}
2167+
2168+
void common_speculative_set_deferred_boundary(common_speculative * spec, llama_seq_id seq_id, llama_pos pos, const std::vector<float> & g) {
2169+
if (spec == nullptr) {
2170+
return;
2171+
}
2172+
2173+
for (auto & impl : spec->impls) {
2174+
impl->set_deferred_boundary(seq_id, pos, g);
2175+
}
2176+
}
2177+
21212178
void common_speculative_print_stats(const common_speculative * spec) {
21222179
if (spec == nullptr) {
21232180
return;

common/speculative.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec);
6868
// informs the speculative context that n_accepted tokens were accepted by the target model
6969
void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted);
7070

71+
// eagle3: deferred-boundary g_embd stash for checkpoints (no-op for other draft types)
72+
bool common_speculative_get_deferred_boundary(common_speculative * spec, llama_seq_id seq_id, std::vector<float> & g_out);
73+
void common_speculative_set_deferred_boundary(common_speculative * spec, llama_seq_id seq_id, llama_pos boundary_pos, const std::vector<float> & g);
74+
7175
// print statistics about the speculative decoding
7276
void common_speculative_print_stats(const common_speculative * spec);
7377

tools/server/server-context.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2154,6 +2154,8 @@ struct server_context_impl {
21542154

21552155
cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
21562156
cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2157+
// stash the draft's deferred boundary with the checkpoint (only eagle3 needs it; no-op otherwise)
2158+
common_speculative_get_deferred_boundary(spec.get(), slot.id, cur.data_dft_boundary_g_embd);
21572159

21582160
SLT_INF(slot,
21592161
"created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n",
@@ -2974,21 +2976,12 @@ struct server_context_impl {
29742976

29752977
bool do_reset = it == slot.prompt.checkpoints.rend();
29762978

2977-
// eagle3 draft is one position behind the target due to deferred boundary), so it
2978-
// can't resume from a checkpoint restored on a recurrent/hybrid target; re-process fully instead.
2979-
const bool spec_eagle3 = std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
2980-
COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3) != params_base.speculative.types.end();
2981-
if (!do_reset && spec_eagle3 &&
2982-
(ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL ||
2983-
ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS)) {
2984-
SLT_WRN(slot, "%s", "eagle3 draft cannot resume from a recurrent/hybrid checkpoint, forcing full re-processing\n");
2985-
do_reset = true;
2986-
}
2987-
29882979
if (!do_reset) {
29892980
// restore the context checkpoint
29902981
it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
29912982
it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY);
2983+
// restore the draft's deferred boundary (only eagle3 needs it; no-op otherwise)
2984+
common_speculative_set_deferred_boundary(spec.get(), slot.id, it->pos_max, it->data_dft_boundary_g_embd);
29922985

29932986
pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max));
29942987
n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens);

0 commit comments

Comments
 (0)