Skip to content

Commit 0377ffd

Browse files
committed
add validation, global step
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 75657a5 commit 0377ffd

File tree

3 files changed

+53
-32
lines changed

3 files changed

+53
-32
lines changed

src/llmcompressor/core/lifecycle.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@ class CompressionLifecycle:
4343
initialized_: bool = False
4444
finalized: bool = False
4545

46+
# event order validation
47+
_last_event_type: Optional[EventType] = EventType.BATCH_END
48+
_event_order: List[EventType] = field(
49+
default_factory=lambda: [
50+
EventType.BATCH_START,
51+
EventType.LOSS_CALCULATED,
52+
EventType.OPTIM_PRE_STEP,
53+
EventType.OPTIM_POST_STEP,
54+
EventType.BATCH_END,
55+
]
56+
)
57+
58+
# track global step in training (could be epoch/batch)
59+
global_step: int = 0
60+
4661
def reset(self):
4762
"""
4863
Reset the compression lifecycle, finalizing any active modifiers
@@ -134,7 +149,9 @@ def finalize(self, **kwargs) -> List[Any]:
134149

135150
return mod_data
136151

137-
def event(self, event_type: EventType, **kwargs) -> List[Any]:
152+
def event(
153+
self, event_type: EventType, global_step: Optional[int] = 0, **kwargs
154+
) -> List[Any]:
138155
"""
139156
Handle a compression event.
140157
@@ -164,6 +181,12 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
164181
f"Use the corresponding method instead."
165182
)
166183

184+
if not self._validate_event_order(event_type):
185+
raise ValueError(
186+
f"Lifecycle events must appear following order: {self._event_order}. "
187+
f"Instead, {self._last_event_type} was called before {event_type}"
188+
)
189+
167190
if event_type == EventType.LOSS_CALCULATED and (
168191
"loss" not in kwargs or kwargs["loss"] is None
169192
):
@@ -172,7 +195,11 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
172195

173196
logger.debug("Handling event: {}", event_type)
174197

175-
event = Event(event_type=event_type)
198+
# update global step
199+
if global_step is not None:
200+
self.global_step = global_step
201+
202+
event = Event(type_=event_type)
176203
mod_data = []
177204
for mod in self.modifiers:
178205
data = mod.update_event(state=self.state, event=event, **kwargs)
@@ -186,6 +213,23 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]:
186213

187214
return mod_data
188215

216+
def _validate_event_order(self, event_type: EventType) -> bool:
217+
if event_type not in self._event_order:
218+
# for unhandled events, do not save last event
219+
return True
220+
221+
if event_type == EventType.BATCH_START:
222+
valid = self._last_event_type != EventType.BATCH_START
223+
224+
else:
225+
last_event_index = self._event_order.index(self._last_event_type)
226+
curr_event_index = self._event_order.index(event_type)
227+
valid = last_event_index <= curr_event_index
228+
229+
if valid:
230+
self._last_event_type = event_type
231+
return valid
232+
189233
def _set_model_layer_prefix(self):
190234
compiled_recipe = self.recipe_container.compiled_recipe
191235
if (

src/llmcompressor/core/session.py

+6-29
Original file line numberDiff line numberDiff line change
@@ -223,57 +223,34 @@ def get_serialized_recipe(self) -> Optional[str]:
223223

224224
def _log_model_info(self):
225225
# Log model level logs if cadence reached
226-
event_lifecycle = self._lifecycle.event_lifecycle
227-
if event_lifecycle is None:
228-
# event lifecycle not available
229-
# when recipe is not provided
230-
return
231-
232-
epoch = event_lifecycle.current_index
226+
current_index = self._lifecycle.global_step
233227

234228
if (
235229
should_log_model_info(
236230
model=self.state.model,
237231
loggers=self.state.loggers,
238-
current_log_step=epoch,
232+
current_log_step=current_index,
239233
last_log_step=self.state._last_log_step,
240234
)
241235
and self.state.loggers.frequency_manager.is_epoch_frequency_manager
242236
):
243237
log_model_info(
244238
state=self.state,
245-
current_log_step=epoch,
239+
current_log_step=current_index,
246240
)
247241
# update last log epoch
248-
self.state.loggers.log_written(epoch)
242+
self.state.loggers.log_written(current_index)
249243

250244
def _log_loss(self, event_type: EventType, loss: Any):
251245
if event_type != EventType.LOSS_CALCULATED:
252246
# only log loss when loss is calculated
253247
return
254-
event_lifecycle = self._lifecycle.event_lifecycle
255248

256-
if event_lifecycle is None:
257-
# event lifecycle not available
258-
# when recipe is not provided
259-
return
260-
261-
epoch = event_lifecycle.current_index
262-
if self.state.loggers.frequency_manager.is_optim_frequency_manager:
263-
# log integer step for optimizer frequency manager
264-
current_step = int(
265-
self.state.loggers.epoch_to_step(
266-
epoch=epoch,
267-
steps_per_epoch=len(self.state.data.train),
268-
)
269-
)
270-
else:
271-
# log float epoch for epoch frequency manager
272-
current_step = epoch
249+
current_index = self._lifecycle.global_step
273250

274251
# always log loss if available
275252
if loss is not None:
276253
loss = loss if isinstance(loss, dict) else {"loss": loss}
277254
self.state.loggers.metric.log_scalars(
278-
tag="Loss", values=loss, step=current_step
255+
tag="Loss", values=loss, step=current_index
279256
)

src/llmcompressor/transformers/finetune/session_mixin.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def training_step(
260260
"""
261261
self._check_super_defined("training_step")
262262

263-
callbacks.batch_start(batch_data=inputs)
263+
callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch)
264264
model_outputs = super().training_step(
265265
model=model, inputs=inputs, num_items_in_batch=num_items_in_batch
266266
)

0 commit comments

Comments
 (0)