Skip to content

Commit 2365315

Browse files
authored
kv-cache : SWA checkpoints store only non-masked cells (#23981)
1 parent f7a0777 commit 2365315

1 file changed

Lines changed: 14 additions & 2 deletions

File tree

src/llama-kv-cache.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,19 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla
18761876
uint32_t cell_range_begin = cells.size();
18771877

18781878
for (uint32_t i = 0; i < cells.size(); ++i) {
1879-
if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) {
1879+
bool add_cell = true;
1880+
1881+
add_cell = add_cell && !cells.is_empty(i);
1882+
add_cell = add_cell && (seq_id == -1 || cells.seq_has(i, seq_id));
1883+
1884+
// check the cell is not SWA-masked
1885+
if (add_cell && seq_id != -1) {
1886+
const bool is_masked = llama_hparams::is_masked_swa(n_swa, swa_type, cells.pos_get(i), cells.seq_pos_max(seq_id));
1887+
1888+
add_cell = !is_masked;
1889+
}
1890+
1891+
if (add_cell) {
18801892
++cell_count;
18811893
if (cell_range_begin == cells.size()) {
18821894
cell_range_begin = i;
@@ -2129,7 +2141,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32
21292141

21302142
sinfo = find_slot(ubatch, false);
21312143
if (sinfo.empty()) {
2132-
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
2144+
LLAMA_LOG_ERROR("%s: failed to find %d available cells in kv cache\n", __func__, cell_count);
21332145
return false;
21342146
}
21352147

0 commit comments

Comments
 (0)