Skip to content

Commit ad93aa8

Browse files
mori360pytorchmergebot
authored andcommitted
E2E composability testing (pytorch#141398)
Add 3D(pp+tp+fsdp) test `test_3d_with_tp_dp_pp` at test_pp_compodability Currently provide @parametrize on "ScheduleClass" for pp in [ScheduleGPipe, Schedule1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS, ScheduleInterleavedZeroBubble] "MixedPrecisionParam" for fsdp in [torch.bfloat16, torch.float32] Future work: 1. add fp8 2. add cp(context parallelism) to enable 4D test Pull Request resolved: pytorch#141398 Approved by: https://github.com/wconstab, https://github.com/kwen2501
1 parent 461bd2c commit ad93aa8

File tree

1 file changed

+194
-1
lines changed

1 file changed

+194
-1
lines changed

test/distributed/_composable/test_composability/test_pp_composability.py

+194-1
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
import torch
77
import torch.distributed.checkpoint as dcp
88
import torch.nn as nn
9+
import torch.nn.functional as F
910
from torch.distributed._tensor import DTensor
1011
from torch.distributed.checkpoint import FileSystemReader
1112
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
1213
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
1314
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
1415
from torch.distributed.checkpoint.stateful import Stateful
15-
from torch.distributed.device_mesh import init_device_mesh
16+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
1617
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
1718
from torch.distributed.pipelining import PipelineStage
1819
from torch.distributed.pipelining.schedules import (
@@ -23,6 +24,11 @@
2324
ScheduleInterleavedZeroBubble,
2425
ScheduleLoopedBFS,
2526
)
27+
from torch.distributed.tensor.parallel import (
28+
ColwiseParallel,
29+
parallelize_module,
30+
RowwiseParallel,
31+
)
2632
from torch.nn.parallel import DistributedDataParallel as DDP
2733
from torch.testing._internal.common_cuda import TEST_MULTIGPU
2834
from torch.testing._internal.common_distributed import (
@@ -58,6 +64,20 @@ def forward(self, x):
5864
return x
5965

6066

67+
class MLPModuleEven(torch.nn.Module):
68+
def __init__(self, d_hid: int):
69+
super().__init__()
70+
self.net1 = nn.Linear(d_hid, d_hid)
71+
self.net2 = nn.Linear(d_hid, d_hid)
72+
self.net3 = nn.Linear(d_hid, d_hid * 2)
73+
74+
def forward(self, x):
75+
x = F.relu(self.net1(x))
76+
x = F.relu(self.net2(x))
77+
x = F.relu(self.net3(x))
78+
return x
79+
80+
6181
class ComposabilityTest(MultiProcessTestCase):
6282
@classmethod
6383
def backend_str(cls) -> str:
@@ -354,6 +374,179 @@ def _dcp_test(self):
354374

355375
_dcp_test(self)
356376

377+
@requires_nccl()
378+
@skip_if_lt_x_gpu(8)
379+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs")
380+
@parametrize(
381+
"ScheduleClass",
382+
[
383+
ScheduleGPipe,
384+
Schedule1F1B,
385+
ScheduleInterleaved1F1B,
386+
ScheduleLoopedBFS,
387+
ScheduleInterleavedZeroBubble,
388+
],
389+
)
390+
@parametrize(
391+
"MixedPrecisionParam",
392+
[
393+
torch.bfloat16,
394+
torch.float32,
395+
],
396+
)
397+
def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam):
398+
device = torch.device("cuda", self.device)
399+
torch.cuda.set_device(self.device)
400+
store = torch.distributed.FileStore(self.file_name, self.world_size)
401+
torch.distributed.init_process_group(
402+
backend="nccl",
403+
store=store,
404+
rank=self.rank,
405+
world_size=self.world_size,
406+
)
407+
dim = 8
408+
tp_size = 2
409+
pp_size = 2
410+
num_microbatches = 8
411+
dp_size = self.world_size // (tp_size * pp_size)
412+
device_mesh = init_device_mesh(
413+
"cuda",
414+
mesh_shape=(dp_size, pp_size, tp_size),
415+
mesh_dim_names=("dp", "pp", "tp"),
416+
)
417+
dp_mesh = device_mesh["dp"]
418+
tp_mesh = device_mesh["tp"]
419+
pp_mesh = device_mesh["pp"]
420+
pp_group = device_mesh["pp"].get_group()
421+
422+
# create "entire model"
423+
total_layers = 8
424+
full_model = nn.ModuleList([MLPModuleEven(dim) for _ in range(total_layers)])
425+
ref_model = nn.Sequential(*copy.deepcopy(full_model))
426+
ref_model.to(self.device)
427+
428+
# dummy loss needed just to force backwards to run in schedule step
429+
def loss_fn(y, target):
430+
return y.sum()
431+
432+
# Apply DP to stage module
433+
def apply_fsdp(partial_model):
434+
# apply FSDP
435+
mp_policy = MixedPrecisionPolicy(
436+
param_dtype=MixedPrecisionParam,
437+
reduce_dtype=torch.float32,
438+
)
439+
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
440+
for layer_id in range(len(partial_model)):
441+
fully_shard(
442+
partial_model[layer_id],
443+
**fsdp_config,
444+
reshard_after_forward=False,
445+
)
446+
dp_model = fully_shard(partial_model, **fsdp_config)
447+
return dp_model
448+
449+
def apply_tp(
450+
model: nn.Module,
451+
tp_mesh: DeviceMesh,
452+
):
453+
parallelize_plan = {
454+
"net1": ColwiseParallel(),
455+
"net2": RowwiseParallel(),
456+
"net3": ColwiseParallel(),
457+
}
458+
for layer in model:
459+
parallelize_module(layer, tp_mesh, parallelize_plan)
460+
return model
461+
462+
# Attach to a schedule
463+
if issubclass(ScheduleClass, PipelineScheduleSingle):
464+
stage_idx = pp_group.rank()
465+
partial_model = nn.Sequential(
466+
*full_model[stage_idx * 2 : stage_idx * 2 + 2]
467+
)
468+
partial_model.to(self.device)
469+
470+
tp_model = apply_tp(partial_model, tp_mesh)
471+
dp_model = apply_fsdp(tp_model)
472+
pipeline_stage = PipelineStage(
473+
dp_model,
474+
stage_idx,
475+
pp_group.size(),
476+
self.device,
477+
group=pp_group,
478+
)
479+
partial_models = [pipeline_stage.submod]
480+
pipeline_schedule = ScheduleClass(
481+
pipeline_stage,
482+
n_microbatches=num_microbatches,
483+
loss_fn=loss_fn,
484+
)
485+
else:
486+
n_virtual = 2
487+
num_stages = pp_group.size() * n_virtual
488+
stages = []
489+
for i in range(n_virtual):
490+
stage_idx = pp_group.rank() + n_virtual * i
491+
# divide the model layers by the number of stages
492+
partial_model = nn.Sequential(*full_model[stage_idx : stage_idx + 1])
493+
partial_model.to(self.device)
494+
495+
tp_model = apply_tp(partial_model, tp_mesh)
496+
dp_model = apply_fsdp(tp_model)
497+
stage = PipelineStage(
498+
dp_model,
499+
stage_idx,
500+
num_stages,
501+
self.device,
502+
group=pp_group,
503+
)
504+
505+
stages.append(stage)
506+
partial_models = [pipeline_stage.submod for pipeline_stage in stages]
507+
pipeline_schedule = ScheduleClass(
508+
stages,
509+
n_microbatches=num_microbatches,
510+
loss_fn=loss_fn,
511+
)
512+
513+
optimizer_kwargs = {
514+
"lr": 0.01,
515+
"betas": (0.9, 0.95),
516+
"weight_decay": 0.1,
517+
"fused": False,
518+
"foreach": True,
519+
}
520+
optimizers = [
521+
torch.optim.AdamW(model.parameters(), **optimizer_kwargs)
522+
for model in partial_models
523+
]
524+
525+
for train_step in range(5):
526+
for optimizer in optimizers:
527+
optimizer.zero_grad()
528+
inputs = torch.rand((num_microbatches, dim), device=self.device)
529+
labels = torch.rand((num_microbatches, dim), device=self.device)
530+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
531+
if pp_mesh.get_local_rank() == 0:
532+
pipeline_schedule.step(inputs)
533+
elif is_last_stage:
534+
losses = []
535+
pipeline_schedule.step(target=labels, losses=losses)
536+
else:
537+
pipeline_schedule.step()
538+
539+
# accumulate losses across pipeline microbatches
540+
loss = (
541+
torch.mean(torch.stack(losses))
542+
if is_last_stage
543+
else torch.Tensor([-1.0])
544+
)
545+
for optimizer in optimizers:
546+
optimizer.step()
547+
548+
torch.distributed.destroy_process_group()
549+
357550

358551
instantiate_parametrized_tests(ComposabilityTest)
359552

0 commit comments

Comments
 (0)