Skip to content

Enable memory efficient fine tuning for very long sequences #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

awni
Copy link
Member

@awni awni commented Mar 22, 2025

Use a KV cache to split long sequences into shorter sequences and reduce memory. I didn't add gradient accumulation yet.. but we may want to use it so that the results match for different values of --seq-step-size.

One can set for example --seq-step-size 1024 and it will use way less memory when trianing on sequences that are length 4096. For example:

Without splitting:

mlx_lm.lora --model mlx-community/llama-3.2-1B-Instruct-bf16 --data ../ --train --steps-per-report 5 --max-seq-length 4096
Iter 5: Train loss 1.520, Learning Rate 1.000e-05, It/sec 0.111, Tokens/sec 1825.855, Trained Tokens 81900, Peak mem 69.810 GB

With splitting:

mlx_lm.lora --model mlx-community/llama-3.2-1B-Instruct-bf16 --data ../ --train --steps-per-report 5 --seq-step-size 1024 --max-seq-length 4096
Iter 5: Train loss 1.496, Learning Rate 1.000e-05, It/sec 0.136, Tokens/sec 2220.005, Trained Tokens 81840, Peak mem 19.704 GB

Copy link
Member

@angeloskath angeloskath left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!

@awni
Copy link
Member Author

awni commented Mar 24, 2025

Interestingly it seems to make a big difference in convergence with/without gradient accumulation:

Waiting for a few more iterations:

Initial validation los:

Iter 1: Val loss 1.343, Val took 885.710s

Without accumulation:

Iter 10: Train loss 1.191, Learning Rate 1.000e-05, It/sec 0.012, Tokens/sec 555.236, Trained Tokens 468769, Peak mem 61.743 GB
Iter 20: Train loss 1.604, Learning Rate 1.000e-05, It/sec 0.016, Tokens/sec 581.244, Trained Tokens 825642, Peak mem 61.743 GB
Iter 30: Train loss 1.484, Learning Rate 1.000e-05, It/sec 0.013, Tokens/sec 554.594, Trained Tokens 1264592, Peak mem 61.743 GB
Iter 40: Train loss 1.263, Learning Rate 1.000e-05, It/sec 0.018, Tokens/sec 587.876, Trained Tokens 1596749, Peak mem 61.743 GB

With accumulation:

Iter 10: Train loss 1.240, Learning Rate 1.000e-05, It/sec 0.012, Tokens/sec 553.871, Trained Tokens 468769.0, Peak mem 61.736 GB
Iter 20: Train loss 1.222, Learning Rate 1.000e-05, It/sec 0.016, Tokens/sec 581.378, Trained Tokens 825642.0, Peak mem 61.736 GB
Iter 30: Train loss 1.183, Learning Rate 1.000e-05, It/sec 0.013, Tokens/sec 552.957, Trained Tokens 1264592.0, Peak mem 61.736 GB
Iter 40: Train loss 1.094, Learning Rate 1.000e-05, It/sec 0.018, Tokens/sec 588.618, Trained Tokens 1596749.0, Peak mem 61.736 GB

@awni
Copy link
Member Author

awni commented Mar 24, 2025

I changed this to use gradient accumulation instead of taking an optimization step after each sequence chunk. For long sequences it seems to converge better for very long sequences. Sort of makes sense since the momentum term doesn't get dominated by one very long sequence.

@awni awni force-pushed the fine_tune_long_sequences branch 3 times, most recently from 5f0024f to 55c4e2b Compare March 27, 2025 13:38
@awni awni force-pushed the fine_tune_long_sequences branch from 55c4e2b to 34940c5 Compare March 27, 2025 15:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants