Skip to content

Commit 0a91ed0

Browse files
omarpavelmetafacebook-github-bot
authored andcommitted
Add inductor_flex_attention_bwd operator
Summary: Add TritonBench operator to benchmark the flex attention backward pass inductor kernel (triton_tem_fused_flex_attention_backward_zeros_1). Uses FWD_ONLY=True but manually times backward via output.backward(dy, retain_graph=True). Compares aten (eager) vs inductor (torch.compile). Backward FLOP count uses 2.5x multiplier (2.0 bwd + 0.5 recompute). Default config: B=8, H=16, D=128, bf16, requires_grad=True on q/k/v. Reviewed By: stashuk-olek Differential Revision: D95461827
1 parent 15ea2e3 commit 0a91ed0

2 files changed

Lines changed: 145 additions & 0 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .operator import Operator
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import argparse
2+
from typing import Callable, Generator, List, Optional, Tuple
3+
4+
import torch
5+
from torch.nn.attention.flex_attention import (
6+
BlockMask,
7+
create_block_mask,
8+
flex_attention,
9+
)
10+
11+
from tritonbench.operators.flex_attention.mods import causal_mask
12+
from tritonbench.utils.triton_op import (
13+
BenchmarkOperator,
14+
BenchmarkOperatorMetrics,
15+
register_benchmark,
16+
register_metric,
17+
register_x_val,
18+
)
19+
20+
21+
torch._dynamo.config.automatic_dynamic_shapes = False
22+
23+
24+
def parse_op_args(args: List[str]):
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument("--batch", type=int, default=8, help="Batch size")
27+
parser.add_argument("--n-heads", type=int, default=16, help="Number of heads")
28+
parser.add_argument("--d-head", type=int, default=128, help="Head dimension")
29+
parser.add_argument(
30+
"--seq-len", type=int, default=None, help="Fixed sequence length"
31+
)
32+
return parser.parse_args(args)
33+
34+
35+
class Operator(BenchmarkOperator):
36+
DEFAULT_METRICS = ["latency", "speedup", "tflops"]
37+
DEFAULT_PRECISION = "bf16"
38+
FWD_ONLY = True # We handle backward timing internally
39+
is_compute_bound = True
40+
41+
def __init__(
42+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
43+
):
44+
super().__init__(tb_args, extra_args)
45+
args = parse_op_args(self.extra_args)
46+
self.batch_size = args.batch
47+
self.num_heads = args.n_heads
48+
self.head_dim = args.d_head
49+
self.seq_len = args.seq_len
50+
51+
@register_x_val(label="(B, H, S, D)")
52+
def get_x_val(self, example_inputs) -> str:
53+
q, k, v, block_mask = example_inputs
54+
B, H, S, D = q.shape
55+
return f"({B}, {H}, {S}, {D})"
56+
57+
@register_benchmark(baseline=True)
58+
def aten(
59+
self,
60+
q: torch.Tensor,
61+
k: torch.Tensor,
62+
v: torch.Tensor,
63+
block_mask: Optional[BlockMask],
64+
) -> Callable:
65+
output = flex_attention(q, k, v, block_mask=block_mask)
66+
dy = torch.randn_like(output)
67+
# Warmup backward
68+
output.backward(dy, retain_graph=True)
69+
70+
def bwd_fn():
71+
for t in [q, k, v]:
72+
t.grad = None
73+
output.backward(dy, retain_graph=True)
74+
75+
return bwd_fn
76+
77+
@register_benchmark()
78+
def inductor(
79+
self,
80+
q: torch.Tensor,
81+
k: torch.Tensor,
82+
v: torch.Tensor,
83+
block_mask: Optional[BlockMask],
84+
) -> Callable:
85+
compiled_fn = torch.compile(flex_attention, fullgraph=True)
86+
output = compiled_fn(q, k, v, block_mask=block_mask)
87+
dy = torch.randn_like(output)
88+
# Warmup backward
89+
output.backward(dy, retain_graph=True)
90+
91+
def bwd_fn():
92+
for t in [q, k, v]:
93+
t.grad = None
94+
output.backward(dy, retain_graph=True)
95+
96+
return bwd_fn
97+
98+
@register_metric()
99+
def tflops(
100+
self, fn_name: str, example_inputs: Tuple, metrics: BenchmarkOperatorMetrics
101+
):
102+
q, k, v, block_mask = example_inputs
103+
B, H, S, D = q.shape
104+
105+
# Backward is ~2.5x forward FLOPs (2.0 bwd + 0.5 recompute)
106+
# Forward: 2 * B * H * S^2 * D * 2 (QK + OV matmuls)
107+
flops = 2.5 * 2.0 * B * H * S * S * D * 2
108+
109+
# Adjust for block sparsity
110+
if block_mask is not None:
111+
sparsity = block_mask.sparsity() / 100.0
112+
flops *= 1 - sparsity
113+
114+
tflops = flops / metrics.latency / 1e12
115+
return (
116+
tflops,
117+
flops / metrics.latency.max / 1e12,
118+
flops / metrics.latency.min / 1e12,
119+
)
120+
121+
def get_input_iter(self) -> Generator:
122+
B = self.batch_size
123+
H = self.num_heads
124+
D = self.head_dim
125+
126+
if self.seq_len:
127+
seq_lens = [self.seq_len]
128+
else:
129+
seq_lens = [2**i for i in range(7, 15)] # 128 to 16384
130+
131+
compiled_block_mask = torch.compile(create_block_mask)
132+
133+
for S in seq_lens:
134+
q = torch.randn(
135+
B, H, S, D, device=self.device, dtype=self.dtype, requires_grad=True
136+
)
137+
k = torch.randn(
138+
B, H, S, D, device=self.device, dtype=self.dtype, requires_grad=True
139+
)
140+
v = torch.randn(
141+
B, H, S, D, device=self.device, dtype=self.dtype, requires_grad=True
142+
)
143+
block_mask = compiled_block_mask(causal_mask, 1, 1, S, S, device=self.device)
144+
yield q, k, v, block_mask

0 commit comments

Comments
 (0)