Skip to content

Commit a182490

Browse files
authored
spec: add backend sampling support for eagle3 (#24655)
1 parent 32120c1 commit a182490

1 file changed

Lines changed: 32 additions & 1 deletion

File tree

common/speculative.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,9 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
418418

419419
std::vector<common_sampler_ptr> smpls;
420420

421+
// backend sampler chain per seq, attached to ctx_dft
422+
std::vector<llama_sampler *> backend_chains;
423+
421424
int32_t n_embd_dec = 0; // draft hidden size
422425
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
423426
int32_t n_embd_tgt = 0; // target model hidden size
@@ -443,7 +446,7 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
443446
, params(params.draft)
444447
{
445448
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
446-
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min);
449+
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
447450

448451
auto * ctx_tgt = this->params.ctx_tgt;
449452
auto * ctx_dft = this->params.ctx_dft;
@@ -478,6 +481,22 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
478481
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
479482
}
480483

484+
// offload draft sampling to the backend
485+
backend_chains.assign(n_seq, nullptr);
486+
if (this->params.backend_sampling) {
487+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
488+
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
489+
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
490+
491+
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
492+
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
493+
llama_sampler_free(chain);
494+
chain = nullptr;
495+
}
496+
backend_chains[seq_id] = chain;
497+
}
498+
}
499+
481500
// turn on extraction of the target layers' input embeddings
482501
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
483502
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
@@ -496,6 +515,18 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
496515
}
497516

498517
~common_speculative_impl_draft_eagle3() override {
518+
auto * ctx_dft = this->params.ctx_dft;
519+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
520+
if (backend_chains[seq_id] == nullptr) {
521+
continue;
522+
}
523+
if (ctx_dft) {
524+
llama_set_sampler(ctx_dft, seq_id, nullptr);
525+
}
526+
llama_sampler_free(backend_chains[seq_id]);
527+
}
528+
backend_chains.clear();
529+
499530
if (batch.token != nullptr) {
500531
free(batch.token);
501532
batch.token = nullptr;

0 commit comments

Comments
 (0)