-
Notifications
You must be signed in to change notification settings - Fork 438
Expand file tree
/
Copy pathbase.py
More file actions
144 lines (121 loc) · 5.02 KB
/
base.py
File metadata and controls
144 lines (121 loc) · 5.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import warnings
from typing import Any, Dict, List, Union
from pydantic import field_validator
from llmcompressor.core import Event, EventType, ModelParameterizedLayer, State
from llmcompressor.modifiers import Modifier
from llmcompressor.modifiers.pruning.helpers import (
PruningCreateSettings,
PruningSchedulerFactory,
SchedulerCalculationType,
)
from llmcompressor.modifiers.pruning.utils.pytorch import (
LayerParamMasking,
MaskCreatorType,
PruningMaskCreatorArgs,
PruningMaskFactory,
)
from llmcompressor.utils.pytorch.module import build_parameterized_layers
__all__ = ["MagnitudePruningModifier"]
class MagnitudePruningModifier(Modifier, LayerParamMasking):
targets: Union[str, List[str]]
init_sparsity: float
final_sparsity: float
update_scheduler: str = "cubic"
scheduler_args: Dict[str, Any] = {}
mask_structure: str = "unstructured"
leave_enabled: bool = False
apply_globally: bool = False
parameterized_layers_: Dict[str, ModelParameterizedLayer] = None
_save_masks: bool = False
_use_hooks: bool = False
scheduler_function_: SchedulerCalculationType = None
mask_creator_function_: MaskCreatorType = None
current_sparsity_: float = None
@field_validator("leave_enabled")
def validate_leave_enabled(value: bool) -> bool:
if value:
warnings.warn(
"MagnitudePruningModifier.leave_enabled has been deprecated "
"and will be set to False.",
DeprecationWarning,
)
return False
def on_initialize(self, state: State, **kwargs) -> bool:
if self.apply_globally:
raise NotImplementedError("global pruning not implemented yet for PyTorch")
if "save_masks" in kwargs:
self._save_masks = kwargs["save_masks"]
if "use_hooks" in kwargs:
self._use_hooks = kwargs["use_hooks"]
if not state.model:
return False
self.scheduler_function_ = PruningSchedulerFactory.create_scheduler(
self.update_scheduler,
PruningCreateSettings(
self.start,
self.end,
self.update,
self.init_sparsity,
self.final_sparsity,
self.scheduler_args,
),
)
self.mask_creator_function_ = PruningMaskFactory.create_mask_creator(
self.mask_structure
)
self.parameterized_layers_ = build_parameterized_layers(
state.model, self.targets
)
for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
self.add_mask(
layer_param_name,
parameterized_layer,
persistent=self._save_masks,
add_hooks=self._use_hooks,
)
return True
def on_finalize(self, state: State, **kwargs) -> bool:
for layer_param_name, _ in self.parameterized_layers_.items():
self.remove_mask(layer_param_name)
return True
def on_start(self, state: State, event: Event, **kwargs):
sparsity = self.scheduler_function_(event, state)
self.current_sparsity_ = sparsity
for layer_param_name, parameterized_layer in self.parameterized_layers_.items():
mask = self.mask_creator_function_(
PruningMaskCreatorArgs(
parameter=parameterized_layer.param,
sparsity=sparsity,
scores=parameterized_layer.param.data.abs(),
)
)
self.update_mask(layer_param_name, mask)
self.enable_masks()
def on_update(self, state: State, event: Event, **kwargs):
if event.type_ == EventType.BATCH_START:
sparsity = self.scheduler_function_(event, state)
if sparsity != self.current_sparsity_:
self.current_sparsity_ = sparsity
for (
layer_param_name,
parameterized_layer,
) in self.parameterized_layers_.items():
mask = self.mask_creator_function_(
PruningMaskCreatorArgs(
parameter=parameterized_layer.param,
sparsity=sparsity,
scores=parameterized_layer.param.data.abs(),
)
)
self.update_mask(layer_param_name, mask)
else:
self._update_masks(event)
def on_end(self, state: State, event: Event, **kwargs):
self.disable_masks()
def _update_masks(self, event: Event):
if event.type_ == EventType.OPTIM_PRE_STEP and not self._use_hooks:
for layer_param_name, _ in self.parameterized_layers_.items():
self.apply_mask_gradient(layer_param_name)
elif event.type_ == EventType.OPTIM_POST_STEP and not self._use_hooks:
for layer_param_name, _ in self.parameterized_layers_.items():
self.apply_mask_weight(layer_param_name)