Skip to content

Commit 3fe6445

Browse files
committed
Enable TransformerEngine-backed Tensor Parallelism with Llama3.
Signed-off-by: Cory Ye <cye@nvidia.com>
1 parent 6f259a8 commit 3fe6445

File tree

4 files changed

+165
-15
lines changed

4 files changed

+165
-15
lines changed

bionemo-recipes/recipes/llama3_native_te/checkpoint.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from torch.distributed.checkpoint.state_dict_saver import async_save as dcp_async_save
3535
from torch.distributed.checkpoint.state_dict_saver import save as dcp_save
3636
from torch.distributed.checkpoint.stateful import Stateful
37+
from torch.distributed.tensor import DTensor
3738
from torchdata.stateful_dataloader import StatefulDataLoader
3839

3940
from distributed_config import DistributedConfig
@@ -219,8 +220,28 @@ class AppState(Stateful):
219220
epoch: int = 0
220221

221222
def state_dict(self):
222-
"""Get the state dict for the model, optimizer, scheduler, and step."""
223+
"""
224+
Get the state dict for the model, optimizer, scheduler, and step.
225+
This factory both retrieves the model state dictionary when saving
226+
checkpoints and initializes a destination for the state read from
227+
DCP checkpoint files when loading checkpoints.
228+
"""
223229
model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer)
230+
for fqn in list(model_state_dict.keys()):
231+
# Get the model parameter.
232+
model_param = model_state_dict[fqn]
233+
if isinstance(model_param, DTensor):
234+
model_param = model_param.to_local()
235+
if model_param.numel() == 0 and fqn in optimizer_state_dict['state']:
236+
# Empty model parameter. Clear the associated optimizer state
237+
# when initializing the optimizer state upon DCP load, because
238+
# empty optimizer state DTensors are not checkpointed with DCP,
239+
# yet get_state_dict / _init_optim_state produce empty Tensors.
240+
# TransformerEngine uses empty Tensors for dummy Parameters.
241+
optimizer_state_dict['state'][fqn] = {}
242+
if fqn.endswith("._extra_state"):
243+
# Evict `_extra_state` quantization data from model checkpoint.
244+
model_state_dict.pop(fqn)
224245
return {
225246
"model": model_state_dict,
226247
"optim": optimizer_state_dict,
@@ -230,12 +251,19 @@ def state_dict(self):
230251
}
231252

232253
def load_state_dict(self, state_dict: dict):
233-
"""Load the state dict for the model, optimizer, scheduler, and step."""
254+
"""
255+
Load the state dict for the model, optimizer, scheduler, and step.
256+
Given the checkpoint-loaded state_dict, set the state of the model,
257+
optimizer, scheduler, step, and epoch to the values in state_dict.
258+
"""
234259
set_state_dict(
235260
self.model,
236261
self.optimizer,
237262
model_state_dict=state_dict["model"],
238263
optim_state_dict=state_dict["optim"],
264+
# Non-strict checkpoint loading ignores empty optimizer states,
265+
# skips loading non-FP8 checkpoint weights (e.g. _extra_state).
266+
options=StateDictOptions(strict=False),
239267
)
240268
self.scheduler.load_state_dict(state_dict["scheduler"])
241269
self.step = state_dict["step"]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defaults:
2+
- L0_sanity
3+
- _self_
4+
5+
tp_size: 2
6+
cp_size: 2
7+
8+
dataset:
9+
# CP2 * (8 for FP8 Activations, 16 for FP8 Parameters)
10+
pad_sequences_to_be_divisible_by: 32
11+
12+
fp8_config:
13+
enabled: true
14+
fp8_recipe: transformer_engine.common.recipe.DelayedScaling
15+
fp8_format: "HYBRID"
16+
fp8_recipe_kwargs: {}
17+
quantized_model_init_kwargs:
18+
# TODO(@cspades): Quantized parameters are
19+
# NOT supported with DCP checkpointing.
20+
enabled: true
21+
22+
checkpoint:
23+
ckpt_dir: ./fsdp_tp_ckpts
24+
save_final_model: true
25+
26+
config_kwargs:
27+
attn_input_format: "bshd" # Alternatively "thd" on datacenter hardware.
28+
self_attn_mask_type: "causal" # Alternatively "padding_causal" for THD inputs.
29+
tensor_parallel: true # Tensor Parallelism for TE
30+
sequence_parallel: true # Sequence parallelism for LayerNorm on TP ranks.
31+
tp_size: ${tp_size} # Tensor Parallel Size

bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import torch.nn as nn
2323
import transformer_engine.pytorch
2424
import transformers
25+
from torch.distributed.tensor.parallel import ColwiseParallel, parallelize_module
26+
from torch.distributed.tensor.placement_types import Replicate
2527
from transformer_engine.pytorch.attention import InferenceParams
2628
from transformer_engine.pytorch.attention.inference import PagedKVCacheManager
2729
from transformer_engine.pytorch.attention.rope import RotaryPositionEmbedding
@@ -49,6 +51,18 @@ class NVLlamaConfig(LlamaConfig):
4951
# "thd" = Total tokens (packed/unpadded), Head, Dimension (sequence packing format)
5052
attn_input_format: str = "thd"
5153
self_attn_mask_type: str = "padding_causal"
54+
tensor_parallel: bool = False
55+
sequence_parallel: bool = False
56+
tp_size: int = 1
57+
tp_mesh: torch.distributed.DeviceMesh | None = None
58+
weight_mesh: torch.distributed.DeviceMesh | None = None
59+
60+
def to_dict(self):
61+
config_dict = super().to_dict()
62+
# DeviceMesh is not serializable. Don't checkpoint it.
63+
config_dict.pop("tp_mesh", None)
64+
config_dict.pop("weight_mesh", None)
65+
return config_dict
5266

5367

5468
class NVLlamaPreTrainedModel(PreTrainedModel):
@@ -114,9 +128,35 @@ def __init__(self, config: LlamaConfig):
114128
self.config = config
115129
self.padding_idx = config.pad_token_id
116130
self.vocab_size = config.vocab_size
131+
self.tp_mesh = config.tp_mesh
117132

118133
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx, dtype=config.dtype)
119134

135+
# Tensor-parallelize torch.nn.Embedding. Combines DTensor-based TP with TE-based TP.
136+
if config.tensor_parallel:
137+
assert (
138+
self.tp_mesh is not None,
139+
"[NVLlamaModel] Tensor parallelism requires a NVLlamaConfig.tp_mesh."
140+
)
141+
assert (
142+
self.tp_mesh.size() == config.tp_size,
143+
f"[NVLlamaModel] DeviceMesh TP size ({self.tp_mesh.size()}) "
144+
f"does not match configured TP size ({config.tp_size})."
145+
)
146+
# NOTE(@cspades): Because the TELinear head is weight-tied to torch.nn.Embedding
147+
# during HuggingFace post-init, this will automatically convert the TELinear head
148+
# weight into a DTensor with the correct sharding placements prior to FSDP2
149+
# fully_shard(), and no need to call TELinear.set_device_mesh().
150+
parallelize_module(
151+
self.embed_tokens,
152+
self.tp_mesh,
153+
# Un-sharded output activations for compatible input to TETransformer.
154+
# NOTE(@cspades): ColwiseParallel -> torch.nn.Embedding -> Shard(dim=1)
155+
# RowwiseParallel doesn't support output_layouts=Replicate() with
156+
# torch.compile: https://github.com/pytorch/torchtitan/issues/534
157+
ColwiseParallel(input_layouts=Replicate(), output_layouts=Replicate())
158+
)
159+
120160
def _init_method(x):
121161
torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range)
122162

@@ -142,6 +182,11 @@ def _init_method(x):
142182
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
143183
init_method=_init_method,
144184
output_layer_init_method=_init_method,
185+
set_parallel_mode=config.tensor_parallel,
186+
sequence_parallel=config.sequence_parallel,
187+
tp_size=config.tp_size,
188+
tp_mesh=config.tp_mesh,
189+
weight_mesh=config.weight_mesh,
145190
)
146191
for layer_idx in range(config.num_hidden_layers)
147192
]
@@ -152,6 +197,8 @@ def _init_method(x):
152197
dtype=config.dtype,
153198
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
154199
)
200+
# Norm modules are non-Base TransformerEngine modules that require a manual call for TP.
201+
self.norm.set_device_mesh(tp_mesh=config.tp_mesh, weight_mesh=config.weight_mesh)
155202

156203
# We use TE's RotaryPositionEmbedding, but we ensure that we use the same inv_freq as the original
157204
# LlamaRotaryEmbedding.
@@ -283,6 +330,7 @@ def __init__(self, config):
283330
super().__init__(config)
284331
self.model = NVLlamaModel(config)
285332
self.vocab_size = config.vocab_size
333+
self.tp_mesh = config.tp_mesh
286334
with transformer_engine.pytorch.quantized_model_init(enabled=False):
287335
self.lm_head = transformer_engine.pytorch.Linear(
288336
config.hidden_size,
@@ -291,9 +339,19 @@ def __init__(self, config):
291339
params_dtype=config.dtype,
292340
device="meta" if torch.get_default_device() == torch.device("meta") else "cuda",
293341
init_method=lambda x: torch.nn.init.normal_(x, mean=0.0, std=config.initializer_range),
342+
parallel_mode="row" if config.tensor_parallel else None,
343+
# This scatters your output, not ever needed for final layer.
344+
# Will all-reduce the output instead, as required.
345+
sequence_parallel=False,
346+
tp_size=config.tp_size,
294347
)
348+
if self.config.tensor_parallel:
349+
# If using tensor parallelism, the head weights have already been tied
350+
# to the embedding weights. Just set the tensor parallel group for TE.
351+
# No parameter quantization either, so no need for weight_mesh.
352+
self.lm_head.set_tensor_parallel_group(self.tp_mesh.get_group())
295353

296-
# Initialize weights and apply final processing
354+
# Initialize weights and apply final processing. Ties weights.
297355
self.post_init()
298356

299357
def forward(
@@ -346,6 +404,13 @@ def forward(
346404
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
347405
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
348406

407+
if self.config.tensor_parallel:
408+
# If using TP, shard your activation across the TP group,
409+
# to support row-wise tensor parallelism in the LM head.
410+
tp_rank = self.tp_mesh.get_local_rank()
411+
tp_stride = hidden_states.shape[-1] // self.config.tp_size
412+
hidden_states = hidden_states[:, :, tp_rank*tp_stride:(tp_rank + 1)*tp_stride]
413+
349414
with transformer_engine.pytorch.autocast(enabled=False):
350415
if hidden_states.ndim == 3:
351416
logits = self.lm_head(hidden_states[:, slice_indices, :])

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py renamed to bionemo-recipes/recipes/llama3_native_te/train_fsdp2_nd_parallel.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
logger.setLevel(logging.INFO)
5858

5959

60-
@hydra.main(config_path="hydra_config", config_name="L0_sanity_cp", version_base="1.2")
60+
@hydra.main(config_path="hydra_config", config_name="L2_sanity_nd", version_base="1.2")
6161
def main(args: DictConfig) -> float | None:
6262
"""Train Llama3 with TE layers using FSDP2 with Context Parallelism.
6363
@@ -73,8 +73,8 @@ def main(args: DictConfig) -> float | None:
7373

7474
device_mesh = init_device_mesh(
7575
"cuda",
76-
mesh_shape=(dist_config.world_size // args.cp_size, args.cp_size),
77-
mesh_dim_names=("dp", "cp"),
76+
mesh_shape=(dist_config.world_size // (args.cp_size * args.tp_size), args.cp_size, args.tp_size),
77+
mesh_dim_names=("dp", "cp", "tp"),
7878
)
7979
logger.info("Created device mesh: %s", device_mesh)
8080

@@ -85,6 +85,22 @@ def main(args: DictConfig) -> float | None:
8585

8686
# --- Model Initialization ---
8787
config = NVLlamaConfig.from_pretrained(args.config_name_or_path, dtype=torch.bfloat16, **args.config_kwargs)
88+
89+
# Identify DeviceMesh that are propagated to `set_device_mesh` in TransformerEngine modules.
90+
# These will convert TransformerEngine parameters into DTensors. Alternatively, users can
91+
# manually call the conversion using `TransformerEngineModule.set_device_mesh(...)`` before
92+
# `reset_parameters` (which triggers quantization) if the module supports DTensor parameters.
93+
if config.tensor_parallel:
94+
config.tp_mesh = device_mesh["tp"]
95+
if (
96+
args.fp8_config.quantized_model_init_kwargs.enabled
97+
and isinstance(fp8_recipe, transformer_engine.common.recipe.Float8CurrentScaling)
98+
):
99+
# When using per-tensor FP8 recipes for quantized parameters, TransformerEngine
100+
# requires a weight sharding mesh for absmax reduction across distributed weights.
101+
# If not provided, will default to DTensor.device_mesh.get_group(), which is not
102+
# appropriate if HSDP (DP-Replicate x DP-Shard) is used.
103+
config.weight_mesh = device_mesh["dp", "cp", "tp"]._flatten("weight_mesh")
88104

89105
# Optionally use transformer engine to initialize only fp8 versions of weights by setting
90106
# `fp8_config.quantized_model_init_kwargs.enabled` to `True`, as opposed to using the default where both bfloat16
@@ -100,6 +116,8 @@ def main(args: DictConfig) -> float | None:
100116
logger.info("Initialized Model:\n%s", model)
101117

102118
# --- Distributed Wrapping (FSDP2 + CP) ---
119+
120+
# Create a flattened mesh for FSDP2-CP sharding. This will shard the model across both the DP and CP ranks.
103121
cp_dp_mesh = device_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_shard_cp")
104122

105123
# Shard the transformer layers with FSDP. For Llama3, the transformer stack is in model.model.layers.
@@ -108,7 +126,7 @@ def main(args: DictConfig) -> float | None:
108126
fully_shard(layer, mesh=cp_dp_mesh)
109127
fully_shard(model, mesh=cp_dp_mesh)
110128

111-
# Attach the CP group to the model.
129+
# Attach the CP ProcessGroup to the TransformerEngine model.
112130
for layer in model.model.layers:
113131
layer.set_context_parallel_group(
114132
device_mesh["cp"].get_group(),
@@ -137,9 +155,12 @@ def main(args: DictConfig) -> float | None:
137155
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
138156
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
139157

140-
# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and eventually TP)
141-
# ranks. This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
142-
if device_mesh["cp"].get_local_rank() == 0:
158+
# We only create the dataloader on rank 0, which is responsible for loading data for all CP (and TP) ranks.
159+
# This ensures that the data remains synchronized, even if we're using a non-deterministic data pipeline.
160+
cp_tp_mesh = device_mesh["cp", "tp"]._flatten(mesh_dim_name="cp_tp")
161+
if cp_tp_mesh.get_local_rank() == 0:
162+
# We only create the dataloader on CP-TP Rank 0 and pass it to a ContextParallelDataLoaderWrapper
163+
# that will shard, replicate, and distribute the data across the flattened CP and TP group.
143164
if args.use_sequence_packing:
144165
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
145166
else:
@@ -156,8 +177,8 @@ def main(args: DictConfig) -> float | None:
156177
train_dataloader = None
157178
dataset_or_sampler = None
158179

159-
# On all ranks, we create a ContextParallelDataLoaderWrapper that broadcasts the data from cp rank 0.
160-
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, device_mesh["cp"])
180+
# Deliver CP-sharded replicates to a flattened CP-TP mesh.
181+
train_dataloader = ContextParallelDataLoaderWrapper(train_dataloader, cp_tp_mesh)
161182

162183
# --- Checkpoint Resume ---
163184
ckpt_path = Path(args.checkpoint.ckpt_dir) / "train_fsdp2" if args.checkpoint.ckpt_dir else None
@@ -170,7 +191,6 @@ def main(args: DictConfig) -> float | None:
170191
ckpt_path=ckpt_path,
171192
dist_config=dist_config,
172193
dataloader=train_dataloader,
173-
process_group=cp_dp_mesh.get_group(),
174194
)
175195
logger.info("Checkpoint loaded, resuming from step %s, epoch %s", start_step, epoch)
176196
else:
@@ -226,6 +246,13 @@ def main(args: DictConfig) -> float | None:
226246
)
227247

228248
if ckpt_path and should_save_checkpoint(step, args.checkpoint.save_every_n_steps):
249+
if args.checkpoint.async_save and args.fp8_config.quantized_model_init_kwargs.enabled:
250+
logger.info(
251+
"Asynchronous checkpointing is not supported with TransformerEngine "
252+
"quantized parameters and FSDP2. Using synchronous checkpointing "
253+
"(checkpoint.async_save=false)..."
254+
)
255+
OmegaConf.update(args, "checkpoint.async_save", False)
229256
save_checkpoint_fsdp2(
230257
model=model,
231258
optimizer=optimizer,
@@ -235,7 +262,6 @@ def main(args: DictConfig) -> float | None:
235262
epoch=epoch,
236263
dist_config=dist_config,
237264
dataloader=train_dataloader if args.dataset.use_stateful_dataloader else None,
238-
process_group=cp_dp_mesh.get_group(),
239265
max_checkpoints=args.checkpoint.max_checkpoints,
240266
async_save=args.checkpoint.async_save,
241267
)
@@ -268,4 +294,4 @@ def main(args: DictConfig) -> float | None:
268294

269295

270296
if __name__ == "__main__":
271-
main()
297+
main()

0 commit comments

Comments
 (0)