Skip to content

Fix bug with autograd hooks and old training replay cache #52

Description

@mseeger

Describe the bug
A setup for which the new training replay cache works fine (but FlexAttention and query-padded SDPA), fails with the old replay cache, based on sdpa_op.py fused operator.

To reproduce

Run this:

CUDA_VISIBLE_DEVICES="1" PYTORCH_ALLOC_CONF=expandable_segments:True KEYSVALS_LOG_DIR="/home/ubuntu/out/finetune/ml_ws/lora/qwen2_5_0_5b/debug3_no_off/logs"; python3 keys_values/__main__.py finetune_long_full Qwen/Qwen2.5-0.5B --out_dir /home/ubuntu/out/finetune/ml_ws/lora/qwen2_5_0_5b/debug3_no_off --data LongBenchV2 --data.max_seq_length 150000 --data.metadata_dir /home/ubuntu/out/finetune/data --precision bf16-true --kv_cache.name h2o-torch-quantized8 --kv_cache.cache_length 16384 --kv_cache.chunk_size 1024 --verbose some --grad.layers_per_cell 1 --train.save_interval 10 --train.micro_batch_size 4 --train.global_batch_size 4 --eval.interval 10 --eval.micro_batch_size 4 --head_model seq_classification_on_logits --eval.initial_validation False --data.trainloader_longest_first True --grad.chunks_per_cell_multiplier 4 --grad.use_old_cache True

Get:

[...]

Running backward pass over 24 rows of cells, 24 layers, using activation checkpointing
4it [00:25,  6.25s/it]
Traceback (most recent call last):
  File "/home/ubuntu/sync/keys_values/keys_values/__main__.py", line 144, in <module>
    main()
  File "/home/ubuntu/sync/keys_values/keys_values/__main__.py", line 140, in main
    auto_cli(PARSER_DATA)
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/jsonargparse/_cli.py", line 129, in auto_cli
    return _run_component(component, init.get(subcommand))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/jsonargparse/_cli.py", line 227, in _run_component
    return component(**cfg)
           ^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 290, in setup
    setup_internal(
  File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 479, in setup_internal
    fabric.launch(
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1010, in launch
    return self._wrap_and_launch(function, self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1121, in _wrap_and_launch
    return to_run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 1126, in _wrap_with_setup
    return to_run(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 763, in main
    token_counts = fit(
                   ^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/finetune/longcontext_full.py", line 1344, in fit
    fabric.backward(loss)
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/fabric.py", line 523, in backward
    self._strategy.backward(tensor, module, *args, **kwargs)
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/strategies/strategy.py", line 192, in backward
    self.precision.backward(tensor, module, *args, **kwargs)
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/lightning/fabric/plugins/precision/precision.py", line 107, in backward
    tensor.backward(*args, **kwargs)
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 122, in backward
    self._model.backward()
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 517, in backward
    self._backward_accumulate_gradients()
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 867, in _backward_accumulate_gradients
    self._backward_accumulate_gradients_nocheck(count)
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/main.py", line 1125, in _backward_accumulate_gradients_nocheck
    self.accumulator.run(
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/accumulate.py", line 523, in run
    scalar_output.backward()
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/_tensor.py", line 630, in backward
    torch.autograd.backward(
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/__init__.py", line 364, in backward
    _engine_run_backward(
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/graph.py", line 865, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/virtenvs/keysvals/lib/python3.12/site-packages/torch/autograd/function.py", line 317, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/sdpa_op.py", line 636, in backward
    ) = ctx.saved_tensors
        ^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/accumulate.py", line 497, in <lambda>
    lambda x: self.autograd_hooks.unpack_hook(x),
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/autograd_hooks.py", line 780, in unpack_hook
    x = self._unpack_from_annotation(annotation)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/sync/keys_values/keys_values/kvcache/gradient/autograd_hooks.py", line 847, in _unpack_from_annotation
    raise ValueError(
ValueError: Annotation scatter-key (19,12): (1, 2, 16384, 64): final chunk_idx = 14, must be in [12, 13]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No fields configured for Bug.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions