Skip to content

Fix ModelPruning iterative reuse RuntimeError on non-leaf tensors#21717

Open
Priyanshu-byte-coder wants to merge 3 commits into
Lightning-AI:masterfrom
Priyanshu-byte-coder:fix/model-pruning-iterative-deepcopy-error
Open

Fix ModelPruning iterative reuse RuntimeError on non-leaf tensors#21717
Priyanshu-byte-coder wants to merge 3 commits into
Lightning-AI:masterfrom
Priyanshu-byte-coder:fix/model-pruning-iterative-deepcopy-error

Conversation

@Priyanshu-byte-coder
Copy link
Copy Markdown

@Priyanshu-byte-coder Priyanshu-byte-coder commented May 17, 2026

What does this PR fix?

Fixes #8542.

When ModelPruning(use_lottery_ticket_hypothesis=True) is reused across multiple trainer.fit() calls (iterative pruning), the second and subsequent calls to setup() raise:

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

Root cause: After the first training pass, apply_lottery_ticket_hypothesis() calls _copy_param() which does dst.data = src.data.to(dst.device). This in-place data assignment makes the model parameters non-leaf tensors. When setup() is called again on the next trainer.fit(), deepcopy(module) fails.

Changes

src/lightning/pytorch/callbacks/pruning.py

  • Added _deepcopy_for_pruning(module) static helper: temporarily replaces non-leaf parameters with detach().clone() copies, deepcopies, then restores originals
  • Updated setup() to use _deepcopy_for_pruning instead of bare deepcopy(module)
  • In setup(), set self._original_layers = None before re-populating to release previous-run tensor references

tests/tests_pytorch/callbacks/test_pruning.py

  • Added test_iterative_pruning_no_runtime_error: runs 3 consecutive trainer.fit() calls with the same pruning callback and verifies no RuntimeError is raised

Reproduction

model = BoringModel()
pruning_callback = ModelPruning("l1_unstructured", use_lottery_ticket_hypothesis=True, amount=0.2)
for _ in range(2):
    trainer = Trainer(max_epochs=1, accelerator="cpu", callbacks=[pruning_callback])
    trainer.fit(model)  # RuntimeError on second iteration (before fix)

📚 Documentation preview 📚: https://pytorch-lightning--21717.org.readthedocs.build/en/21717/

When ModelPruning with use_lottery_ticket_hypothesis=True is reused
across multiple trainer.fit() calls, setup() is called again each run.
After the first pass, _copy_param sets dst.data = src.data, making the
parameters non-leaf tensors. The subsequent deepcopy(module) then raises:

  RuntimeError: Only Tensors created explicitly by the user (graph
  leaves) support the deepcopy protocol at the moment

Fix by adding _deepcopy_for_pruning(), a helper that temporarily
replaces non-leaf parameters with detached clones before deepcopy and
restores the originals afterward. Also explicitly set _original_layers
to None before re-populating in setup() so the previous run's tensor
references are released before new copies are allocated.

Fixes Lightning-AI#8542
@github-actions github-actions Bot added the pl Generic label for PyTorch Lightning package label May 17, 2026
The actual non-leaf tensor after pruning is module.weight stored in
module.__dict__ by the forward pre-hook (weight_orig * weight_mask),
not in module._parameters. Patching _parameters had two problems:

1. Did not fix the RuntimeError (wrong dict being patched)
2. Caused mypy errors: assigning Tensor to Parameter | None

Switch to iterating module.__dict__ instead, which correctly captures
the hook-written non-leaf weight attribute.
@Priyanshu-byte-coder Priyanshu-byte-coder force-pushed the fix/model-pruning-iterative-deepcopy-error branch from 3de4112 to 03ed64a Compare May 17, 2026 11:02
@Priyanshu-byte-coder
Copy link
Copy Markdown
Author

Pinging for review — all CI checks are green (triage, pre-commit, GitGuardian, ReadTheDocs all pass).

@justusschock @tchaton could one of you take a look when you get a chance? The fix is in _deepcopy_for_pruning in pruning.py — patches module.__dict__ instead of module._parameters to handle the non-leaf tensor that PyTorch's pruning forward pre-hook stores there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pl Generic label for PyTorch Lightning package

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Pruning callback causes GPU memory leak when used iteratively

1 participant