Skip to content

Commit cdcb2bd

Browse files
jananisriramfacebook-github-bot
authored andcommitted
Add Diode Inductor max-autotune benchmarks to gemm, addmm, bmm
Summary: Add Inductor max-autotune benchmarking using Diode to TritonBench for the `gemm`, `addmm`, and `bmm` operators. This diff reduces friction in verifying Diode results against the base max-autotune results. NOTE: The TritonBench benchmark currently includes the overhead of setting up and tearing down Diode within the benchmark, i.e. metrics like tflops and latency. Actual Diode benchmarking numbers should be slightly better than what TritonBench reports. Differential Revision: D89398916
1 parent af04af0 commit cdcb2bd

3 files changed

Lines changed: 93 additions & 0 deletions

File tree

tritonbench/operators/addmm/operator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import itertools
3+
import logging
34
from typing import Any, Callable, Generator, List, Optional, Tuple
45

56
import torch
@@ -16,6 +17,7 @@
1617
with try_import("HAS_STREAMK"):
1718
from tritonbench.operators.gemm.stream_k import streamk_cuda_matmul
1819

20+
from tritonbench.utils.diode_utils import setup_diode_model, teardown_diode_model
1921
from tritonbench.utils.triton_op import (
2022
BenchmarkOperator,
2123
BenchmarkOperatorMetrics,
@@ -85,6 +87,9 @@
8587

8688
BATCH_SCALING_SHAPES = [(1 << i, 512, 512, False) for i in range(6, 21)]
8789

90+
logger = logging.getLogger(__name__)
91+
logger.setLevel(logging.INFO)
92+
8893

8994
class Operator(BenchmarkOperator):
9095
DEFAULT_METRICS = ["tflops", "best_config"]
@@ -149,6 +154,24 @@ def pt2_addmm_maxautotune(self, a, mat1, mat2) -> Callable:
149154
compiled(a, mat1, mat2)
150155
return lambda: compiled(a, mat1, mat2)
151156

157+
@register_benchmark(enabled=False)
158+
def pt2_addmm_maxautotune_diode(self, a, mat1, mat2) -> Callable:
159+
torch._dynamo.reset()
160+
logger.info("[DIODE][TritonBench] Run PT2 addmm Max-Autotune Diode benchmark")
161+
old_diode_configs = setup_diode_model()
162+
163+
with inductor_config.patch(
164+
max_autotune=True,
165+
max_autotune_gemm_backends="ATEN,TRITON",
166+
autotune_num_choices_displayed=None,
167+
):
168+
f = lambda a, mat1, mat2: torch.addmm(a, mat1, mat2)
169+
compiled = torch.compile(f, dynamic=False)
170+
compiled(a, mat1, mat2)
171+
172+
teardown_diode_model(old_diode_configs)
173+
return lambda: compiled(a, mat1, mat2)
174+
152175
@register_metric()
153176
def gbps(
154177
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics

tritonbench/operators/gemm/operator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import contextlib
33
import csv
44
import itertools
5+
import logging
56
import os
67
from typing import Any, Callable, Generator, List, Optional, Tuple
78

@@ -19,6 +20,7 @@
1920
blackwell_matmul_tma,
2021
blackwell_matmul_tma_persistent,
2122
)
23+
from tritonbench.utils.diode_utils import setup_diode_model, teardown_diode_model
2224
from tritonbench.utils.triton_utils import has_tlx
2325

2426
if has_tlx():
@@ -136,6 +138,9 @@ def _tlx_matmul(*args, **kwargs):
136138

137139
PERSISTENT_TUTORIAL_SHAPES = [(8192, 8192, 1 << k, None) for k in range(9, 15)]
138140

141+
logger = logging.getLogger(__name__)
142+
logger.setLevel(logging.INFO)
143+
139144

140145
@contextlib.contextmanager
141146
def set_env_variable(key, value):
@@ -365,6 +370,27 @@ def pt2_matmul_maxautotune(self, a, b, bias) -> Callable:
365370

366371
return lambda: compiled(a, b)
367372

373+
@register_benchmark(enabled=False)
374+
def pt2_matmul_maxautotune_diode(self, a, b, bias) -> Callable:
375+
torch._dynamo.reset()
376+
logger.info("[DIODE][TritonBench] Run PT2 gemm Max-Autotune Diode benchmark")
377+
old_diode_configs = setup_diode_model()
378+
379+
with inductor_config.patch(
380+
max_autotune=True,
381+
max_autotune_gemm_backends="ATEN,TRITON",
382+
autotune_num_choices_displayed=self.inductor_autotune_num_choices_displayed,
383+
):
384+
if bias is not None:
385+
f = lambda a, b: a.matmul(b) + bias
386+
else:
387+
f = lambda a, b: a.matmul(b)
388+
compiled = torch.compile(f, dynamic=False)
389+
compiled(a, b)
390+
391+
teardown_diode_model(old_diode_configs)
392+
return lambda: compiled(a, b)
393+
368394
@register_benchmark(enabled=not is_cuda())
369395
def streamk_matmul(self, a, b, bias) -> Callable:
370396
return lambda: (

tritonbench/utils/diode_utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Diode (ML model for pruning autotuning configs) utils for TritonBench operators."""
2+
3+
import diode.torch_diode.config as diode_config
4+
from diode.torch_diode.choices import DiodeInductorChoices
5+
from diode.torch_diode.models.triton_gemm.model import (
6+
GEMMModelV2,
7+
MODEL_CONFIGS,
8+
)
9+
from diode.torch_diode.registry import register, get_registry
10+
import logging
11+
from torch._inductor.virtualized import V
12+
from torch._inductor.choices import InductorChoices
13+
14+
DIODE_MODEL_CONFIGS_VERSION = "v3_12_04_2025"
15+
16+
logger = logging.getLogger(__name__)
17+
logger.setLevel(logging.INFO)
18+
19+
def setup_diode_model(topk: int = 1, expand_search_space: bool = True) -> tuple[int, bool]:
20+
logger.info("[DIODE][TritonBench] Setup Diode model.")
21+
22+
old_topk = diode_config.topk
23+
old_expand_search_space = diode_config.expand_search_space
24+
25+
diode_config.topk = topk
26+
diode_config.expand_search_space = expand_search_space
27+
28+
gemm_diode_model: GEMMModelV2 = GEMMModelV2(
29+
model_config=MODEL_CONFIGS[DIODE_MODEL_CONFIGS_VERSION]
30+
)
31+
register(gemm_diode_model)
32+
33+
V.set_choices_handler(DiodeInductorChoices())
34+
35+
return old_topk, old_expand_search_space
36+
37+
def teardown_diode_model(old_configs):
38+
logger.info("[DIODE][TritonBench] Teardown Diode model.")
39+
40+
old_topk, old_expand_search_space = old_configs
41+
diode_config.topk = old_topk
42+
diode_config.expand_search_space = old_expand_search_space
43+
get_registry().clear()
44+
V.set_choices_handler(InductorChoices())

0 commit comments

Comments
 (0)