From 2be5e0bb07e8eb0aa5612de76846dcd23411ce55 Mon Sep 17 00:00:00 2001
From: Garibaldi Pineda-Garcia
Date: Sat, 3 May 2025 18:44:01 +0100
Subject: [PATCH 1/4] check param is of nn.Parameter type
This came up for an LSTM module, the bias parameter is a boolean which later stages of the pruning.
---
src/lightning/pytorch/callbacks/pruning.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py
index 1517ef6920b0d..91ea64c93ef3f 100644
--- a/src/lightning/pytorch/callbacks/pruning.py
+++ b/src/lightning/pytorch/callbacks/pruning.py
@@ -458,7 +458,8 @@ def sanitize_parameters_to_prune(
if not parameters_to_prune:
parameters_to_prune = [
- (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None
+ (m, p) for p in parameters for m in current_modules
+ if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter)
]
elif (
isinstance(parameters_to_prune, (list, tuple))
From f9674f5478a67532b50d22c66e688016b1a12455 Mon Sep 17 00:00:00 2001
From: "pre-commit-ci[bot]"
<66853113+pre-commit-ci[bot]@users.noreply.github.com>
Date: Sat, 3 May 2025 17:51:10 +0000
Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---
src/lightning/pytorch/callbacks/pruning.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py
index 91ea64c93ef3f..fb888aa087ccf 100644
--- a/src/lightning/pytorch/callbacks/pruning.py
+++ b/src/lightning/pytorch/callbacks/pruning.py
@@ -458,7 +458,9 @@ def sanitize_parameters_to_prune(
if not parameters_to_prune:
parameters_to_prune = [
- (m, p) for p in parameters for m in current_modules
+ (m, p)
+ for p in parameters
+ for m in current_modules
if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter)
]
elif (
From ba5aef494e1ab353ccf5f22e078c9d6013fce220 Mon Sep 17 00:00:00 2001
From: Rittik Panda
Date: Sun, 15 Jun 2025 18:57:17 +0530
Subject: [PATCH 3/4] Update test_pruning.py
---
tests/tests_pytorch/callbacks/test_pruning.py | 77 +++++++++++++++++++
1 file changed, 77 insertions(+)
diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py
index d70ab68b78b32..a3d6caedcdc88 100644
--- a/tests/tests_pytorch/callbacks/test_pruning.py
+++ b/tests/tests_pytorch/callbacks/test_pruning.py
@@ -338,3 +338,80 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
assert not hasattr(model.layer.mlp_3, "weight_orig")
model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
+
+
+def test_sanitize_parameters_explicit_check():
+ """Test the sanitize_parameters_to_prune method with various attribute types."""
+
+ class TestModule(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = nn.Parameter(torch.randn(5, 5))
+ self.bias = nn.Parameter(torch.randn(5))
+ self.some_bool = True
+ self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter
+ self.some_string = "test"
+ self.some_none = None
+
+ class TestModel(BoringModel):
+ def __init__(self):
+ super().__init__()
+ self.test_module = TestModule()
+
+ model = TestModel()
+
+ parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
+ model,
+ parameters_to_prune=(),
+ parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"],
+ )
+
+ param_names_found = set()
+ for module, param_name in parameters_to_prune:
+ param = getattr(module, param_name)
+ assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}"
+ param_names_found.add(param_name)
+
+ assert "weight" in param_names_found
+ assert "bias" in param_names_found
+ assert "some_bool" not in param_names_found
+ assert "some_tensor" not in param_names_found
+ assert "some_string" not in param_names_found
+ assert "some_none" not in param_names_found
+
+
+def test_original_issue_reproduction():
+ """Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835."""
+
+ class ProblematicModel(BoringModel):
+ def __init__(self):
+ super().__init__()
+ self.layer = Sequential(
+ OrderedDict([
+ ("mlp_1", nn.Linear(32, 32)),
+ ("mlp_2", nn.Linear(32, 2)),
+ ])
+ )
+ # Add boolean attributes that would cause the original error
+ self.layer.mlp_1.training = True
+ self.layer.mlp_2.requires_grad = True
+
+ model = ProblematicModel()
+
+ try:
+ parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
+ model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
+ )
+
+ for module, param_name in parameters_to_prune:
+ param = getattr(module, param_name)
+ assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
+
+ success = True
+ except AttributeError as e:
+ if "'bool' object has no attribute 'is_cuda'" in str(e):
+ success = False # Original bug still present
+ else:
+ raise # Different error
+
+ assert success, "The fix for issue #10835 is not working correctly"
From 2ddda83ea3c304a0607d46bc2209682f885671d1 Mon Sep 17 00:00:00 2001
From: Rittik Panda
Date: Wed, 18 Jun 2025 20:49:09 +0530
Subject: [PATCH 4/4] Update test_pruning.py
---
tests/tests_pytorch/callbacks/test_pruning.py | 22 +++++--------------
1 file changed, 6 insertions(+), 16 deletions(-)
diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py
index a3d6caedcdc88..6efe9b9992d00 100644
--- a/tests/tests_pytorch/callbacks/test_pruning.py
+++ b/tests/tests_pytorch/callbacks/test_pruning.py
@@ -398,20 +398,10 @@ def __init__(self):
model = ProblematicModel()
- try:
- parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
- model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
- )
-
- for module, param_name in parameters_to_prune:
- param = getattr(module, param_name)
- assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"
-
- success = True
- except AttributeError as e:
- if "'bool' object has no attribute 'is_cuda'" in str(e):
- success = False # Original bug still present
- else:
- raise # Different error
+ parameters_to_prune = ModelPruning.sanitize_parameters_to_prune(
+ model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"]
+ )
- assert success, "The fix for issue #10835 is not working correctly"
+ for module, param_name in parameters_to_prune:
+ param = getattr(module, param_name)
+ assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"