Skip to content

Conversation

@shuningjin
Copy link
Collaborator

@shuningjin shuningjin commented Nov 6, 2025

Description

The checkpoint utility to_maxtext.py instantiate concrete maxtext training state with maxtext_utils.setup_training_state, which results large memory use. Note there is a delay in htop view.

For instance, during to_maxtext.py, deepseek3-671b has memory usage exceeding 3.7T on CPU, leading to the program being killed. (ds3 to maxtext is still under development.)

To solve this problem, we get the abstract structure (param name and shape) through jax.eval_shape

  • when transforming hf to maxtext, we only need the abstract structure
  • this avoids materializing the model weights, as well as optimizer states

Memory Analysis

Estimate: for model with a total param of $x$ billion, we would theoretically save of $12x$ GB, and have $8x$ GB current use

  • assume default setting: weight_dtype=float32, opt_type=adamw

Estimate 1: we can save $12x$ GB

Memory cost for setup_training_state, stored in abstract_state

  • 4 (weight f32) + 4 (opt state in f32) * 2 (num states) = 12 byte for each param
  • 12 byte * x * 1e9 / 1e9 = $12x$ GB

What if we get rid of optimizer state? We still have the weight copy.

  • One alternative is setup_decode_state, which get rids of the optimizer states.
  • We would still have the weight copy: need $4x$ GB for weight_dtype=float32 (or $2x$ GB for weight_dtype=bfloat16)

Our solution gets rid of all maxtext weights

  • we need to use abstract_params_flat throughout the program, to access the maxtext parameter name and shape
  • inside the transform loop for path_tuple, abstract_leaf_value in abstract_params_flat:

Estimate2: the current cost is $8x$ GB

  • $x$ billion * 4 byte (f32) * 2 (hf copy + new copy) = $8x$ GB
  • if we further set weight_dtype=bfloat16, it would be $4x$ GB

Tests

memory comparison: 1 & 2

  • (default: weight_dtype=float32, opt_type=adamw)
  • llama3.1-70b, save is 757GB (1360-603)=10.8 $x$, current effective cost is 588GB (603-15)=8.4 $x$
  • qwen3-4b, save is 40GB, save is 45GB (106-60.7)=11.2 $x$, current effective cost is 38GB (60.7-23)=9.5 $x$
  • This mostly agrees with the analysis. The empirical save is $11x$, and the empirical current use is $10x$.

forward logit check: 2 & 3

1 llama3.1-70b, CPU

conversion

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=llama3.1-70b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True
Transforming weights: 100%|████████████████| 723/723 [00:41<00:00, 17.51param/s, RAM: 301.7/3783.8GB (8.5%)]

before (e3ddb1a):

Transforming weights: 100%|████████████████| 723/723 [01:36<00:00,  7.50param/s, RAM: 1093.4/3783.8GB (29.4%)]

2 qwen3-4b, CPU

conversion

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-4b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True

before (e3ddb1a):

forward logit check

# tokenizer_type=huggingface is hard-coded in its model config
CKPT=gs://runner-maxtext-logs/2025-11-06-09-48/0/items
python3 -m tests.forward_pass_logit_checker src/MaxText/configs/base.yml \
model_name=qwen3-4b attention=dot_product \
override_model_config=true enable_dropout=false tokenizer_type=huggingface \
load_parameters_path=$CKPT scan_layers=false \
per_device_batch_size=1 max_prefill_predict_length=4 max_target_length=8 \
tokenizer_path=Qwen/Qwen3-4B --run_hf_model=True --hf_model_path=Qwen/Qwen3-4B \
--max_kl_div=0.015 --atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

max KL=0.004159 (I love to): https://paste.googleplex.com/6539267438411776

3 gemma3-4b-multimodal, CPU

conversion

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml \
model_name=gemma3-4b use_multimodal=true scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN \
hardware=cpu skip_jax_distributed_system=True

https://paste.googleplex.com/6113203033604096

forward logit check

same test as in b/443777964#comment2

# gcloud storage cp gs://aireenmei-multipod/golden-logits/google/gemma-3-4b-it/golden_gemma-3-4b-it_image_only.jsonl /tmp/golden_gemma-3-4b-it_image_only.jsonl

export GOLDEN_LOGITS=/tmp/golden_gemma-3-4b-it_image_only.jsonl
export CHECKPOINT_TPU_UNSCANNED=gs://runner-maxtext-logs/2025-11-06-10-17/0/items
export SCAN_LAYERS=false
export MODEL_SIZE=gemma3-4b
export SEQ_LEN=260
IDX=$(date +%Y-%m-%d-%H-%M)
export MULTIMODAL=true

PYTHONPATH=src python3 -m tests.forward_pass_logit_checker \
MaxText/configs/base.yml \
load_parameters_path=${CHECKPOINT_TPU_UNSCANNED} \
per_device_batch_size=1 model_name=$MODEL_SIZE max_prefill_predict_length=$SEQ_LEN \
max_target_length=$SEQ_LEN dtype=float32 activations_in_float32=true \
matmul_precision=highest float32_logits=true float32_qk_product=true async_checkpointing=false scan_layers=$SCAN_LAYERS \
use_multimodal=$MULTIMODAL enable_dropout=False attention=dot_product \
--golden_logits_path=$GOLDEN_LOGITS \
--atol=0.5 --rtol=0.5 \
skip_jax_distributed_system=True

max KL divergence = 0.024395478889346123: https://paste.googleplex.com/5583389885333504

4 qwen3-4b, TPU

It should behave same if adding the flags hardware=cpu skip_jax_distributed_system=True. Test in case we don't have the flags and jax distributed system is initialized.

conversion

BASE_OUTPUT_PATH=gs://runner-maxtext-logs/$(date +%Y-%m-%d-%H-%M); \
echo $BASE_OUTPUT_PATH/0/items; \
python -m MaxText.utils.ckpt_conversion.to_maxtext MaxText/configs/base.yml model_name=qwen3-4b scan_layers=false \
base_output_directory=$BASE_OUTPUT_PATH hf_access_token=$HF_TOKEN

https://paste.googleplex.com/4937498778271744

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Collaborator

@hengtaoguo hengtaoguo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great fix! Do you know how much memory it saved (%) after this fix?

@github-actions
Copy link

github-actions bot commented Nov 6, 2025

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This pull request introduces a significant memory optimization to the checkpoint conversion utility by using jax.eval_shape to avoid materializing the full model. This is a crucial improvement for handling very large models, preventing potential out-of-memory errors.

🔍 General Feedback

  • The implementation is clean, and the use of jax.eval_shape is appropriate for the problem.
  • The new function get_abstract_param is well-designed and placed in a reusable location.
  • The robustness improvement in handling parameter paths is a good addition.
  • The pull request description includes comprehensive testing, which demonstrates the effectiveness and correctness of the changes.

Overall, this is a high-quality contribution that addresses a real-world problem effectively.

Copy link
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@shuningjin
Copy link
Collaborator Author

shuningjin commented Nov 6, 2025

Thanks for the great fix! Do you know how much memory it saved (%) after this fix?

I have added memory analysis and experiment in PR description. For a model with $x$ billion parameters, it would save $12x$ GB theoretically, and $11x$ GB empirically.

@copybara-service copybara-service bot merged commit 93c15e7 into main Nov 6, 2025
137 of 138 checks passed
@copybara-service copybara-service bot deleted the shuningjin-ckpt-opt branch November 6, 2025 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants