-
Notifications
You must be signed in to change notification settings - Fork 419
Save memory for checkpoint utility with abstract structure #2609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b836100 to
c29a723
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the great fix! Do you know how much memory it saved (%) after this fix?
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a significant memory optimization to the checkpoint conversion utility by using jax.eval_shape to avoid materializing the full model. This is a crucial improvement for handling very large models, preventing potential out-of-memory errors.
🔍 General Feedback
- The implementation is clean, and the use of
jax.eval_shapeis appropriate for the problem. - The new function
get_abstract_paramis well-designed and placed in a reusable location. - The robustness improvement in handling parameter paths is a good addition.
- The pull request description includes comprehensive testing, which demonstrates the effectiveness and correctness of the changes.
Overall, this is a high-quality contribution that addresses a real-world problem effectively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I have added memory analysis and experiment in PR description. For a model with |
Description
The checkpoint utility
to_maxtext.pyinstantiate concrete maxtext training state withmaxtext_utils.setup_training_state, which results large memory use. Note there is a delay in htop view.For instance, during
to_maxtext.py, deepseek3-671b has memory usage exceeding 3.7T on CPU, leading to the program being killed. (ds3 to maxtext is still under development.)To solve this problem, we get the abstract structure (param name and shape) through
jax.eval_shapeMemory Analysis
Estimate: for model with a total param of$x$ billion, we would theoretically save of $12x$ GB, and have $8x$ GB current use
Estimate 1: we can save$12x$ GB
Memory cost for setup_training_state, stored in
abstract_stateWhat if we get rid of optimizer state? We still have the weight copy.
setup_decode_state, which get rids of the optimizer states.Our solution gets rid of all maxtext weights
abstract_params_flatthroughout the program, to access the maxtext parameter name and shapefor path_tuple, abstract_leaf_value in abstract_params_flat:abstract_leaf_valueis a concrete array. That is, we keep all maxtext weights in memory.abstract_leaf_valueis jax.ShapeDtypeStruct, a container for arrayEstimate2: the current cost is$8x$ GB
weight_dtype=bfloat16, it would beTests
memory comparison: 1 & 2
forward logit check: 2 & 3
1 llama3.1-70b, CPU
conversion
before (e3ddb1a):
2 qwen3-4b, CPU
conversion
before (e3ddb1a):
forward logit check
max KL=0.004159 (I love to): https://paste.googleplex.com/6539267438411776
3 gemma3-4b-multimodal, CPU
conversion
https://paste.googleplex.com/6113203033604096
forward logit check
same test as in b/443777964#comment2
max KL divergence = 0.024395478889346123: https://paste.googleplex.com/5583389885333504
4 qwen3-4b, TPU
It should behave same if adding the flags
hardware=cpu skip_jax_distributed_system=True. Test in case we don't have the flags and jax distributed system is initialized.conversion
https://paste.googleplex.com/4937498778271744
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.