Skip to content

Feature/resumable lora with metadata #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion mlx_lm/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,35 @@ def train_model(
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")

# Resume from weights if provided
start_iteration = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor the following code into a separate function in trainer.py, see my comment about saving, it would be a its counterpart.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if args.resume_adapter_file is not None:
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
adapter_file = Path(args.resume_adapter_file)
if adapter_file.is_dir():
safetensor_files = sorted(
adapter_file.glob("*_adapters.safetensors"),
key=lambda f: int(f.name.split("_")[0]),
reverse=True,
)
if not safetensor_files:
raise ValueError("No adapter files found to resume from.")
latest = safetensor_files[0]
print(f"Auto-resuming from latest adapter file: {latest}")
args.resume_adapter_file = str(latest)
else:
print(f"Resuming from: {args.resume_adapter_file}")
model.load_weights(args.resume_adapter_file, strict=False)

# Log resume state and extract iteration
from safetensors.numpy import safe_open

with safe_open(args.resume_adapter_file, framework="numpy") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can load the metadata in mlx by doing _, metadata = mx.load(args.resume_adapter_file, return_metadata=True).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. I've moved to this method for retrieving the metadata.

metadata = dict(f.metadata())
print("✅ Resuming from checkpoint:")
print(" Metadata:", metadata)
if "iteration" in metadata:
start_iteration = int(metadata["iteration"]) + 1
print(f" Continuing from iteration {start_iteration}")

print_trainable_parameters(model)

adapter_path = Path(args.adapter_path)
Expand Down Expand Up @@ -264,6 +289,7 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
start_step=start_iteration,
)


Expand Down
27 changes: 22 additions & 5 deletions mlx_lm/tuner/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def train(
loss: callable = default_loss,
iterate_batches: callable = iterate_batches,
training_callback: TrainingCallback = None,
start_step: int = 1,
):
mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"])
print(f"Starting training..., iters: {args.iters}")
Expand Down Expand Up @@ -247,7 +248,7 @@ def step(batch):
train_time = 0
# Main training loop
for it, batch in zip(
range(1, args.iters + 1),
range(start_step, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
Expand All @@ -259,7 +260,7 @@ def step(batch):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
if it == start_step or it % args.steps_per_eval == 0 or it == args.iters:
tic = time.perf_counter()
val_loss = evaluate(
model=model,
Expand Down Expand Up @@ -339,11 +340,22 @@ def step(batch):
# Save adapter weights
if it % args.steps_per_save == 0 and rank == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor the following into a save_checkpoint function. In order for resume to be really useful we need to save the optimizer state as well so together with the adapter weights, I would save the optimizer.state as well.

This also makes any lr schedule work as expected as the optimizer will be using the same step number etc.

Then you would need a load_checkpoint function which I would write quite close to the one above that takes a file, a model and an optimizer and returns the start iteration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I've moved the code to trainer.py and refactored it to include the optimizer code too.
I had a few struggles but think I got there in the end. I am going to be honest, I'm finding coding in python is very foreign to me. I keep thinking to build utility classes, extensions, protocols and dependancy injection patterns. Thus I find myself a little lost about where to put code.

Any further feedback will be warmly received as I am learning the language ecosystem as I go.


metadata = {
"iteration": str(it),
"trained_tokens": str(trained_tokens),
"loss": f"{train_loss:.6f}",
}

mx.save_safetensors(
str(args.adapter_file),
adapter_weights,
metadata=metadata,
)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
mx.save_safetensors(str(checkpoint), adapter_weights, metadata=metadata)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
Expand All @@ -352,5 +364,10 @@ def step(batch):
# Save final weights
if rank == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
metadata = {
"iteration": str(args.iters),
"trained_tokens": str(trained_tokens),
"final": "true",
}
mx.save_safetensors(str(args.adapter_file), adapter_weights, metadata=metadata)
print(f"Saved final weights to {args.adapter_file}.")
115 changes: 115 additions & 0 deletions tests/test_lora_resume.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using pytest but unittest see tests/test_finetune.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I viewed the other tests and cleaned this up. Hopefully the new test code aligns better with the existing testing patterns and test libraries.

Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
import shutil
from pathlib import Path

import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
import pytest
from safetensors.numpy import safe_open

from mlx_lm.tuner.datasets import CacheDataset
from mlx_lm.tuner.trainer import TrainingArgs, train

# ---------------------
# Mock Components
# ---------------------


class MockModel(nn.Module):
def __init__(self, vocab_size=128, hidden_size=32):
super().__init__()
self.embed = nn.Embedding(vocab_size, hidden_size)
self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(2)]
self.out = nn.Linear(hidden_size, vocab_size)

def __call__(self, x):
x = self.embed(x)
for layer in self.layers:
x = layer(x)
return self.out(x)

def freeze(self):
pass

def unfreeze(self):
pass

def train(self):
pass

def eval(self):
pass


class DummyDataset:
def __getitem__(self, idx):
return [i % 100 for i in range(32)]

def __len__(self):
return 1000


class DummyTokenizer:
def __call__(self, texts):
return [[i % 100 for i, _ in enumerate(text.split())] for text in texts]


# ---------------------
# Training Runner
# ---------------------


def run_training(iters, adapter_file, resume_from=None):
model = MockModel()
dataset = DummyDataset()
train_set = CacheDataset(dataset)
val_set = CacheDataset(dataset)
tokenizer = DummyTokenizer()
optimizer = optim.Adam(learning_rate=1e-4)

args = TrainingArgs(
iters=iters,
batch_size=16,
val_batches=2,
steps_per_report=5,
steps_per_save=5,
adapter_file=adapter_file,
max_seq_length=64,
)

if resume_from:
model.load_weights(resume_from, strict=False)

train(model, tokenizer, optimizer, train_set, val_set, args=args)


# ---------------------
# Test Case
# ---------------------


@pytest.mark.order(1)
def test_adapter_resume_and_metadata(tmp_path):
adapter_dir = tmp_path / "adapters"
adapter_dir.mkdir(parents=True, exist_ok=True)

# Step 1: Train for 5 iters
adapter_file = adapter_dir / "adapters.safetensors"
run_training(iters=5, adapter_file=adapter_file)

assert (adapter_dir / "0000005_adapters.safetensors").exists()

# Step 2: Resume for 5 more iters (should end at 10)
resume_file = adapter_dir / "0000005_adapters.safetensors"
run_training(iters=10, adapter_file=adapter_file, resume_from=resume_file)

final_file = adapter_dir / "0000010_adapters.safetensors"
assert final_file.exists()

# Step 3: Check metadata
with safe_open(str(final_file), framework="numpy") as f:
metadata = f.metadata()
assert "step" in metadata
assert metadata["step"] == "10"