Skip to content

Commit

Permalink
[not for land] TE experiments
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --training.use_te
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jul 23, 2024
1 parent 0f70507 commit 0da00ff
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 1 deletion.
27 changes: 27 additions & 0 deletions test/test_te.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
import torch.nn as nn
import torchtitan.te_utils as te_utils
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
maybe_te_float8_ctx = te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)

def test():
# for now, single GPU smoke test of TE fp8

x = torch.randn(32, 32, device='cuda')

m = nn.Sequential(nn.Linear(32, 32)).cuda()
te_utils.swap_linear_to_te_linear(m)
print(m)

with maybe_te_float8_ctx:
y = m(x)
y.sum().backward()

print('done')

if __name__ == '__main__':
test()
14 changes: 14 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,20 @@ def __init__(self):
default=False,
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--training.use_te",
action="store_true",
help="""
If true, uses TransformerEngine (not for land)
""",
)
self.parser.add_argument(
"--training.use_te_float8",
action="store_true",
help="""
If true, enables TransformerEngine's float8 integration (not for land)
""",
)
self.parser.add_argument(
"--training.gc_freq",
type=int,
Expand Down
57 changes: 57 additions & 0 deletions torchtitan/te_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for testing TransformerEngine
Note: I attempted to hack in DTensor-based TP/SP to te.Linear in the
link below, and gave up for now as it seemed to be a lot of remaining work.
We can power through that if needed later.
* https://gist.github.com/vkuzo/64d5362b63dd6c76410464e020d9a35f
Note: I looked into using te.LayerNormLinear, and that would require changing
how Attention and FFN are defined in torchtitan to use a single gemm for
attn.kqv and ffn.w1_w3. Punting for now but we can do this later if needed.
"""

import contextlib
import os

# required for current build to work with fp8 on devgpu003.cco3
# context: https://github.com/NVIDIA/TransformerEngine/pull/575
# error stack trace if not enabled: https://gist.github.com/vkuzo/8e78282f4a986961753fba25249fdf77
os.environ["NVTE_UNFUSED_FP8_UPDATE"] = "1"

import torch

# import transformer_engine as te
import transformer_engine.pytorch as te

from transformer_engine.common.recipe import Format, DelayedScaling
te_fp8_format = Format.HYBRID
te_fp8_recipe = DelayedScaling(fp8_format=te_fp8_format, amax_history_len=16, amax_compute_algo="max")

def swap_linear_to_te_linear(model, fqn=''):
for name, child in model.named_children():
new_fqn = f"{fqn}.{name}"
if isinstance(child, torch.nn.Linear):
te_linear = te.Linear(child.in_features, child.out_features, bias=child.bias is not None)
te_linear.weight = child.weight
te_linear.bias = child.bias
setattr(model, name, te_linear)
else:
swap_linear_to_te_linear(child, new_fqn)

def get_maybe_fp8_autocast(job_config):
# not for land - set up TransformerEngine fp8 autocast
# Note: te.fp8_autocast has to be created at every training iteration.
# If we try to create it once and reuse, we get this error:
# https://gist.github.com/vkuzo/d9840328c8bdc2901b8d04aa570ecb5b
maybe_te_float8_ctx = contextlib.nullcontext()
if job_config.training.use_te and job_config.training.use_te_float8:
maybe_te_float8_ctx = te.fp8_autocast(enabled=False, fp8_recipe=te_fp8_recipe)
return maybe_te_float8_ctx
14 changes: 13 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
maybe_build_fp8_linear,
maybe_precompute_fp8_dynamic_scale_for_fsdp,
)
import torchtitan.te_utils as te_utils
from torchtitan.logging_utils import init_logger, logger
from torchtitan.lr_scheduling import get_lr_schedulers
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
Expand Down Expand Up @@ -238,6 +239,12 @@ def loss_fn(pred, labels):
# swap to Float8Linear base on fp8 config
maybe_build_fp8_linear(whole_model, job_config, parallel_dims.dp_enabled)

# not for land - set up TransformerEngine
if job_config.training.use_te:
print('before', whole_model)
te_utils.swap_linear_to_te_linear(whole_model)
print('after', whole_model)

# log model size
model_param_count = get_num_params(whole_model)
num_flop_per_token = get_num_flop_per_token(
Expand Down Expand Up @@ -377,7 +384,11 @@ def loss_fn(pred, labels):
labels = labels.cuda()
optimizers.zero_grad()

# not for land - set up TransformerEngine fp8 autocast
maybe_te_float8_ctx = te_utils.get_maybe_fp8_autocast(job_config)

if parallel_dims.pp_enabled:
assert not job_config.training.use_te, "unsupported"
# pipeline parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

Expand All @@ -399,7 +410,8 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
with maybe_te_float8_ctx:
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
Expand Down

0 comments on commit 0da00ff

Please sign in to comment.