-
Notifications
You must be signed in to change notification settings - Fork 60
Feature/resumable lora with metadata #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/resumable lora with metadata #97
Conversation
Is restoring the step number alone sufficient for resuming the optimizer and learning schedule state, as described here? |
Currently, restoring the step number is sufficient to resume adapter weight training from the correct iteration (and this pr accounts for that when resuming from adapter checkpoints with no metadata). However, it does not restore optimizer internals (e.g., momentum buffers, learning rate schedule progress). This patch adds metadata as a foundation for fully resumable training. Future work could store the optimizer state and LR schedule status using this metadata to improve training continuity and reproducibility. |
Who do I have to pay to get a code review around here? 😅 |
@gingofthesouth thanks for the contribution and sorry for the delay. Just to set your expectations it might take us a little while (~week) to get to this, but we'll try and take a look soon. |
No worries at all, we’re all juggling a lot, so I completely understand. I mostly just wanted to get the ball rolling. As I dive deeper into the library and learn more about LLM training and tuning, I’m finding more areas where I could potentially contribute — like optimiser behaviour and restoring state when resuming training. it's a fun journey. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is I think something useful but it needs a bit of work. I left comments throughout.
mlx_lm/lora.py
Outdated
# Log resume state and extract iteration | ||
from safetensors.numpy import safe_open | ||
|
||
with safe_open(args.resume_adapter_file, framework="numpy") as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can load the metadata in mlx by doing _, metadata = mx.load(args.resume_adapter_file, return_metadata=True)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for pointing this out. I've moved to this method for retrieving the metadata.
mlx_lm/lora.py
Outdated
@@ -215,10 +215,35 @@ def train_model( | |||
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") | |||
|
|||
# Resume from weights if provided | |||
start_iteration = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would refactor the following code into a separate function in trainer.py
, see my comment about saving, it would be a its counterpart.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -339,11 +340,22 @@ def step(batch): | |||
# Save adapter weights | |||
if it % args.steps_per_save == 0 and rank == 0: | |||
adapter_weights = dict(tree_flatten(model.trainable_parameters())) | |||
mx.save_safetensors(str(args.adapter_file), adapter_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would refactor the following into a save_checkpoint
function. In order for resume to be really useful we need to save the optimizer state as well so together with the adapter weights, I would save the optimizer.state
as well.
This also makes any lr schedule work as expected as the optimizer will be using the same step number etc.
Then you would need a load_checkpoint
function which I would write quite close to the one above that takes a file, a model and an optimizer and returns the start iteration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. I've moved the code to trainer.py and refactored it to include the optimizer code too.
I had a few struggles but think I got there in the end. I am going to be honest, I'm finding coding in python is very foreign to me. I keep thinking to build utility classes, extensions, protocols and dependancy injection patterns. Thus I find myself a little lost about where to put code.
Any further feedback will be warmly received as I am learning the language ecosystem as I go.
tests/test_lora_resume.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not using pytest
but unittest
see tests/test_finetune.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I viewed the other tests and cleaned this up. Hopefully the new test code aligns better with the existing testing patterns and test libraries.
…he checkpoint functions to include saving and loading optimizer state. Re-wrote the tests to follow a similar pattern to existing tests.
Summary
This pull request introduces resumable LoRA fine-tuning for mlx-lm, motivated by real-world issues encountered during training. When training was interrupted due to Out-Of-Memory (OOM) errors on macOS Metal, there was previously no way to resume progress. Training would always restart from iteration 1, overwriting adapter files leaving us back at square one.
This change adds:
Changes
train_model() and train() updated to track iteration count
trainer.py saves the current training step and optimizer state in adapter metadata
On resume, the --resume-adapter-file checkpoint is loaded and the iteration count and optimizer state are restored
Filenames of the form 000XXXX_adapters.safetensors continue correctly from the last step
Added test script test_lora_resume.py that tests saving and resuming from iteration and optimizer metadata
Why this is needed
Without resume functionality, training interruptions due to memory limits or system crashes result in wasted compute time and no way to continue training. This update provides robust continuation support while maintaining compatibility with existing workflows.
Test Plan