77
88#include < algorithm>
99#include < cctype>
10+ #include < climits>
1011#include < cmath>
1112#include < cstring>
1213#include < unordered_map>
@@ -109,6 +110,7 @@ struct common_sampler {
109110 common_params_sampling params;
110111
111112 struct llama_sampler * grmr;
113+ struct llama_sampler * rbudget;
112114 struct llama_sampler * chain;
113115
114116 ring_buffer<llama_token> prev;
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
188190 lparams.no_perf = params.no_perf ;
189191
190192 llama_sampler * grmr = nullptr ;
193+ llama_sampler * rbudget = nullptr ;
191194 llama_sampler * chain = llama_sampler_chain_init (lparams);
192195
193196 std::vector<llama_sampler *> samplers;
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
270273 }
271274 }
272275
273- if (grmr) {
276+ if (grmr && !params. grammar_lazy ) {
274277 try {
275278 for (const auto & token : prefill_tokens) {
276279 llama_sampler_accept (grmr, token);
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
284287 }
285288 }
286289
287- // reasoning budget sampler — added first so it can force tokens before other samplers
288- if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced .empty ()) {
289- samplers. push_back ( common_reasoning_budget_init (
290+ // reasoning budget sampler
291+ if (! params.reasoning_budget_start . empty () && !params.reasoning_budget_end .empty ()) {
292+ rbudget = common_reasoning_budget_init (
290293 vocab,
291294 params.reasoning_budget_start ,
292295 params.reasoning_budget_end ,
293296 params.reasoning_budget_forced ,
294- params.reasoning_budget_tokens ,
295- prefill_tokens)) ;
297+ params.reasoning_budget_tokens < 0 ? INT_MAX : params. reasoning_budget_tokens ,
298+ prefill_tokens);
296299 }
297300
298301 if (params.has_logit_bias ()) {
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
383386 auto * result = new common_sampler {
384387 /* .params = */ params,
385388 /* .grmr = */ grmr,
389+ /* .rbudget = */ rbudget,
386390 /* .chain = */ chain,
387391 /* .prev = */ ring_buffer<llama_token>(std::max (32 , params.n_prev )),
388392 /* .cur = */ {},
@@ -398,18 +402,39 @@ void common_sampler_free(struct common_sampler * gsmpl) {
398402 }
399403
400404 llama_sampler_free (gsmpl->grmr );
405+ llama_sampler_free (gsmpl->rbudget );
401406 llama_sampler_free (gsmpl->chain );
402407
403408 delete gsmpl;
404409}
405410
411+ static bool grammar_should_apply (struct common_sampler * gsmpl) {
412+ if (!gsmpl->grmr ) {
413+ return false ;
414+ }
415+ if (!gsmpl->rbudget ) {
416+ return true ;
417+ }
418+ if (gsmpl->params .grammar_lazy ) {
419+ // if grammar is lazy, only apply when reasoning budget is not active
420+ const auto state = common_reasoning_budget_get_state (gsmpl->rbudget );
421+ return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
422+ }
423+ return true ;
424+ }
425+
406426void common_sampler_accept (struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
407427 if (!gsmpl) {
408428 return ;
409429 }
410430
411431 const auto tm = gsmpl->tm ();
412432
433+ // grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
434+ accept_grammar = accept_grammar && grammar_should_apply (gsmpl);
435+
436+ llama_sampler_accept (gsmpl->rbudget , token);
437+
413438 if (gsmpl->grmr && accept_grammar) {
414439 llama_sampler_accept (gsmpl->grmr , token);
415440 }
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
431456 return new common_sampler {
432457 /* .params = */ gsmpl->params ,
433458 /* .grmr = */ llama_sampler_clone (gsmpl->grmr ),
459+ /* .rbudget = */ llama_sampler_clone (gsmpl->rbudget ),
434460 /* .chain = */ llama_sampler_clone (gsmpl->chain ),
435461 /* .prev = */ gsmpl->prev ,
436462 /* .cur = */ gsmpl->cur ,
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
500526 llama_token id = LLAMA_TOKEN_NULL;
501527
502528 auto & grmr = gsmpl->grmr ;
529+ auto & rbudget = gsmpl->rbudget ;
503530 auto & chain = gsmpl->chain ;
504531 auto & cur_p = gsmpl->cur_p ; // initialized by set_logits
505532
@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
511538 if (id != LLAMA_TOKEN_NULL) {
512539 LOG_DBG (" %s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n " , __func__, id);
513540
514- GGML_ASSERT (!gsmpl->grmr && " using grammar in combination with backend sampling is not supported" );
541+ GGML_ASSERT (!gsmpl->grmr && " using grammar in combination with backend sampling is not supported" );
542+ GGML_ASSERT (!gsmpl->rbudget && " using reasoning budget in combination with backend sampling is not supported" );
515543
516544 // TODO: simplify
517545 gsmpl->cur .resize (1 );
@@ -524,15 +552,18 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
524552
525553 gsmpl->set_logits (ctx, idx);
526554
527- if (grammar_first) {
555+ // apply reasoning budget first
556+ llama_sampler_apply (rbudget, &cur_p);
557+
558+ if (grammar_first && grammar_should_apply (gsmpl)) {
528559 llama_sampler_apply (grmr, &cur_p);
529560 }
530561
531562 llama_sampler_apply (chain, &cur_p);
532563
533564 id = cur_p.data [cur_p.selected ].id ;
534565
535- if (grammar_first) {
566+ if (grammar_first || ! grammar_should_apply (gsmpl) ) {
536567 return id;
537568 }
538569
@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
553584 // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
554585 gsmpl->set_logits (ctx, idx);
555586
556- llama_sampler_apply (grmr, &cur_p);
587+ llama_sampler_apply (rbudget, &cur_p);
588+
589+ if (grammar_should_apply (gsmpl)) {
590+ llama_sampler_apply (grmr, &cur_p);
591+ }
592+
557593 llama_sampler_apply (chain, &cur_p);
558594
559595 GGML_ASSERT (cur_p.selected != -1 && " no selected token during sampling - check your sampling configuration" );
0 commit comments