Skip to content

Commit 03ed64a

Browse files
Fix _deepcopy_for_pruning to patch __dict__ not _parameters
The actual non-leaf tensor after pruning is module.weight stored in module.__dict__ by the forward pre-hook (weight_orig * weight_mask), not in module._parameters. Patching _parameters had two problems: 1. Did not fix the RuntimeError (wrong dict being patched) 2. Caused mypy errors: assigning Tensor to Parameter | None Switch to iterating module.__dict__ instead, which correctly captures the hook-written non-leaf weight attribute.
1 parent 498e11f commit 03ed64a

1 file changed

Lines changed: 18 additions & 15 deletions

File tree

src/lightning/pytorch/callbacks/pruning.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -277,28 +277,31 @@ def make_pruning_permanent(self, module: nn.Module) -> None:
277277

278278
@staticmethod
279279
def _deepcopy_for_pruning(module: nn.Module) -> nn.Module:
280-
"""Deep-copy a module safely when parameters may be non-leaf tensors.
280+
"""Deep-copy a module that may contain non-leaf tensors.
281281
282-
After a pruning pass with ``use_lottery_ticket_hypothesis=True``, the
283-
module's parameters are rewritten via ``_copy_param`` (``dst.data =
284-
src.data``). This makes them non-leaf tensors, and a plain
285-
``deepcopy`` raises ``RuntimeError: Only Tensors created explicitly by
286-
the user (graph leaves) support the deepcopy protocol``.
282+
PyTorch pruning hooks write the masked weight (e.g.
283+
``weight_orig * weight_mask``) back onto the module as a plain
284+
``__dict__`` attribute (``module.weight``) after each forward pass.
285+
That stored value is a non-leaf tensor, so a bare ``deepcopy`` raises::
287286
288-
This helper temporarily replaces every non-leaf parameter with a
289-
detached clone, performs the deep-copy, then restores the originals so
290-
the live model is unchanged.
287+
RuntimeError: Only Tensors created explicitly by the user
288+
(graph leaves) support the deepcopy protocol at the moment.
289+
290+
This helper temporarily replaces any non-leaf tensor found in
291+
``module.__dict__`` with a detached leaf clone, performs the
292+
deep-copy, then restores the originals so the live module is
293+
unchanged.
291294
"""
292295
non_leaf: dict[str, Tensor] = {}
293-
for param_name, param in list(module._parameters.items()):
294-
if param is not None and not param.is_leaf:
295-
non_leaf[param_name] = param
296-
module._parameters[param_name] = param.detach().clone()
296+
for attr_name, attr_val in list(module.__dict__.items()):
297+
if isinstance(attr_val, Tensor) and not attr_val.is_leaf:
298+
non_leaf[attr_name] = attr_val
299+
module.__dict__[attr_name] = attr_val.detach().clone()
297300
try:
298301
return deepcopy(module)
299302
finally:
300-
for param_name, original in non_leaf.items():
301-
module._parameters[param_name] = original
303+
for attr_name, original in non_leaf.items():
304+
module.__dict__[attr_name] = original
302305

303306
@staticmethod
304307
def _copy_param(new: nn.Module, old: nn.Module, name: str) -> None:

0 commit comments

Comments
 (0)