-
Notifications
You must be signed in to change notification settings - Fork 69
Feature/resumable lora with metadata #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -215,10 +215,35 @@ def train_model( | |
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") | ||
|
||
# Resume from weights if provided | ||
start_iteration = 1 | ||
if args.resume_adapter_file is not None: | ||
print(f"Loading fine-tuned weights from {args.resume_adapter_file}") | ||
adapter_file = Path(args.resume_adapter_file) | ||
if adapter_file.is_dir(): | ||
safetensor_files = sorted( | ||
adapter_file.glob("*_adapters.safetensors"), | ||
key=lambda f: int(f.name.split("_")[0]), | ||
reverse=True, | ||
) | ||
if not safetensor_files: | ||
raise ValueError("No adapter files found to resume from.") | ||
latest = safetensor_files[0] | ||
print(f"Auto-resuming from latest adapter file: {latest}") | ||
args.resume_adapter_file = str(latest) | ||
else: | ||
print(f"Resuming from: {args.resume_adapter_file}") | ||
model.load_weights(args.resume_adapter_file, strict=False) | ||
|
||
# Log resume state and extract iteration | ||
from safetensors.numpy import safe_open | ||
|
||
with safe_open(args.resume_adapter_file, framework="numpy") as f: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can load the metadata in mlx by doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for pointing this out. I've moved to this method for retrieving the metadata. |
||
metadata = dict(f.metadata()) | ||
print("✅ Resuming from checkpoint:") | ||
print(" Metadata:", metadata) | ||
if "iteration" in metadata: | ||
start_iteration = int(metadata["iteration"]) + 1 | ||
print(f" Continuing from iteration {start_iteration}") | ||
|
||
print_trainable_parameters(model) | ||
|
||
adapter_path = Path(args.adapter_path) | ||
|
@@ -264,6 +289,7 @@ def train_model( | |
train_dataset=train_set, | ||
val_dataset=valid_set, | ||
training_callback=training_callback, | ||
start_step=start_iteration, | ||
) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -207,6 +207,7 @@ def train( | |
loss: callable = default_loss, | ||
iterate_batches: callable = iterate_batches, | ||
training_callback: TrainingCallback = None, | ||
start_step: int = 1, | ||
): | ||
mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"]) | ||
print(f"Starting training..., iters: {args.iters}") | ||
|
@@ -247,7 +248,7 @@ def step(batch): | |
train_time = 0 | ||
# Main training loop | ||
for it, batch in zip( | ||
range(1, args.iters + 1), | ||
range(start_step, args.iters + 1), | ||
iterate_batches( | ||
dataset=train_dataset, | ||
tokenizer=tokenizer, | ||
|
@@ -259,7 +260,7 @@ def step(batch): | |
tic = time.perf_counter() | ||
# Report validation loss if needed, the first validation loss | ||
# is always measured before any training. | ||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: | ||
if it == start_step or it % args.steps_per_eval == 0 or it == args.iters: | ||
tic = time.perf_counter() | ||
val_loss = evaluate( | ||
model=model, | ||
|
@@ -339,11 +340,22 @@ def step(batch): | |
# Save adapter weights | ||
if it % args.steps_per_save == 0 and rank == 0: | ||
adapter_weights = dict(tree_flatten(model.trainable_parameters())) | ||
mx.save_safetensors(str(args.adapter_file), adapter_weights) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would refactor the following into a This also makes any lr schedule work as expected as the optimizer will be using the same step number etc. Then you would need a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea. I've moved the code to trainer.py and refactored it to include the optimizer code too. Any further feedback will be warmly received as I am learning the language ecosystem as I go. |
||
|
||
metadata = { | ||
"iteration": str(it), | ||
"trained_tokens": str(trained_tokens), | ||
"loss": f"{train_loss:.6f}", | ||
} | ||
|
||
mx.save_safetensors( | ||
str(args.adapter_file), | ||
adapter_weights, | ||
metadata=metadata, | ||
) | ||
checkpoint = ( | ||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" | ||
) | ||
mx.save_safetensors(str(checkpoint), adapter_weights) | ||
mx.save_safetensors(str(checkpoint), adapter_weights, metadata=metadata) | ||
print( | ||
f"Iter {it}: Saved adapter weights to " | ||
f"{args.adapter_file} and {checkpoint}." | ||
|
@@ -352,5 +364,10 @@ def step(batch): | |
# Save final weights | ||
if rank == 0: | ||
adapter_weights = dict(tree_flatten(model.trainable_parameters())) | ||
mx.save_safetensors(str(args.adapter_file), adapter_weights) | ||
metadata = { | ||
"iteration": str(args.iters), | ||
"trained_tokens": str(trained_tokens), | ||
"final": "true", | ||
} | ||
mx.save_safetensors(str(args.adapter_file), adapter_weights, metadata=metadata) | ||
print(f"Saved final weights to {args.adapter_file}.") |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are not using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I viewed the other tests and cleaned this up. Hopefully the new test code aligns better with the existing testing patterns and test libraries. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
import shutil | ||
from pathlib import Path | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
import mlx.optimizers as optim | ||
import numpy as np | ||
import pytest | ||
from safetensors.numpy import safe_open | ||
|
||
from mlx_lm.tuner.datasets import CacheDataset | ||
from mlx_lm.tuner.trainer import TrainingArgs, train | ||
|
||
# --------------------- | ||
# Mock Components | ||
# --------------------- | ||
|
||
|
||
class MockModel(nn.Module): | ||
def __init__(self, vocab_size=128, hidden_size=32): | ||
super().__init__() | ||
self.embed = nn.Embedding(vocab_size, hidden_size) | ||
self.layers = [nn.Linear(hidden_size, hidden_size) for _ in range(2)] | ||
self.out = nn.Linear(hidden_size, vocab_size) | ||
|
||
def __call__(self, x): | ||
x = self.embed(x) | ||
for layer in self.layers: | ||
x = layer(x) | ||
return self.out(x) | ||
|
||
def freeze(self): | ||
pass | ||
|
||
def unfreeze(self): | ||
pass | ||
|
||
def train(self): | ||
pass | ||
|
||
def eval(self): | ||
pass | ||
|
||
|
||
class DummyDataset: | ||
def __getitem__(self, idx): | ||
return [i % 100 for i in range(32)] | ||
|
||
def __len__(self): | ||
return 1000 | ||
|
||
|
||
class DummyTokenizer: | ||
def __call__(self, texts): | ||
return [[i % 100 for i, _ in enumerate(text.split())] for text in texts] | ||
|
||
|
||
# --------------------- | ||
# Training Runner | ||
# --------------------- | ||
|
||
|
||
def run_training(iters, adapter_file, resume_from=None): | ||
model = MockModel() | ||
dataset = DummyDataset() | ||
train_set = CacheDataset(dataset) | ||
val_set = CacheDataset(dataset) | ||
tokenizer = DummyTokenizer() | ||
optimizer = optim.Adam(learning_rate=1e-4) | ||
|
||
args = TrainingArgs( | ||
iters=iters, | ||
batch_size=16, | ||
val_batches=2, | ||
steps_per_report=5, | ||
steps_per_save=5, | ||
adapter_file=adapter_file, | ||
max_seq_length=64, | ||
) | ||
|
||
if resume_from: | ||
model.load_weights(resume_from, strict=False) | ||
|
||
train(model, tokenizer, optimizer, train_set, val_set, args=args) | ||
|
||
|
||
# --------------------- | ||
# Test Case | ||
# --------------------- | ||
|
||
|
||
@pytest.mark.order(1) | ||
def test_adapter_resume_and_metadata(tmp_path): | ||
adapter_dir = tmp_path / "adapters" | ||
adapter_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# Step 1: Train for 5 iters | ||
adapter_file = adapter_dir / "adapters.safetensors" | ||
run_training(iters=5, adapter_file=adapter_file) | ||
|
||
assert (adapter_dir / "0000005_adapters.safetensors").exists() | ||
|
||
# Step 2: Resume for 5 more iters (should end at 10) | ||
resume_file = adapter_dir / "0000005_adapters.safetensors" | ||
run_training(iters=10, adapter_file=adapter_file, resume_from=resume_file) | ||
|
||
final_file = adapter_dir / "0000010_adapters.safetensors" | ||
assert final_file.exists() | ||
|
||
# Step 3: Check metadata | ||
with safe_open(str(final_file), framework="numpy") as f: | ||
metadata = f.metadata() | ||
assert "step" in metadata | ||
assert metadata["step"] == "10" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would refactor the following code into a separate function in
trainer.py
, see my comment about saving, it would be a its counterpart.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.