5757logger .setLevel (logging .INFO )
5858
5959
60- @hydra .main (config_path = "hydra_config" , config_name = "L0_sanity_cp " , version_base = "1.2" )
60+ @hydra .main (config_path = "hydra_config" , config_name = "L2_sanity_nd " , version_base = "1.2" )
6161def main (args : DictConfig ) -> float | None :
6262 """Train Llama3 with TE layers using FSDP2 with Context Parallelism.
6363
@@ -73,8 +73,8 @@ def main(args: DictConfig) -> float | None:
7373
7474 device_mesh = init_device_mesh (
7575 "cuda" ,
76- mesh_shape = (dist_config .world_size // args .cp_size , args .cp_size ),
77- mesh_dim_names = ("dp" , "cp" ),
76+ mesh_shape = (dist_config .world_size // ( args .cp_size * args . tp_size ) , args .cp_size , args . tp_size ),
77+ mesh_dim_names = ("dp" , "cp" , "tp" ),
7878 )
7979 logger .info ("Created device mesh: %s" , device_mesh )
8080
@@ -85,6 +85,22 @@ def main(args: DictConfig) -> float | None:
8585
8686 # --- Model Initialization ---
8787 config = NVLlamaConfig .from_pretrained (args .config_name_or_path , dtype = torch .bfloat16 , ** args .config_kwargs )
88+
89+ # Identify DeviceMesh that are propagated to `set_device_mesh` in TransformerEngine modules.
90+ # These will convert TransformerEngine parameters into DTensors. Alternatively, users can
91+ # manually call the conversion using `TransformerEngineModule.set_device_mesh(...)`` before
92+ # `reset_parameters` (which triggers quantization) if the module supports DTensor parameters.
93+ if config .tensor_parallel :
94+ config .tp_mesh = device_mesh ["tp" ]
95+ if (
96+ args .fp8_config .quantized_model_init_kwargs .enabled
97+ and isinstance (fp8_recipe , transformer_engine .common .recipe .Float8CurrentScaling )
98+ ):
99+ # When using per-tensor FP8 recipes for quantized parameters, TransformerEngine
100+ # requires a weight sharding mesh for absmax reduction across distributed weights.
101+ # If not provided, will default to DTensor.device_mesh.get_group(), which is not
102+ # appropriate if HSDP (DP-Replicate x DP-Shard) is used.
103+ config .weight_mesh = device_mesh ["dp" , "cp" , "tp" ]._flatten ("weight_mesh" )
88104
89105 # Optionally use transformer engine to initialize only fp8 versions of weights by setting
90106 # `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
@@ -100,6 +116,8 @@ def main(args: DictConfig) -> float | None:
100116 logger .info ("Initialized Model:\n %s" , model )
101117
102118 # --- Distributed Wrapping (FSDP2 + CP) ---
119+
120+ # Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks.
103121 cp_dp_mesh = device_mesh ["dp" , "cp" ]._flatten (mesh_dim_name = "dp_shard_cp" )
104122
105123 # Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
@@ -108,7 +126,7 @@ def main(args: DictConfig) -> float | None:
108126 fully_shard (layer , mesh = cp_dp_mesh )
109127 fully_shard (model , mesh = cp_dp_mesh )
110128
111- # Attach the CP group to the model.
129+ # Attach the CP ProcessGroup to the TransformerEngine model.
112130 for layer in model .model .layers :
113131 layer .set_context_parallel_group (
114132 device_mesh ["cp" ].get_group (),
@@ -137,9 +155,12 @@ def main(args: DictConfig) -> float | None:
137155 logger .info ("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2" )
138156 OmegaConf .update (args , "dataset.pad_sequences_to_be_divisible_by" , device_mesh ["cp" ].size () * 2 )
139157
140- # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
141- # ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
142- if device_mesh ["cp" ].get_local_rank () == 0 :
158+ # We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks.
159+ # This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
160+ cp_tp_mesh = device_mesh ["cp" , "tp" ]._flatten (mesh_dim_name = "cp_tp" )
161+ if cp_tp_mesh .get_local_rank () == 0 :
162+ # We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper
163+ # that will shard, replicate, and distribute the data across the flattened CP and TP group.
143164 if args .use_sequence_packing :
144165 train_dataloader , dataset_or_sampler = create_thd_dataloader (dist_config , ** args .dataset )
145166 else :
@@ -156,8 +177,8 @@ def main(args: DictConfig) -> float | None:
156177 train_dataloader = None
157178 dataset_or_sampler = None
158179
159- # On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0 .
160- train_dataloader = ContextParallelDataLoaderWrapper (train_dataloader , device_mesh [ "cp" ] )
180+ # Deliver CP-sharded replicates to a flattened CP-TP mesh .
181+ train_dataloader = ContextParallelDataLoaderWrapper (train_dataloader , cp_tp_mesh )
161182
162183 # --- Checkpoint Resume ---
163184 ckpt_path = Path (args .checkpoint .ckpt_dir ) / "train_fsdp2" if args .checkpoint .ckpt_dir else None
@@ -170,7 +191,6 @@ def main(args: DictConfig) -> float | None:
170191 ckpt_path = ckpt_path ,
171192 dist_config = dist_config ,
172193 dataloader = train_dataloader ,
173- process_group = cp_dp_mesh .get_group (),
174194 )
175195 logger .info ("Checkpoint loaded, resuming from step %s, epoch %s" , start_step , epoch )
176196 else :
@@ -226,6 +246,13 @@ def main(args: DictConfig) -> float | None:
226246 )
227247
228248 if ckpt_path and should_save_checkpoint (step , args .checkpoint .save_every_n_steps ):
249+ if args .checkpoint .async_save and args .fp8_config .quantized_model_init_kwargs .enabled :
250+ logger .info (
251+ "Asynchronous checkpointing is not supported with TransformerEngine "
252+ "quantized parameters and FSDP2. Using synchronous checkpointing "
253+ "(checkpoint.async_save=false)..."
254+ )
255+ OmegaConf .update (args , "checkpoint.async_save" , False )
229256 save_checkpoint_fsdp2 (
230257 model = model ,
231258 optimizer = optimizer ,
@@ -235,7 +262,6 @@ def main(args: DictConfig) -> float | None:
235262 epoch = epoch ,
236263 dist_config = dist_config ,
237264 dataloader = train_dataloader if args .dataset .use_stateful_dataloader else None ,
238- process_group = cp_dp_mesh .get_group (),
239265 max_checkpoints = args .checkpoint .max_checkpoints ,
240266 async_save = args .checkpoint .async_save ,
241267 )
@@ -268,4 +294,4 @@ def main(args: DictConfig) -> float | None:
268294
269295
270296if __name__ == "__main__" :
271- main ()
297+ main ()
0 commit comments