@@ -137,8 +137,11 @@ def __init__(self, cfg: DictConfig) -> None:
137
137
)
138
138
self ._log_peak_memory_stats = False
139
139
140
+ # Distributed variables
140
141
self .world_size , self .rank = utils .get_world_size_and_rank ()
141
142
self ._is_rank_zero = self .rank == 0
143
+ self .nnodes = dist .get_local_size ()
144
+ self .enable_tensor_parallel = cfg .get ("enable_tensor_parallel" , False )
142
145
143
146
# Training cfg
144
147
self ._resume_from_checkpoint = cfg .resume_from_checkpoint
@@ -521,21 +524,22 @@ def _setup_model(
521
524
model , auto_wrap_policy = {modules .TransformerSelfAttentionLayer }
522
525
)
523
526
524
- # Apply TP if specified
525
- mesh_shape = (1 , 8 )
526
- device_mesh = init_device_mesh (
527
- "cuda" , tp_mesh_shape , mesh_dim_names = ("dp" , "tp" )
528
- )
529
-
530
- # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
531
- training .prepare_mha_for_tp (model , device_mesh ["tp" ])
532
- parallelize_module (
533
- model ,
534
- device_mesh ["tp" ],
535
- parallelize_plan = config .instantiate (cfg .parallelize_plan ),
536
- )
527
+ device_mesh = {}
528
+ if self .enable_tensor_parallel :
529
+ mesh_shape = (self .nnodes , self .world_size // self .nnodes )
530
+ device_mesh = init_device_mesh (
531
+ "cuda" , mesh_shape , mesh_dim_names = ("dp" , "tp" )
532
+ )
533
+ # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
534
+ training .prepare_mha_for_tp (model , device_mesh ["tp" ])
535
+ # Apply tensor parallelism to the model
536
+ parallelize_module (
537
+ model ,
538
+ device_mesh ["tp" ],
539
+ parallelize_plan = config .instantiate (cfg .parallelize_plan ),
540
+ )
537
541
538
- # For FSDP sharding
542
+ # Shard the model
539
543
fsdp_shard_conditions = [
540
544
partial (
541
545
training .get_shard_conditions ,
@@ -547,7 +551,7 @@ def _setup_model(
547
551
shard_conditions = fsdp_shard_conditions ,
548
552
cpu_offload = fsdp_cpu_offload ,
549
553
reshard_after_forward = reshard_after_forward ,
550
- device_mesh = device_mesh [ "dp" ] ,
554
+ dp_device_mesh = device_mesh . get ( "dp" ) ,
551
555
)
552
556
553
557
with training .set_default_dtype (self ._dtype ), self ._device :
0 commit comments