Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions end_to_end/gpu/te/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# MaxText + Transformer Engine E2E Benchmarking

This directory contains scripts for testing MaxText with Transformer Engine (TE) integration across different parallelization strategies and quantization recipes.

Requirements:
- NVIDIA MaxText image with installed Transformer Engine (TE). Suggested to use the latest version of `ghcr.io/nvidia/jax:maxtext`.
- `test-maxtext.sh` script which is available in the suggested image. Otherwise, you can get it (here)[https://github.com/NVIDIA/JAX-Toolbox/blob/main/.github/container/test-maxtext.sh].
- NVIDIA GPU(s) with compute capability 9.0 or higher for FP8 quantization, 10.0 or higher for MXFP8 quantization.

## Quick Start

### 1. Run Individual Tests

#### MaxText Baseline with FP8
```bash
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=fp8 --model llama3.1-8b --steps 100
```

#### TE with DelayedScaling FP8
```bash
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100
```

#### TE with CurrentScaling FP8
```bash
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_currentscaling --model llama3.1-8b --steps 100
```

#### TE with MXFP8 Block Scaling
```bash
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_mxfp8 --model llama3.1-8b --steps 100
```

#### Enable Profiling/Tracing
Add profiling arguments to collect XPlane traces (only the last step is traced):
```bash
MAXTEXT_DIR=/path/to/maxtext bash test-maxtext.sh --data-parallel=1 --tensor-sequence-parallel=1 --fsdp=1 --quantization=te_fp8_delayedscaling --model llama3.1-8b --steps 100 --additional-args="profiler=xplane skip_first_n_steps_for_profiler=99 profiler_steps=1"
```

### 2. Run Comprehensive Benchmarking

The `run_single_node_model_parallel.sh` script automatically tests all quantization recipes across multiple parallelization strategies:

#### Basic Usage
```bash
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100
```

#### With Tracing Enabled
```bash
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true
```

#### Collecting traces with custom number of decoder layers
```bash
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --trace true --num-decoder-layers 4
```

#### Skip Single GPU Tests
```bash
bash run_single_node_model_parallel.sh --model llama3.1-8b --steps 100 --single-gpu-run false
```
104 changes: 104 additions & 0 deletions end_to_end/gpu/te/normalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Normalize the raw results to get the percentage difference from the baseline"""

# Usage: python normalize.py input_raw_results.csv output_summary.{csv|txt} format
# format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated

import csv
import sys

if len(sys.argv) < 4:
print("Usage: normalize.py input_raw_results.csv output_summary.{csv|txt} format")
print(" format = 'csv' for comma-separated, 'txt' or 'tsv' for tab-separated")
sys.exit(1)

input_csv = sys.argv[1]
output_file = sys.argv[2]
format_type = sys.argv[3].lower()

data = {}
key_order = [] # preserve order of keys

# Read input TSV
with open(input_csv, encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
for row in reader:
key = tuple(row.get(k) if row.get(k) not in [None, ""] else "NA" for k in ["dp", "tpsp", "fsdp"])
if not row.get("test"):
continue
if key not in data:
data[key] = {}
key_order.append(key) # remember when first seen
data[key][row["test"]] = row

header = ["test", "dp", "tpsp", "fsdp", "mean", "stddev", "normalized"]
rows = []

# iterate keys in first-seen order
for key in key_order:
rowset = data[key]
baseline = rowset.get("fp8", {})
base_mean = baseline.get("mean", "NA")
try:
base_mean_val = float(base_mean)
has_baseline = True
except ValueError:
base_mean_val = 1.0 # dummy value for pylint
has_baseline = False

# iterate tests in first-seen order
for testname in rowset:
row = rowset[testname]
mean = row["mean"]
stddev = row["stddev"]
if mean == "NA":
normalized = "-"
elif testname == "fp8":
testname = "maxtext_fp8"
normalized = "0.00%" if has_baseline else "-"
elif has_baseline and mean != "NA":
try:
normalized = f"{(float(mean) / base_mean_val - 1) * 100:.2f}%"
except ValueError:
normalized = "-"
else:
normalized = "-"
rows.append(
[
testname,
row["dp"],
row["tpsp"],
row["fsdp"],
mean,
stddev,
normalized,
]
)

if format_type in ("csv",):
with open(output_file, "a", newline="", encoding="utf-8") as out:
writer = csv.writer(out)
writer.writerow(header)
writer.writerows(rows)
elif format_type in ("txt", "tsv"):
with open(output_file, "a", encoding="utf-8") as out:
out.write("\t".join(header) + "\n")
for r in rows:
out.write("\t".join(r) + "\n")
else:
print("Invalid format type! Use 'csv' or 'txt'/'tsv'.")
sys.exit(2)

print(f"Done. Wrote summary to {output_file} as {format_type}.")
103 changes: 103 additions & 0 deletions end_to_end/gpu/te/plot_loss_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Plot loss curves from training logs """

# Usage: python plot_loss_curves.py logdir

import re
import os
import sys
import argparse

import matplotlib.pyplot as plt


def parse_loss_data(file_path):
"""
Parses a text file for lines matching the pattern:
completed step: <int>, seconds: <float>, TFLOP/s/device: <float>,
Tokens/s/device: <float>, total_weights: <int>, loss: <float>
Returns a list of tuples with the extracted values.
"""
pattern = re.compile(
r"completed step: (\d+), seconds: ([\d.]+), TFLOP/s/device: ([\d.]+), Tokens/s/device: ([\d.]+), total_weights: (\d+), loss: ([\d.]+)" # pylint: disable=line-too-long
)
results = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
match = pattern.search(line)
if match:
step = int(match.group(1))
seconds = float(match.group(2))
tflops = float(match.group(3))
tokens_per_sec = float(match.group(4))
total_weights = int(match.group(5))
loss = float(match.group(6))
results.append((step, seconds, tflops, tokens_per_sec, total_weights, loss))
return results


def main(args):
parser = argparse.ArgumentParser(description="Plot training loss curve from log files.")
parser.add_argument("logdir", type=str, help="Directory containing training log files.")
parsed_args = parser.parse_args(args)

logdir = parsed_args.logdir
log_files = [
os.path.join(logdir, f)
for f in os.listdir(logdir)
if os.path.isfile(os.path.join(logdir, f)) and f.endswith(".log")
]

# Extract parallelism configs from filenames
config_pattern = re.compile(r"dp(\d+)_tpsp(\d+)_fsdp(\d+)")
configs = {}
for log_file in log_files:
fname = os.path.basename(log_file)
match = config_pattern.search(fname)
if match:
dp, tpsp, fsdp = match.groups()
key = (int(dp), int(tpsp), int(fsdp))
configs.setdefault(key, []).append(log_file)

# Plot for each config
for (dp, tpsp, fsdp), files in configs.items():
plt.figure(figsize=(8, 5))
for log_file in files:
data = parse_loss_data(log_file)
if not data:
continue
steps = [item[0] for item in data]
losses = [item[5] for item in data]
plt.plot(
steps,
losses,
marker="",
linestyle="-",
label=os.path.basename(log_file),
)
plt.legend()
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title(f"Loss Curves (dp={dp}, tpsp={tpsp}, fsdp={fsdp})")
plt.grid(True)
plt.tight_layout()
out_image_path = f"loss_curves_dp{dp}_tpsp{tpsp}_fsdp{fsdp}.png"
plt.savefig(out_image_path)
print(f"Saved plot to {out_image_path}")
plt.close()


if __name__ == "__main__":
main(sys.argv[1:])
Loading
Loading