@@ -161,6 +161,10 @@ struct common_speculative_impl {
161161
162162 virtual void accept (llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
163163
164+ // eagle3: deferred-boundary g_embd stash for checkpoints (default: none)
165+ virtual bool get_deferred_boundary (llama_seq_id /* seq_id*/ , std::vector<float > & /* g_out*/ ) const { return false ; }
166+ virtual void set_deferred_boundary (llama_seq_id /* seq_id*/ , llama_pos /* pos*/ , const std::vector<float > & /* g*/ ) {}
167+
164168 // true if this implementation requires the target context to extract post-norm embeddings
165169 virtual bool need_embd () const = 0;
166170
@@ -841,6 +845,35 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
841845 (size_t ) n_embd_dec * sizeof (float ));
842846 }
843847
848+ // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
849+ // their single-position checkpoints drop it on restore
850+ bool need_boundary_stash () const {
851+ const llama_model * model_tgt = llama_get_model (params.ctx_tgt );
852+ return llama_model_is_recurrent (model_tgt) || llama_model_is_hybrid (model_tgt);
853+ }
854+
855+ bool get_deferred_boundary (llama_seq_id seq_id, std::vector<float > & g_out) const override {
856+ if (!need_boundary_stash ()) {
857+ return false ;
858+ }
859+ if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0 ) {
860+ return false ;
861+ }
862+ g_out = pending_g_last[seq_id];
863+ return true ;
864+ }
865+
866+ void set_deferred_boundary (llama_seq_id seq_id, llama_pos pos, const std::vector<float > & g) override {
867+ if (!need_boundary_stash ()) {
868+ return ;
869+ }
870+ if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || (int32_t ) g.size () != n_embd_dec) {
871+ return ;
872+ }
873+ pending_pos_last[seq_id] = pos;
874+ pending_g_last[seq_id] = g;
875+ }
876+
844877 bool need_embd () const override {
845878 return false ;
846879 }
@@ -2118,6 +2151,30 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u
21182151 }
21192152}
21202153
2154+ bool common_speculative_get_deferred_boundary (common_speculative * spec, llama_seq_id seq_id, std::vector<float > & g_out) {
2155+ if (spec == nullptr ) {
2156+ return false ;
2157+ }
2158+
2159+ for (auto & impl : spec->impls ) {
2160+ if (impl->get_deferred_boundary (seq_id, g_out)) {
2161+ return true ;
2162+ }
2163+ }
2164+
2165+ return false ;
2166+ }
2167+
2168+ void common_speculative_set_deferred_boundary (common_speculative * spec, llama_seq_id seq_id, llama_pos pos, const std::vector<float > & g) {
2169+ if (spec == nullptr ) {
2170+ return ;
2171+ }
2172+
2173+ for (auto & impl : spec->impls ) {
2174+ impl->set_deferred_boundary (seq_id, pos, g);
2175+ }
2176+ }
2177+
21212178void common_speculative_print_stats (const common_speculative * spec) {
21222179 if (spec == nullptr ) {
21232180 return ;
0 commit comments