From 2be5e0bb07e8eb0aa5612de76846dcd23411ce55 Mon Sep 17 00:00:00 2001 From: Garibaldi Pineda-Garcia Date: Sat, 3 May 2025 18:44:01 +0100 Subject: [PATCH 1/4] check param is of nn.Parameter type This came up for an LSTM module, the bias parameter is a boolean which later stages of the pruning. --- src/lightning/pytorch/callbacks/pruning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 1517ef6920b0d..91ea64c93ef3f 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -458,7 +458,8 @@ def sanitize_parameters_to_prune( if not parameters_to_prune: parameters_to_prune = [ - (m, p) for p in parameters for m in current_modules if getattr(m, p, None) is not None + (m, p) for p in parameters for m in current_modules + if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter) ] elif ( isinstance(parameters_to_prune, (list, tuple)) From f9674f5478a67532b50d22c66e688016b1a12455 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 3 May 2025 17:51:10 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/callbacks/pruning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/callbacks/pruning.py b/src/lightning/pytorch/callbacks/pruning.py index 91ea64c93ef3f..fb888aa087ccf 100644 --- a/src/lightning/pytorch/callbacks/pruning.py +++ b/src/lightning/pytorch/callbacks/pruning.py @@ -458,7 +458,9 @@ def sanitize_parameters_to_prune( if not parameters_to_prune: parameters_to_prune = [ - (m, p) for p in parameters for m in current_modules + (m, p) + for p in parameters + for m in current_modules if getattr(m, p, None) is not None and isinstance(getattr(m, p, None), nn.Parameter) ] elif ( From ba5aef494e1ab353ccf5f22e078c9d6013fce220 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Sun, 15 Jun 2025 18:57:17 +0530 Subject: [PATCH 3/4] Update test_pruning.py --- tests/tests_pytorch/callbacks/test_pruning.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index d70ab68b78b32..a3d6caedcdc88 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -338,3 +338,80 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint): assert not hasattr(model.layer.mlp_3, "weight_orig") model = TestModel.load_from_checkpoint(trainer.checkpoint_callback.last_model_path) assert not hasattr(model.layer.mlp_3, "weight_orig") + + +def test_sanitize_parameters_explicit_check(): + """Test the sanitize_parameters_to_prune method with various attribute types.""" + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(5, 5)) + self.bias = nn.Parameter(torch.randn(5)) + self.some_bool = True + self.some_tensor = torch.randn(3, 3) # Regular tensor, not parameter + self.some_string = "test" + self.some_none = None + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.test_module = TestModule() + + model = TestModel() + + parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( + model, + parameters_to_prune=(), + parameter_names=["weight", "bias", "some_bool", "some_tensor", "some_string", "some_none"], + ) + + param_names_found = set() + for module, param_name in parameters_to_prune: + param = getattr(module, param_name) + assert isinstance(param, nn.Parameter), f"Expected Parameter, got {type(param)}" + param_names_found.add(param_name) + + assert "weight" in param_names_found + assert "bias" in param_names_found + assert "some_bool" not in param_names_found + assert "some_tensor" not in param_names_found + assert "some_string" not in param_names_found + assert "some_none" not in param_names_found + + +def test_original_issue_reproduction(): + """Issue: https://github.com/Lightning-AI/pytorch-lightning/issues/10835.""" + + class ProblematicModel(BoringModel): + def __init__(self): + super().__init__() + self.layer = Sequential( + OrderedDict([ + ("mlp_1", nn.Linear(32, 32)), + ("mlp_2", nn.Linear(32, 2)), + ]) + ) + # Add boolean attributes that would cause the original error + self.layer.mlp_1.training = True + self.layer.mlp_2.requires_grad = True + + model = ProblematicModel() + + try: + parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( + model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"] + ) + + for module, param_name in parameters_to_prune: + param = getattr(module, param_name) + assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}" + + success = True + except AttributeError as e: + if "'bool' object has no attribute 'is_cuda'" in str(e): + success = False # Original bug still present + else: + raise # Different error + + assert success, "The fix for issue #10835 is not working correctly" From 2ddda83ea3c304a0607d46bc2209682f885671d1 Mon Sep 17 00:00:00 2001 From: Rittik Panda Date: Wed, 18 Jun 2025 20:49:09 +0530 Subject: [PATCH 4/4] Update test_pruning.py --- tests/tests_pytorch/callbacks/test_pruning.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index a3d6caedcdc88..6efe9b9992d00 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -398,20 +398,10 @@ def __init__(self): model = ProblematicModel() - try: - parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( - model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"] - ) - - for module, param_name in parameters_to_prune: - param = getattr(module, param_name) - assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}" - - success = True - except AttributeError as e: - if "'bool' object has no attribute 'is_cuda'" in str(e): - success = False # Original bug still present - else: - raise # Different error + parameters_to_prune = ModelPruning.sanitize_parameters_to_prune( + model, parameters_to_prune=(), parameter_names=["weight", "bias", "training", "requires_grad"] + ) - assert success, "The fix for issue #10835 is not working correctly" + for module, param_name in parameters_to_prune: + param = getattr(module, param_name) + assert isinstance(param, nn.Parameter), f"Non-parameter found: {type(param)}"