diff --git a/src/llmcompressor/core/__init__.py b/src/llmcompressor/core/__init__.py index 85a074869..3a78cf70e 100644 --- a/src/llmcompressor/core/__init__.py +++ b/src/llmcompressor/core/__init__.py @@ -1,10 +1,4 @@ -from llmcompressor.core.events import ( - CallbacksEventLifecycle, - Event, - EventLifecycle, - EventType, - OptimizerEventLifecycle, -) +from llmcompressor.core.events import Event, EventType from llmcompressor.core.lifecycle import CompressionLifecycle from llmcompressor.core.model_layer import ModelParameterizedLayer from llmcompressor.core.session import CompressionSession @@ -20,9 +14,6 @@ __all__ = [ "Event", "EventType", - "EventLifecycle", - "CallbacksEventLifecycle", - "OptimizerEventLifecycle", "State", "Data", "Hardware", diff --git a/src/llmcompressor/core/events/__init__.py b/src/llmcompressor/core/events/__init__.py index e2fdb5b2d..9b51efb32 100644 --- a/src/llmcompressor/core/events/__init__.py +++ b/src/llmcompressor/core/events/__init__.py @@ -8,14 +8,5 @@ """ from .event import Event, EventType -from .event_lifecycle import EventLifecycle -from .lifecycle_callbacks import CallbacksEventLifecycle -from .lifecycle_optimizer import OptimizerEventLifecycle -__all__ = [ - "Event", - "EventType", - "EventLifecycle", - "CallbacksEventLifecycle", - "OptimizerEventLifecycle", -] +__all__ = ["Event", "EventType"] diff --git a/src/llmcompressor/core/events/event_lifecycle.py b/src/llmcompressor/core/events/event_lifecycle.py deleted file mode 100644 index c0bbcad30..000000000 --- a/src/llmcompressor/core/events/event_lifecycle.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -Module for defining and managing the event lifecycle in the LLM Compressor. - -This module provides an abstract base class for defining event lifecycles -and methods for retrieving events based on their type, -managing step and batch counts, and triggering events. -""" - -from abc import ABC, abstractmethod -from typing import List - -from loguru import logger - -from llmcompressor.core.events.event import Event, EventType - -__all__ = [ - "EventLifecycle", -] - - -class EventLifecycle(ABC, Event): - """ - A lifecycle for events to be used in a LLMCompressor session. - Provides base utilities and defines the contract that - all inheritors must follow. - - The order in which the events are called is determined by - the inheritors of this class. - - The expected lifecycle is as follows with optional gradient accumulation: - for gradient_batches in training loop: - for batch in gradient_batches or [gradient_batches]: - BATCH_START - LOSS_CALCULATED - - if not last batch: - BATCH_END - else: - OPTIM_PRE_STEP - OPTIM_POST_STEP - BATCH_END - For older flows where the optimizer is wrapped and invocations_per_step > 1: - for gradient_batches in training loop: - for invocation in range(invocations_per_step): - for batch in gradient_batches or [gradient_batches]: - BATCH_START - LOSS_CALCULATED - - if not last batch or not last invocation: - BATCH_END - else: - OPTIM_PRE_STEP - OPTIM_POST_STEP - BATCH_END - - :param type_first: The first event type to be called - :type type_first: EventType - :param start: The start event to base the lifecycle off of - :type start: Event - """ - - def __init__(self, type_first: EventType, start: Event): - logger.debug( - "Initializing EventLifecycle with type_first={} and start={}", - type_first, - start, - ) - self.type_first = type_first - self.last_type = None - - self.steps_per_epoch = start.steps_per_epoch - self.batches_per_step = start.batches_per_step - self.invocations_per_step = start.invocations_per_step - self.global_step = start.global_step - self.global_batch = start.global_batch - - def events_from_type(self, type_: EventType) -> List[Event]: - """ - Get the list of events for a given type. - - :param type_: The event type to get the events for - :type type_: EventType - :return: The list of events for the given type - :rtype: List[Event] - :raises ValueError: If the event type is invalid - """ - logger.debug("Fetching events from type: {}", type_) - if type_ == EventType.BATCH_START: - return self.batch_start_events() - if type_ == EventType.LOSS_CALCULATED: - return self.loss_calculated_events() - if type_ == EventType.OPTIM_PRE_STEP: - return self.optim_pre_step_events() - if type_ == EventType.OPTIM_POST_STEP: - return self.optim_post_step_events() - if type_ == EventType.BATCH_END: - return self.batch_end_events() - logger.error("Invalid event type: {}", type_) - raise ValueError(f"Invalid event type {type_}") - - def check_batches_per_step_count(self, increment: bool) -> bool: - """ - Check if the batch count is at the step or step invocation count. - If batches_per_step is None or < 2, always returns True. - - If invocations_per_step is > 1, - then returns True for batches matching the invocation. - Check check_invocations_per_step_count for the invocation count. - - :param increment: Whether to increment the batch count - :type increment: bool - :return: True if the batch count is at the step count, False otherwise - :rtype: bool - """ - compare_batch = self.global_batch + 1 - at_step = ( - self.batches_per_step is None - or self.batches_per_step < 2 - or (compare_batch % self.batches_per_step == 0) - ) - if increment: - self.global_batch = compare_batch - - return at_step - - def check_invocations_per_step_count(self, increment: bool) -> bool: - """ - Check if the invocation count is at the step count. - If invocations_per_step is None or < 2, always returns True. - - :param increment: Whether to increment the step count - :type increment: bool - :return: True if the invocation count is at the step count, False otherwise - :rtype: bool - """ - compare_step = self.global_step + 1 - at_step = ( - self.invocations_per_step is None - or self.invocations_per_step < 2 - or (compare_step % self.invocations_per_step == 0) - ) - - if increment: - self.global_step = compare_step - - return at_step - - @abstractmethod - def batch_start_events(self) -> List[Event]: - """Return the list of events to be called for the batch start.""" - raise NotImplementedError() - - @abstractmethod - def loss_calculated_events(self) -> List[Event]: - """Return the list of events to be called for the loss calculated.""" - raise NotImplementedError() - - @abstractmethod - def optim_pre_step_events(self) -> List[Event]: - """Return the list of events to be called for the optim pre step.""" - raise NotImplementedError() - - @abstractmethod - def optim_post_step_events(self) -> List[Event]: - """Return the list of events to be called for the optim post step.""" - raise NotImplementedError() - - @abstractmethod - def batch_end_events(self) -> List[Event]: - """Return the list of events to be called for the batch end.""" - raise NotImplementedError() diff --git a/src/llmcompressor/core/events/lifecycle_callbacks.py b/src/llmcompressor/core/events/lifecycle_callbacks.py deleted file mode 100644 index 6abad1c15..000000000 --- a/src/llmcompressor/core/events/lifecycle_callbacks.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Module for defining and managing callback event lifecycles in the LLM Compressor. - -This module provides a class for defining event lifecycles when callbacks -are used to communicate the state of the training pipeline. -""" - -from typing import List - -from loguru import logger - -from llmcompressor.core.events.event import Event, EventType -from llmcompressor.core.events.event_lifecycle import EventLifecycle - -__all__ = [ - "CallbacksEventLifecycle", -] - - -class CallbacksEventLifecycle(EventLifecycle): - """ - An event lifecycle for when callbacks are used to communicate the - state of the training pipeline. - - The expected lifecycle is as follows with optional gradient accumulation: - for gradient_batches in training loop: - for batch in gradient_batches or [gradient_batches]: - batch_start() -> [BATCH_START] - loss_calculated() -> [LOSS_CALCULATED] - - if not last batch: - batch_end -> [BATCH_END] - else: - optim_pre_step() -> [OPTIM_PRE_STEP] - optim_post_step() -> [OPTIM_POST_STEP] - batch_end -> [BATCH_END] - - Which gives the following logic: - - BATCH_START: must be called first or after OPTIM_POST_STEP - - LOSS_CALCULATED: must be called after BATCH_START and - before BATCH_END or OPTIM_POST_STEP - - OPTIM_PRE_STEP: must be called after LOSS_CALCULATED - and before OPTIM_POST_STEP - - OPTIM_POST_STEP: must be called after OPTIM_PRE_STEP and before BATCH_END - - BATCH_END: must be called after LOSS_CALCULATED or OPTIM_POST_STEP - """ - - def __init__(self, type_first: EventType, start: Event): - """ - Initialize the CallbacksEventLifecycle. - - :param type_first: The first event type to be called - :type type_first: EventType - :param start: The start event to base the lifecycle off of - :type start: Event - """ - super().__init__(type_first=type_first, start=start) - self.skip_post_step = False - - def batch_start_events(self) -> List[Event]: - """ - Return the list of events to be called for the batch start. - - :return: The list of events to be called for the batch start - :rtype: List[Event] - :raises ValueError: If batch start is not called first or if it - is not called after batch end - """ - if self.type_first != EventType.BATCH_START: - logger.error("batch start must be called first for callbacks") - raise ValueError("batch start must be called first for callbacks") - - if self.last_type not in {None, EventType.BATCH_END, EventType.OPTIM_POST_STEP}: - logger.error("batch start must be called after batch end") - raise ValueError("batch start must be called after batch end") - - self.last_type = EventType.BATCH_START - step_ready = self.check_batches_per_step_count(increment=True) - logger.debug( - "Batch start event processed with step_ready={}, " - "global_step={}, and global_batch={}", - step_ready, - self.global_step, - self.global_batch, - ) - - return [self.new_instance(type_=EventType.BATCH_START)] - - def loss_calculated_events(self) -> List[Event]: - """ - Return the list of events to be called for the loss calculated. - - :return: The list of events to be called for the loss calculated - :rtype: List[Event] - :raises ValueError: If loss calculated is not called after batch start - """ - if self.last_type != EventType.BATCH_START: - logger.error("loss calculated must be called after batch start") - raise ValueError("loss calculated must be called after batch start") - - self.last_type = EventType.LOSS_CALCULATED - logger.debug( - "Loss calculated event processed with global_batch={} and global_step={}", - self.global_batch, - self.global_step, - ) - - return [self.new_instance(type_=EventType.LOSS_CALCULATED)] - - def optim_pre_step_events(self) -> List[Event]: - """ - Return the list of events to be called for the optim pre step. - - :return: The list of events to be called for the optim pre step - :rtype: List[Event] - :raises ValueError: If optim pre step is not called after batch start - or loss calculated - """ - if self.last_type not in {EventType.LOSS_CALCULATED}: - logger.error("optim pre step must be called after loss calculated") - raise ValueError("optim pre step must be called after loss calculated") - - self.last_type = EventType.OPTIM_PRE_STEP - at_invocation = self.check_invocations_per_step_count(increment=True) - logger.debug( - "Optim pre step event processed with at_invocation={}, " - "global_step={}, and global_batch={}", - at_invocation, - self.global_step, - self.global_batch, - ) - - if not at_invocation: - self.skip_post_step = True - return [] - else: - self.skip_post_step = False - return [self.new_instance(type_=EventType.OPTIM_PRE_STEP)] - - def optim_post_step_events(self) -> List[Event]: - """ - Return the list of events to be called for the optim post step. - - :return: The list of events to be called for the optim post step - :rtype: List[Event] - :raises ValueError: If optim post step is not called after optim pre step - """ - if self.last_type != EventType.OPTIM_PRE_STEP: - logger.error("optim post step must be called after optim pre step") - raise ValueError("optim post step must be called after optim pre step") - - self.last_type = EventType.OPTIM_POST_STEP - logger.debug( - "Optim post step event processed with global_batch={} and global_step={}", - self.global_batch, - self.global_step, - ) - - if self.skip_post_step: - return [] - else: - return [self.new_instance(type_=EventType.OPTIM_POST_STEP)] - - def batch_end_events(self) -> List[Event]: - """ - Return the list of events to be called for the batch end. - - :return: The list of events to be called for the batch end - :rtype: List[Event] - :raises ValueError: If batch end is not called after optim post step, - loss calculated, or batch start - """ - if self.last_type not in { - EventType.OPTIM_POST_STEP, - EventType.LOSS_CALCULATED, - }: - logger.error( - "batch end must be called after loss calculated or optim post step" - ) - raise ValueError( - "batch end must be called after loss calculated or optim post step" - ) - - self.last_type = EventType.BATCH_END - logger.debug( - "Batch end event processed with global_batch={} and global_step={}", - self.global_batch, - self.global_step, - ) - - return [self.new_instance(type_=EventType.BATCH_END)] diff --git a/src/llmcompressor/core/events/lifecycle_optimizer.py b/src/llmcompressor/core/events/lifecycle_optimizer.py deleted file mode 100644 index d1df99f53..000000000 --- a/src/llmcompressor/core/events/lifecycle_optimizer.py +++ /dev/null @@ -1,266 +0,0 @@ -""" -Module for defining and managing optimizer event lifecycles in the LLM Compressor. - -This module provides a class for defining event lifecycles when the optimizer is wrapped -to invoke the event lifecycle and no callbacks are used. -""" - -from typing import List - -from loguru import logger - -from llmcompressor.core.events.event import Event, EventType -from llmcompressor.core.events.event_lifecycle import EventLifecycle - -__all__ = [ - "OptimizerEventLifecycle", -] - - -class OptimizerEventLifecycle(EventLifecycle): - """ - An event lifecycle for when the optimizer is wrapped to invoke the event lifecycle - and no callbacks are used. - - For all flows with the OptimizerEventLifecycle, the optimizer is wrapped - to trigger around the step function on the optimizer. - loss_calculated is optional, but if used it must be consistently - called before optimizer steps. - - The expected lifecycle is as follows with no gradient accumulation: - for batch in training loop: - if loss callbacks: - loss_calculated() -> [BATCH_START, LOSS_CALCULATED] - optim.step() -> [BATCH_START, OPTIM_PRE_STEP, - OPTIM_POST_STEP, BATCH_END] - else: - optim.step() -> [BATCH_START, OPTIM_PRE_STEP, - OPTIM_POST_STEP, BATCH_END] - For gradient accumulation: - for gradient_batches in training loop: - for batch in gradient_batches: - if loss callbacks: - if not last batch: - loss_calculated() -> [BATCH_START, LOSS_CALCULATED, BATCH_END] - else: - loss_calculated() -> [BATCH_START, LOSS_CALCULATED] - optim.step() -> [OPTIM_PRE_STEP, OPTIM_POST_STEP, BATCH_END] - else: - if last batch: - optim.step() -> [BATCH_START, OPTIM_PRE_STEP, - OPTIM_POST_STEP, BATCH_END] - For older amp scale flows that use invocations_per_step > 1: - for batch in training loop: - for invocation in range(invocations_per_step): - if not last invocation: - if loss callbacks: - loss_calculated() -> [BATCH_START, LOSS_CALCULATED] - optim.step() -> [BATCH_END] - else: - optim.step() -> [BATCH_START, BATCH_END] - else: - if loss callbacks: - loss_calculated() -> [BATCH_START, LOSS_CALCULATED] - optim.step() -> [OPTIM_PRE_STEP, OPTIM_POST_STEP, BATCH_END] - else: - optim.step() -> [BATCH_START, OPTIM_PRE_STEP, - OPTIM_POST_STEP, BATCH_END] - - - batch_start: must not be invoked, auto triggered - from loss calculated if that is called, otherwise from pre_step - - loss_calculated: optional pathway and invoked through callbacks. - It must be called as the first event if used, - and after optim post step or batch end for subsequent calls. - - batch_end: must not be invoked, auto triggered from optim_post_step - - optim_pre_step: must be called before optim_post_step - - optim_post_step: must be called only once after optim_pre_step - """ - - def __init__(self, type_first: EventType, start: Event): - """ - Initialize the OptimizerEventLifecycle with the first type and start event. - - :param type_first: The first event type to be called - :type type_first: EventType - :param start: The start event to base the lifecycle off of - :type start: Event - """ - super().__init__(type_first=type_first, start=start) - self.skip_post_step = False - - def batch_start_events(self) -> List[Event]: - """ - Raises a ValueError as this method should not be called. - - :raises ValueError: If invoked as this should not be called - """ - logger.error("batch start should not be invoked when only wrapped optim") - raise ValueError("batch start should not be invoked when only wrapped optim") - - def loss_calculated_events(self) -> List[Event]: - """ - Return the list of events to be called for the loss calculated. - - :return: The list of events to be called for the loss calculated - :rtype: List[Event] - :raises ValueError: If invoked before loss calculation - """ - if self.type_first != EventType.LOSS_CALCULATED: - logger.error("loss calculated must be called first for wrapped optim") - raise ValueError("loss calculated must be called first for wrapped optim") - - if self.last_type not in { - EventType.LOSS_CALCULATED, - EventType.OPTIM_POST_STEP, - None, - }: - logger.error( - "loss calculated must be called after batch end or optim post step" - ) - raise ValueError( - "loss calculated must be called after batch end or optim post step" - ) - - self.last_type = EventType.LOSS_CALCULATED - - if not self.check_batches_per_step_count(increment=True): - logger.debug( - "Loss calculated event processed, " - "step not ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - return [ - self.new_instance(type_=EventType.BATCH_START), - self.new_instance(type_=EventType.LOSS_CALCULATED), - self.new_instance(type_=EventType.BATCH_END), - ] - else: - logger.debug( - "Loss calculated event processed, " - "step ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - return [ - self.new_instance(type_=EventType.BATCH_START), - self.new_instance(type_=EventType.LOSS_CALCULATED), - ] - - def optim_pre_step_events(self) -> List[Event]: - """ - Return the list of events to be called for the optim pre step. - - :return: The list of events to be called for the optim pre step - :rtype: List[Event] - :raises ValueError: If optim pre step is not called before optim post step - or after loss calculated - """ - if self.type_first == EventType.LOSS_CALCULATED: - # handle loss calculated case where gradient accumulation - # is automatically handled by the loss callbacks - - if self.last_type != EventType.LOSS_CALCULATED: - logger.error("optim pre step must be called after loss calculated") - raise ValueError("optim pre step must be called after loss calculated") - - self.last_type = EventType.OPTIM_PRE_STEP - - if not self.check_invocations_per_step_count(increment=False): - logger.debug( - "Optim pre step event processed, " - "but invocations not ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - self.skip_post_step = True - return [] - else: - logger.debug( - "Optim pre step event processed, " - "invocations ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - self.skip_post_step = False - return [self.new_instance(type_=EventType.OPTIM_PRE_STEP)] - - # handle no callbacks case to emulate batch events for gradient accumulation - if self.last_type not in {EventType.OPTIM_POST_STEP, None}: - logger.error( - "optim pre step must be called at the start or after optim post step" - ) - raise ValueError( - "optim pre step must be called at the start or after optim post step" - ) - - batch_events = [ - self.new_instance(type_=EventType.BATCH_START), - ] - while not self.check_batches_per_step_count(increment=True): - batch_events.append(self.new_instance(type_=EventType.BATCH_END)) - batch_events.append(self.new_instance(type_=EventType.BATCH_START)) - - self.last_type = EventType.OPTIM_PRE_STEP - - if not self.check_invocations_per_step_count(increment=False): - logger.debug( - "Optim pre step event processed, " - "but invocations not ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - self.skip_post_step = True - return batch_events - else: - logger.debug( - "Optim pre step event processed, " - "invocations ready at global_step: {} and global_batch: {}", - self.global_step, - self.global_batch, - ) - self.skip_post_step = False - return batch_events + [self.new_instance(type_=EventType.OPTIM_PRE_STEP)] - - def optim_post_step_events(self) -> List[Event]: - """ - Return the list of events to be called for the optim post step. - - :return: The list of events to be called for the optim post step - :rtype: List[Event] - :raises ValueError: If optim post step is not called after optim pre step - """ - if self.last_type != EventType.OPTIM_PRE_STEP: - logger.error("optim post step must be called after optim pre step") - raise ValueError("optim post step must be called after optim pre step") - - self.last_type = EventType.OPTIM_POST_STEP - - if self.skip_post_step: - logger.debug( - "Skipping optim post step event at global_step: " - "{} and global_batch: {}", - self.global_step, - self.global_batch, - ) - return [self.new_instance(type_=EventType.BATCH_END)] - else: - logger.debug( - "Optim post step event processed at global_step: " - "{} and global_batch: {}", - self.global_step, - self.global_batch, - ) - return [ - self.new_instance(type_=EventType.OPTIM_POST_STEP), - self.new_instance(type_=EventType.BATCH_END), - ] - - def batch_end_events(self) -> List[Event]: - """ - Raises a ValueError as this method should not be called. - - :raises ValueError: If invoked as this should not be called - """ - logger.error("batch end should not be invoked when only wrapped optim") - raise ValueError("batch end should not be invoked when only wrapped optim") diff --git a/src/llmcompressor/core/lifecycle.py b/src/llmcompressor/core/lifecycle.py index e69882800..ff91c70c8 100644 --- a/src/llmcompressor/core/lifecycle.py +++ b/src/llmcompressor/core/lifecycle.py @@ -10,12 +10,7 @@ from loguru import logger -from llmcompressor.core.events import ( - CallbacksEventLifecycle, - EventLifecycle, - EventType, - OptimizerEventLifecycle, -) +from llmcompressor.core.events import Event, EventType from llmcompressor.core.state import State from llmcompressor.modifiers import StageModifiers from llmcompressor.recipe import ( @@ -39,18 +34,30 @@ class CompressionLifecycle: :type recipe_container: RecipeContainer :param modifiers: The list of stage modifiers :type modifiers: List[StageModifiers] - :param event_lifecycle: The event lifecycle manager - :type event_lifecycle: Optional[EventLifecycle] """ state: State = field(default_factory=State) recipe_container: RecipeContainer = field(default_factory=RecipeContainer) modifiers: List[StageModifiers] = field(default_factory=list) - event_lifecycle: Optional[EventLifecycle] = None initialized_: bool = False finalized: bool = False + # event order validation + _last_event_type: Optional[EventType] = EventType.BATCH_END + _event_order: List[EventType] = field( + default_factory=lambda: [ + EventType.BATCH_START, + EventType.LOSS_CALCULATED, + EventType.OPTIM_PRE_STEP, + EventType.OPTIM_POST_STEP, + EventType.BATCH_END, + ] + ) + + # track global step in training (could be epoch/batch) + global_step: int = 0 + def reset(self): """ Reset the compression lifecycle, finalizing any active modifiers @@ -142,7 +149,9 @@ def finalize(self, **kwargs) -> List[Any]: return mod_data - def event(self, event_type: EventType, **kwargs) -> List[Any]: + def event( + self, event_type: EventType, global_step: Optional[int] = 0, **kwargs + ) -> List[Any]: """ Handle a compression event. @@ -172,6 +181,12 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: f"Use the corresponding method instead." ) + if not self._validate_event_order(event_type): + raise ValueError( + f"Lifecycle events must appear following order: {self._event_order}. " + f"Instead, {self._last_event_type} was called before {event_type}" + ) + if event_type == EventType.LOSS_CALCULATED and ( "loss" not in kwargs or kwargs["loss"] is None ): @@ -179,84 +194,41 @@ def event(self, event_type: EventType, **kwargs) -> List[Any]: raise ValueError("Loss must be provided for loss calculated event") logger.debug("Handling event: {}", event_type) - self._check_setup_event_lifecycle(event_type) - event = None - mod_data = [] - for event in self.event_lifecycle.events_from_type(event_type): - if self.state.start_event is None: - self.state.start_event = event + # update global step + if global_step is not None: + self.global_step = global_step - for mod in self.modifiers: - data = mod.update_event(state=self.state, event=event, **kwargs) - logger.debug("Updated event with modifier: {}", mod) - if data is not None: - mod_data.append(data) + event = Event(type_=event_type) + mod_data = [] + for mod in self.modifiers: + data = mod.update_event(state=self.state, event=event, **kwargs) + logger.debug("Updated event with modifier: {}", mod) + if data is not None: + mod_data.append(data) assert ( event is not None ), f"Event lifecycle did not return an event for {event_type}" - self.state.last_event = event return mod_data - def _check_setup_event_lifecycle(self, event_type: EventType): - if self.event_lifecycle is not None: - return - - if ( - self.state is None - or self.state.model is None - or self.state.start_event is None - or self.recipe_container.compiled_recipe is None - ): - logger.error("Cannot invoke event before recipe, model, and start are set") - raise ValueError( - "Cannot invoke event before recipe, model, and start are set" - ) - - if not self.state.compression_ready: - logger.error("Cannot invoke event before recipe, model, and start are set") - raise ValueError( - "Cannot invoke event before recipe, model, and start are set" - ) + def _validate_event_order(self, event_type: EventType) -> bool: + if event_type not in self._event_order: + # for unhandled events, do not save last event + return True - logger.debug("Setting up event lifecycle for event type: {}", event_type) - - for mod in self.modifiers: - logger.debug("Checking if modifier is initialized: {}", mod) - mod.check_initialized() - - # first check for creation of a callbacks event lifecycle - # must start with BATCH_START event if event_type == EventType.BATCH_START: - self.event_lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=self.state.start_event - ) - elif ( - event_type == EventType.LOSS_CALCULATED - or event_type == EventType.OPTIM_PRE_STEP - ): - self.event_lifecycle = OptimizerEventLifecycle( - type_first=event_type, start=self.state.start_event - ) + valid = self._last_event_type != EventType.BATCH_START + else: - logger.error( - "Invalid event type for initializing event lifecycle: " - "{}. Must be BATCH_START, LOSS_CALCULATED, or OPTIM_PRE_STEP", - event_type, - ) - raise ValueError( - f"Invalid event type for initializing event lifecycle: " - f"{event_type}. Must be BATCH_START, LOSS_CALCULATED, or OPTIM_PRE_STEP" - ) + last_event_index = self._event_order.index(self._last_event_type) + curr_event_index = self._event_order.index(event_type) + valid = last_event_index <= curr_event_index - logger.info( - "Event lifecycle for compression lifecycle created: " - "{} with start event type: {}", - self.event_lifecycle, - event_type, - ) + if valid: + self._last_event_type = event_type + return valid def _set_model_layer_prefix(self): compiled_recipe = self.recipe_container.compiled_recipe diff --git a/src/llmcompressor/core/session.py b/src/llmcompressor/core/session.py index f028510bc..f94c51cee 100644 --- a/src/llmcompressor/core/session.py +++ b/src/llmcompressor/core/session.py @@ -223,57 +223,34 @@ def get_serialized_recipe(self) -> Optional[str]: def _log_model_info(self): # Log model level logs if cadence reached - event_lifecycle = self._lifecycle.event_lifecycle - if event_lifecycle is None: - # event lifecycle not available - # when recipe is not provided - return - - epoch = event_lifecycle.current_index + current_index = self._lifecycle.global_step if ( should_log_model_info( model=self.state.model, loggers=self.state.loggers, - current_log_step=epoch, + current_log_step=current_index, last_log_step=self.state._last_log_step, ) and self.state.loggers.frequency_manager.is_epoch_frequency_manager ): log_model_info( state=self.state, - current_log_step=epoch, + current_log_step=current_index, ) # update last log epoch - self.state.loggers.log_written(epoch) + self.state.loggers.log_written(current_index) def _log_loss(self, event_type: EventType, loss: Any): if event_type != EventType.LOSS_CALCULATED: # only log loss when loss is calculated return - event_lifecycle = self._lifecycle.event_lifecycle - if event_lifecycle is None: - # event lifecycle not available - # when recipe is not provided - return - - epoch = event_lifecycle.current_index - if self.state.loggers.frequency_manager.is_optim_frequency_manager: - # log integer step for optimizer frequency manager - current_step = int( - self.state.loggers.epoch_to_step( - epoch=epoch, - steps_per_epoch=len(self.state.data.train), - ) - ) - else: - # log float epoch for epoch frequency manager - current_step = epoch + current_index = self._lifecycle.global_step # always log loss if available if loss is not None: loss = loss if isinstance(loss, dict) else {"loss": loss} self.state.loggers.metric.log_scalars( - tag="Loss", values=loss, step=current_step + tag="Loss", values=loss, step=current_index ) diff --git a/src/llmcompressor/core/state.py b/src/llmcompressor/core/state.py index 23b150284..42e43e134 100644 --- a/src/llmcompressor/core/state.py +++ b/src/llmcompressor/core/state.py @@ -11,7 +11,6 @@ from loguru import logger -from llmcompressor.core.events import Event from llmcompressor.metrics import BaseLogger, LoggerManager __all__ = ["State", "Data", "Hardware", "ModifiedState"] @@ -94,10 +93,6 @@ class State: :type data: Data :param hardware: Hardware instance holding info about the target hardware being used :type hardware: Hardware - :param start_event: The start event to begin compression - :type start_event: Event - :param last_event: The last compression event that occurred - :type last_event: Event :param loggers: LoggerManager instance holding all the loggers to log :type loggers: Optional[LoggerManager] :param model_log_cadence: The cadence to log model information w.r.t epochs. @@ -113,8 +108,6 @@ class State: batch_data: Any = None data: Data = field(default_factory=Data) hardware: Hardware = field(default_factory=Hardware) - start_event: Optional[Event] = None - last_event: Optional[Event] = None loggers: Optional[LoggerManager] = None model_log_cadence: Optional[float] = None _last_log_step: Union[float, int, None] = None @@ -226,20 +219,6 @@ def update( if "device" in kwargs: self.hardware.device = kwargs["device"] - if ( - start is not None - or steps_per_epoch is not None - or batches_per_step is not None - ): - if self.start_event is None: - self.start_event = Event() - if start is not None: - self.start_event.current_index = start - if steps_per_epoch is not None: - self.start_event.steps_per_epoch = steps_per_epoch - if batches_per_step is not None: - self.start_event.batches_per_step = batches_per_step - loggers = loggers or [] if isinstance(loggers, list): loggers = LoggerManager(loggers) diff --git a/src/llmcompressor/modifiers/interface.py b/src/llmcompressor/modifiers/interface.py index f1c73c54b..da54dce47 100644 --- a/src/llmcompressor/modifiers/interface.py +++ b/src/llmcompressor/modifiers/interface.py @@ -27,14 +27,6 @@ def finalized(self) -> bool: """ raise NotImplementedError() - @abstractmethod - def check_initialized(self): - """ - Check if the modifier has been initialized and - raise an error if not - """ - raise NotImplementedError() - @abstractmethod def calculate_start(self) -> float: """ diff --git a/src/llmcompressor/modifiers/modifier.py b/src/llmcompressor/modifiers/modifier.py index 4092cc3de..38911b590 100644 --- a/src/llmcompressor/modifiers/modifier.py +++ b/src/llmcompressor/modifiers/modifier.py @@ -15,6 +15,13 @@ class Modifier(ModifierInterface, HooksMixin): Modifiers are used to modify the training process for a model. Defines base attributes and methods available to all modifiers + Lifecycle: + 1. initialize + 2. on_event -> + * on_start if self.start <= event.current_index + * on_end if self.end >= event.current_index + 5. finalize + :param index: The index of the modifier in the list of modifiers for the model :param group: The group name for the modifier @@ -48,13 +55,6 @@ def finalized(self) -> bool: """ return self.finalized_ - def check_initialized(self): - """ - :raises RuntimeError: if the modifier has not been initialized - """ - if not self.initialized_: - raise RuntimeError("modifier has not been initialized") - def calculate_start(self) -> float: """ Calculate and return the start epoch for the modifier. @@ -78,37 +78,21 @@ def initialize(self, state: State, **kwargs): :param kwargs: Additional arguments for initializing the modifier """ if self.initialized_: - return + raise RuntimeError( + "Cannot initialize a modifier that has already been initialized" + ) if self.finalized_: - raise RuntimeError("cannot initialize a finalized modifier") - - if state.start_event is None: - return - - # ignore modifier structure initialized from one-shot - if state.start_event.current_index >= 0 and self.calculate_start() < 0: - return - - # if modifier should have ended by current index, don't initialize - if ( - self.calculate_end() >= 0 - and state.start_event.current_index >= self.calculate_end() - ): - return - - initialized = self.on_initialize(state=state, **kwargs) - - if not isinstance(initialized, bool): - raise ValueError( - "on_initialize must return a boolean value; " - "True for success, False for not initialized" + raise RuntimeError( + "Cannot initialize a modifier that has already been finalized" ) - self.initialized_ = initialized + self.initialized_ = self.on_initialize(state=state, **kwargs) - if self.should_start(state.start_event): - self.on_start(state, state.start_event, **kwargs) + # trigger start + fake_start_event = Event(type_=EventType.BATCH_START, global_step=0) + if self.should_start(fake_start_event): + self.on_start(state, fake_start_event, **kwargs) self.started_ = True def finalize(self, state: State, **kwargs): @@ -125,15 +109,8 @@ def finalize(self, state: State, **kwargs): if not self.initialized_: raise RuntimeError("cannot finalize an uninitialized modifier") - finalized = self.on_finalize(state=state, **kwargs) - - if not isinstance(finalized, bool): - raise ValueError( - "on_finalize must return a boolean value; " - "True for success, False for not finalized" - ) - - self.finalized_ = finalized + # TODO: all finalization should succeed + self.finalized_ = self.on_finalize(state=state, **kwargs) def update_event(self, state: State, event: Event, **kwargs): """ @@ -148,10 +125,10 @@ def update_event(self, state: State, event: Event, **kwargs): :param kwargs: Additional arguments for updating the modifier """ if not self.initialized_: - return + raise RuntimeError("Cannot update an uninitialized modifier") if self.finalized_: - raise RuntimeError("cannot update a finalized modifier") + raise RuntimeError("Cannot update a finalized modifier") self.on_event(state, event, **kwargs) diff --git a/src/llmcompressor/modifiers/pruning/constant/base.py b/src/llmcompressor/modifiers/pruning/constant/base.py index 9ac07168c..929ee5a5d 100644 --- a/src/llmcompressor/modifiers/pruning/constant/base.py +++ b/src/llmcompressor/modifiers/pruning/constant/base.py @@ -26,7 +26,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: if "use_hooks" in kwargs: self._use_hooks = kwargs["use_hooks"] - if not state.model or not state.start_event: + if not state.model: return False self.parameterized_layers_ = get_layers_params(self.targets, state.model) diff --git a/src/llmcompressor/modifiers/pruning/magnitude/base.py b/src/llmcompressor/modifiers/pruning/magnitude/base.py index fb0fa1817..1a218d0e3 100644 --- a/src/llmcompressor/modifiers/pruning/magnitude/base.py +++ b/src/llmcompressor/modifiers/pruning/magnitude/base.py @@ -55,7 +55,7 @@ def on_initialize(self, state: State, **kwargs) -> bool: if "use_hooks" in kwargs: self._use_hooks = kwargs["use_hooks"] - if not state.model or not state.start_event: + if not state.model: return False self.scheduler_function_ = PruningSchedulerFactory.create_scheduler( diff --git a/src/llmcompressor/modifiers/stage.py b/src/llmcompressor/modifiers/stage.py index fe773bcb5..75a11ffc5 100644 --- a/src/llmcompressor/modifiers/stage.py +++ b/src/llmcompressor/modifiers/stage.py @@ -1,6 +1,5 @@ from typing import List, Optional -from loguru import logger from pydantic import BaseModel, Field from llmcompressor.core.events import Event @@ -50,24 +49,6 @@ def unique_id(self) -> str: """ return self.group + "_" + str(self.index) - def check_initialized(self): - """ - Check if all of the stage modifiers have been initialized, and log a warning - if not. This warning is expected when loading an input recipe during finetuning - """ - - at_least_one_initialized = False - for modifier in self.modifiers: - if modifier.initialized: - at_least_one_initialized = True - if not at_least_one_initialized: - modifier_names = [type(mod).__name__ for mod in self.modifiers] - logger.warning( - f"Found no initialized modifiers in stage {self.group}. " - "Found the following uninitialized modifiers: " - f"{modifier_names}" - ) - def calculate_start(self) -> float: """ :return: The minimum start time of all the stage modifiers diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 20d9ae510..67eac59b4 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -258,7 +258,7 @@ def training_step( """ self._check_super_defined("training_step") - callbacks.batch_start(batch_data=inputs) + callbacks.batch_start(batch_data=inputs, global_step=self.state.epoch) model_outputs = super().training_step( model=model, inputs=inputs, num_items_in_batch=num_items_in_batch ) diff --git a/tests/llmcompressor/modifiers/conf.py b/tests/llmcompressor/modifiers/conf.py index 0a910788c..19050e0c0 100644 --- a/tests/llmcompressor/modifiers/conf.py +++ b/tests/llmcompressor/modifiers/conf.py @@ -2,9 +2,7 @@ from torch.utils.data import DataLoader -from llmcompressor.core import State -from llmcompressor.core.events import EventType -from llmcompressor.core.lifecycle import CallbacksEventLifecycle +from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers.factory import ModifierFactory @@ -31,14 +29,9 @@ def __init__( calib_data=DataLoader(MagicMock(__len__=lambda _: 0, column_names=[])), ) - self.event_lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=self.state.start_event - ) - def update_modifier(self, modifier, event_type): - events = self.event_lifecycle.events_from_type(event_type) - for event in events: - modifier.update_event(self.state, event=event) + event = Event(event_type=event_type) + modifier.update_event(self.state, event=event) def get_state(self): return self.state diff --git a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py index 0c5ad534d..ab63a5414 100644 --- a/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/pruning/sparsegpt/test_pytorch.py @@ -53,7 +53,7 @@ def test_successful_layerwise_recipe(self): sparsity=sparsities, block_size=128, targets=targets ) testing_harness = LifecyleTestingHarness(model=LinearNet(), start=-1) - modifier.initialize(testing_harness.get_state()) + modifier.initialize(testing_harness.get_state()) # falls back to basic pipeline model = testing_harness.state.model num_hooks = len(modifier._hooks) diff --git a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml index 0d3c8bad5..d00b36085 100644 --- a/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml +++ b/tests/llmcompressor/transformers/compression/recipes/sparse_24_fp8.yaml @@ -23,16 +23,4 @@ quant_stage: strategy: token dynamic: true symmetric: true - targets: ["Linear"] - pruning_modifiers: - ConstantPruningModifier: - targets: [ - 're:.*q_proj.weight', - 're:.*k_proj.weight', - 're:.*v_proj.weight', - 're:.*o_proj.weight', - 're:.*gate_proj.weight', - 're:.*up_proj.weight', - 're:.*down_proj.weight', - ] - start: 0 \ No newline at end of file + targets: ["Linear"] \ No newline at end of file diff --git a/tests/unit/core/events/test_event_lifecycle.py b/tests/unit/core/events/test_event_lifecycle.py deleted file mode 100644 index 7a154d60e..000000000 --- a/tests/unit/core/events/test_event_lifecycle.py +++ /dev/null @@ -1,179 +0,0 @@ -from typing import List - -import pytest - -from llmcompressor.core import Event, EventLifecycle, EventType - - -class DummyEventLifecycle(EventLifecycle): - def batch_start_events(self) -> List[Event]: - return [Event(type_=EventType.BATCH_START)] - - def loss_calculated_events(self) -> List[Event]: - return [Event(type_=EventType.LOSS_CALCULATED)] - - def optim_pre_step_events(self) -> List[Event]: - return [Event(type_=EventType.OPTIM_PRE_STEP)] - - def optim_post_step_events(self) -> List[Event]: - return [Event(type_=EventType.OPTIM_POST_STEP)] - - def batch_end_events(self) -> List[Event]: - return [Event(type_=EventType.BATCH_END)] - - -@pytest.mark.smoke -def test_event_lifecycle_initialization(): - lifecycle = DummyEventLifecycle( - type_first=EventType.BATCH_START, - start=Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ), - ) - - assert lifecycle.type_first == EventType.BATCH_START - assert lifecycle.last_type is None - assert lifecycle.steps_per_epoch == 10 - assert lifecycle.batches_per_step == 2 - assert lifecycle.invocations_per_step == 1 - assert lifecycle.global_step == 0 - assert lifecycle.global_batch == 0 - - -@pytest.mark.smoke -def test_check_step_batches_count(): - lifecycle = DummyEventLifecycle( - type_first=EventType.BATCH_START, - start=Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ), - ) - - assert lifecycle.check_batches_per_step_count(increment=True) is False - assert lifecycle.global_batch == 1 - - assert lifecycle.check_batches_per_step_count(increment=False) is True - assert lifecycle.global_batch == 1 - - assert lifecycle.check_batches_per_step_count(increment=True) is True - assert lifecycle.global_batch == 2 - - -@pytest.mark.smoke -def test_check_default_step_invocations_count(): - lifecycle = DummyEventLifecycle( - type_first=EventType.BATCH_START, - start=Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ), - ) - - assert lifecycle.check_invocations_per_step_count(increment=True) is True - assert lifecycle.global_step == 1 - - assert lifecycle.check_invocations_per_step_count(increment=False) is True - assert lifecycle.global_step == 1 - - assert lifecycle.check_invocations_per_step_count(increment=True) is True - assert lifecycle.global_step == 2 - - -@pytest.mark.smoke -def test_events_from_type(): - lifecycle = DummyEventLifecycle( - type_first=EventType.BATCH_START, - start=Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ), - ) - - events = lifecycle.events_from_type(EventType.BATCH_START) - assert len(events) == 1 - assert events[0].type_ == EventType.BATCH_START - - events = lifecycle.events_from_type(EventType.LOSS_CALCULATED) - assert len(events) == 1 - assert events[0].type_ == EventType.LOSS_CALCULATED - - events = lifecycle.events_from_type(EventType.OPTIM_PRE_STEP) - assert len(events) == 1 - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.events_from_type(EventType.OPTIM_POST_STEP) - assert len(events) == 1 - assert events[0].type_ == EventType.OPTIM_POST_STEP - - events = lifecycle.events_from_type(EventType.BATCH_END) - assert len(events) == 1 - assert events[0].type_ == EventType.BATCH_END - - -@pytest.mark.regression -def test_check_step_invocations_count(): - lifecycle = DummyEventLifecycle( - type_first=EventType.BATCH_START, - start=Event( - steps_per_epoch=10, - batches_per_step=2, - global_step=0, - global_batch=0, - invocations_per_step=2, - ), - ) - - assert lifecycle.check_invocations_per_step_count(increment=True) is False - assert lifecycle.global_step == 1 - - assert lifecycle.check_invocations_per_step_count(increment=False) is True - assert lifecycle.global_step == 1 - - assert lifecycle.check_invocations_per_step_count(increment=True) is True - assert lifecycle.global_step == 2 - - -@pytest.mark.regression -def test_not_implemented_errors(): - with pytest.raises(TypeError): - - class TestIncompleteEventLifecycle(EventLifecycle): - pass - - TestIncompleteEventLifecycle(None, None) - - class IncompleteEventLifecycle(EventLifecycle): - def batch_start_events(self) -> List[Event]: - return super().batch_start_events() - - def loss_calculated_events(self) -> List[Event]: - return super().loss_calculated_events() - - def optim_pre_step_events(self) -> List[Event]: - return super().optim_pre_step_events() - - def optim_post_step_events(self) -> List[Event]: - return super().optim_post_step_events() - - def batch_end_events(self) -> List[Event]: - return super().batch_end_events() - - start_event = Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ) - lifecycle = IncompleteEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - with pytest.raises(NotImplementedError): - lifecycle.batch_start_events() - - with pytest.raises(NotImplementedError): - lifecycle.loss_calculated_events() - - with pytest.raises(NotImplementedError): - lifecycle.optim_pre_step_events() - - with pytest.raises(NotImplementedError): - lifecycle.optim_post_step_events() - - with pytest.raises(NotImplementedError): - lifecycle.batch_end_events() diff --git a/tests/unit/core/events/test_lifecycle_callbacks.py b/tests/unit/core/events/test_lifecycle_callbacks.py deleted file mode 100644 index 884414015..000000000 --- a/tests/unit/core/events/test_lifecycle_callbacks.py +++ /dev/null @@ -1,201 +0,0 @@ -import pytest - -from llmcompressor.core.events.event import Event, EventType -from llmcompressor.core.events.lifecycle_callbacks import CallbacksEventLifecycle - - -@pytest.mark.smoke -def test_initialization(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - assert lifecycle.type_first == EventType.BATCH_START - assert lifecycle.steps_per_epoch == 10 - assert lifecycle.batches_per_step == 1 - assert lifecycle.global_step == 0 - assert lifecycle.global_batch == 0 - - -@pytest.mark.regression -def test_lifecycle(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - events = lifecycle.batch_start_events() - assert events[0].type_ == EventType.BATCH_START - - events = lifecycle.loss_calculated_events() - assert events[0].type_ == EventType.LOSS_CALCULATED - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - - events = lifecycle.batch_end_events() - assert events[0].type_ == EventType.BATCH_END - - -@pytest.mark.sanity -def test_batch_start_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - events = lifecycle.batch_start_events() - assert len(events) == 1 - assert events[0].type_ == EventType.BATCH_START - - with pytest.raises(ValueError): - lifecycle.batch_start_events() - - -@pytest.mark.sanity -def test_loss_calculated_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - lifecycle.batch_start_events() - events = lifecycle.loss_calculated_events() - assert len(events) == 1 - assert events[0].type_ == EventType.LOSS_CALCULATED - - with pytest.raises(ValueError): - lifecycle.loss_calculated_events() - - -@pytest.mark.sanity -def test_optim_pre_step_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - lifecycle.batch_start_events() - lifecycle.loss_calculated_events() - events = lifecycle.optim_pre_step_events() - assert len(events) == 1 - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - with pytest.raises(ValueError): - lifecycle.optim_pre_step_events() - - -@pytest.mark.sanity -def test_optim_post_step_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - lifecycle.batch_start_events() - lifecycle.loss_calculated_events() - lifecycle.optim_pre_step_events() - events = lifecycle.optim_post_step_events() - assert len(events) == 1 - assert events[0].type_ == EventType.OPTIM_POST_STEP - - with pytest.raises(ValueError): - lifecycle.optim_post_step_events() - - -@pytest.mark.sanity -def test_batch_end_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - lifecycle.batch_start_events() - lifecycle.loss_calculated_events() - lifecycle.optim_pre_step_events() - lifecycle.optim_post_step_events() - events = lifecycle.batch_end_events() - assert len(events) == 1 - assert events[0].type_ == EventType.BATCH_END - - with pytest.raises(ValueError): - lifecycle.batch_end_events() - - -@pytest.mark.regression -def test_lifecycle_gradient_accumulation(): - start_event = Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - for index in range(2): - events = lifecycle.batch_start_events() - assert events[0].type_ == EventType.BATCH_START - - events = lifecycle.loss_calculated_events() - assert events[0].type_ == EventType.LOSS_CALCULATED - - if index == 0: - events = lifecycle.batch_end_events() - assert events[0].type_ == EventType.BATCH_END - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - - events = lifecycle.batch_end_events() - assert events[0].type_ == EventType.BATCH_END - - -@pytest.mark.regression -def test_invalid_event_order(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = CallbacksEventLifecycle( - type_first=EventType.BATCH_START, start=start_event - ) - - with pytest.raises(ValueError): - lifecycle.loss_calculated_events() - - lifecycle.batch_start_events() - - with pytest.raises(ValueError): - lifecycle.optim_pre_step_events() - - lifecycle.loss_calculated_events() - - with pytest.raises(ValueError): - lifecycle.optim_post_step_events() - - lifecycle.optim_pre_step_events() - - with pytest.raises(ValueError): - lifecycle.batch_end_events() - - lifecycle.optim_post_step_events() - lifecycle.batch_end_events() diff --git a/tests/unit/core/events/test_lifecycle_optimizer.py b/tests/unit/core/events/test_lifecycle_optimizer.py deleted file mode 100644 index 23c10867a..000000000 --- a/tests/unit/core/events/test_lifecycle_optimizer.py +++ /dev/null @@ -1,186 +0,0 @@ -import pytest - -from llmcompressor.core.events.event import Event, EventType -from llmcompressor.core.events.lifecycle_optimizer import OptimizerEventLifecycle - - -@pytest.mark.smoke -def test_initialization(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - assert lifecycle.type_first == EventType.LOSS_CALCULATED - assert lifecycle.steps_per_epoch == 10 - assert lifecycle.batches_per_step == 1 - assert lifecycle.global_step == 0 - assert lifecycle.global_batch == 0 - - -@pytest.mark.smoke -def test_lifecycle(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.OPTIM_PRE_STEP, start=start_event - ) - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.BATCH_START - assert events[1].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - assert events[1].type_ == EventType.BATCH_END - - -@pytest.mark.smoke -def test_lifecycle_with_loss(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - events = lifecycle.loss_calculated_events() - assert events[0].type_ == EventType.BATCH_START - assert events[1].type_ == EventType.LOSS_CALCULATED - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - assert events[1].type_ == EventType.BATCH_END - - -@pytest.mark.sanity -def test_loss_calculated_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - events = lifecycle.loss_calculated_events() - assert len(events) == 2 - assert events[0].type_ == EventType.BATCH_START - assert events[1].type_ == EventType.LOSS_CALCULATED - - -@pytest.mark.sanity -def test_optim_pre_step_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - lifecycle.loss_calculated_events() - events = lifecycle.optim_pre_step_events() - assert len(events) == 1 - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - with pytest.raises(ValueError): - lifecycle.optim_pre_step_events() - - -@pytest.mark.sanity -def test_optim_post_step_events(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - lifecycle.loss_calculated_events() - lifecycle.optim_pre_step_events() - events = lifecycle.optim_post_step_events() - assert len(events) == 2 - assert events[0].type_ == EventType.OPTIM_POST_STEP - assert events[1].type_ == EventType.BATCH_END - - with pytest.raises(ValueError): - lifecycle.optim_post_step_events() - - -@pytest.mark.regression -def test_lifecycle_gradient_accumulation(): - start_event = Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.OPTIM_PRE_STEP, start=start_event - ) - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.BATCH_START - assert events[1].type_ == EventType.BATCH_END - assert events[2].type_ == EventType.BATCH_START - assert events[3].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - assert events[1].type_ == EventType.BATCH_END - - -@pytest.mark.regression -def test_lifecycle_gradient_accumulation_loss_calculated(): - start_event = Event( - steps_per_epoch=10, batches_per_step=2, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - for index in range(2): - events = lifecycle.loss_calculated_events() - - assert events[0].type_ == EventType.BATCH_START - assert events[1].type_ == EventType.LOSS_CALCULATED - - if index == 0: - assert events[2].type_ == EventType.BATCH_END - - events = lifecycle.optim_pre_step_events() - assert events[0].type_ == EventType.OPTIM_PRE_STEP - - events = lifecycle.optim_post_step_events() - assert events[0].type_ == EventType.OPTIM_POST_STEP - assert events[1].type_ == EventType.BATCH_END - - -@pytest.mark.regression -def test_invalid_event_order(): - start_event = Event( - steps_per_epoch=10, batches_per_step=1, global_step=0, global_batch=0 - ) - lifecycle = OptimizerEventLifecycle( - type_first=EventType.LOSS_CALCULATED, start=start_event - ) - - with pytest.raises(ValueError): - lifecycle.optim_pre_step_events() - - events = lifecycle.loss_calculated_events() - assert events - - with pytest.raises(ValueError): - lifecycle.optim_post_step_events() - - events = lifecycle.optim_pre_step_events() - assert events - - with pytest.raises(ValueError): - lifecycle.batch_end_events() - - events = lifecycle.optim_post_step_events() - assert events diff --git a/tests/unit/core/test_state.py b/tests/unit/core/test_state.py index 3f7f992dc..203e2a664 100644 --- a/tests/unit/core/test_state.py +++ b/tests/unit/core/test_state.py @@ -15,8 +15,6 @@ def test_state_initialization(): assert state.batch_data is None assert state.data == Data() assert state.hardware == Hardware() - assert state.start_event is None - assert state.last_event is None assert state.loggers is None assert state.model_log_cadence is None assert state._last_log_step is None @@ -62,8 +60,6 @@ def test_state_update(): assert state.data.test == "new_test_data" assert state.data.calib == "new_calib_data" assert state.hardware.device == "cpu" - assert state.start_event.current_index == 1.0 - assert state.start_event.batches_per_step == 10 assert state.model_log_cadence == 2