Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion src/lightning/pytorch/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,35 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
hook.remove(module)
del module._forward_pre_hooks[k]

@staticmethod
def _deepcopy_for_pruning(module: nn.Module) -> nn.Module:
"""Deep-copy a module that may contain non-leaf tensors.

PyTorch pruning hooks write the masked weight (e.g.
``weight_orig * weight_mask``) back onto the module as a plain
``__dict__`` attribute (``module.weight``) after each forward pass.
That stored value is a non-leaf tensor, so a bare ``deepcopy`` raises::

RuntimeError: Only Tensors created explicitly by the user
(graph leaves) support the deepcopy protocol at the moment.

This helper temporarily replaces any non-leaf tensor found in
``module.__dict__`` with a detached leaf clone, performs the
deep-copy, then restores the originals so the live module is
unchanged.

"""
non_leaf: dict[str, Tensor] = {}
for attr_name, attr_val in list(module.__dict__.items()):
if isinstance(attr_val, Tensor) and not attr_val.is_leaf:
non_leaf[attr_name] = attr_val
module.__dict__[attr_name] = attr_val.detach().clone()
try:
return deepcopy(module)
finally:
for attr_name, original in non_leaf.items():
module.__dict__[attr_name] = original

@staticmethod
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:
# Check if the parameter has been pruned (has _orig suffix)
Expand Down Expand Up @@ -376,12 +405,18 @@ def setup(self, trainer: "pl.Trainer", pl_module: LightningModule, stage: str) -
self._parameters_to_prune = self.filter_parameters_to_prune(parameters_to_prune)

if self._use_lottery_ticket_hypothesis:
# Release references from any previous run so their tensors can be
# garbage-collected before we allocate the new copies.
self._original_layers = None
# group modules by id. Each entry has a copy of the initial data
# and a list of the associated parameter names to prune
self._original_layers = {}
for i, (module, name) in enumerate(self._parameters_to_prune):
id_ = id(module)
self._original_layers.setdefault(id_, _LayerRef(data=deepcopy(module), names=[]))
# Use the detach-safe helper so that iterative pruning (where
# parameters may already be non-leaf tensors from a previous
# pruning cycle) does not raise a RuntimeError.
self._original_layers.setdefault(id_, _LayerRef(data=self._deepcopy_for_pruning(module), names=[]))
self._original_layers[id_]["names"].append((i, name))

def _run_pruning(self, current_epoch: int) -> None:
Expand Down
36 changes: 36 additions & 0 deletions tests/tests_pytorch/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,3 +551,39 @@ def forward(self, x):
expected_pruned_count = int(expected_total_params * pruning_amount)
pruned_tolerance = max(1, int(expected_total_params * 0.05))
assert abs(pruned_count - expected_pruned_count) <= pruned_tolerance


def test_iterative_pruning_no_runtime_error(tmp_path):
"""Reusing a ModelPruning callback with use_lottery_ticket_hypothesis across multiple trainer.fit() calls must not
raise RuntimeError due to non-leaf tensors.

Regression test for https://github.com/Lightning-AI/pytorch-lightning/issues/8542

"""
seed_everything(42)

model = BoringModel()
pruning_callback = ModelPruning(
"l1_unstructured",
use_lottery_ticket_hypothesis=True,
use_global_unstructured=True,
make_pruning_permanent=False,
amount=0.2,
)

for _ in range(3):
trainer = Trainer(
default_root_dir=tmp_path,
enable_progress_bar=False,
enable_model_summary=False,
enable_checkpointing=False,
logger=False,
limit_train_batches=2,
limit_val_batches=1,
max_epochs=1,
accelerator="cpu",
callbacks=[pruning_callback],
)
# Must not raise RuntimeError: "Only Tensors created explicitly by the
# user (graph leaves) support the deepcopy protocol"
trainer.fit(model)