Skip to content

Commit f722ab4

Browse files
committed
calculate total_steps_per_epoch earlier, remove compression_ready
Signed-off-by: Kyle Sayers <[email protected]>
1 parent f15df69 commit f722ab4

File tree

4 files changed

+20
-51
lines changed

4 files changed

+20
-51
lines changed

src/llmcompressor/core/lifecycle.py

-6
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,6 @@ def _check_setup_event_lifecycle(self, event_type: EventType):
218218
"Cannot invoke event before recipe, model, and start are set"
219219
)
220220

221-
if not self.state.compression_ready:
222-
logger.error("Cannot invoke event before recipe, model, and start are set")
223-
raise ValueError(
224-
"Cannot invoke event before recipe, model, and start are set"
225-
)
226-
227221
logger.debug("Setting up event lifecycle for event type: {}", event_type)
228222

229223
for mod in self.modifiers:

src/llmcompressor/core/state.py

-12
Original file line numberDiff line numberDiff line change
@@ -119,18 +119,6 @@ class State:
119119
model_log_cadence: Optional[float] = None
120120
_last_log_step: Union[float, int, None] = None
121121

122-
@property
123-
def compression_ready(self) -> bool:
124-
"""
125-
Check if the model and optimizer are set for compression.
126-
127-
:return: True if model and optimizer are set, False otherwise
128-
:rtype: bool
129-
"""
130-
ready = self.model is not None and self.optimizer is not None
131-
logger.debug("Compression ready: {}", ready)
132-
return ready
133-
134122
def update(
135123
self,
136124
model: Any = None,

src/llmcompressor/transformers/finetune/session_mixin.py

+20-23
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,25 @@ def initialize_session(
149149

150150
train_data = self.get_train_dataloader()
151151

152+
# calculate total_steps_per_epoch
153+
# n_gpu handled internally by dataloader
154+
total_batch_size = (
155+
self.args.per_device_train_batch_size
156+
* self.args.gradient_accumulation_steps
157+
)
158+
if isinstance(self.train_dataset, IterableDataset):
159+
logger.warning(
160+
"Training is being run with a streamed dataset, "
161+
"steps_per_epoch cannot be determined and will default to "
162+
"1. LLM Compressor modifiers utilizing this statistic may not "
163+
"behave as expected. "
164+
)
165+
self.total_steps_per_epoch = 1
166+
else:
167+
self.total_steps_per_epoch = math.ceil(
168+
len(self.train_dataset) / total_batch_size
169+
)
170+
152171
self.accelerator.wait_for_everyone()
153172
with summon_full_params_context(self.model, offload_to_cpu=True):
154173
initialize(
@@ -161,6 +180,7 @@ def initialize_session(
161180
start=epoch,
162181
copy_data=False,
163182
fsdp_active=self.is_fsdp_enabled,
183+
steps_per_epoch=self.total_steps_per_epoch,
164184
metadata=self.metadata,
165185
)
166186
self.accelerator.wait_for_everyone()
@@ -203,29 +223,6 @@ def create_optimizer(self):
203223
self._check_super_defined("create_optimizer")
204224
super().create_optimizer()
205225

206-
# n_gpu handled internally by dataloader
207-
total_batch_size = (
208-
self.args.per_device_train_batch_size
209-
* self.args.gradient_accumulation_steps
210-
)
211-
212-
if isinstance(self.train_dataset, IterableDataset):
213-
logger.warning(
214-
"Training is being run with a streamed dataset, "
215-
"steps_per_epoch cannot be determined and will default to "
216-
"1. LLM Compressor modifiers utilizing this statistic may not "
217-
"behave as expected. "
218-
)
219-
self.total_steps_per_epoch = 1
220-
else:
221-
self.total_steps_per_epoch = math.ceil(
222-
len(self.train_dataset) / total_batch_size
223-
)
224-
225-
active_session().state.update(
226-
optimizer=self.optimizer, steps_per_epoch=self.total_steps_per_epoch
227-
)
228-
229226
return self.optimizer
230227

231228
def create_scheduler(

tests/unit/core/test_state.py

-10
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,6 @@ def test_state_update():
6767
assert state.model_log_cadence == 2
6868

6969

70-
@pytest.mark.regression
71-
def test_state_sparsification_ready():
72-
state = State()
73-
assert not state.compression_ready
74-
75-
state.model = "model"
76-
state.optimizer = "optimizer"
77-
assert state.compression_ready
78-
79-
8070
@pytest.mark.regression
8171
def test_state_update_loggers():
8272
state = State()

0 commit comments

Comments
 (0)