Skip to content

Commit 423dc89

Browse files
phu0ngngjberchtold-nvidia
authored andcommitted
TE Dense op integration
Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: Jeremy Berchtold <[email protected]>
1 parent 151fa9f commit 423dc89

File tree

10 files changed

+724
-4
lines changed

10 files changed

+724
-4
lines changed

end_to_end/gpu/te/README.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# MaxText + Transformer Engine E2E Benchmarking
2+
3+
This directory contains scripts for testing MaxText with Transformer Engine (TE) integration across different parallelization strategies and quantization recipes.
4+
5+
Requirements:
6+
- NVIDIA MaxText image with installed Transformer Engine (TE). Suggested to use the latest version of `ghcr.io/nvidia/jax:maxtext`.
7+
- `test-maxtext.sh` script which is available in the suggested image. Otherwise, you can get it (here)[https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/test-maxtext.sh].
8+
- NVIDIA GPU(s) with compute capability 9.0 or higher for FP8 quantization, 10.0 or higher for MXFP8 quantization.
9+
10+
## Quick Start
11+
12+
### 1. Run Individual Tests
13+
14+
#### MaxText Baseline with FP8
15+
```bash
16+
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=fp8 --model llama3.1-8b --steps 100
17+
```
18+
19+
#### TE with DelayedScaling FP8
20+
```bash
21+
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100
22+
```
23+
24+
#### TE with CurrentScaling FP8
25+
```bash
26+
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_currentscaling --model llama3.1-8b --steps 100
27+
```
28+
29+
#### TE with MXFP8 Block Scaling
30+
```bash
31+
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_mxfp8 --model llama3.1-8b --steps 100
32+
```
33+
34+
#### Enable Profiling/Tracing
35+
Add profiling arguments to collect XPlane traces (only the last step is traced):
36+
```bash
37+
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100 --additional-args="profiler=xplane skip_first_n_steps_for_profiler=99 profiler_steps=1"
38+
```
39+
40+
### 2. Run Comprehensive Benchmarking
41+
42+
The `run_single_node_model_parallel.sh` script automatically tests all quantization recipes across multiple parallelization strategies:
43+
44+
#### Basic Usage
45+
```bash
46+
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100
47+
```
48+
49+
#### With Tracing Enabled
50+
```bash
51+
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true
52+
```
53+
54+
#### Collecting traces with custom number of decoder layers
55+
```bash
56+
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true --num-decoder-layers 4
57+
```
58+
59+
#### Skip Single GPU Tests
60+
```bash
61+
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --single-gpu-run false
62+
```

end_to_end/gpu/te/normalize.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
""" Normalize the raw results to get the percentage difference from the baseline"""
15+
16+
# Usage: python normalize.py input_raw_results.csv output_summary.{csv|txt} format
17+
# format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated
18+
19+
import csv
20+
import sys
21+
22+
if len(sys.argv) < 4:
23+
print("Usage: normalize.py input_raw_results.csv output_summary.{csv|txt} format")
24+
print(" format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated")
25+
sys.exit(1)
26+
27+
input_csv = sys.argv[1]
28+
output_file = sys.argv[2]
29+
format_type = sys.argv[3].lower()
30+
31+
data = {}
32+
key_order = [] # preserve order of keys
33+
34+
# Read input TSV
35+
with open(input_csv, encoding="utf-8") as f:
36+
reader = csv.DictReader(f, delimiter="\t")
37+
for row in reader:
38+
key = tuple(row.get(k) if row.get(k) not in [None, ""] else "NA" for k in ["dp", "tpsp", "fsdp"])
39+
if not row.get("test"):
40+
continue
41+
if key not in data:
42+
data[key] = {}
43+
key_order.append(key) # remember when first seen
44+
data[key][row["test"]] = row
45+
46+
header = ["test", "dp", "tpsp", "fsdp", "mean", "stddev", "normalized"]
47+
rows = []
48+
49+
# iterate keys in first-seen order
50+
for key in key_order:
51+
rowset = data[key]
52+
baseline = rowset.get("fp8", {})
53+
base_mean = baseline.get("mean", "NA")
54+
try:
55+
base_mean_val = float(base_mean)
56+
has_baseline = True
57+
except ValueError:
58+
base_mean_val = 1.0 # dummy value for pylint
59+
has_baseline = False
60+
61+
# iterate tests in first-seen order
62+
for testname in rowset:
63+
row = rowset[testname]
64+
mean = row["mean"]
65+
stddev = row["stddev"]
66+
if mean == "NA":
67+
normalized = "-"
68+
elif testname == "fp8":
69+
testname = "maxtext_fp8"
70+
normalized = "0.00%" if has_baseline else "-"
71+
elif has_baseline and mean != "NA":
72+
try:
73+
normalized = f"{(float(mean) / base_mean_val - 1) * 100:.2f}%"
74+
except ValueError:
75+
normalized = "-"
76+
else:
77+
normalized = "-"
78+
rows.append(
79+
[
80+
testname,
81+
row["dp"],
82+
row["tpsp"],
83+
row["fsdp"],
84+
mean,
85+
stddev,
86+
normalized,
87+
]
88+
)
89+
90+
if format_type in ("csv",):
91+
with open(output_file, "a", newline="", encoding="utf-8") as out:
92+
writer = csv.writer(out)
93+
writer.writerow(header)
94+
writer.writerows(rows)
95+
elif format_type in ("txt", "tsv"):
96+
with open(output_file, "a", encoding="utf-8") as out:
97+
out.write("\t".join(header) + "\n")
98+
for r in rows:
99+
out.write("\t".join(r) + "\n")
100+
else:
101+
print("Invalid format type! Use 'csv' or 'txt'/'tsv'.")
102+
sys.exit(2)
103+
104+
print(f"Done. Wrote summary to {output_file} as {format_type}.")
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
""" Plot loss curves from training logs """
15+
16+
# Usage: python plot_loss_curves.py logdir
17+
18+
import re
19+
import os
20+
import sys
21+
import argparse
22+
23+
import matplotlib.pyplot as plt
24+
25+
26+
def parse_loss_data(file_path):
27+
"""
28+
Parses a text file for lines matching the pattern:
29+
completed step: <int>, seconds: <float>, TFLOP/s/device: <float>,
30+
Tokens/s/device: <float>, total_weights: <int>, loss: <float>
31+
Returns a list of tuples with the extracted values.
32+
"""
33+
pattern = re.compile(
34+
r"completed step: (\d+), seconds: ([\d.]+), TFLOP/s/device: ([\d.]+), Tokens/s/device: ([\d.]+), total_weights: (\d+), loss: ([\d.]+)" # pylint: disable=line-too-long
35+
)
36+
results = []
37+
with open(file_path, "r", encoding="utf-8") as f:
38+
for line in f:
39+
match = pattern.search(line)
40+
if match:
41+
step = int(match.group(1))
42+
seconds = float(match.group(2))
43+
tflops = float(match.group(3))
44+
tokens_per_sec = float(match.group(4))
45+
total_weights = int(match.group(5))
46+
loss = float(match.group(6))
47+
results.append((step, seconds, tflops, tokens_per_sec, total_weights, loss))
48+
return results
49+
50+
51+
def main(args):
52+
parser = argparse.ArgumentParser(description="Plot training loss curve from log files.")
53+
parser.add_argument("logdir", type=str, help="Directory containing training log files.")
54+
parsed_args = parser.parse_args(args)
55+
56+
logdir = parsed_args.logdir
57+
log_files = [
58+
os.path.join(logdir, f)
59+
for f in os.listdir(logdir)
60+
if os.path.isfile(os.path.join(logdir, f)) and f.endswith(".log")
61+
]
62+
63+
# Extract parallelism configs from filenames
64+
config_pattern = re.compile(r"dp(\d+)_tpsp(\d+)_fsdp(\d+)")
65+
configs = {}
66+
for log_file in log_files:
67+
fname = os.path.basename(log_file)
68+
match = config_pattern.search(fname)
69+
if match:
70+
dp, tpsp, fsdp = match.groups()
71+
key = (int(dp), int(tpsp), int(fsdp))
72+
configs.setdefault(key, []).append(log_file)
73+
74+
# Plot for each config
75+
for (dp, tpsp, fsdp), files in configs.items():
76+
plt.figure(figsize=(8, 5))
77+
for log_file in files:
78+
data = parse_loss_data(log_file)
79+
if not data:
80+
continue
81+
steps = [item[0] for item in data]
82+
losses = [item[5] for item in data]
83+
plt.plot(
84+
steps,
85+
losses,
86+
marker="",
87+
linestyle="-",
88+
label=os.path.basename(log_file),
89+
)
90+
plt.legend()
91+
plt.xlabel("Step")
92+
plt.ylabel("Loss")
93+
plt.title(f"Loss Curves (dp={dp}, tpsp={tpsp}, fsdp={fsdp})")
94+
plt.grid(True)
95+
plt.tight_layout()
96+
out_image_path = f"loss_curves_dp{dp}_tpsp{tpsp}_fsdp{fsdp}.png"
97+
plt.savefig(out_image_path)
98+
print(f"Saved plot to {out_image_path}")
99+
plt.close()
100+
101+
102+
if __name__ == "__main__":
103+
main(sys.argv[1:])

0 commit comments

Comments
 (0)