Skip to content

Commit eba50ab

Browse files
adamomainzfacebook-github-bot
authored andcommitted
production data used in fp32_Gemm, bf16_gemm and softmax
Summary: adding support for more production data usage in tritonBench. Last left is HSTU for first cut of metric changes weights are working as well here Reviewed By: xuzhao9 Differential Revision: D65779069 fbshipit-source-id: a81237e39e407c47b3304e0bc9c4a00aebefb73c
1 parent 3c83e0b commit eba50ab

File tree

4 files changed

+36
-6
lines changed

4 files changed

+36
-6
lines changed

tritonbench/operators/fp8_gemm_rowwise/operator.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
import triton
9+
from tritonbench.utils.data_utils import get_production_shapes
910

1011
from tritonbench.utils.triton_op import (
1112
BenchmarkOperator,
@@ -113,7 +114,9 @@ def __init__(
113114
super().__init__(tb_args, extra_args)
114115
self.use_cuda_graphs = True
115116
addmm_args = parse_args(self.extra_args)
116-
if addmm_args.m and addmm_args.n and addmm_args.k:
117+
if tb_args.production_shapes:
118+
self.shapes = get_production_shapes(self.name, "fp8_gemm")
119+
elif addmm_args.m and addmm_args.n and addmm_args.k:
117120
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
118121
elif addmm_args.llama:
119122
self.shapes = gemm_shapes()

tritonbench/operators/gemm/operator.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
import torch._inductor.config as inductor_config
1010
import triton
11+
from tritonbench.utils.data_utils import get_production_shapes
1112

1213
from tritonbench.utils.path_utils import REPO_PATH
1314

@@ -129,7 +130,9 @@ def __init__(
129130
self.use_cuda_graphs = False
130131
gemm_args = parse_args(self.extra_args)
131132
self.layout = gemm_args.layout
132-
if gemm_args.input:
133+
if IS_FBCODE and tb_args.production_shapes:
134+
self.shapes = get_production_shapes(self.name, f"{tb_args.precision}_gemm")
135+
elif gemm_args.input:
133136
self.shapes = read_shapes_from_csv(gemm_args.input)
134137
elif gemm_args.splitk:
135138
self.shapes = SPLIT_K_SHAPES
@@ -286,7 +289,13 @@ def _scaled_randn(*args, scale: float, **kwargs) -> torch.Tensor:
286289

287290
def get_input_iter(self) -> Generator:
288291
for shape in self.shapes:
289-
m, n, k, bias = shape
292+
if len(shape) == 4:
293+
m, n, k, bias = shape
294+
elif len(shape) == 3:
295+
m, n, k = shape
296+
bias = None
297+
else:
298+
raise ValueError(f"Invalid shape {shape}")
290299
a = self._scaled_randn(
291300
(m, k), scale=k, device=self.device, dtype=self.dtype
292301
)

tritonbench/operators/softmax/operator.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
import torch
44
import triton
55
import triton.language as tl
6+
from tritonbench.utils.data_utils import get_production_shapes
67

78
from tritonbench.utils.triton_op import (
89
BenchmarkOperator,
910
BenchmarkOperatorMetrics,
11+
IS_FBCODE,
1012
register_benchmark,
1113
register_metric,
1214
)
@@ -101,13 +103,15 @@ def _inner():
101103

102104
def get_input_iter(self):
103105
M = 4096
104-
for i in range(2, 100):
105-
N = 128 * i
106+
shapes = (tuple(M, 128 * i) for i in range(2, 100))
107+
if IS_FBCODE and self.tb_args.production_shapes:
108+
shapes = get_production_shapes(self.name, "softmax")
109+
for M, N in shapes:
106110
yield (torch.randn([M, N], dtype=self.dtype, device=self.device),)
107111

108112
def get_x_val(self, example_inputs) -> int:
109113
shape = example_inputs[0].size()
110-
return shape[1]
114+
return [shape[0], shape[1]]
111115

112116
@register_metric()
113117
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:

tritonbench/utils/data_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from .triton_ops import IS_FBCODE
2+
3+
4+
def get_production_shapes(op_name, op_type):
5+
"""Gets a list of Softmax shapes for benchmarking"""
6+
if IS_FBCODE:
7+
from .fb.durin_data import productionDataLoader
8+
9+
return [
10+
shape
11+
for shape in productionDataLoader.get_shapes_from_frozen_durin(
12+
op_name, op_type
13+
)
14+
]

0 commit comments

Comments
 (0)