Skip to content

Commit 34940c5

Browse files
committed
fix test
1 parent 78be1bc commit 34940c5

File tree

2 files changed

+9
-11
lines changed

2 files changed

+9
-11
lines changed

mlx_lm/tuner/trainer.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from mlx.utils import tree_flatten, tree_map
1616
from transformers import PreTrainedTokenizer
1717

18-
from .datasets import CacheDataset
1918
from ..models.cache import KVCache, make_prompt_cache
19+
from .datasets import CacheDataset
2020

2121

2222
def reset_prompt_cache(cache):
@@ -267,6 +267,7 @@ def step(batch):
267267

268268
model.train()
269269
seq_step_size = args.seq_step_size or args.max_seq_length
270+
270271
def seq_split_step(batch):
271272
losses = mx.array(0.0)
272273
n_tokens = mx.array(0.0)
@@ -299,7 +300,6 @@ def seq_split_step(batch):
299300

300301
loss_value_and_grad = nn.value_and_grad(model, loss)
301302

302-
>>>>>>> 568a8d6 (use gradient accumulation)
303303
losses = 0
304304
n_tokens = 0
305305
steps = 0
@@ -321,7 +321,6 @@ def seq_split_step(batch):
321321
# is always measured before any training.
322322
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
323323
tic = time.perf_counter()
324-
val_loss = 0.0
325324
val_loss = evaluate(
326325
model=model,
327326
dataset=val_dataset,

tests/test_finetune.py

+7-8
Original file line numberDiff line numberDiff line change
@@ -370,11 +370,10 @@ def test_evaluate_calls(self):
370370
mock_iterate_batches = MagicMock()
371371

372372
mock_iterate_batches.return_value = [
373-
(MagicMock(), MagicMock()),
374-
(MagicMock(), MagicMock()),
375-
(MagicMock(), MagicMock()),
376-
(MagicMock(), MagicMock()),
377-
(MagicMock(), MagicMock()),
373+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
374+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
375+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
376+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
378377
]
379378

380379
mock_default_loss.side_effect = [
@@ -412,9 +411,9 @@ def test_evaluate_infinite_batches(self):
412411
mock_iterate_batches = MagicMock()
413412

414413
mock_iterate_batches.return_value = [
415-
(MagicMock(), MagicMock()),
416-
(MagicMock(), MagicMock()),
417-
(MagicMock(), MagicMock()),
414+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
415+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
416+
(mx.ones((2, 8), mx.int32), mx.ones((2, 2), mx.int32)),
418417
]
419418

420419
mock_default_loss.side_effect = [

0 commit comments

Comments
 (0)