Add checkpoint resharding script for faster loading#3801
Add checkpoint resharding script for faster loading#3801shuningjin wants to merge 1 commit intomainfrom
Conversation
d549c06 to
fa776ba
Compare
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces a new script reshard_checkpoint.py designed to re-shard MaxText checkpoints on CPU. This utility is highly effective for reducing checkpoint loading times on TPUs, as demonstrated by the significant performance gains reported for DeepSeek-V3. The PR also includes minor robustness improvements and bug fixes in llama_or_mistral_ckpt.py.
🔍 General Feedback
- Performance: The reported 10x reduction in loading time (from 60 min to 6 min) for DeepSeek-V3 is a major improvement for large-scale model training and inference.
- Initialization Timing: A key concern is the timing of JAX initialization in the new script. Setting environment variables like
XLA_FLAGSafter importing JAX-dependent modules may lead to them being ignored if the XLA backend has already been initialized. - Flexibility: Adding a way to specify or preserve the
step_numberwould enhance the utility of the resharding script.
| from maxtext.utils import max_utils, max_logging | ||
| from maxtext.common import checkpointing | ||
| from maxtext.checkpoint_conversion.utils.utils import print_peak_memory | ||
|
|
There was a problem hiding this comment.
| # Dummy configs for the checkpoint_manager | |
| step_number = config.step if hasattr(config, 'step') else 0 |
There was a problem hiding this comment.
Supplementing the previous review with the missed comment on JAX initialization timing. Overall, the PR is very valuable for optimizing large model checkpoints.
🔍 General Feedback
- Initialization Timing: Setting
XLA_FLAGSbefore JAX imports ensures the simulated CPU mesh is correctly established.
| # Set JAX environment | ||
| jax.config.update("jax_platforms", "cpu") | ||
| # Simulate CPU devices as virtual mesh | ||
| os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" |
There was a problem hiding this comment.
| os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" | |
| if __name__ == "__main__": | |
| # Define local parser | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--simulated_cpu_devices_count", | |
| type=int, | |
| required=False, | |
| default=16, | |
| help="Number of simulated CPU devices for sharding the checkpoint", | |
| ) | |
| # Parse known args returns the namespace AND the list of remaining arguments | |
| local_args, remaining_args = parser.parse_known_args() | |
| # Set JAX environment BEFORE any jax imports if possible, | |
| # or at least before any jax calls that might trigger initialization. | |
| os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging | |
| import jax | |
| jax.config.update("jax_default_prng_impl", "unsafe_rbg") | |
| jax.config.update("jax_platforms", "cpu") | |
| # Reconstruct model_args (script name + the args MaxText needs) | |
| model_args = [sys.argv[0]] + remaining_args | |
| app.run(main, argv=model_args) |
Description
This script re-shards a MaxText checkpoint on CPU. The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead.
FIXES: b/504714612
Introduction
Problem: In checkpoint conversion, we typically shard along the 0th dimension (usually the expert dimension for MoE). Consequently, loading is fast when the target sharding is EP (e.g., a few minutes), but noticeably slow for FSDP (e.g., an hour). This is a major bottleneck because FSDP is our most common use case.
Effectiveness: Our experiments show that pre-sharding a checkpoint to fsdp=16 reduces the loading time of DeepSeek-V3 from 60 minutes to 6 minutes on a v5p-128 cluster targeting fsdp=64. Furthermore, the solution scales efficiently to v7x 1k chips, maintaining a brief 10-minute load time.
Generalizability: While this was built to solve the FSDP loading bottleneck, the solution generalizes to pre-shard checkpoints into other target sharding layout.
Method
The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh, and then saved to a new checkpoint.
Key operation trace: maxengine.load_params -> maxtext_utils.setup_decode_state -> checkpointing.load_params_from_path -> orbax.checkpoint.Checkpointer.restore
User Guide
Full details are in docstring.
Key Parameters:
--simulated_cpu_devices_count(defaults to 16). Examples:--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2weight_dtype: The dtype used to load and save the checkpoint. Highly recommend usingweight_dtype=bfloat16.Memory Requirements:
weight_dtype=bfloat16).Tests
deepseek3-671b with mtp
Full test details in b/504714612 (comment3, comment8)
deepseek2-16b
Reshard:
Inspect structure:
forward_pass_logit_checker, load with target sharding fsdp=16:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.