@@ -137,8 +137,8 @@ def __init__(self, cfg: DictConfig) -> None:
137
137
)
138
138
self ._log_peak_memory_stats = False
139
139
140
- _ , rank = utils .get_world_size_and_rank ()
141
- self ._is_rank_zero = rank == 0
140
+ self . world_size , self . rank = utils .get_world_size_and_rank ()
141
+ self ._is_rank_zero = self . rank == 0
142
142
143
143
# Training cfg
144
144
self ._resume_from_checkpoint = cfg .resume_from_checkpoint
@@ -521,6 +521,20 @@ def _setup_model(
521
521
model , auto_wrap_policy = {modules .TransformerSelfAttentionLayer }
522
522
)
523
523
524
+ # Apply TP if specified
525
+ mesh_shape = (1 , 8 )
526
+ device_mesh = init_device_mesh (
527
+ "cuda" , tp_mesh_shape , mesh_dim_names = ("dp" , "tp" )
528
+ )
529
+
530
+ # Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
531
+ training .prepare_mha_for_tp (model , device_mesh ["tp" ])
532
+ parallelize_module (
533
+ model ,
534
+ device_mesh ["tp" ],
535
+ parallelize_plan = config .instantiate (cfg .parallelize_plan ),
536
+ )
537
+
524
538
# For FSDP sharding
525
539
fsdp_shard_conditions = [
526
540
partial (
@@ -533,6 +547,7 @@ def _setup_model(
533
547
shard_conditions = fsdp_shard_conditions ,
534
548
cpu_offload = fsdp_cpu_offload ,
535
549
reshard_after_forward = reshard_after_forward ,
550
+ device_mesh = device_mesh ["dp" ],
536
551
)
537
552
538
553
with training .set_default_dtype (self ._dtype ), self ._device :
@@ -638,8 +653,6 @@ def _setup_data(
638
653
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
639
654
iterable datasets and streaming datasets are not supported.
640
655
"""
641
- world_size , rank = utils .get_world_size_and_rank ()
642
-
643
656
if isinstance (cfg_dataset , ListConfig ):
644
657
datasets = [
645
658
config .instantiate (single_cfg_dataset , self ._tokenizer )
@@ -657,7 +670,7 @@ def _setup_data(
657
670
collate_fn = _get_component_from_path (collate_fn )
658
671
659
672
sampler = DistributedSampler (
660
- ds , num_replicas = world_size , rank = rank , shuffle = shuffle , seed = 0
673
+ ds , num_replicas = world_size , rank = self . rank , shuffle = shuffle , seed = 0
661
674
)
662
675
dataloader = DataLoader (
663
676
dataset = ds ,
@@ -687,8 +700,6 @@ def train(self) -> None:
687
700
# clean up before training begins
688
701
training .cleanup_before_training ()
689
702
690
- world_size , rank = utils .get_world_size_and_rank ()
691
-
692
703
# zero out the gradients before starting training
693
704
if not self ._optimizer_in_bwd :
694
705
self ._optimizer .zero_grad ()
@@ -708,7 +719,7 @@ def train(self) -> None:
708
719
# in case shuffle is True
709
720
self ._sampler .set_epoch (curr_epoch )
710
721
711
- pbar = tqdm (total = self ._steps_per_epoch , disable = not ( rank == 0 ) )
722
+ pbar = tqdm (total = self ._steps_per_epoch , disable = not self . _is_rank_zero )
712
723
for idx , batch in enumerate (self ._dataloader ):
713
724
if (
714
725
self .max_steps_per_epoch is not None
0 commit comments