Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions tritonbench/metadata/oss_cuda_kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,6 @@ blackwell_attentions:
sdpa:
tags:
- pt2
tlx_blackwell_ws_pipelined_fwd:
tags:
- tlx
tlx_blackwell_ws_pipelined_persistent_fwd:
tags:
- tlx
Expand Down
18 changes: 0 additions & 18 deletions tritonbench/operators/blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,24 +683,6 @@ def fn(q, k, v):

return preproc_noop, fn

# Only works with triton beta, forward only.
@register_benchmark(enabled=HAS_TLX)
@multi_input_wrapper
def tlx_blackwell_ws_pipelined_fwd(self, *args) -> Tuple[Callable, Callable]:
if self.D_HEAD < 128:
raise NotImplementedError("TLX only supports d_head >= 128")

def fn(q, k, v):
return tlx_blackwell(
q,
k,
v,
self.sm_scale,
self.causal,
)

return preproc_noop, fn

# Only works with triton beta.
@register_benchmark(enabled=HAS_TLX)
@multi_input_wrapper
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .operator import Operator
117 changes: 117 additions & 0 deletions tritonbench/operators/mxfp8_blackwell_attentions/operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import math
from typing import Any, Callable, Generator, List, Optional, Tuple

import torch

from tritonbench.utils.env_utils import is_blackwell

try:
# @manual=//triton:triton
import triton.language.extra.tlx as tlx # type: ignore
from triton.language.extra.tlx.tutorials.blackwell_fa_ws_pipelined_persistent_mxfp8 import (
attention as tlx_mxfp8_attention,
generate_attention_inputs,
)

HAS_TLX_MXFP8 = True
except (ImportError, AttributeError):
HAS_TLX_MXFP8 = False

IS_BLACKWELL = is_blackwell()

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Mode as BenchmarkMode,
register_benchmark,
register_metric,
register_x_val,
)


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument("--batch", type=int, default=4, help="Batch size")
parser.add_argument("--seq-len", type=int, default=None, help="Sequence length")
parser.add_argument("--n-heads", type=int, default=48, help="Number of heads")
parser.add_argument(
"--d-head", type=int, default=128, help="Head dimension"
)
parser.add_argument(
"--causal", action="store_true", help="Enable causal masking"
)
return parser.parse_args(args)


class Operator(BenchmarkOperator):
DEFAULT_PRECISION = "bf16"
DEFAULT_METRICS = ["latency", "tflops"]

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
):
super().__init__(tb_args, extra_args)
args = parse_op_args(self.extra_args)
self.BATCH = args.batch
self.SEQ_LEN = args.seq_len
self.H = args.n_heads
self.D_HEAD = args.d_head
self.causal = args.causal
self.sm_scale = 1.0 / math.sqrt(self.D_HEAD)

def get_input_iter(self) -> Generator:
SEQ_LEN_LOG2 = 7
if self.SEQ_LEN is not None:
seq_lens = [self.SEQ_LEN]
else:
seq_lens = [2**i for i in range(SEQ_LEN_LOG2, 16)]

for seq_len in seq_lens:
shape = (self.BATCH, self.H, seq_len, self.D_HEAD)
(q_data, q_scale, _), (k_data, k_scale, _), (v_data, v_scale, _) = (
generate_attention_inputs(shape, self.device, torch.float8_e4m3fn)
)
yield (q_data, k_data, v_data, q_scale, k_scale, v_scale)

@register_benchmark(enabled=HAS_TLX_MXFP8 and IS_BLACKWELL)
def tlx_mxfp8_persistent(
self, q, k, v, q_scale, k_scale, v_scale
) -> Callable:
def fn():
return tlx_mxfp8_attention(
q, k, v, q_scale, k_scale, v_scale, self.sm_scale, self.causal
)

return fn

@register_x_val(label="(B, H, SeqLen, D)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int, int]:
q = example_inputs[0]
return (q.shape[0], q.shape[1], q.shape[2], q.shape[3])

@register_metric(x_only=True)
def flops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
q = example_inputs[0]
k = example_inputs[1]
BATCH, H, N_CTX, D_HEAD = q.shape
N_CTX_KV = k.shape[2]

flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX_KV * D_HEAD
flops = 2 * flops_per_matmul
if self.causal:
flops *= 0.5

if self.mode == BenchmarkMode.BWD:
flops *= 2.5
elif self.mode == BenchmarkMode.FWD_BWD:
flops *= 3.5
return flops
Loading