Skip to content

Commit 59d8402

Browse files
aldehirpwilkin
andauthored
common : inhibit lazy grammar sampler while reasoning is active (ggml-org#20970)
* common : inhibit grammar while reasoning budget is active * cont : update force_pos in accept * cont : fix tests * cont : tweak should apply logic * cont : return early not using grammar sampler * Add tests * cont : prevent backend sampling when reasoning budget enabled * cont : fix typo --------- Co-authored-by: Piotr Wilkin <piotr.wilkin@syndatis.com>
1 parent ff934e2 commit 59d8402

8 files changed

Lines changed: 287 additions & 98 deletions

File tree

common/reasoning-budget.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,11 @@ static void common_reasoning_budget_accept(struct llama_sampler * smpl, llama_to
115115
break;
116116
}
117117
case REASONING_BUDGET_FORCING:
118-
// force_pos is advanced in apply(), not here.
119-
// This ensures the first forced token isn't skipped when the sampler
120-
// is initialized directly in FORCING state (e.g. COUNTING + budget=0)
118+
ctx->force_pos++;
119+
if (ctx->force_pos >= ctx->forced_tokens.size()) {
120+
ctx->state = REASONING_BUDGET_DONE;
121+
LOG_INF("reasoning-budget: forced sequence complete, done\n");
122+
}
121123
break;
122124
case REASONING_BUDGET_DONE:
123125
break;
@@ -144,14 +146,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
144146
cur_p->data[i].logit = -INFINITY;
145147
}
146148
}
147-
148-
// advance to next forced token (done here rather than in accept so that
149-
// the first forced token isn't skipped when starting in FORCING state)
150-
ctx->force_pos++;
151-
if (ctx->force_pos >= ctx->forced_tokens.size()) {
152-
ctx->state = REASONING_BUDGET_DONE;
153-
LOG_INF("reasoning-budget: forced sequence complete, done\n");
154-
}
155149
}
156150

157151
static void common_reasoning_budget_reset(struct llama_sampler * smpl) {
@@ -261,3 +255,10 @@ struct llama_sampler * common_reasoning_budget_init(
261255
common_reasoning_budget_state initial_state) {
262256
return common_reasoning_budget_init_state(vocab, start_tokens, end_tokens, forced_tokens, budget, initial_state);
263257
}
258+
259+
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl) {
260+
if (!smpl) {
261+
return REASONING_BUDGET_IDLE;
262+
}
263+
return ((const common_reasoning_budget_ctx *)smpl->ctx)->state;
264+
}

common/reasoning-budget.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@ struct llama_sampler * common_reasoning_budget_init(
5151
const std::vector<llama_token> & forced_tokens,
5252
int32_t budget,
5353
common_reasoning_budget_state initial_state);
54+
55+
common_reasoning_budget_state common_reasoning_budget_get_state(const struct llama_sampler * smpl);

common/sampling.cpp

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#include <algorithm>
99
#include <cctype>
10+
#include <climits>
1011
#include <cmath>
1112
#include <cstring>
1213
#include <unordered_map>
@@ -109,6 +110,7 @@ struct common_sampler {
109110
common_params_sampling params;
110111

111112
struct llama_sampler * grmr;
113+
struct llama_sampler * rbudget;
112114
struct llama_sampler * chain;
113115

114116
ring_buffer<llama_token> prev;
@@ -188,6 +190,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
188190
lparams.no_perf = params.no_perf;
189191

190192
llama_sampler * grmr = nullptr;
193+
llama_sampler * rbudget = nullptr;
191194
llama_sampler * chain = llama_sampler_chain_init(lparams);
192195

193196
std::vector<llama_sampler *> samplers;
@@ -270,7 +273,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
270273
}
271274
}
272275

273-
if (grmr) {
276+
if (grmr && !params.grammar_lazy) {
274277
try {
275278
for (const auto & token : prefill_tokens) {
276279
llama_sampler_accept(grmr, token);
@@ -284,15 +287,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
284287
}
285288
}
286289

287-
// reasoning budget sampler — added first so it can force tokens before other samplers
288-
if (params.reasoning_budget_tokens >= 0 && !params.reasoning_budget_forced.empty()) {
289-
samplers.push_back(common_reasoning_budget_init(
290+
// reasoning budget sampler
291+
if (!params.reasoning_budget_start.empty() && !params.reasoning_budget_end.empty()) {
292+
rbudget = common_reasoning_budget_init(
290293
vocab,
291294
params.reasoning_budget_start,
292295
params.reasoning_budget_end,
293296
params.reasoning_budget_forced,
294-
params.reasoning_budget_tokens,
295-
prefill_tokens));
297+
params.reasoning_budget_tokens < 0 ? INT_MAX : params.reasoning_budget_tokens,
298+
prefill_tokens);
296299
}
297300

298301
if (params.has_logit_bias()) {
@@ -383,6 +386,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
383386
auto * result = new common_sampler {
384387
/* .params = */ params,
385388
/* .grmr = */ grmr,
389+
/* .rbudget = */ rbudget,
386390
/* .chain = */ chain,
387391
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
388392
/* .cur = */ {},
@@ -398,18 +402,39 @@ void common_sampler_free(struct common_sampler * gsmpl) {
398402
}
399403

400404
llama_sampler_free(gsmpl->grmr);
405+
llama_sampler_free(gsmpl->rbudget);
401406
llama_sampler_free(gsmpl->chain);
402407

403408
delete gsmpl;
404409
}
405410

411+
static bool grammar_should_apply(struct common_sampler * gsmpl) {
412+
if (!gsmpl->grmr) {
413+
return false;
414+
}
415+
if (!gsmpl->rbudget) {
416+
return true;
417+
}
418+
if (gsmpl->params.grammar_lazy) {
419+
// if grammar is lazy, only apply when reasoning budget is not active
420+
const auto state = common_reasoning_budget_get_state(gsmpl->rbudget);
421+
return state == REASONING_BUDGET_IDLE || state == REASONING_BUDGET_DONE;
422+
}
423+
return true;
424+
}
425+
406426
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
407427
if (!gsmpl) {
408428
return;
409429
}
410430

411431
const auto tm = gsmpl->tm();
412432

433+
// grammar_should_apply() checks the reasoning budget state, so calculate this before we accept
434+
accept_grammar = accept_grammar && grammar_should_apply(gsmpl);
435+
436+
llama_sampler_accept(gsmpl->rbudget, token);
437+
413438
if (gsmpl->grmr && accept_grammar) {
414439
llama_sampler_accept(gsmpl->grmr, token);
415440
}
@@ -431,6 +456,7 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
431456
return new common_sampler {
432457
/* .params = */ gsmpl->params,
433458
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
459+
/* .rbudget = */ llama_sampler_clone(gsmpl->rbudget),
434460
/* .chain = */ llama_sampler_clone(gsmpl->chain),
435461
/* .prev = */ gsmpl->prev,
436462
/* .cur = */ gsmpl->cur,
@@ -500,6 +526,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
500526
llama_token id = LLAMA_TOKEN_NULL;
501527

502528
auto & grmr = gsmpl->grmr;
529+
auto & rbudget = gsmpl->rbudget;
503530
auto & chain = gsmpl->chain;
504531
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
505532

@@ -511,7 +538,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
511538
if (id != LLAMA_TOKEN_NULL) {
512539
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
513540

514-
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
541+
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
542+
GGML_ASSERT(!gsmpl->rbudget && "using reasoning budget in combination with backend sampling is not supported");
515543

516544
// TODO: simplify
517545
gsmpl->cur.resize(1);
@@ -524,15 +552,18 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
524552

525553
gsmpl->set_logits(ctx, idx);
526554

527-
if (grammar_first) {
555+
// apply reasoning budget first
556+
llama_sampler_apply(rbudget, &cur_p);
557+
558+
if (grammar_first && grammar_should_apply(gsmpl)) {
528559
llama_sampler_apply(grmr, &cur_p);
529560
}
530561

531562
llama_sampler_apply(chain, &cur_p);
532563

533564
id = cur_p.data[cur_p.selected].id;
534565

535-
if (grammar_first) {
566+
if (grammar_first || !grammar_should_apply(gsmpl)) {
536567
return id;
537568
}
538569

@@ -553,7 +584,12 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
553584
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
554585
gsmpl->set_logits(ctx, idx);
555586

556-
llama_sampler_apply(grmr, &cur_p);
587+
llama_sampler_apply(rbudget, &cur_p);
588+
589+
if (grammar_should_apply(gsmpl)) {
590+
llama_sampler_apply(grmr, &cur_p);
591+
}
592+
557593
llama_sampler_apply(chain, &cur_p);
558594

559595
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");

0 commit comments

Comments
 (0)