12
12
import mlx .nn as nn
13
13
import numpy as np
14
14
from mlx .nn .utils import average_gradients
15
- from mlx .utils import tree_flatten
15
+ from mlx .utils import tree_flatten , tree_map
16
16
from transformers import PreTrainedTokenizer
17
17
18
18
from .datasets import CacheDataset
@@ -80,11 +80,11 @@ class TrainingArgs:
80
80
)
81
81
82
82
83
- def default_loss (model , batch , lengths , cache ):
83
+ def default_loss (model , batch , lengths , cache = None ):
84
84
inputs = batch [:, :- 1 ]
85
85
targets = batch [:, 1 :]
86
86
87
- offset = cache [0 ].offset
87
+ offset = cache [0 ].offset if cache is not None else 0
88
88
logits = model (inputs , cache = cache )
89
89
logits = logits .astype (mx .float32 )
90
90
@@ -184,6 +184,7 @@ def evaluate(
184
184
185
185
seq_step_size = seq_step_size or max_seq_length
186
186
187
+ cache = make_prompt_cache (model )
187
188
for _ , batch in zip (
188
189
index_iterator ,
189
190
iterate_batches (
@@ -193,13 +194,14 @@ def evaluate(
193
194
max_seq_length = max_seq_length ,
194
195
),
195
196
):
196
- cache = make_prompt_cache (model )
197
197
seq_length = batch [0 ].shape [1 ]
198
198
for s in range (0 , seq_length , seq_step_size ):
199
199
local_batch = (batch [0 ][:, s :s + seq_step_size ], batch [1 ])
200
200
losses , toks = loss (model , * local_batch , cache )
201
201
all_losses += losses * toks
202
202
ntokens += toks
203
+ if s + seq_step_size >= seq_length :
204
+ reset_prompt_cache (cache )
203
205
mx .eval (all_losses , ntokens )
204
206
205
207
all_losses = mx .distributed .all_sum (all_losses , stream = mx .cpu )
@@ -241,13 +243,14 @@ def train(
241
243
if args .grad_checkpoint :
242
244
grad_checkpoint (model .layers [0 ])
243
245
246
+ seq_step_size = args .seq_step_size or args .max_seq_length
244
247
cache = make_prompt_cache (model )
245
248
state = [model .state , optimizer .state , mx .random .state ]
246
249
247
250
@partial (mx .compile , inputs = state , outputs = state )
248
251
def step (batch ):
249
252
# Forward and backward pass
250
- (lvalue , toks ), grad = loss_value_and_grad (model , * batch , cache )
253
+ (lvalue , toks ), grad = loss_value_and_grad (model , * batch )
251
254
252
255
# All reduce the gradients if running in distributed mode
253
256
grad = average_gradients (grad )
@@ -264,6 +267,38 @@ def step(batch):
264
267
265
268
model .train ()
266
269
seq_step_size = args .seq_step_size or args .max_seq_length
270
+ def seq_split_step (batch ):
271
+ losses = mx .array (0.0 )
272
+ n_tokens = mx .array (0.0 )
273
+ seq_length = batch [0 ].shape [1 ]
274
+ grad_accum = None
275
+ for s in range (0 , seq_length , seq_step_size ):
276
+ local_batch = (batch [0 ][:, s :s + seq_step_size ], batch [1 ])
277
+ (lvalue , toks ), grad = loss_value_and_grad (model , * local_batch , cache )
278
+ prev_n_tokens = n_tokens
279
+ losses += toks * lvalue
280
+ n_tokens += toks
281
+
282
+ if grad_accum is None :
283
+ grad_accum = grad
284
+ else :
285
+ scale_g = toks / n_tokens
286
+ scale_acc = prev_n_tokens / n_tokens
287
+ grad_accum = tree_map (lambda g , acc : scale_g * g + scale_acc * acc , grad , grad_accum )
288
+
289
+
290
+ # Let go of the prompt cache before the last eval
291
+ if s + seq_step_size >= seq_length :
292
+ reset_prompt_cache (cache )
293
+ mx .eval (grad_accum , losses , n_tokens )
294
+
295
+ grad_accum = average_gradients (grad_accum )
296
+ optimizer .update (model , grad_accum )
297
+ return losses / n_tokens , n_tokens
298
+
299
+ loss_value_and_grad = nn .value_and_grad (model , loss )
300
+
301
+ > >> >> >> 568 a8d6 (use gradient accumulation )
267
302
losses = 0
268
303
n_tokens = 0
269
304
steps = 0
@@ -317,15 +352,16 @@ def step(batch):
317
352
318
353
tic = time .perf_counter ()
319
354
320
- seq_length = batch [0 ].shape [1 ]
321
- for s in range (0 , seq_length , seq_step_size ):
322
- local_batch = (batch [0 ][:, s :s + seq_step_size ], batch [1 ])
323
- lvalue , toks = step (local_batch )
324
- losses += lvalue
325
- n_tokens += toks
326
- steps += 1
327
- mx .eval (state , losses , n_tokens )
328
- reset_prompt_cache (cache )
355
+ if batch [0 ].shape [1 ] > seq_step_size :
356
+ lvalue , toks = seq_split_step (batch )
357
+ else :
358
+ lvalue , toks = step (batch )
359
+
360
+ losses += lvalue
361
+ n_tokens += toks
362
+ steps += 1
363
+ mx .eval (state , losses , n_tokens )
364
+
329
365
train_time += time .perf_counter () - tic
330
366
331
367
# Report training loss if needed
0 commit comments