You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix _deepcopy_for_pruning to patch __dict__ not _parameters
The actual non-leaf tensor after pruning is module.weight stored in
module.__dict__ by the forward pre-hook (weight_orig * weight_mask),
not in module._parameters. Patching _parameters had two problems:
1. Did not fix the RuntimeError (wrong dict being patched)
2. Caused mypy errors: assigning Tensor to Parameter | None
Switch to iterating module.__dict__ instead, which correctly captures
the hook-written non-leaf weight attribute.
0 commit comments