Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable Context Parallel #592

Draft
wants to merge 5 commits into
base: gh/XilunWu/6/base
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down
23 changes: 23 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,29 @@ def build_test_list():
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"FSDP+CP",
"fsdp+cp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP",
"hsdp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
# Note: removed the 2x relaxing in CP enablement
self.model_args.max_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @tianyu-l Want to understand is this okay?

For a general use case, we can also expand the CP to support stride-like feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit on why this change was needed by CP?

self.model_args.rope_theta,
)

Expand Down
21 changes: 14 additions & 7 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -24,36 +25,38 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
strict=True,
):
if d > 1:
Expand Down Expand Up @@ -86,6 +89,10 @@ def dp_replicate_enabled(self):
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
9 changes: 8 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

from collections import defaultdict
from typing import Tuple

import torch
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
Expand Down Expand Up @@ -72,13 +74,18 @@ def parallelize_llama(
)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_enabled or parallel_dims.cp_enabled:
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

if parallel_dims.cp_enabled:
dp_dim_names = dp_mesh.mesh_dim_names
assert isinstance(dp_dim_names, Tuple)
dp_mesh = world_mesh[(*dp_dim_names, "cp")]._flatten()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember we want to initialize all the PG in the very beginning. Can we move this to parallel_dims.py and use mesh_dim_name to rename it to dp_shard_cp?


apply_fsdp(
model,
dp_mesh,
Expand Down
53 changes: 48 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from datetime import timedelta

import torch

from typing import List, Optional, Set
from functools import partial

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record

from torchtitan import utils
Expand All @@ -28,17 +33,47 @@
)
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling

try:
from torch.distributed.tensor.experimental import context_parallel
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)


def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_mesh: Optional[DeviceMesh] = None,
):
if cp_mesh is not None:
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)

def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
def context(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_mesh is not None:
stack.enter_context(
context_parallel_ctx(
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)
)

yield

return context
Expand All @@ -61,6 +96,7 @@ def main(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -226,6 +262,7 @@ def loss_fn(pred, labels):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
world_mesh["cp"] if parallel_dims.cp_enabled else None,
)

# variables used to keep info for metrics logging
Expand Down Expand Up @@ -259,18 +296,24 @@ def loss_fn(pred, labels):
data_load_start = time.perf_counter()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
ntokens_since_last_log += labels.numel() // parallel_dims.cp
data_loading_times.append(time.perf_counter() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

training_context = train_context(
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_no_restore_buffers={input_ids, labels},
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
with training_context:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -287,7 +330,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with training_context:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

Expand Down
Loading