Skip to content

Commit 2a59554

Browse files
authored
[Callbacks] Remove MagnitudePruningModifier.leave_enabled (#1198)
## Purpose ## * Simplify the modifier lifecycle by removing the ability for modifiers to affect the model after the modifier's `end` event * This allows the `on_event` method to be removed in a future change ## Background ## * The `leave_enabled` option was originally intended as a shortcut to simplify recipes which used magnitude pruning during the iterative pruning, then needed the masks to stay enabled during stabilization SFT * This change proposes making the recipe clearer by requiring the ConstantPruningModifier after the MagnitudePruningModifier becomes inactive ## Changes ## * Remove `MagnitudePruningModifier.leave_enabled` with a deprecation warning Signed-off-by: Kyle Sayers <[email protected]>
1 parent 4607036 commit 2a59554

File tree

1 file changed

+15
-10
lines changed
  • src/llmcompressor/modifiers/pruning/magnitude

1 file changed

+15
-10
lines changed

src/llmcompressor/modifiers/pruning/magnitude/base.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import warnings
12
from typing import Any, Dict, List, Union
23

4+
from pydantic import field_validator
5+
36
from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
47
from llmcompressor.modifiers import Modifier
58
from llmcompressor.modifiers.pruning.helpers import (
@@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
2528
update_scheduler: str = "cubic"
2629
scheduler_args: Dict[str, Any] = {}
2730
mask_structure: str = "unstructured"
28-
leave_enabled: bool = True
31+
leave_enabled: bool = False
2932
apply_globally: bool = False
3033

3134
parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
@@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
3538
mask_creator_function_: MaskCreatorType = None
3639
current_sparsity_: float = None
3740

41+
@field_validator("leave_enabled")
42+
def validate_leave_enabled(value: bool) -> bool:
43+
warnings.warn(
44+
"MagnitudePruningModifier.leave_enable has been deprecated",
45+
DeprecationWarning,
46+
)
47+
return False
48+
3849
def on_initialize(self, state: State, **kwargs) -> bool:
3950
if self.apply_globally:
4051
raise NotImplementedError("global pruning not implemented yet for PyTorch")
@@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
7586
return True
7687

7788
def on_finalize(self, state: State, **kwargs) -> bool:
78-
if not self.leave_enabled:
79-
for layer_param_name, _ in self.parameterized_layers_.items():
80-
self.remove_mask(layer_param_name)
89+
for layer_param_name, _ in self.parameterized_layers_.items():
90+
self.remove_mask(layer_param_name)
8191

8292
return True
8393

@@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
119129
self._update_masks(event)
120130

121131
def on_end(self, state: State, event: Event, **kwargs):
122-
if not self.leave_enabled:
123-
self.disable_masks()
124-
125-
def on_event(self, state: State, event: Event, **kwargs):
126-
if event.current_index >= self.end and self.leave_enabled:
127-
self._update_masks(event)
132+
self.disable_masks()
128133

129134
def _update_masks(self, event: Event):
130135
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:

0 commit comments

Comments
 (0)