Skip to content

Commit fa776ba

Browse files
committed
Add checkpoint resharding script for faster loading
1 parent 1adb45c commit fa776ba

2 files changed

Lines changed: 158 additions & 6 deletions

File tree

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This script re-shards a MaxText checkpoint on CPU, assuming linen format.
17+
- The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh,
18+
and then saved to a new checkpoint.
19+
- The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead.
20+
E.g., when target sharding is fsdp=64, checkpoint loading time varies across source sharding (fsdp=64 < fsdp=16 < ep=16)
21+
22+
Key Parameters:
23+
- `--simulated_cpu_devices_count` (defaults to 16). Examples:
24+
- **Suitable for most cases**: `--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16`
25+
- More customization: `--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2`
26+
- `weight_dtype`: The dtype used to load and save the checkpoint. **Highly recommend** using `weight_dtype=bfloat16`.
27+
- `load_parameters_path`: The input checkpoint path (GCS or local).
28+
- `base_output_directory`: The output directory (GCS or local).
29+
- The output checkpoint path will be `<base_output_directory>/0/items`
30+
31+
Memory Requirements:
32+
- For X billion parameters, needs slightly over 2X GB RAM (each param takes 2 bytes with `weight_dtype=bfloat16`).
33+
- Note: We only hold one model copy in memory, as the re-sharding happens dynamically during the read operation.
34+
Additional buffer memory is needed mainly for the I/O streaming overhead, usually small compared to model weight.
35+
- Example: DeepSeek-V3 with MTP layers has 685B parameters, uses 1.37 TB for weights, and hits a peak RAM of ~1.45 TB.
36+
37+
Example Commands:
38+
39+
python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \
40+
model_name=deepseek2-16b attention=dot_product mla_naive_kvcache=false \
41+
scan_layers=True load_parameters_path=<input_ckpt_path> \
42+
base_output_directory=<output_ckpt_dir> \
43+
weight_dtype=bfloat16 \
44+
checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \
45+
skip_jax_distributed_system=True ici_fsdp_parallelism=16 \
46+
--simulated_cpu_devices_count=16
47+
48+
python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \
49+
model_name=deepseek3-671b mtp_num_layers=1 mtp_loss_scaling_factor=0.1 attention=dot_product mla_naive_kvcache=false \
50+
scan_layers=True load_parameters_path=<input_ckpt_path> \
51+
base_output_directory=<output_ckpt_dir> \
52+
weight_dtype=bfloat16 \
53+
checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \
54+
skip_jax_distributed_system=True ici_fsdp_parallelism=16 ici_expert_parallelism=2 \
55+
--simulated_cpu_devices_count=32
56+
"""
57+
58+
59+
import argparse
60+
import os
61+
import sys
62+
import time
63+
from typing import Sequence
64+
from absl import app
65+
66+
import jax
67+
from flax.training import train_state
68+
69+
from maxtext.configs import pyconfig
70+
from maxtext.inference.maxengine import maxengine
71+
from maxtext.utils import max_utils, max_logging
72+
from maxtext.common import checkpointing
73+
from maxtext.checkpoint_conversion.utils.utils import print_peak_memory
74+
75+
76+
def main(argv: Sequence[str]) -> None:
77+
config = pyconfig.initialize(argv)
78+
max_utils.print_system_information()
79+
max_logging.log(f"Load and save checkpoint with weight dtype: {config.weight_dtype}")
80+
81+
# 1. Engine sets up the mesh based on config
82+
engine = maxengine.MaxEngine(config)
83+
rng = jax.random.PRNGKey(1234)
84+
rng, rng_load_params = jax.random.split(rng)
85+
86+
# 2. Load parameters and reshard with the mesh
87+
start = time.time()
88+
params = engine.load_params(rng_load_params)
89+
max_logging.log(f"Elapse for checkpoint load (with reshard): {(time.time() - start) / 60:.2f} min")
90+
91+
# 3. Save checkpoint
92+
start = time.time()
93+
save_ckpt_directory = config.base_output_directory
94+
95+
# Dummy configs for the checkpoint_manager
96+
step_number = 0
97+
enable_checkpointing = True
98+
async_checkpointing = False
99+
save_interval_steps = 1
100+
101+
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
102+
save_ckpt_directory,
103+
enable_checkpointing,
104+
async_checkpointing,
105+
save_interval_steps,
106+
use_ocdbt=config.checkpoint_storage_use_ocdbt,
107+
use_zarr3=config.checkpoint_storage_use_zarr3,
108+
)
109+
if checkpoint_manager is None:
110+
raise RuntimeError("Failed to create Orbax checkpoint manager.")
111+
112+
state_new = train_state.TrainState(
113+
step=step_number, apply_fn=None, params=params, tx=None, opt_state={} # type: ignore
114+
)
115+
116+
if checkpointing.save_checkpoint(checkpoint_manager, step_number, state_new):
117+
save_ckpt_path = os.path.join(save_ckpt_directory, str(step_number), "items")
118+
max_logging.log(f"Saved checkpoint: {save_ckpt_path}")
119+
# Upon preemption, exit when and only when all ongoing saves are complete.
120+
checkpoint_manager.wait_until_finished()
121+
122+
max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")
123+
print_peak_memory()
124+
125+
126+
if __name__ == "__main__":
127+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
128+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging
129+
130+
# Define local parser
131+
parser = argparse.ArgumentParser()
132+
parser.add_argument(
133+
"--simulated_cpu_devices_count",
134+
type=int,
135+
required=False,
136+
default=16,
137+
help="Number of simulated CPU devices for sharding the checkpoint",
138+
)
139+
140+
# Parse known args returns the namespace AND the list of remaining arguments
141+
local_args, remaining_args = parser.parse_known_args()
142+
143+
# Reconstruct model_args (script name + the args MaxText needs)
144+
model_args = [sys.argv[0]] + remaining_args
145+
146+
# Set JAX environment
147+
jax.config.update("jax_platforms", "cpu")
148+
# Simulate CPU devices as virtual mesh
149+
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}"
150+
151+
app.run(main, argv=model_args)

src/maxtext/checkpoint_conversion/standalone_scripts/llama_or_mistral_ckpt.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1760,17 +1760,18 @@ def save_weights_to_checkpoint(
17601760
use_ocdbt=use_ocdbt,
17611761
use_zarr3=use_zarr3,
17621762
)
1763+
if checkpoint_manager is None:
1764+
raise RuntimeError("Failed to create Orbax checkpoint manager.")
17631765

17641766
state_new = train_state.TrainState(
1765-
step=0, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
1767+
step=step_number_to_save_new_ckpt, apply_fn=None, params={"params": jax_weights}, tx=None, opt_state={} # type: ignore
17661768
)
17671769

17681770
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
1769-
if checkpoint_manager is not None:
1770-
if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new):
1771-
max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}")
1772-
# Upon preemption, exit when and only when all ongoing saves are complete.
1773-
checkpoint_manager.wait_until_finished()
1771+
if checkpointing.save_checkpoint(checkpoint_manager, step_number_to_save_new_ckpt, state_new):
1772+
max_logging.log(f"saved a checkpoint at step {step_number_to_save_new_ckpt}")
1773+
# Upon preemption, exit when and only when all ongoing saves are complete.
1774+
checkpoint_manager.wait_until_finished()
17741775

17751776
max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")
17761777

0 commit comments

Comments
 (0)