Skip to content

Commit 6e8041d

Browse files
committed
Updates
1 parent df17170 commit 6e8041d

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

recipes/full_finetune_distributed.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,11 @@ def __init__(self, cfg: DictConfig) -> None:
137137
)
138138
self._log_peak_memory_stats = False
139139

140+
# Distributed variables
140141
self.world_size, self.rank = utils.get_world_size_and_rank()
141142
self._is_rank_zero = self.rank == 0
143+
self.nnodes = dist.get_local_size()
144+
self.enable_tensor_parallel = cfg.get("enable_tensor_parallel", False)
142145

143146
# Training cfg
144147
self._resume_from_checkpoint = cfg.resume_from_checkpoint
@@ -521,21 +524,22 @@ def _setup_model(
521524
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
522525
)
523526

524-
# Apply TP if specified
525-
mesh_shape = (1, 8)
526-
device_mesh = init_device_mesh(
527-
"cuda", tp_mesh_shape, mesh_dim_names=("dp", "tp")
528-
)
529-
530-
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
531-
training.prepare_mha_for_tp(model, device_mesh["tp"])
532-
parallelize_module(
533-
model,
534-
device_mesh["tp"],
535-
parallelize_plan=config.instantiate(cfg.parallelize_plan),
536-
)
527+
device_mesh = {}
528+
if self.enable_tensor_parallel:
529+
mesh_shape = (self.nnodes, self.world_size // self.nnodes)
530+
device_mesh = init_device_mesh(
531+
"cuda", mesh_shape, mesh_dim_names=("dp", "tp")
532+
)
533+
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
534+
training.prepare_mha_for_tp(model, device_mesh["tp"])
535+
# Apply tensor parallelism to the model
536+
parallelize_module(
537+
model,
538+
device_mesh["tp"],
539+
parallelize_plan=config.instantiate(cfg.parallelize_plan),
540+
)
537541

538-
# For FSDP sharding
542+
# Shard the model
539543
fsdp_shard_conditions = [
540544
partial(
541545
training.get_shard_conditions,
@@ -547,7 +551,7 @@ def _setup_model(
547551
shard_conditions=fsdp_shard_conditions,
548552
cpu_offload=fsdp_cpu_offload,
549553
reshard_after_forward=reshard_after_forward,
550-
device_mesh=device_mesh["dp"],
554+
dp_device_mesh=device_mesh.get("dp"),
551555
)
552556

553557
with training.set_default_dtype(self._dtype), self._device:

0 commit comments

Comments
 (0)