|
6 | 6 | import torch
|
7 | 7 | import torch.distributed.checkpoint as dcp
|
8 | 8 | import torch.nn as nn
|
| 9 | +import torch.nn.functional as F |
9 | 10 | from torch.distributed._tensor import DTensor
|
10 | 11 | from torch.distributed.checkpoint import FileSystemReader
|
11 | 12 | from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
12 | 13 | from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
|
13 | 14 | from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
14 | 15 | from torch.distributed.checkpoint.stateful import Stateful
|
15 |
| -from torch.distributed.device_mesh import init_device_mesh |
| 16 | +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh |
16 | 17 | from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
|
17 | 18 | from torch.distributed.pipelining import PipelineStage
|
18 | 19 | from torch.distributed.pipelining.schedules import (
|
|
23 | 24 | ScheduleInterleavedZeroBubble,
|
24 | 25 | ScheduleLoopedBFS,
|
25 | 26 | )
|
| 27 | +from torch.distributed.tensor.parallel import ( |
| 28 | + ColwiseParallel, |
| 29 | + parallelize_module, |
| 30 | + RowwiseParallel, |
| 31 | +) |
26 | 32 | from torch.nn.parallel import DistributedDataParallel as DDP
|
27 | 33 | from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
28 | 34 | from torch.testing._internal.common_distributed import (
|
@@ -58,6 +64,20 @@ def forward(self, x):
|
58 | 64 | return x
|
59 | 65 |
|
60 | 66 |
|
| 67 | +class MLPModuleEven(torch.nn.Module): |
| 68 | + def __init__(self, d_hid: int): |
| 69 | + super().__init__() |
| 70 | + self.net1 = nn.Linear(d_hid, d_hid) |
| 71 | + self.net2 = nn.Linear(d_hid, d_hid) |
| 72 | + self.net3 = nn.Linear(d_hid, d_hid * 2) |
| 73 | + |
| 74 | + def forward(self, x): |
| 75 | + x = F.relu(self.net1(x)) |
| 76 | + x = F.relu(self.net2(x)) |
| 77 | + x = F.relu(self.net3(x)) |
| 78 | + return x |
| 79 | + |
| 80 | + |
61 | 81 | class ComposabilityTest(MultiProcessTestCase):
|
62 | 82 | @classmethod
|
63 | 83 | def backend_str(cls) -> str:
|
@@ -354,6 +374,179 @@ def _dcp_test(self):
|
354 | 374 |
|
355 | 375 | _dcp_test(self)
|
356 | 376 |
|
| 377 | + @requires_nccl() |
| 378 | + @skip_if_lt_x_gpu(8) |
| 379 | + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "Test requires 8+ GPUs") |
| 380 | + @parametrize( |
| 381 | + "ScheduleClass", |
| 382 | + [ |
| 383 | + ScheduleGPipe, |
| 384 | + Schedule1F1B, |
| 385 | + ScheduleInterleaved1F1B, |
| 386 | + ScheduleLoopedBFS, |
| 387 | + ScheduleInterleavedZeroBubble, |
| 388 | + ], |
| 389 | + ) |
| 390 | + @parametrize( |
| 391 | + "MixedPrecisionParam", |
| 392 | + [ |
| 393 | + torch.bfloat16, |
| 394 | + torch.float32, |
| 395 | + ], |
| 396 | + ) |
| 397 | + def test_3d_with_tp_dp_pp(self, ScheduleClass, MixedPrecisionParam): |
| 398 | + device = torch.device("cuda", self.device) |
| 399 | + torch.cuda.set_device(self.device) |
| 400 | + store = torch.distributed.FileStore(self.file_name, self.world_size) |
| 401 | + torch.distributed.init_process_group( |
| 402 | + backend="nccl", |
| 403 | + store=store, |
| 404 | + rank=self.rank, |
| 405 | + world_size=self.world_size, |
| 406 | + ) |
| 407 | + dim = 8 |
| 408 | + tp_size = 2 |
| 409 | + pp_size = 2 |
| 410 | + num_microbatches = 8 |
| 411 | + dp_size = self.world_size // (tp_size * pp_size) |
| 412 | + device_mesh = init_device_mesh( |
| 413 | + "cuda", |
| 414 | + mesh_shape=(dp_size, pp_size, tp_size), |
| 415 | + mesh_dim_names=("dp", "pp", "tp"), |
| 416 | + ) |
| 417 | + dp_mesh = device_mesh["dp"] |
| 418 | + tp_mesh = device_mesh["tp"] |
| 419 | + pp_mesh = device_mesh["pp"] |
| 420 | + pp_group = device_mesh["pp"].get_group() |
| 421 | + |
| 422 | + # create "entire model" |
| 423 | + total_layers = 8 |
| 424 | + full_model = nn.ModuleList([MLPModuleEven(dim) for _ in range(total_layers)]) |
| 425 | + ref_model = nn.Sequential(*copy.deepcopy(full_model)) |
| 426 | + ref_model.to(self.device) |
| 427 | + |
| 428 | + # dummy loss needed just to force backwards to run in schedule step |
| 429 | + def loss_fn(y, target): |
| 430 | + return y.sum() |
| 431 | + |
| 432 | + # Apply DP to stage module |
| 433 | + def apply_fsdp(partial_model): |
| 434 | + # apply FSDP |
| 435 | + mp_policy = MixedPrecisionPolicy( |
| 436 | + param_dtype=MixedPrecisionParam, |
| 437 | + reduce_dtype=torch.float32, |
| 438 | + ) |
| 439 | + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} |
| 440 | + for layer_id in range(len(partial_model)): |
| 441 | + fully_shard( |
| 442 | + partial_model[layer_id], |
| 443 | + **fsdp_config, |
| 444 | + reshard_after_forward=False, |
| 445 | + ) |
| 446 | + dp_model = fully_shard(partial_model, **fsdp_config) |
| 447 | + return dp_model |
| 448 | + |
| 449 | + def apply_tp( |
| 450 | + model: nn.Module, |
| 451 | + tp_mesh: DeviceMesh, |
| 452 | + ): |
| 453 | + parallelize_plan = { |
| 454 | + "net1": ColwiseParallel(), |
| 455 | + "net2": RowwiseParallel(), |
| 456 | + "net3": ColwiseParallel(), |
| 457 | + } |
| 458 | + for layer in model: |
| 459 | + parallelize_module(layer, tp_mesh, parallelize_plan) |
| 460 | + return model |
| 461 | + |
| 462 | + # Attach to a schedule |
| 463 | + if issubclass(ScheduleClass, PipelineScheduleSingle): |
| 464 | + stage_idx = pp_group.rank() |
| 465 | + partial_model = nn.Sequential( |
| 466 | + *full_model[stage_idx * 2 : stage_idx * 2 + 2] |
| 467 | + ) |
| 468 | + partial_model.to(self.device) |
| 469 | + |
| 470 | + tp_model = apply_tp(partial_model, tp_mesh) |
| 471 | + dp_model = apply_fsdp(tp_model) |
| 472 | + pipeline_stage = PipelineStage( |
| 473 | + dp_model, |
| 474 | + stage_idx, |
| 475 | + pp_group.size(), |
| 476 | + self.device, |
| 477 | + group=pp_group, |
| 478 | + ) |
| 479 | + partial_models = [pipeline_stage.submod] |
| 480 | + pipeline_schedule = ScheduleClass( |
| 481 | + pipeline_stage, |
| 482 | + n_microbatches=num_microbatches, |
| 483 | + loss_fn=loss_fn, |
| 484 | + ) |
| 485 | + else: |
| 486 | + n_virtual = 2 |
| 487 | + num_stages = pp_group.size() * n_virtual |
| 488 | + stages = [] |
| 489 | + for i in range(n_virtual): |
| 490 | + stage_idx = pp_group.rank() + n_virtual * i |
| 491 | + # divide the model layers by the number of stages |
| 492 | + partial_model = nn.Sequential(*full_model[stage_idx : stage_idx + 1]) |
| 493 | + partial_model.to(self.device) |
| 494 | + |
| 495 | + tp_model = apply_tp(partial_model, tp_mesh) |
| 496 | + dp_model = apply_fsdp(tp_model) |
| 497 | + stage = PipelineStage( |
| 498 | + dp_model, |
| 499 | + stage_idx, |
| 500 | + num_stages, |
| 501 | + self.device, |
| 502 | + group=pp_group, |
| 503 | + ) |
| 504 | + |
| 505 | + stages.append(stage) |
| 506 | + partial_models = [pipeline_stage.submod for pipeline_stage in stages] |
| 507 | + pipeline_schedule = ScheduleClass( |
| 508 | + stages, |
| 509 | + n_microbatches=num_microbatches, |
| 510 | + loss_fn=loss_fn, |
| 511 | + ) |
| 512 | + |
| 513 | + optimizer_kwargs = { |
| 514 | + "lr": 0.01, |
| 515 | + "betas": (0.9, 0.95), |
| 516 | + "weight_decay": 0.1, |
| 517 | + "fused": False, |
| 518 | + "foreach": True, |
| 519 | + } |
| 520 | + optimizers = [ |
| 521 | + torch.optim.AdamW(model.parameters(), **optimizer_kwargs) |
| 522 | + for model in partial_models |
| 523 | + ] |
| 524 | + |
| 525 | + for train_step in range(5): |
| 526 | + for optimizer in optimizers: |
| 527 | + optimizer.zero_grad() |
| 528 | + inputs = torch.rand((num_microbatches, dim), device=self.device) |
| 529 | + labels = torch.rand((num_microbatches, dim), device=self.device) |
| 530 | + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 |
| 531 | + if pp_mesh.get_local_rank() == 0: |
| 532 | + pipeline_schedule.step(inputs) |
| 533 | + elif is_last_stage: |
| 534 | + losses = [] |
| 535 | + pipeline_schedule.step(target=labels, losses=losses) |
| 536 | + else: |
| 537 | + pipeline_schedule.step() |
| 538 | + |
| 539 | + # accumulate losses across pipeline microbatches |
| 540 | + loss = ( |
| 541 | + torch.mean(torch.stack(losses)) |
| 542 | + if is_last_stage |
| 543 | + else torch.Tensor([-1.0]) |
| 544 | + ) |
| 545 | + for optimizer in optimizers: |
| 546 | + optimizer.step() |
| 547 | + |
| 548 | + torch.distributed.destroy_process_group() |
| 549 | + |
357 | 550 |
|
358 | 551 | instantiate_parametrized_tests(ComposabilityTest)
|
359 | 552 |
|
|
0 commit comments