@@ -43,6 +43,21 @@ class CompressionLifecycle:
43
43
initialized_ : bool = False
44
44
finalized : bool = False
45
45
46
+ # event order validation
47
+ _last_event_type : Optional [EventType ] = EventType .BATCH_END
48
+ _event_order : List [EventType ] = field (
49
+ default_factory = lambda : [
50
+ EventType .BATCH_START ,
51
+ EventType .LOSS_CALCULATED ,
52
+ EventType .OPTIM_PRE_STEP ,
53
+ EventType .OPTIM_POST_STEP ,
54
+ EventType .BATCH_END ,
55
+ ]
56
+ )
57
+
58
+ # track global step in training (could be epoch/batch)
59
+ global_step : int = 0
60
+
46
61
def reset (self ):
47
62
"""
48
63
Reset the compression lifecycle, finalizing any active modifiers
@@ -134,7 +149,9 @@ def finalize(self, **kwargs) -> List[Any]:
134
149
135
150
return mod_data
136
151
137
- def event (self , event_type : EventType , ** kwargs ) -> List [Any ]:
152
+ def event (
153
+ self , event_type : EventType , global_step : Optional [int ] = 0 , ** kwargs
154
+ ) -> List [Any ]:
138
155
"""
139
156
Handle a compression event.
140
157
@@ -164,6 +181,12 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
164
181
f"Use the corresponding method instead."
165
182
)
166
183
184
+ if not self ._validate_event_order (event_type ):
185
+ raise ValueError (
186
+ f"Lifecycle events must appear following order: { self ._event_order } . "
187
+ f"Instead, { self ._last_event_type } was called before { event_type } "
188
+ )
189
+
167
190
if event_type == EventType .LOSS_CALCULATED and (
168
191
"loss" not in kwargs or kwargs ["loss" ] is None
169
192
):
@@ -172,7 +195,11 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
172
195
173
196
logger .debug ("Handling event: {}" , event_type )
174
197
175
- event = Event (event_type = event_type )
198
+ # update global step
199
+ if global_step is not None :
200
+ self .global_step = global_step
201
+
202
+ event = Event (type_ = event_type )
176
203
mod_data = []
177
204
for mod in self .modifiers :
178
205
data = mod .update_event (state = self .state , event = event , ** kwargs )
@@ -186,6 +213,23 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
186
213
187
214
return mod_data
188
215
216
+ def _validate_event_order (self , event_type : EventType ) -> bool :
217
+ if event_type not in self ._event_order :
218
+ # for unhandled events, do not save last event
219
+ return True
220
+
221
+ if event_type == EventType .BATCH_START :
222
+ valid = self ._last_event_type != EventType .BATCH_START
223
+
224
+ else :
225
+ last_event_index = self ._event_order .index (self ._last_event_type )
226
+ curr_event_index = self ._event_order .index (event_type )
227
+ valid = last_event_index <= curr_event_index
228
+
229
+ if valid :
230
+ self ._last_event_type = event_type
231
+ return valid
232
+
189
233
def _set_model_layer_prefix (self ):
190
234
compiled_recipe = self .recipe_container .compiled_recipe
191
235
if (
0 commit comments