Skip to content

Commit

Permalink
Retrieve schedules from get_schedule_class()
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
H-Huang committed Oct 1, 2024
1 parent eef8bb2 commit c3d9a80
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 21 deletions.
4 changes: 2 additions & 2 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--experimental.pipeline_parallel_schedule FlexibleInterleaved1F1B",
],
],
"PP looped flexible 1f1b test",
Expand Down Expand Up @@ -265,7 +265,7 @@ def build_test_list():
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--experimental.pipeline_parallel_schedule Interleaved1F1B",
],
],
"PP looped 1f1b test",
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,14 +299,14 @@ def __init__(self):
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
type=str,
choices=["1f1b", "gpipe", "interleaved_1f1b", "flexible_interleaved_1f1b"],
choices=["1f1b", "gpipe", "Interleaved1F1B", "FlexibleInterleaved1F1B"],
default="1f1b",
help="""
Specify the Pipeline Parallel schedule to use.
The schedule must be compatible with the split points and stages_per_rank.
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
Looped schedules (e.g. Interleaved1F1B) require specifying pipeline_parallel_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
Expand Down
21 changes: 4 additions & 17 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,21 @@
from typing import Tuple

from torch.distributed.pipelining import (
Schedule1F1B,
ScheduleFlexibleInterleaved1F1B,
ScheduleGPipe,
ScheduleInterleaved1F1B,
)
from torch.distributed.pipelining.schedules import get_schedule_class
from torchtitan.logging import logger


def build_pipeline_schedule(job_config, stages, loss_fn):
looped_schedule = False

if job_config.experimental.pipeline_parallel_schedule == "1f1b":
schedule_class = Schedule1F1B
elif job_config.experimental.pipeline_parallel_schedule == "gpipe":
schedule_class = ScheduleGPipe
elif job_config.experimental.pipeline_parallel_schedule == "interleaved_1f1b":
schedule_class = ScheduleInterleaved1F1B
looped_schedule = True
elif (
schedule_class = get_schedule_class(
job_config.experimental.pipeline_parallel_schedule
== "flexible_interleaved_1f1b"
):
schedule_class = ScheduleFlexibleInterleaved1F1B
)
if schedule_class in [ScheduleInterleaved1F1B, ScheduleFlexibleInterleaved1F1B]:
looped_schedule = True
else:
raise NotImplementedError(
f"{job_config.experimental.pipeline_parallel_schedule} is not implemented"
)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule}"
)
Expand Down

0 comments on commit c3d9a80

Please sign in to comment.