Skip to content

enable Context Parallel #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
3bf7333
enable Context Parallel
XilunWu Sep 30, 2024
afb1051
Update on "enable Context Parallel"
XilunWu Sep 30, 2024
f99a6f5
Update on "enable Context Parallel"
XilunWu Oct 3, 2024
4ad6881
Update on "enable Context Parallel"
XilunWu Oct 3, 2024
038b5ce
Update on "enable Context Parallel"
XilunWu Oct 4, 2024
4758df2
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
a6758dd
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
f570fa8
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
c102f73
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
83230fd
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
2863907
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
534ce58
Update base for Update on "enable Context Parallel"
XilunWu Oct 21, 2024
0c355e6
Update on "enable Context Parallel"
XilunWu Oct 21, 2024
172717d
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
b89e59b
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
a5e453f
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
e319ab9
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
99fe0bc
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
9bec02c
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
47c0078
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
15c00d5
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
bba36b4
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
346d721
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
a5d1fdf
Update base for Update on "enable Context Parallel"
XilunWu Oct 22, 2024
8045cad
Update on "enable Context Parallel"
XilunWu Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def estimate_memory(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down
23 changes: 23 additions & 0 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,29 @@ def build_test_list():
"hsdp+tp",
ngpu=8,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"FSDP+CP",
"fsdp+cp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_shard_degree=2",
"--training.data_parallel_replicate_degree=2",
"--experimental.context_parallel_degree=2",
]
],
"HSDP+CP",
"hsdp+cp",
ngpu=8,
),
OverrideDefinitions(
[
[
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
return precompute_freqs_cis(
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
# Note: removed the 2x relaxing in CP enablement
self.model_args.max_seq_len,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc., @tianyu-l Want to understand is this okay?

For a general use case, we can also expand the CP to support stride-like feature.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please elaborate a bit on why this change was needed by CP?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l CP parallelize on the sequence dimension, anything related to the sequence dimension needs to be shard. So freqs_cis is the positional embedding and is required to be sharded according to the sequence length. So it is easier to support CP if everything has the same sequence length.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable to me. @awgu to confirm this is OK.

Also we need to add a note in docs/composability.md to clarify why this (model change) is needed. It can be addressed in a separate PR; in that case please create issue / leave TODO.

self.model_args.rope_theta,
)

Expand Down
21 changes: 14 additions & 7 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class ParallelDims:
dp_replicate: int
dp_shard: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -24,36 +25,38 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp_replicate, dp_shard, tp, pp = (
dp_replicate, dp_shard, cp, tp, pp = (
self.dp_replicate,
self.dp_shard,
self.cp,
self.tp,
self.pp,
)
for d in (dp_replicate, tp, pp):
for d in (dp_replicate, cp, tp, pp):
assert d >= 1, "Parallelism degree should be >= 1, except for dp_shard"
assert dp_shard == -1 or dp_shard >= 1, " dp_shard must -1 or >=1."

dp = dp_replicate * dp_shard
if dp < 0:
dp = self.world_size // (tp * pp)
dp = self.world_size // (cp * tp * pp)
self.dp_shard = dp_shard = dp // dp_replicate

assert dp_replicate >= 1
assert dp_shard >= 1
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert dp_replicate * dp_shard * tp * pp == self.world_size, (
assert dp_replicate * dp_shard * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp_replicate({dp_replicate}) * dp_shard({dp_shard}) * "
f"tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
f"cp({cp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
)

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp_replicate, self.dp_shard, self.tp],
["pp", "dp_replicate", "dp_shard", "tp"],
[self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp],
["pp", "dp_replicate", "dp_shard", "cp", "tp"],
strict=True,
):
if d > 1:
Expand Down Expand Up @@ -86,6 +89,10 @@ def dp_replicate_enabled(self):
def dp_shard_enabled(self):
return self.dp_shard > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand Down
9 changes: 8 additions & 1 deletion torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# training techniques (e.g. activation checkpointing and compile) to the Llama model.

from collections import defaultdict
from typing import Tuple

import torch
import torch.nn as nn

from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
Expand Down Expand Up @@ -72,13 +74,18 @@ def parallelize_llama(
)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_enabled or parallel_dims.cp_enabled:
if parallel_dims.dp_shard_enabled:
if parallel_dims.dp_replicate_enabled:
dp_mesh = world_mesh["dp_replicate", "dp_shard"]
else:
dp_mesh = world_mesh["dp"]

if parallel_dims.cp_enabled:
dp_dim_names = dp_mesh.mesh_dim_names
assert isinstance(dp_dim_names, Tuple)
dp_mesh = world_mesh[(*dp_dim_names, "cp")]._flatten()

apply_fsdp(
model,
dp_mesh,
Expand Down
53 changes: 48 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
from datetime import timedelta

import torch

from typing import List, Optional, Set
from functools import partial

from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.elastic.multiprocessing.errors import record

from torchtitan import utils
Expand All @@ -28,17 +33,47 @@
)
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling

try:
from torch.distributed.tensor.experimental import context_parallel
except ImportError:
print(
f"PyTorch version {torch.__version__} does not include the experimental "
"Context Parallel API. Please update to a newer version."
)


def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_mesh: Optional[DeviceMesh] = None,
):
if cp_mesh is not None:
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)

def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
@contextlib.contextmanager
def context():
def context(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_mesh is not None:
stack.enter_context(
context_parallel_ctx(
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)
)

yield

return context
Expand All @@ -61,6 +96,7 @@ def main(job_config: JobConfig):
parallel_dims = ParallelDims(
dp_shard=job_config.training.data_parallel_shard_degree,
dp_replicate=job_config.training.data_parallel_replicate_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -226,6 +262,7 @@ def loss_fn(pred, labels):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
world_mesh["cp"] if parallel_dims.cp_enabled else None,
)

# variables used to keep info for metrics logging
Expand Down Expand Up @@ -259,18 +296,24 @@ def loss_fn(pred, labels):
data_load_start = time.perf_counter()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
ntokens_since_last_log += labels.numel() // parallel_dims.cp
data_loading_times.append(time.perf_counter() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

training_context = train_context(
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_no_restore_buffers={input_ids, labels},
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
with training_context:
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -287,7 +330,7 @@ def loss_fn(pred, labels):
)
else:
# Non-PP forward / backward
with train_context():
with training_context:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down
1 change: 1 addition & 0 deletions train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ compile = false
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)

[experimental]
context_parallel_degree = 1
pipeline_parallel_degree = 1
enable_async_tensor_parallel = false

Expand Down
Loading