Skip to content

Commit 943dc2b

Browse files
committed
Small change
1 parent 7c6c5db commit 943dc2b

2 files changed

Lines changed: 12 additions & 4 deletions

File tree

README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,8 @@ token position (in the complete sequence batch) for which `keys[b, h, j, :]`,
392392
but also to create the causal attention masks for multi-head self attention.
393393
In other words, we do not maintain keys and values as block-sparse tensors, but
394394
as standard dense tensors: this is simple and allows us to use normal `PyTorch`
395-
operators. `token_pos` matterns only when creating attention masks. Moreover,
396-
we use `torch.gather` to extract information for a slot, and `torch.scatter`
395+
operators. `token_pos` matters only when creating attention masks. Moreover,
396+
we use `torch.gather` to extract information for slots, and `torch.scatter`
397397
to write information for new tokens into the cache.
398398

399399
For the CLI, a cache is identified by `kv_cache.name`, which can be a string
@@ -656,7 +656,7 @@ To be precise, gradients are computed in two phases:
656656
forward over chunks to store KV cache checkpoints on CPU. Then, we loop
657657
backwards over cells, running `torch.autograd` to accumulate gradients.
658658

659-
One more idea is important. The larger cells are the faster our method runs,
659+
Two more ideas are important. The larger cells are the faster our method runs,
660660
because `torch.autograd` is best run as few times as possible on larger graphs.
661661
However, `autograd` stores tensors in its compute graph which are needed during
662662
the backward pass, which quickly fills up GPU memory. The largest such nodes
@@ -679,6 +679,15 @@ lengths for a cell should be approximately `cache_length`. With this convention,
679679
the size of tensors stored in the `autograd` graph scales with `cache_length`
680680
rather than `chunk_size`, so becomes comparable to KV cache size.
681681

682+
Second, when using `torch.nn.functional.scaled_dot_product_attention` as
683+
operator, we find that this creates several large arrays in the `autograd` graph.
684+
To get around this, we implemented our own `PyTorch` operator
685+
[KVCacheScatterUpdateAndSDPAFunction](./keys_values/kvcache/gradient/sdpa_op.py#474).
686+
for SDPA fused with `torch.scatter` KV cache update. Its `backward` requires naive
687+
blockwise SDPA. We are working on a CUDA version for this fused SDPA operator,
688+
which will speed up computations without sacrificing memory efficiency (like
689+
PyTorch SDPA does).
690+
682691
Important arguments for gradient computations are:
683692

684693
* `--grad.layers_per_cell`: Second phase GPU memory requirements depend

keys_values/data/longbench_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ class LongBenchV2(DataModule):
139139
each record
140140
141141
"""
142-
143142
def __init__(
144143
self,
145144
mask_prompt: bool = True,

0 commit comments

Comments
 (0)