Skip to content

Commit 0da00ff

Browse files
committed
[not for land] TE experiments
Summary: Test Plan: ``` with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te ``` Reviewers: Subscribers: Tasks: Tags:
1 parent 0f70507 commit 0da00ff

File tree

4 files changed

+111
-1
lines changed

4 files changed

+111
-1
lines changed

test/test_te.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import torch.nn as nn
3+
import torchtitan.te_utils as te_utils
4+
import transformer_engine.pytorch as te
5+
from transformer_engine.common.recipe import Format, DelayedScaling
6+
7+
fp8_format = Format.HYBRID
8+
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
9+
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
10+
11+
def test():
12+
# for now, single GPU smoke test of TE fp8
13+
14+
x = torch.randn(32, 32, device='cuda')
15+
16+
m = nn.Sequential(nn.Linear(32, 32)).cuda()
17+
te_utils.swap_linear_to_te_linear(m)
18+
print(m)
19+
20+
with maybe_te_float8_ctx:
21+
y = m(x)
22+
y.sum().backward()
23+
24+
print('done')
25+
26+
if __name__ == '__main__':
27+
test()

torchtitan/config_manager.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,20 @@ def __init__(self):
370370
default=False,
371371
help="Whether precompute float8 scales dynamically for FSDP",
372372
)
373+
self.parser.add_argument(
374+
"--training.use_te",
375+
action="store_true",
376+
help="""
377+
If true, uses TransformerEngine (not for land)
378+
""",
379+
)
380+
self.parser.add_argument(
381+
"--training.use_te_float8",
382+
action="store_true",
383+
help="""
384+
If true, enables TransformerEngine's float8 integration (not for land)
385+
""",
386+
)
373387
self.parser.add_argument(
374388
"--training.gc_freq",
375389
type=int,

torchtitan/te_utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Utilities for testing TransformerEngine
9+
10+
Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the
11+
link below, and gave up for now as it seemed to be a lot of remaining work.
12+
We can power through that if needed later.
13+
* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f
14+
15+
Note: I looked into using te.LayerNormLinear, and that would require changing
16+
how Attention and FFN are defined in torchtitan to use a single gemm for
17+
attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed.
18+
19+
"""
20+
21+
import contextlib
22+
import os
23+
24+
# required for current build to work with fp8 on devgpu003.cco3
25+
# context: https://github.com/NVIDIA/TransformerEngine/pull/575
26+
# error stack trace if not enabled: https://gist.github.com/vkuzo/8e78282f4a986961753fba25249fdf77
27+
os.environ["NVTE_UNFUSED_FP8_UPDATE"] = "1"
28+
29+
import torch
30+
31+
# import transformer_engine as te
32+
import transformer_engine.pytorch as te
33+
34+
from transformer_engine.common.recipe import Format, DelayedScaling
35+
te_fp8_format = Format.HYBRID
36+
te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max")
37+
38+
def swap_linear_to_te_linear(model, fqn=''):
39+
for name, child in model.named_children():
40+
new_fqn = f"{fqn}.{name}"
41+
if isinstance(child, torch.nn.Linear):
42+
te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None)
43+
te_linear.weight = child.weight
44+
te_linear.bias = child.bias
45+
setattr(model, name, te_linear)
46+
else:
47+
swap_linear_to_te_linear(child, new_fqn)
48+
49+
def get_maybe_fp8_autocast(job_config):
50+
# not for land - set up TransformerEngine fp8 autocast
51+
# Note: te.fp8_autocast has to be created at every training iteration.
52+
# If we try to create it once and reuse, we get this error:
53+
# https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b
54+
maybe_te_float8_ctx = contextlib.nullcontext()
55+
if job_config.training.use_te and job_config.training.use_te_float8:
56+
maybe_te_float8_ctx = te.fp8_autocast(enabled=False, fp8_recipe=te_fp8_recipe)
57+
return maybe_te_float8_ctx

train.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
maybe_build_fp8_linear,
3232
maybe_precompute_fp8_dynamic_scale_for_fsdp,
3333
)
34+
import torchtitan.te_utils as te_utils
3435
from torchtitan.logging_utils import init_logger, logger
3536
from torchtitan.lr_scheduling import get_lr_schedulers
3637
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
@@ -238,6 +239,12 @@ def loss_fn(pred, labels):
238239
# swap to Float8Linear base on fp8 config
239240
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)
240241

242+
# not for land - set up TransformerEngine
243+
if job_config.training.use_te:
244+
print('before', whole_model)
245+
te_utils.swap_linear_to_te_linear(whole_model)
246+
print('after', whole_model)
247+
241248
# log model size
242249
model_param_count = get_num_params(whole_model)
243250
num_flop_per_token = get_num_flop_per_token(
@@ -377,7 +384,11 @@ def loss_fn(pred, labels):
377384
labels = labels.cuda()
378385
optimizers.zero_grad()
379386

387+
# not for land - set up TransformerEngine fp8 autocast
388+
maybe_te_float8_ctx = te_utils.get_maybe_fp8_autocast(job_config)
389+
380390
if parallel_dims.pp_enabled:
391+
assert not job_config.training.use_te, "unsupported"
381392
# pipeline parallel forward / backward inside step() call
382393
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
383394

@@ -399,7 +410,8 @@ def loss_fn(pred, labels):
399410
else:
400411
# Non-PP forward / backward
401412
with train_context():
402-
pred = model(input_ids)
413+
with maybe_te_float8_ctx:
414+
pred = model(input_ids)
403415
loss = loss_fn(pred, labels)
404416
# pred.shape=(bs, seq_len, vocab_size)
405417
# need to free to before bwd to avoid peaking memory

0 commit comments

Comments
 (0)