Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tritonbench/operators/addmm/operator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import itertools
import logging
from typing import Any, Callable, Generator, List, Optional, Tuple

import torch
Expand All @@ -16,6 +17,7 @@
with try_import("HAS_STREAMK"):
from tritonbench.operators.gemm.stream_k import streamk_cuda_matmul

from tritonbench.utils.diode_utils import setup_diode_model, teardown_diode_model
from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Expand Down Expand Up @@ -85,6 +87,9 @@

BATCH_SCALING_SHAPES = [(1 << i, 512, 512, False) for i in range(6, 21)]

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["tflops", "best_config"]
Expand Down Expand Up @@ -149,6 +154,24 @@ def pt2_addmm_maxautotune(self, a, mat1, mat2) -> Callable:
compiled(a, mat1, mat2)
return lambda: compiled(a, mat1, mat2)

@register_benchmark(enabled=False)
def pt2_addmm_maxautotune_diode(self, a, mat1, mat2) -> Callable:
torch._dynamo.reset()
logger.info("[DIODE][TritonBench] Run PT2 addmm Max-Autotune Diode benchmark")
old_diode_configs = setup_diode_model()

with inductor_config.patch(
max_autotune=True,
max_autotune_gemm_backends="ATEN,TRITON",
autotune_num_choices_displayed=None,
):
f = lambda a, mat1, mat2: torch.addmm(a, mat1, mat2)
compiled = torch.compile(f, dynamic=False)
compiled(a, mat1, mat2)

teardown_diode_model(old_diode_configs)
return lambda: compiled(a, mat1, mat2)

@register_metric()
def gbps(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down
26 changes: 26 additions & 0 deletions tritonbench/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import csv
import itertools
import logging
import os
from typing import Any, Callable, Generator, List, Optional, Tuple

Expand All @@ -19,6 +20,7 @@
blackwell_matmul_tma,
blackwell_matmul_tma_persistent,
)
from tritonbench.utils.diode_utils import setup_diode_model, teardown_diode_model
from tritonbench.utils.triton_utils import has_tlx

if has_tlx():
Expand Down Expand Up @@ -136,6 +138,9 @@ def _tlx_matmul(*args, **kwargs):

PERSISTENT_TUTORIAL_SHAPES = [(8192, 8192, 1 << k, None) for k in range(9, 15)]

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@contextlib.contextmanager
def set_env_variable(key, value):
Expand Down Expand Up @@ -365,6 +370,27 @@ def pt2_matmul_maxautotune(self, a, b, bias) -> Callable:

return lambda: compiled(a, b)

@register_benchmark(enabled=False)
def pt2_matmul_maxautotune_diode(self, a, b, bias) -> Callable:
torch._dynamo.reset()
logger.info("[DIODE][TritonBench] Run PT2 gemm Max-Autotune Diode benchmark")
old_diode_configs = setup_diode_model()

with inductor_config.patch(
max_autotune=True,
max_autotune_gemm_backends="ATEN,TRITON",
autotune_num_choices_displayed=self.inductor_autotune_num_choices_displayed,
):
if bias is not None:
f = lambda a, b: a.matmul(b) + bias
else:
f = lambda a, b: a.matmul(b)
compiled = torch.compile(f, dynamic=False)
compiled(a, b)

teardown_diode_model(old_diode_configs)
return lambda: compiled(a, b)

@register_benchmark(enabled=not is_cuda())
def streamk_matmul(self, a, b, bias) -> Callable:
return lambda: (
Expand Down
44 changes: 44 additions & 0 deletions tritonbench/utils/diode_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Diode (ML model for pruning autotuning configs) utils for TritonBench operators."""

import diode.torch_diode.config as diode_config
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might wanna try the with try_import("HAS_DIODE"): pattern here. https://github.com/meta-pytorch/tritonbench/blob/main/tritonbench/operators/gemm/operator.py#L34

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah second Nikhil here, diode will not always exist in OSS env and this will cause issues.

from diode.torch_diode.choices import DiodeInductorChoices
from diode.torch_diode.models.triton_gemm.model import (
GEMMModelV2,
MODEL_CONFIGS,
)
from diode.torch_diode.registry import register, get_registry
import logging
from torch._inductor.virtualized import V
from torch._inductor.choices import InductorChoices

DIODE_MODEL_CONFIGS_VERSION = "v3_12_04_2025"

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

def setup_diode_model(topk: int = 1, expand_search_space: bool = True) -> tuple[int, bool]:
logger.info("[DIODE][TritonBench] Setup Diode model.")

old_topk = diode_config.topk
old_expand_search_space = diode_config.expand_search_space

diode_config.topk = topk
diode_config.expand_search_space = expand_search_space

gemm_diode_model: GEMMModelV2 = GEMMModelV2(
model_config=MODEL_CONFIGS[DIODE_MODEL_CONFIGS_VERSION]
)
register(gemm_diode_model)

V.set_choices_handler(DiodeInductorChoices())

return old_topk, old_expand_search_space

def teardown_diode_model(old_configs):
logger.info("[DIODE][TritonBench] Teardown Diode model.")

old_topk, old_expand_search_space = old_configs
diode_config.topk = old_topk
diode_config.expand_search_space = old_expand_search_space
get_registry().clear()
V.set_choices_handler(InductorChoices())
Loading