Skip to content

Commit 78be1bc

Browse files
committed
use gradient accumulation
1 parent 07be2b5 commit 78be1bc

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

mlx_lm/tuner/trainer.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from transformers import PreTrainedTokenizer
1717

1818
from .datasets import CacheDataset
19-
from ..models.cache import make_prompt_cache, KVCache
19+
from ..models.cache import KVCache, make_prompt_cache
2020

2121

2222
def reset_prompt_cache(cache):
@@ -74,7 +74,7 @@ class TrainingArgs:
7474
default=False,
7575
metadata={"help": "Use gradient checkpointing to reduce memory use."},
7676
)
77-
seq_step_size : Optional[int] = field(
77+
seq_step_size: Optional[int] = field(
7878
default=None,
7979
metadata={"help": "The examples are processsed in seq_step_size chunks."},
8080
)
@@ -196,7 +196,7 @@ def evaluate(
196196
):
197197
seq_length = batch[0].shape[1]
198198
for s in range(0, seq_length, seq_step_size):
199-
local_batch = (batch[0][:, s:s+seq_step_size], batch[1])
199+
local_batch = (batch[0][:, s : s + seq_step_size], batch[1])
200200
losses, toks = loss(model, *local_batch, cache)
201201
all_losses += losses * toks
202202
ntokens += toks
@@ -273,7 +273,7 @@ def seq_split_step(batch):
273273
seq_length = batch[0].shape[1]
274274
grad_accum = None
275275
for s in range(0, seq_length, seq_step_size):
276-
local_batch = (batch[0][:, s:s+seq_step_size], batch[1])
276+
local_batch = (batch[0][:, s : s + seq_step_size], batch[1])
277277
(lvalue, toks), grad = loss_value_and_grad(model, *local_batch, cache)
278278
prev_n_tokens = n_tokens
279279
losses += toks * lvalue
@@ -284,8 +284,9 @@ def seq_split_step(batch):
284284
else:
285285
scale_g = toks / n_tokens
286286
scale_acc = prev_n_tokens / n_tokens
287-
grad_accum = tree_map(lambda g, acc: scale_g * g + scale_acc * acc, grad, grad_accum)
288-
287+
grad_accum = tree_map(
288+
lambda g, acc: scale_g * g + scale_acc * acc, grad, grad_accum
289+
)
289290

290291
# Let go of the prompt cache before the last eval
291292
if s + seq_step_size >= seq_length:

0 commit comments

Comments
 (0)