@@ -418,6 +418,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
418418
419419 std::vector<common_sampler_ptr> smpls;
420420
421+ // backend sampler chain per seq, attached to ctx_dft
422+ std::vector<llama_sampler *> backend_chains;
423+
421424 int32_t n_embd_dec = 0 ; // draft hidden size
422425 int32_t n_embd_enc = 0 ; // target_layer_ids_n * target_hidden_size
423426 int32_t n_embd_tgt = 0 ; // target model hidden size
@@ -443,7 +446,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
443446 , params(params.draft)
444447 {
445448 LOG_INF (" %s: adding speculative implementation 'draft-eagle3'\n " , __func__);
446- LOG_INF (" %s: - n_max=%d, n_min=%d, p_min=%f\n " , __func__, params.draft .n_max , params.draft .n_min , params.draft .p_min );
449+ LOG_INF (" %s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d \n " , __func__, params.draft .n_max , params.draft .n_min , params.draft .p_min , ( int ) params. draft . backend_sampling );
447450
448451 auto * ctx_tgt = this ->params .ctx_tgt ;
449452 auto * ctx_dft = this ->params .ctx_dft ;
@@ -478,6 +481,22 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
478481 s.reset (common_sampler_init (llama_get_model (ctx_dft), sparams));
479482 }
480483
484+ // offload draft sampling to the backend
485+ backend_chains.assign (n_seq, nullptr );
486+ if (this ->params .backend_sampling ) {
487+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
488+ llama_sampler * chain = llama_sampler_chain_init (llama_sampler_chain_default_params ());
489+ llama_sampler_chain_add (chain, llama_sampler_init_top_k (10 ));
490+
491+ if (!llama_set_sampler (ctx_dft, seq_id, chain)) {
492+ LOG_WRN (" %s: backend offload failed for seq_id=%d; using CPU sampler\n " , __func__, (int ) seq_id);
493+ llama_sampler_free (chain);
494+ chain = nullptr ;
495+ }
496+ backend_chains[seq_id] = chain;
497+ }
498+ }
499+
481500 // turn on extraction of the target layers' input embeddings
482501 for (uint32_t k = 0 ; k < target_layer_ids_n; ++k) {
483502 llama_set_embeddings_layer_inp (ctx_tgt, (uint32_t ) target_layer_ids[k], true );
@@ -496,6 +515,18 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
496515 }
497516
498517 ~common_speculative_impl_draft_eagle3 () override {
518+ auto * ctx_dft = this ->params .ctx_dft ;
519+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) backend_chains.size (); ++seq_id) {
520+ if (backend_chains[seq_id] == nullptr ) {
521+ continue ;
522+ }
523+ if (ctx_dft) {
524+ llama_set_sampler (ctx_dft, seq_id, nullptr );
525+ }
526+ llama_sampler_free (backend_chains[seq_id]);
527+ }
528+ backend_chains.clear ();
529+
499530 if (batch.token != nullptr ) {
500531 free (batch.token );
501532 batch.token = nullptr ;
0 commit comments