Skip to content

Commit acca004

Browse files
committed
start on initialize
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 60a536a commit acca004

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/llmcompressor/modifiers/modifier.py

+6
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,12 @@ def initialize(self, state: State, **kwargs):
8989

9090
self.initialized_ = self.on_initialize(state=state, **kwargs)
9191

92+
# trigger start
93+
fake_start_event = Event(type_=EventType.BATCH_START, global_step=0)
94+
if self.should_start(fake_start_event):
95+
self.on_start(state, fake_start_event, **kwargs)
96+
self.started_ = True
97+
9298
def finalize(self, state: State, **kwargs):
9399
"""
94100
Finalize the modifier for the given model and state.

tests/llmcompressor/pytorch/modifiers/pruning/constant/test_pytorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_constant_pruning_modifier_e2e(model, optimizer):
8181
end=1,
8282
update=0.5,
8383
)
84-
modifier.initialize(state, start=0)
84+
modifier.initialize(state)
8585

8686
# check mask is added and has correct sparsity
8787

0 commit comments

Comments
 (0)