@@ -392,8 +392,8 @@ token position (in the complete sequence batch) for which `keys[b, h, j, :]`,
392392but also to create the causal attention masks for multi-head self attention.
393393In other words, we do not maintain keys and values as block-sparse tensors, but
394394as standard dense tensors: this is simple and allows us to use normal ` PyTorch `
395- operators. ` token_pos ` matterns only when creating attention masks. Moreover,
396- we use ` torch.gather ` to extract information for a slot , and ` torch.scatter `
395+ operators. ` token_pos ` matters only when creating attention masks. Moreover,
396+ we use ` torch.gather ` to extract information for slots , and ` torch.scatter `
397397to write information for new tokens into the cache.
398398
399399For the CLI, a cache is identified by ` kv_cache.name ` , which can be a string
@@ -656,7 +656,7 @@ To be precise, gradients are computed in two phases:
656656 forward over chunks to store KV cache checkpoints on CPU. Then, we loop
657657 backwards over cells, running ` torch.autograd ` to accumulate gradients.
658658
659- One more idea is important. The larger cells are the faster our method runs,
659+ Two more ideas are important. The larger cells are the faster our method runs,
660660because ` torch.autograd ` is best run as few times as possible on larger graphs.
661661However, ` autograd ` stores tensors in its compute graph which are needed during
662662the backward pass, which quickly fills up GPU memory. The largest such nodes
@@ -679,6 +679,15 @@ lengths for a cell should be approximately `cache_length`. With this convention,
679679the size of tensors stored in the ` autograd ` graph scales with ` cache_length `
680680rather than ` chunk_size ` , so becomes comparable to KV cache size.
681681
682+ Second, when using ` torch.nn.functional.scaled_dot_product_attention ` as
683+ operator, we find that this creates several large arrays in the ` autograd ` graph.
684+ To get around this, we implemented our own ` PyTorch ` operator
685+ [ KVCacheScatterUpdateAndSDPAFunction] ( ./keys_values/kvcache/gradient/sdpa_op.py#474 ) .
686+ for SDPA fused with ` torch.scatter ` KV cache update. Its ` backward ` requires naive
687+ blockwise SDPA. We are working on a CUDA version for this fused SDPA operator,
688+ which will speed up computations without sacrificing memory efficiency (like
689+ PyTorch SDPA does).
690+
682691Important arguments for gradient computations are:
683692
684693* ` --grad.layers_per_cell ` : Second phase GPU memory requirements depend
0 commit comments