Skip to content

Commit 07be2b5

Browse files
committed
use gradient accumulation
1 parent d6d5d80 commit 07be2b5

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

mlx_lm/tuner/trainer.py

+50-14
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import mlx.nn as nn
1313
import numpy as np
1414
from mlx.nn.utils import average_gradients
15-
from mlx.utils import tree_flatten
15+
from mlx.utils import tree_flatten, tree_map
1616
from transformers import PreTrainedTokenizer
1717

1818
from .datasets import CacheDataset
@@ -80,11 +80,11 @@ class TrainingArgs:
8080
)
8181

8282

83-
def default_loss(model, batch, lengths, cache):
83+
def default_loss(model, batch, lengths, cache=None):
8484
inputs = batch[:, :-1]
8585
targets = batch[:, 1:]
8686

87-
offset = cache[0].offset
87+
offset = cache[0].offset if cache is not None else 0
8888
logits = model(inputs, cache=cache)
8989
logits = logits.astype(mx.float32)
9090

@@ -184,6 +184,7 @@ def evaluate(
184184

185185
seq_step_size = seq_step_size or max_seq_length
186186

187+
cache = make_prompt_cache(model)
187188
for _, batch in zip(
188189
index_iterator,
189190
iterate_batches(
@@ -193,13 +194,14 @@ def evaluate(
193194
max_seq_length=max_seq_length,
194195
),
195196
):
196-
cache = make_prompt_cache(model)
197197
seq_length = batch[0].shape[1]
198198
for s in range(0, seq_length, seq_step_size):
199199
local_batch = (batch[0][:, s:s+seq_step_size], batch[1])
200200
losses, toks = loss(model, *local_batch, cache)
201201
all_losses += losses * toks
202202
ntokens += toks
203+
if s + seq_step_size >= seq_length:
204+
reset_prompt_cache(cache)
203205
mx.eval(all_losses, ntokens)
204206

205207
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
@@ -241,13 +243,14 @@ def train(
241243
if args.grad_checkpoint:
242244
grad_checkpoint(model.layers[0])
243245

246+
seq_step_size = args.seq_step_size or args.max_seq_length
244247
cache = make_prompt_cache(model)
245248
state = [model.state, optimizer.state, mx.random.state]
246249

247250
@partial(mx.compile, inputs=state, outputs=state)
248251
def step(batch):
249252
# Forward and backward pass
250-
(lvalue, toks), grad = loss_value_and_grad(model, *batch, cache)
253+
(lvalue, toks), grad = loss_value_and_grad(model, *batch)
251254

252255
# All reduce the gradients if running in distributed mode
253256
grad = average_gradients(grad)
@@ -264,6 +267,38 @@ def step(batch):
264267

265268
model.train()
266269
seq_step_size = args.seq_step_size or args.max_seq_length
270+
def seq_split_step(batch):
271+
losses = mx.array(0.0)
272+
n_tokens = mx.array(0.0)
273+
seq_length = batch[0].shape[1]
274+
grad_accum = None
275+
for s in range(0, seq_length, seq_step_size):
276+
local_batch = (batch[0][:, s:s+seq_step_size], batch[1])
277+
(lvalue, toks), grad = loss_value_and_grad(model, *local_batch, cache)
278+
prev_n_tokens = n_tokens
279+
losses += toks * lvalue
280+
n_tokens += toks
281+
282+
if grad_accum is None:
283+
grad_accum = grad
284+
else:
285+
scale_g = toks / n_tokens
286+
scale_acc = prev_n_tokens / n_tokens
287+
grad_accum = tree_map(lambda g, acc: scale_g * g + scale_acc * acc, grad, grad_accum)
288+
289+
290+
# Let go of the prompt cache before the last eval
291+
if s + seq_step_size >= seq_length:
292+
reset_prompt_cache(cache)
293+
mx.eval(grad_accum, losses, n_tokens)
294+
295+
grad_accum = average_gradients(grad_accum)
296+
optimizer.update(model, grad_accum)
297+
return losses / n_tokens, n_tokens
298+
299+
loss_value_and_grad = nn.value_and_grad(model, loss)
300+
301+
>>>>>>> 568a8d6 (use gradient accumulation)
267302
losses = 0
268303
n_tokens = 0
269304
steps = 0
@@ -317,15 +352,16 @@ def step(batch):
317352

318353
tic = time.perf_counter()
319354

320-
seq_length = batch[0].shape[1]
321-
for s in range(0, seq_length, seq_step_size):
322-
local_batch = (batch[0][:, s:s+seq_step_size], batch[1])
323-
lvalue, toks = step(local_batch)
324-
losses += lvalue
325-
n_tokens += toks
326-
steps += 1
327-
mx.eval(state, losses, n_tokens)
328-
reset_prompt_cache(cache)
355+
if batch[0].shape[1] > seq_step_size:
356+
lvalue, toks = seq_split_step(batch)
357+
else:
358+
lvalue, toks = step(batch)
359+
360+
losses += lvalue
361+
n_tokens += toks
362+
steps += 1
363+
mx.eval(state, losses, n_tokens)
364+
329365
train_time += time.perf_counter() - tic
330366

331367
# Report training loss if needed

0 commit comments

Comments
 (0)