Skip to content

Commit 64ce5fc

Browse files
committed
better approach when SWA window exceeded, simply refill the window. this is not 100% correct but good enough for fastforward users. Disable FF or increase window if not good enough
1 parent fa3f86e commit 64ce5fc

3 files changed

Lines changed: 26 additions & 14 deletions

File tree

gpttype_adapter.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4245,7 +4245,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
42454245
{
42464246
if(kcpp_data->use_fastforward)
42474247
{
4248-
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true, 0);
4248+
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, false, true, 0, 0);
42494249
}
42504250
}
42514251
if(is_recurrent)
@@ -4297,17 +4297,18 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
42974297
bool triggerff = kcpp_data->use_fastforward;
42984298
if(!blank_prompt) //special case for blank prompts, no fast forward or shifts
42994299
{
4300+
int ff_swa_retain_amount = 0; //a hack for SWA to improve coherency for illegal rewinds
43004301
if(triggerff && !kcpp_data->swa_full && (file_format == FileFormat::GGUF_GENERIC))
43014302
{
43024303
const int swa_pos_min = llama_memory_seq_pos_min(llama_get_memory(llama_ctx_v4), 0); //this is the furthest back we can rewind to.
43034304
int goal_npast = ComputeSharedPrefixLength(current_context_tokens,embd_inp); //this is where we want to rewind to.
43044305
goal_npast -= 4;
43054306
goal_npast = goal_npast < 0 ? 0 : goal_npast;
43064307
if (swa_pos_min < 0 || goal_npast <= swa_pos_min) {
4307-
triggerff = false;
4308+
ff_swa_retain_amount = kcpp_active_swa_size;
43084309
if (debugmode==1 && !is_quiet)
43094310
{
4310-
printf("\nNote: Context cannot be reused (Desired n_past=%d, SWA lowest n_past=%d), doing a full reprocess... to avoid this, disable SWA or increase SWA padding)\n", goal_npast, swa_pos_min);
4311+
printf("\nNote: SWA context cannot be reused (Desired n_past=%d, SWA lowest n_past=%d), to avoid this, disable SWA or increase SWA padding), output may degrade.\n", goal_npast, swa_pos_min);
43114312
}
43124313
}
43134314
}
@@ -4318,7 +4319,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
43184319
}
43194320
if(triggerff)
43204321
{
4321-
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false, 4);
4322+
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, triggersc, false, 4, ff_swa_retain_amount);
43224323
}
43234324
}
43244325
if(file_format == FileFormat::GGUF_GENERIC)

model_adapter.cpp

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,15 @@ std::string gguf_get_model_arch(const std::string & gguf_filename)
468468
return longest;
469469
}
470470

471-
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
472-
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
473-
bool useSmartContext, const bool requireFullSubset, const int minimum_to_proceed)
474-
{
475-
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
476-
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
477-
const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context
478-
const float SCTruncationRatio = 0.5; //ratio for how many tokens to fast forward
479-
const int SCTokThreshold = 32 + (nctx*0.05); //how many tokens of similarity triggers smartcontext
471+
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp,
472+
int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext,
473+
bool useSmartContext, const bool requireFullSubset, const int minimum_to_proceed, const int minimum_input_to_keep)
474+
{
475+
const int SCCtxLenThreshold = nctx * 0.8; //how much context length must be reach to trigger smartcontext
476+
const int SCInpLenThreshold = nctx * 0.6; //how big must the input array be to trigger smartcontext
477+
const int SCPastLenThreshold = nctx * 0.5; //how wide of a gap between the fast forwarded past and the present to trigger smart context
478+
const float SCTruncationRatio = 0.5; //ratio for how many tokens to fast forward
479+
const int SCTokThreshold = 32 + (nctx*0.05); //how many tokens of similarity triggers smartcontext
480480

481481

482482
//fast forward the past based on identical tokens, stop once a divergence is noted
@@ -532,6 +532,17 @@ std::string gguf_get_model_arch(const std::string & gguf_filename)
532532
fastforwardok = false;
533533
}
534534

535+
//we must ensure that embd_input is at least minimum_input_to_keep if possible, or as large as it can be
536+
if (minimum_input_to_keep > 0 && n_past > embd_inp_len - minimum_input_to_keep)
537+
{
538+
int max_allowed_past = std::max(0, embd_inp_len - minimum_input_to_keep);
539+
n_past = max_allowed_past;
540+
if(n_past<=0)
541+
{
542+
fastforwardok = false;
543+
}
544+
}
545+
535546
if(fastforwardok)
536547
{
537548
last_n_tokens.erase(last_n_tokens.begin(), last_n_tokens.begin() + n_past);

model_adapter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ bool ArrStartWith(const std::vector<int> targetArray, const std::vector<int> sea
117117
int ArrFindIndexOf(const std::vector<int> targetArray, const std::vector<int> searchSeq);
118118

119119
FileFormat check_file_format(const std::string & fname, FileFormatExtraMeta * fileformatmeta);
120-
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp, int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext, const bool useSmartContext, const bool requireFullSubset, const int minimum_to_proceed);
120+
void ContextFastForward(std::vector<int> &current_context_tokens, std::vector<int> &embd_inp, int &n_past, std::vector<int> &last_n_tokens, const int nctx, std::vector<int> &smartcontext, const bool useSmartContext, const bool requireFullSubset, const int minimum_to_proceed, const int minimum_input_to_keep);
121121
bool gguf_tensor_exists(const std::string & filename, std::string tensor_name, bool exactmatch);
122122
std::string gguf_get_model_arch(const std::string & filename);
123123

0 commit comments

Comments
 (0)