Skip to content

Commit df17170

Browse files
committed
Initial commit
1 parent d4465c8 commit df17170

File tree

2 files changed

+23
-8
lines changed

2 files changed

+23
-8
lines changed

recipes/full_finetune_distributed.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def __init__(self, cfg: DictConfig) -> None:
137137
)
138138
self._log_peak_memory_stats = False
139139

140-
_, rank = utils.get_world_size_and_rank()
141-
self._is_rank_zero = rank == 0
140+
self.world_size, self.rank = utils.get_world_size_and_rank()
141+
self._is_rank_zero = self.rank == 0
142142

143143
# Training cfg
144144
self._resume_from_checkpoint = cfg.resume_from_checkpoint
@@ -521,6 +521,20 @@ def _setup_model(
521521
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
522522
)
523523

524+
# Apply TP if specified
525+
mesh_shape = (1, 8)
526+
device_mesh = init_device_mesh(
527+
"cuda", tp_mesh_shape, mesh_dim_names=("dp", "tp")
528+
)
529+
530+
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
531+
training.prepare_mha_for_tp(model, device_mesh["tp"])
532+
parallelize_module(
533+
model,
534+
device_mesh["tp"],
535+
parallelize_plan=config.instantiate(cfg.parallelize_plan),
536+
)
537+
524538
# For FSDP sharding
525539
fsdp_shard_conditions = [
526540
partial(
@@ -533,6 +547,7 @@ def _setup_model(
533547
shard_conditions=fsdp_shard_conditions,
534548
cpu_offload=fsdp_cpu_offload,
535549
reshard_after_forward=reshard_after_forward,
550+
device_mesh=device_mesh["dp"],
536551
)
537552

538553
with training.set_default_dtype(self._dtype), self._device:
@@ -638,8 +653,6 @@ def _setup_data(
638653
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
639654
iterable datasets and streaming datasets are not supported.
640655
"""
641-
world_size, rank = utils.get_world_size_and_rank()
642-
643656
if isinstance(cfg_dataset, ListConfig):
644657
datasets = [
645658
config.instantiate(single_cfg_dataset, self._tokenizer)
@@ -657,7 +670,7 @@ def _setup_data(
657670
collate_fn = _get_component_from_path(collate_fn)
658671

659672
sampler = DistributedSampler(
660-
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
673+
ds, num_replicas=world_size, rank=self.rank, shuffle=shuffle, seed=0
661674
)
662675
dataloader = DataLoader(
663676
dataset=ds,
@@ -687,8 +700,6 @@ def train(self) -> None:
687700
# clean up before training begins
688701
training.cleanup_before_training()
689702

690-
world_size, rank = utils.get_world_size_and_rank()
691-
692703
# zero out the gradients before starting training
693704
if not self._optimizer_in_bwd:
694705
self._optimizer.zero_grad()
@@ -708,7 +719,7 @@ def train(self) -> None:
708719
# in case shuffle is True
709720
self._sampler.set_epoch(curr_epoch)
710721

711-
pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
722+
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
712723
for idx, batch in enumerate(self._dataloader):
713724
if (
714725
self.max_steps_per_epoch is not None

torchtune/training/_distributed.py

+4
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def shard_model(
508508
*,
509509
cpu_offload: bool,
510510
reshard_after_forward: bool = True,
511+
device_mesh: Optional[DeviceMesh] = None,
511512
) -> None:
512513
"""
513514
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
@@ -534,6 +535,9 @@ def shard_model(
534535
if cpu_offload:
535536
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()
536537

538+
if device_mesh is not None:
539+
fsdp_kwargs["mesh"] = device_mesh
540+
537541
# Shard the model with FSDP, iterating in reverse to start with
538542
# lowest-level modules first
539543
num_layers_sharded = 0

0 commit comments

Comments
 (0)