|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" |
| 16 | +This script re-shards a MaxText checkpoint on CPU, assuming linen format. |
| 17 | +- The Orbax checkpoint is streamed from storage directly into the target sharded layout on a simulated CPU mesh, |
| 18 | + and then saved to a new checkpoint. |
| 19 | +- The goal is to pre-shard checkpoints (source) to accelerate loading on TPUs (target) by reducing re-sharding overhead. |
| 20 | + E.g., when target sharding is fsdp=64, checkpoint loading time varies across source sharding (fsdp=64 < fsdp=16 < ep=16) |
| 21 | +
|
| 22 | +Key Parameters: |
| 23 | +- `--simulated_cpu_devices_count` (defaults to 16). Examples: |
| 24 | + - **Suitable for most cases**: `--simulated_cpu_devices_count=16 ici_fsdp_parallelism=16` |
| 25 | + - More customization: `--simulated_cpu_devices_count=32 ici_fsdp_parallelism=16 ici_expert_parallelism=2` |
| 26 | +- `weight_dtype`: The dtype used to load and save the checkpoint. **Highly recommend** using `weight_dtype=bfloat16`. |
| 27 | +- `load_parameters_path`: The input checkpoint path (GCS or local). |
| 28 | +- `base_output_directory`: The output directory (GCS or local). |
| 29 | + - The output checkpoint path will be `<base_output_directory>/0/items` |
| 30 | +
|
| 31 | +Memory Requirements: |
| 32 | +- For X billion parameters, needs slightly over 2X GB RAM (each param takes 2 bytes with `weight_dtype=bfloat16`). |
| 33 | +- Note: We only hold one model copy in memory, as the re-sharding happens dynamically during the read operation. |
| 34 | + Additional buffer memory is needed mainly for the I/O streaming overhead, usually small compared to model weight. |
| 35 | +- Example: DeepSeek-V3 with MTP layers has 685B parameters, uses 1.37 TB for weights, and hits a peak RAM of ~1.45 TB. |
| 36 | +
|
| 37 | +Example Commands: |
| 38 | +
|
| 39 | +python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \ |
| 40 | + model_name=deepseek2-16b attention=dot_product mla_naive_kvcache=false \ |
| 41 | + scan_layers=True load_parameters_path=<input_ckpt_path> \ |
| 42 | + base_output_directory=<output_ckpt_dir> \ |
| 43 | + weight_dtype=bfloat16 \ |
| 44 | + checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \ |
| 45 | + skip_jax_distributed_system=True ici_fsdp_parallelism=16 \ |
| 46 | + --simulated_cpu_devices_count=16 |
| 47 | +
|
| 48 | +python3 -m maxtext.checkpoint_conversion.reshard_checkpoint \ |
| 49 | + model_name=deepseek3-671b mtp_num_layers=1 mtp_loss_scaling_factor=0.1 attention=dot_product mla_naive_kvcache=false \ |
| 50 | + scan_layers=True load_parameters_path=<input_ckpt_path> \ |
| 51 | + base_output_directory=<output_ckpt_dir> \ |
| 52 | + weight_dtype=bfloat16 \ |
| 53 | + checkpoint_storage_concurrent_gb=1024 checkpoint_storage_use_ocdbt=True checkpoint_storage_use_zarr3=True \ |
| 54 | + skip_jax_distributed_system=True ici_fsdp_parallelism=16 ici_expert_parallelism=2 \ |
| 55 | + --simulated_cpu_devices_count=32 |
| 56 | +""" |
| 57 | + |
| 58 | + |
| 59 | +import argparse |
| 60 | +import os |
| 61 | +import sys |
| 62 | +import time |
| 63 | +from typing import Sequence |
| 64 | +from absl import app |
| 65 | + |
| 66 | +import jax |
| 67 | +from flax.training import train_state |
| 68 | + |
| 69 | +from maxtext.configs import pyconfig |
| 70 | +from maxtext.inference.maxengine import maxengine |
| 71 | +from maxtext.utils import max_utils, max_logging |
| 72 | +from maxtext.common import checkpointing |
| 73 | +from maxtext.checkpoint_conversion.utils.utils import print_peak_memory |
| 74 | + |
| 75 | + |
| 76 | +def main(argv: Sequence[str]) -> None: |
| 77 | + config = pyconfig.initialize(argv) |
| 78 | + max_utils.print_system_information() |
| 79 | + max_logging.log(f"Load and save checkpoint with weight dtype: {config.weight_dtype}") |
| 80 | + |
| 81 | + # 1. Engine sets up the mesh based on config |
| 82 | + engine = maxengine.MaxEngine(config) |
| 83 | + rng = jax.random.PRNGKey(1234) |
| 84 | + rng, rng_load_params = jax.random.split(rng) |
| 85 | + |
| 86 | + # 2. Load parameters and reshard with the mesh |
| 87 | + start = time.time() |
| 88 | + params = engine.load_params(rng_load_params) |
| 89 | + max_logging.log(f"Elapse for checkpoint load (with reshard): {(time.time() - start) / 60:.2f} min") |
| 90 | + |
| 91 | + # 3. Save checkpoint |
| 92 | + start = time.time() |
| 93 | + save_ckpt_directory = config.base_output_directory |
| 94 | + |
| 95 | + # Dummy configs for the checkpoint_manager |
| 96 | + step_number = 0 |
| 97 | + enable_checkpointing = True |
| 98 | + async_checkpointing = False |
| 99 | + save_interval_steps = 1 |
| 100 | + |
| 101 | + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( |
| 102 | + save_ckpt_directory, |
| 103 | + enable_checkpointing, |
| 104 | + async_checkpointing, |
| 105 | + save_interval_steps, |
| 106 | + use_ocdbt=config.checkpoint_storage_use_ocdbt, |
| 107 | + use_zarr3=config.checkpoint_storage_use_zarr3, |
| 108 | + ) |
| 109 | + if checkpoint_manager is None: |
| 110 | + raise RuntimeError("Failed to create Orbax checkpoint manager.") |
| 111 | + |
| 112 | + state_new = train_state.TrainState( |
| 113 | + step=step_number, apply_fn=None, params=params, tx=None, opt_state={} # type: ignore |
| 114 | + ) |
| 115 | + |
| 116 | + if checkpointing.save_checkpoint(checkpoint_manager, step_number, state_new): |
| 117 | + save_ckpt_path = os.path.join(save_ckpt_directory, str(step_number), "items") |
| 118 | + max_logging.log(f"Saved checkpoint: {save_ckpt_path}") |
| 119 | + # Upon preemption, exit when and only when all ongoing saves are complete. |
| 120 | + checkpoint_manager.wait_until_finished() |
| 121 | + |
| 122 | + max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min") |
| 123 | + print_peak_memory() |
| 124 | + |
| 125 | + |
| 126 | +if __name__ == "__main__": |
| 127 | + jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
| 128 | + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" # Suppress TensorFlow logging |
| 129 | + |
| 130 | + # Define local parser |
| 131 | + parser = argparse.ArgumentParser() |
| 132 | + parser.add_argument( |
| 133 | + "--simulated_cpu_devices_count", |
| 134 | + type=int, |
| 135 | + required=False, |
| 136 | + default=16, |
| 137 | + help="Number of simulated CPU devices for sharding the checkpoint", |
| 138 | + ) |
| 139 | + |
| 140 | + # Parse known args returns the namespace AND the list of remaining arguments |
| 141 | + local_args, remaining_args = parser.parse_known_args() |
| 142 | + |
| 143 | + # Reconstruct model_args (script name + the args MaxText needs) |
| 144 | + model_args = [sys.argv[0]] + remaining_args |
| 145 | + |
| 146 | + # Set JAX environment |
| 147 | + jax.config.update("jax_platforms", "cpu") |
| 148 | + # Simulate CPU devices as virtual mesh |
| 149 | + os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={local_args.simulated_cpu_devices_count}" |
| 150 | + |
| 151 | + app.run(main, argv=model_args) |
0 commit comments