Skip to content

Feature/resumable lora with metadata #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

gingofthesouth
Copy link

@gingofthesouth gingofthesouth commented Apr 16, 2025

Summary

This pull request introduces resumable LoRA fine-tuning for mlx-lm, motivated by real-world issues encountered during training. When training was interrupted due to Out-Of-Memory (OOM) errors on macOS Metal, there was previously no way to resume progress. Training would always restart from iteration 1, overwriting adapter files leaving us back at square one.

This change adds:

  • The ability to resume training from the most recent adapter checkpoint (--resume-adapter-file)
  • Metadata saved in adapter .safetensors files to track the training iteration
  • Optimizer state saved and restored so training resumes accurately.
  • Test coverage to validate resume behavior, metadata integrity, and optimizer state

Changes

  • train_model() and train() updated to track iteration count

  • trainer.py saves the current training step and optimizer state in adapter metadata

  • On resume, the --resume-adapter-file checkpoint is loaded and the iteration count and optimizer state are restored

  • Filenames of the form 000XXXX_adapters.safetensors continue correctly from the last step

  • Added test script test_lora_resume.py that tests saving and resuming from iteration and optimizer metadata

Why this is needed

Without resume functionality, training interruptions due to memory limits or system crashes result in wasted compute time and no way to continue training. This update provides robust continuation support while maintaining compatibility with existing workflows.

Test Plan

  • Manual testing with --resume-adapter-file on interrupted training
  • Automated test test_lora_resume.py validates end-to-end resume behavior

@chimezie
Copy link

Is restoring the step number alone sufficient for resuming the optimizer and learning schedule state, as described here?

@gingofthesouth
Copy link
Author

Is restoring the step number alone sufficient for resuming the optimizer and learning schedule state, as described here?

Currently, restoring the step number is sufficient to resume adapter weight training from the correct iteration (and this pr accounts for that when resuming from adapter checkpoints with no metadata). However, it does not restore optimizer internals (e.g., momentum buffers, learning rate schedule progress). This patch adds metadata as a foundation for fully resumable training. Future work could store the optimizer state and LR schedule status using this metadata to improve training continuity and reproducibility.

@gingofthesouth
Copy link
Author

Who do I have to pay to get a code review around here? 😅

@awni
Copy link
Member

awni commented Apr 17, 2025

@gingofthesouth thanks for the contribution and sorry for the delay. Just to set your expectations it might take us a little while (~week) to get to this, but we'll try and take a look soon.

@gingofthesouth
Copy link
Author

@gingofthesouth thanks for the contribution and sorry for the delay. Just to set your expectations it might take us a little while (~week) to get to this, but we'll try and take a look soon.

No worries at all, we’re all juggling a lot, so I completely understand. I mostly just wanted to get the ball rolling. As I dive deeper into the library and learn more about LLM training and tuning, I’m finding more areas where I could potentially contribute — like optimiser behaviour and restoring state when resuming training.
I also noticed there’s no automated MR pipeline in place (unit tests, linting, static analysis, or even an agent-assisted first-pass review), so could potentially help there if that was of any value to the project.

it's a fun journey.

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is I think something useful but it needs a bit of work. I left comments throughout.

mlx_lm/lora.py Outdated
# Log resume state and extract iteration
from safetensors.numpy import safe_open

with safe_open(args.resume_adapter_file, framework="numpy") as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can load the metadata in mlx by doing _, metadata = mx.load(args.resume_adapter_file, return_metadata=True).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for pointing this out. I've moved to this method for retrieving the metadata.

mlx_lm/lora.py Outdated
@@ -215,10 +215,35 @@ def train_model(
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")

# Resume from weights if provided
start_iteration = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor the following code into a separate function in trainer.py, see my comment about saving, it would be a its counterpart.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -339,11 +340,22 @@ def step(batch):
# Save adapter weights
if it % args.steps_per_save == 0 and rank == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would refactor the following into a save_checkpoint function. In order for resume to be really useful we need to save the optimizer state as well so together with the adapter weights, I would save the optimizer.state as well.

This also makes any lr schedule work as expected as the optimizer will be using the same step number etc.

Then you would need a load_checkpoint function which I would write quite close to the one above that takes a file, a model and an optimizer and returns the start iteration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I've moved the code to trainer.py and refactored it to include the optimizer code too.
I had a few struggles but think I got there in the end. I am going to be honest, I'm finding coding in python is very foreign to me. I keep thinking to build utility classes, extensions, protocols and dependancy injection patterns. Thus I find myself a little lost about where to put code.

Any further feedback will be warmly received as I am learning the language ecosystem as I go.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not using pytest but unittest see tests/test_finetune.py.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I viewed the other tests and cleaned this up. Hopefully the new test code aligns better with the existing testing patterns and test libraries.

…he checkpoint functions to include saving and loading optimizer state. Re-wrote the tests to follow a similar pattern to existing tests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants