Skip to content

Commit ef4cf36

Browse files
committed
Merge remote-tracking branch 'origin/main' into bench-sparsity
2 parents 2850389 + 09c2760 commit ef4cf36

File tree

74 files changed

+4270
-2037
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+4270
-2037
lines changed

.github/workflows/float8nocompile_test.yaml

+29-29
Original file line numberDiff line numberDiff line change
@@ -21,33 +21,33 @@ concurrency:
2121
env:
2222
HF_TOKEN: ${{ secrets.HF_TOKEN }}
2323

24-
jobs:
25-
test:
26-
strategy:
27-
fail-fast: false
28-
matrix:
29-
include:
30-
- name: SM-89
31-
runs-on: linux.g6.4xlarge.experimental.nvidia.gpu
32-
torch-spec: '--pre torch --index-url https://download.pytorch.org/whl/nightly/cu121'
33-
gpu-arch-type: "cuda"
34-
gpu-arch-version: "12.1"
24+
# jobs:
25+
# test:
26+
# strategy:
27+
# fail-fast: false
28+
# matrix:
29+
# include:
30+
# - name: H100
31+
# runs-on: linux.aws.h100
32+
# torch-spec: '--pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124'
33+
# gpu-arch-type: "cuda"
34+
# gpu-arch-version: "12.4"
3535

36-
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
37-
with:
38-
timeout: 300
39-
runner: ${{ matrix.runs-on }}
40-
gpu-arch-type: ${{ matrix.gpu-arch-type }}
41-
gpu-arch-version: ${{ matrix.gpu-arch-version }}
42-
submodules: recursive
43-
script: |
44-
conda create -n venv python=3.9 -y
45-
conda activate venv
46-
export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
47-
python -m pip install --upgrade pip
48-
pip install ${{ matrix.torch-spec }}
49-
pip install -r dev-requirements.txt
50-
pip install .
51-
cd torchao/prototype/float8nocompile
52-
pytest kernels/ --verbose -s
53-
pytest test/train_test.py --verbose -s
36+
# uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
37+
# with:
38+
# timeout: 300
39+
# runner: ${{ matrix.runs-on }}
40+
# gpu-arch-type: ${{ matrix.gpu-arch-type }}
41+
# gpu-arch-version: ${{ matrix.gpu-arch-version }}
42+
# submodules: recursive
43+
# script: |
44+
# conda create -n venv python=3.9 -y
45+
# conda activate venv
46+
# export PATH=/opt/rh/devtoolset-10/root/usr/bin/:$PATH
47+
# python -m pip install --upgrade pip
48+
# pip install ${{ matrix.torch-spec }}
49+
# pip install -r dev-requirements.txt
50+
# pip install .
51+
# cd torchao/prototype/float8nocompile
52+
# pytest kernels/ --verbose -s
53+
# pytest test/train_test.py --verbose -s

.github/workflows/regression_test_rocm.yml

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ jobs:
4343
python -m pip install --upgrade pip
4444
pip install ${{ matrix.torch-spec }}
4545
pip install -r dev-requirements.txt
46+
pip uninstall -y bitsandbytes
47+
pip install --force-reinstall 'https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_multi-backend-refactor/bitsandbytes-0.44.1.dev0-py3-none-manylinux_2_24_x86_64.whl'
4648
pip install .
4749
export CONDA=$(dirname $(dirname $(which conda)))
4850
export LD_LIBRARY_PATH=$CONDA/lib/:$LD_LIBRARY_PATH

.github/workflows/torchao_experimental_test.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ jobs:
3636
# Install executorch first because it installs its own version
3737
# of torch and torchao, which we do not want to use
3838
pip install executorch
39-
pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
39+
pip install torch==2.7.0.dev20250311 --index-url "https://download.pytorch.org/whl/nightly/cpu" --force-reinstall
4040
pip install numpy
4141
pip install pytest
4242
pip install parameterized
43-
USE_CPP=1 pip install .
43+
USE_CPP=1 TOCHAO_BUILD_KLEIDIAI=1 pip install .
4444
- name: Run python tests
4545
run: |
4646
conda activate venv
@@ -103,6 +103,7 @@ jobs:
103103
pip install parameterized
104104
pip install pyyaml
105105
pip install numpy
106+
pip install importlib-metadata
106107
- name: Print pip freeze
107108
run: |
108109
pip freeze

benchmarks/float8/float8_roofline.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def run(
372372
).requires_grad_()
373373

374374
# get the gradient of the right shape
375-
grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda")
375+
grad_output = torch.randn(M_val, N_val, dtype=torch.bfloat16, device="cuda")
376376

377377
# get the bf16 gpu kernel time
378378
torch._dynamo.reset()

benchmarks/float8/training/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@ Training parameters can be configured via environment variables.
1414
- `FLOAT8_RECIPE_WITH_BEST_SETTINGS`: "rowwise" or "tensorwise". Applies float8 training with the specified scaling recipe, as well as additional training configs which are optimal for that scaling recipe. See `float8_training_benchmark.sh` for more details.
1515
- `BATCH_SIZE`: Defaults to 1.
1616
- `STEPS`: Defaults to 100.
17+
- `EXTRA_ARGS`: Extra arguments to pass to torchtitan training script. See [torchtitan](https://github.com/pytorch/torchtitan) docs for the full list of options.
1718

1819
**NOTE**: `torch.compile` and FSDP2 are always used. Other forms of parallelism supported in torchtitan are not yet supported in this script.

benchmarks/float8/training/float8_training_benchmark.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ if [ -z "${TORCHTITAN_ROOT}" ]; then
2222
echo " * FLOAT8_RECIPE_WITH_BEST_SETTINGS: "rowwise" or "tensorwise". if set, use float8 training in torchtitan with the specified recipe, including the additional settings which are optimal for that recipe. otherwise, use bf16 mixed precision training."
2323
echo " * BATCH_SIZE: defaults to 1."
2424
echo " * STEPS: defaults to 100."
25+
echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script."
2526
exit 1
2627
fi
2728

@@ -44,7 +45,7 @@ cd ${TORCHTITAN_ROOT}
4445
echo "float8 args: ${FLOAT8_ARGS}"
4546

4647
# run the command with the specified arguments
47-
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} 2>&1 | tee ${LOG_FILE}
48+
CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ${TORCHTITAN_ROOT}/run_train.sh --training.steps=${STEPS} --training.batch_size=${BATCH_SIZE} --training.compile ${FLOAT8_ARGS} ${EXTRA_ARGS} 2>&1 | tee ${LOG_FILE}
4849

4950
# return to original working directory
5051
cd $original_dir

benchmarks/mx_formats/cast_bench.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
from typing import Callable, Tuple
2+
3+
import fire
4+
import torch
5+
import triton
6+
from torch._inductor.utils import do_bench_using_profiling
7+
8+
from torchao.prototype.mx_formats.custom_cast import (
9+
triton_to_mxfp8_dim1,
10+
)
11+
from torchao.prototype.mx_formats.mx_tensor import to_mx
12+
13+
torch.manual_seed(0)
14+
15+
bytes_per_el_bf16 = 2
16+
bytes_per_el_fp8 = 1
17+
18+
19+
def scale_dim0_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
20+
assert x_hp.is_contiguous()
21+
x_hp_d0_block = x_hp.reshape(-1, block_size)
22+
x_hp_d0_block_abs = x_hp_d0_block.abs()
23+
amax_dim0 = torch.amax(x_hp_d0_block_abs, dim=1).unsqueeze(1)
24+
x_hp_d0_block_normalized = x_hp_d0_block / amax_dim0
25+
x_hp_d0_normalized = x_hp_d0_block_normalized.reshape(x_hp.shape)
26+
return x_hp_d0_normalized, amax_dim0
27+
28+
29+
def scale_dim1_reference(x_hp, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
30+
assert x_hp.is_contiguous()
31+
x_hp_d1 = x_hp.t().contiguous()
32+
x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
33+
x_hp_d1_block_abs = x_hp_d1_block.abs()
34+
amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
35+
x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
36+
x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
37+
return x_hp_d1_normalized, amax_dim1
38+
39+
40+
def scale_dim0_dim1_reference(
41+
x_hp: torch.Tensor, block_size
42+
) -> Tuple[torch.Tensor, torch.Tensor]:
43+
# normalize across dim0
44+
x_hp_d0_normalized, amax_dim0 = scale_dim0_reference(x_hp, block_size)
45+
# normalize across dim1
46+
x_hp_d1_normalized, amax_dim1 = scale_dim1_reference(x_hp, block_size)
47+
return x_hp_d0_normalized, x_hp_d1_normalized.t(), amax_dim0, amax_dim1
48+
49+
50+
def to_mx_dim0_reference(x_hp, block_size):
51+
scale_d0, data_d0 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
52+
return data_d0, scale_d0
53+
54+
55+
def to_mx_dim1_reference(x_hp, block_size):
56+
x_hp = x_hp.t().contiguous()
57+
scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size)
58+
return data_d1.t(), scale_d1
59+
60+
61+
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
62+
"""Thin wrapper around do_bench_using_profiling"""
63+
no_args = lambda: func(*args, **kwargs)
64+
time = do_bench_using_profiling(no_args)
65+
return time * 1e3
66+
67+
68+
def run(
69+
M: int = 16384,
70+
K: int = 16384,
71+
BLOCK_SIZE: int = 32,
72+
mode: str = "dim0",
73+
):
74+
print(f"M {M} K {K} BLOCK_SIZE {BLOCK_SIZE}")
75+
print(f"GPU: {torch.cuda.get_device_name(0)}")
76+
print(f"torch version: {torch.__version__}")
77+
print(f"triton version: {triton.__version__}")
78+
print(f"mode: {mode}")
79+
assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton")
80+
81+
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000
82+
83+
if mode == "dim0":
84+
scale_dim0_reference_c = torch.compile(scale_dim0_reference)
85+
y_d0, s_d0 = scale_dim0_reference_c(x, BLOCK_SIZE)
86+
87+
for _ in range(2):
88+
__ = scale_dim0_reference_c(x, BLOCK_SIZE)
89+
time_us = benchmark_cuda_function_in_microseconds(
90+
lambda x, b: scale_dim0_reference_c(x, BLOCK_SIZE),
91+
x,
92+
BLOCK_SIZE,
93+
)
94+
95+
assert y_d0.dtype == torch.bfloat16
96+
assert s_d0.dtype == torch.bfloat16
97+
bytes_rw = sum(t.numel() for t in [x, y_d0, s_d0]) * bytes_per_el_bf16
98+
bps = bytes_rw / (time_us / 1e6)
99+
100+
elif mode == "dim1":
101+
scale_dim1_reference_c = torch.compile(scale_dim1_reference)
102+
y_d1, s_d1 = scale_dim1_reference_c(x, BLOCK_SIZE)
103+
104+
for _ in range(2):
105+
__ = scale_dim1_reference_c(x, BLOCK_SIZE)
106+
time_us = benchmark_cuda_function_in_microseconds(
107+
lambda x, b: scale_dim1_reference_c(x, BLOCK_SIZE),
108+
x,
109+
BLOCK_SIZE,
110+
)
111+
112+
assert y_d1.dtype == torch.bfloat16
113+
assert s_d1.dtype == torch.bfloat16
114+
bytes_rw = sum(t.numel() for t in [x, y_d1, s_d1]) * bytes_per_el_bf16
115+
bps = bytes_rw / (time_us / 1e6)
116+
117+
elif mode == "dim0_dim1":
118+
scale_dim0_dim1_reference_c = torch.compile(scale_dim0_dim1_reference)
119+
y_d0, y_d1, s_d0, s_d1 = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
120+
121+
for _ in range(2):
122+
__ = scale_dim0_dim1_reference_c(x, BLOCK_SIZE)
123+
time_us = benchmark_cuda_function_in_microseconds(
124+
lambda x, b: scale_dim0_dim1_reference_c(x, BLOCK_SIZE),
125+
x,
126+
BLOCK_SIZE,
127+
)
128+
129+
assert y_d0.dtype == torch.bfloat16
130+
assert s_d0.dtype == torch.bfloat16
131+
assert y_d1.dtype == torch.bfloat16
132+
assert s_d1.dtype == torch.bfloat16
133+
bytes_rw = (
134+
sum(t.numel() for t in [x, y_d0, y_d1, s_d0, s_d1]) * bytes_per_el_bf16
135+
)
136+
bps = bytes_rw / (time_us / 1e6)
137+
138+
elif mode == "dim0_mx":
139+
to_mx_dim0_reference_c = torch.compile(to_mx_dim0_reference)
140+
y_d0, s_d0 = to_mx_dim0_reference_c(x, BLOCK_SIZE)
141+
142+
for _ in range(2):
143+
__ = to_mx_dim0_reference_c(x, BLOCK_SIZE)
144+
time_us = benchmark_cuda_function_in_microseconds(
145+
lambda x, b: to_mx_dim0_reference_c(x, BLOCK_SIZE),
146+
x,
147+
BLOCK_SIZE,
148+
)
149+
150+
assert y_d0.dtype == torch.float8_e4m3fn
151+
assert s_d0.dtype == torch.uint8
152+
bytes_r = x.numel() * bytes_per_el_bf16
153+
bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8
154+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
155+
156+
elif mode == "dim1_mx":
157+
to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference)
158+
y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE)
159+
160+
for _ in range(2):
161+
__ = to_mx_dim1_reference_c(x, BLOCK_SIZE)
162+
time_us = benchmark_cuda_function_in_microseconds(
163+
lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE),
164+
x,
165+
BLOCK_SIZE,
166+
)
167+
168+
assert y_d1.dtype == torch.float8_e4m3fn
169+
assert s_d1.dtype == torch.uint8
170+
bytes_r = x.numel() * bytes_per_el_bf16
171+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
172+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
173+
174+
elif mode == "dim1_mx_triton":
175+
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
176+
177+
for _ in range(2):
178+
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
179+
time_us = benchmark_cuda_function_in_microseconds(
180+
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
181+
x,
182+
BLOCK_SIZE,
183+
)
184+
185+
assert y_d1.dtype == torch.float8_e4m3fn
186+
assert s_d1.dtype == torch.float8_e8m0fnu
187+
bytes_r = x.numel() * bytes_per_el_bf16
188+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
189+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
190+
191+
else:
192+
raise AssertionError(f"unknown mode {mode}")
193+
194+
print("time_us", time_us)
195+
print("mem_bw_gbps", bps / 1e9)
196+
197+
198+
if __name__ == "__main__":
199+
fire.Fire(run)

examples/sam2_amg_server/compile_export_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ def aot_compile(
119119
"triton.cudagraphs": True,
120120
}
121121

122-
from torch.export import export_for_inference
122+
from torch.export import export_for_training
123123

124-
exported = export_for_inference(fn, sample_args, sample_kwargs)
124+
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
125+
exported.run_decompositions()
125126
output_path = torch._inductor.aoti_compile_and_package(
126127
exported,
127128
package_path=str(path),

examples/sam2_vos_example/compile_export_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,10 @@ def aot_compile(
8282
"triton.cudagraphs": True,
8383
}
8484

85-
from torch.export import export_for_inference
85+
from torch.export import export_for_training
8686

87-
exported = export_for_inference(fn, sample_args, sample_kwargs)
87+
exported = export_for_training(fn, sample_args, sample_kwargs, strict=True)
88+
exported.run_decompositions()
8889
output_path = torch._inductor.aoti_compile_and_package(
8990
exported,
9091
package_path=str(path),

setup.py

-12
Original file line numberDiff line numberDiff line change
@@ -299,18 +299,6 @@ def get_extensions():
299299
extra_compile_args["nvcc"].append("-g")
300300
extra_link_args.append("/DEBUG")
301301

302-
curdir = os.path.dirname(os.path.curdir)
303-
extensions_dir = os.path.join(curdir, "torchao", "csrc")
304-
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
305-
306-
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
307-
cuda_sources = list(
308-
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
309-
)
310-
311-
if use_cuda:
312-
sources += cuda_sources
313-
314302
# Get base directory and source paths
315303
curdir = os.path.dirname(os.path.curdir)
316304
extensions_dir = os.path.join(curdir, "torchao", "csrc")

0 commit comments

Comments
 (0)