Skip to content

Commit 1a36ef2

Browse files
committed
Merge branch 'upstream' into concedo_experimental
# Conflicts: # tests/test-backend-ops.cpp
2 parents 30c74d5 + 8abcc70 commit 1a36ef2

4 files changed

Lines changed: 27 additions & 36 deletions

File tree

common/ngram-map.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,15 @@ static std::string common_tokens_to_str(const llama_tokens & inp, size_t start,
4747
* @return Vector of draft tokens, empty if no matching pattern is found
4848
*/
4949
llama_tokens common_ngram_simple_draft(
50-
common_ngram_simple_state & state,
50+
const common_ngram_simple_config & config,
5151
const llama_tokens & tokens, llama_token sampled) {
5252

5353
// Simple implementation of self-speculative decoding without a draft model.
5454
//
5555
const size_t cur_len = tokens.size();
56-
// Only check every check_rate tokens to save compute
57-
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
58-
if (state.idx_last_check + state.config.check_rate > cur_len) {
59-
llama_tokens draft_tokens;
60-
return draft_tokens;
61-
}
6256

63-
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history
64-
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft
57+
const size_t n_draft_min = config.size_ngram; // size of n-gram to lookup in token history
58+
const size_t n_draft_max = config.size_mgram; // the m-gram following the found n-gram is used for draft
6559

6660
// vector for tokens we want to verify.
6761
// return empty vector if there is no match.
@@ -80,9 +74,6 @@ llama_tokens common_ngram_simple_draft(
8074
}
8175
pattern.push_back(sampled); // add the last token to the pattern
8276

83-
// We do a search in the token history.
84-
state.idx_last_check = cur_len;
85-
8677
size_t match_pos = 0; // we ignore position 0, position 0 == no match
8778
// search backwards, but skip the current match (we are currently there)
8879
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {

common/ngram-map.h

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,9 @@ struct common_ngram_simple_config {
2727
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
2828
};
2929

30-
// current state (and config) of n-gram simple.
31-
struct common_ngram_simple_state {
32-
common_ngram_simple_config config;
33-
34-
size_t idx_last_check = 0; // index of last check in context history (mutable)
35-
36-
common_ngram_simple_state(const common_ngram_simple_config & config)
37-
: config(config) {}
38-
};
39-
4030
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
41-
// state: the ngram simple state to search in.
42-
// inp: the tokens generated so far.
43-
// sampled: the token that was just sampled.
44-
// draft: vector to store the draft tokens, initially empty.
4531
llama_tokens common_ngram_simple_draft(
46-
common_ngram_simple_state & state,
32+
const common_ngram_simple_config & config,
4733
const llama_tokens & tokens, llama_token sampled);
4834

4935

common/speculative.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -463,12 +463,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
463463

464464
// state of self-speculation (simple implementation, not ngram-map)
465465
struct common_speculative_state_ngram_simple : public common_speculative_state {
466-
common_ngram_simple_state state;
466+
common_ngram_simple_config config;
467+
468+
uint16_t check_id = 0; // used to control the frequency of generating drafts
467469

468470
common_speculative_state_ngram_simple(
469471
enum common_speculative_type type,
470-
common_ngram_simple_state state)
471-
: common_speculative_state(type), state(state) {}
472+
common_ngram_simple_config config)
473+
: common_speculative_state(type), config(config) {}
472474

473475
void begin(const llama_tokens & prompt) override {
474476
GGML_UNUSED(prompt);
@@ -479,7 +481,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
479481
const llama_tokens & prompt_tgt,
480482
llama_token id_last,
481483
llama_tokens & result) override {
482-
result = common_ngram_simple_draft(state, prompt_tgt, id_last);
484+
++check_id;
485+
if (check_id < config.check_rate) {
486+
return;
487+
}
488+
check_id = 0;
489+
490+
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
483491
GGML_UNUSED(params);
484492
}
485493

@@ -889,14 +897,14 @@ common_speculative * common_speculative_init(
889897
uint16_t mgram_size_value = ngram_map.size_value;
890898
uint16_t check_rate = ngram_map.check_rate;
891899

892-
auto config_simple = common_ngram_simple_config{
900+
auto config_simple = common_ngram_simple_config {
893901
/* .size_ngram = */ ngram_size_key,
894902
/* .size_mgram = */ mgram_size_value,
895903
/* .check_rate = */ check_rate
896904
};
897905
auto state = std::make_unique<common_speculative_state_ngram_simple>(
898906
/* .type = */ config.type,
899-
/* .state = */ common_ngram_simple_state(config_simple)
907+
/* .state = */ config_simple
900908
);
901909
impls.push_back(std::move(state));
902910
break;

src/models/qwen3next.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,15 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
265265
cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs)
266266

267267
ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff);
268-
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp);
268+
ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp,
269+
1, chunk_size, n_chunks, g_diff_exp->ne[3]);
270+
271+
ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t);
269272
cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs)
270273

274+
ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff));
275+
cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs)
276+
271277

272278
// state to be updated per chunk
273279
ggml_tensor * new_state = state; // ggml_dup(ctx0, state);
@@ -322,9 +328,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chu
322328
: ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2);
323329

324330
// kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new
325-
ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk));
331+
ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk);
326332
//ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why?
327-
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff)));
333+
ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t);
328334

329335
// last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew
330336
ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk));

0 commit comments

Comments
 (0)