CUDA_VISIBLE_DEVICES="1" PYTORCH_ALLOC_CONF=expandable_segments:True KEYSVALS_LOG_DIR="/home/ubuntu/out/finetune/ml_ws/lora/qwen2_5_0_5b/debug3_no_off/logs"; python3 keys_values/__main__.py finetune_long_full Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/ml_ws/lora/qwen2_5_0_5b/debug3_no_off --data LongBenchV2 --data.max_seq_length 150000 --data.metadata_dir /home/ubuntu/out/finetune/data --precision bf16-true --kv_cache.name h2o-torch-quantized8 --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --verbose some --grad.layers_per_cell 1 --train.save_interval 10 --train.micro_batch_size 4 --train.global_batch_size 4 --eval.interval 10 --eval.micro_batch_size 4 --head_model seq_classification_on_logits --eval.initial_validation False --data.trainloader_longest_first True --grad.chunks_per_cell_multiplier 4 --grad.use_old_cache True
[...]
Running backward pass over 24 rows of cells, 24 layers, using activation checkpointing
4it [00:25, 6.25s/it]
Traceback (most recent call last):
File "/home/ubuntu/sync/keys_values/keys_values/__main__.py", line 144, in <module>
main()
File "/home/ubuntu/sync/keys_values/keys_values/__main__.py", line 140, in main
auto_cli(PARSER_DATA)
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/jsonargparse/_cli.py", line 129, in auto_cli
return _run_component(component, init.get(subcommand))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/jsonargparse/_cli.py", line 227, in _run_component
return component(**cfg)
^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 290, in setup
setup_internal(
File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 479, in setup_internal
fabric.launch(
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1010, in launch
return self._wrap_and_launch(function, self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1121, in _wrap_and_launch
return to_run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1126, in _wrap_with_setup
return to_run(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 763, in main
token_counts = fit(
^^^^
File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 1344, in fit
fabric.backward(loss)
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 523, in backward
self._strategy.backward(tensor, module, *args, **kwargs)
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/strategies/strategy.py", line 192, in backward
self.precision.backward(tensor, module, *args, **kwargs)
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
tensor.backward(*args, **kwargs)
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 122, in backward
self._model.backward()
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 517, in backward
self._backward_accumulate_gradients()
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 867, in _backward_accumulate_gradients
self._backward_accumulate_gradients_nocheck(count)
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 1125, in _backward_accumulate_gradients_nocheck
self.accumulator.run(
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/accumulate.py", line 523, in run
scalar_output.backward()
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
torch.autograd.backward(
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
_engine_run_backward(
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/sdpa_op.py", line 636, in backward
) = ctx.saved_tensors
^^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/accumulate.py", line 497, in <lambda>
lambda x: self.autograd_hooks.unpack_hook(x),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/autograd_hooks.py", line 780, in unpack_hook
x = self._unpack_from_annotation(annotation)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/autograd_hooks.py", line 847, in _unpack_from_annotation
raise ValueError(
ValueError: Annotation scatter-key (19,12): (1, 2, 16384, 64): final chunk_idx = 14, must be in [12, 13]
Describe the bug
A setup for which the new training replay cache works fine (but FlexAttention and query-padded SDPA), fails with the old replay cache, based on
sdpa_op.pyfused operator.To reproduce
Run this:
Get: