@@ -143,25 +143,6 @@ def initialize_session(
143
143
144
144
train_data = self .get_train_dataloader ()
145
145
146
- # calculate total_steps_per_epoch
147
- # n_gpu handled internally by dataloader
148
- total_batch_size = (
149
- self .args .per_device_train_batch_size
150
- * self .args .gradient_accumulation_steps
151
- )
152
- if isinstance (self .train_dataset , IterableDataset ):
153
- logger .warning (
154
- "Training is being run with a streamed dataset, "
155
- "steps_per_epoch cannot be determined and will default to "
156
- "1. LLM Compressor modifiers utilizing this statistic may not "
157
- "behave as expected. "
158
- )
159
- self .total_steps_per_epoch = 1
160
- else :
161
- self .total_steps_per_epoch = math .ceil (
162
- len (self .train_dataset ) / total_batch_size
163
- )
164
-
165
146
self .accelerator .wait_for_everyone ()
166
147
with summon_full_params_context (self .model , offload_to_cpu = True ):
167
148
active_session ().initialize (
@@ -175,7 +156,6 @@ def initialize_session(
175
156
copy_data = False ,
176
157
attach_optim_callbacks = True ,
177
158
fsdp_active = self .is_fsdp_enabled ,
178
- steps_per_epoch = self .total_steps_per_epoch ,
179
159
metadata = self .metadata ,
180
160
)
181
161
@@ -219,6 +199,29 @@ def create_optimizer(self):
219
199
self ._check_super_defined ("create_optimizer" )
220
200
super ().create_optimizer ()
221
201
202
+ # n_gpu handled internally by dataloader
203
+ total_batch_size = (
204
+ self .args .per_device_train_batch_size
205
+ * self .args .gradient_accumulation_steps
206
+ )
207
+
208
+ if isinstance (self .train_dataset , IterableDataset ):
209
+ logger .warning (
210
+ "Training is being run with a streamed dataset, "
211
+ "steps_per_epoch cannot be determined and will default to "
212
+ "1. LLM Compressor modifiers utilizing this statistic may not "
213
+ "behave as expected. "
214
+ )
215
+ self .total_steps_per_epoch = 1
216
+ else :
217
+ self .total_steps_per_epoch = math .ceil (
218
+ len (self .train_dataset ) / total_batch_size
219
+ )
220
+
221
+ active_session ().initialize (
222
+ optimizer = self .optimizer , steps_per_epoch = self .total_steps_per_epoch
223
+ )
224
+
222
225
return self .optimizer
223
226
224
227
def create_scheduler (
@@ -255,7 +258,7 @@ def training_step(
255
258
"""
256
259
self ._check_super_defined ("training_step" )
257
260
258
- callbacks .batch_start (batch_data = inputs )
261
+ callbacks .batch_start (batch_data = inputs , global_step = self . state . epoch )
259
262
model_outputs = super ().training_step (
260
263
model = model , inputs = inputs , num_items_in_batch = num_items_in_batch
261
264
)
0 commit comments