Skip to content

Commit c2ef66a

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Update HSTU and use the OSS wrapper for non-persisent kernels (#53)
Summary: Add the following features to ragged_hstu: 1. Add tflops metric 2. Use _RaggedAttentionRelativeBiasFunction to wrap the Triton kernel 3. Add backward Pull Request resolved: #53 Test Plan: ``` $ python run.py --op ragged_attention --metrics latency,tflops --mode bwd x_val hstu_triton_ragged_attention-tflops hstu_triton_ragged_attention-latency ----------------- ------------------------------------- -------------------------------------- (8, 4, 512, 2048) 0.306747 2.81939 (8, 4, 512, 2048) 1.65614 0.867936 (8, 4, 512, 2048) 2.00125 0.84768 (8, 4, 512, 2048) 2.13756 0.991968 (8, 4, 512, 2048) 1.96315 0.902976 (8, 4, 512, 2048) 1.50214 0.836192 (8, 4, 512, 2048) 1.34825 0.859936 (8, 4, 512, 2048) 1.90546 0.97408 (8, 4, 512, 2048) 1.72114 0.902368 (8, 4, 512, 2048) 2.30999 1.01107 ``` Reviewed By: manman-ren Differential Revision: D66021701 Pulled By: xuzhao9 fbshipit-source-id: b0d9f32d49e02c113e4aafa597be68c17d952283
1 parent c08a2a8 commit c2ef66a

File tree

6 files changed

+166
-51
lines changed

6 files changed

+166
-51
lines changed

install.py

+6
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def install_xformers():
114114
parser.add_argument(
115115
"--fa3", action="store_true", help="Install optional flash_attention 3 kernels"
116116
)
117+
parser.add_argument("--hstu", action="store_true", help="Install HSTU.")
117118
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
118119
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
119120
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
@@ -153,6 +154,11 @@ def install_xformers():
153154
if args.xformers or args.all:
154155
logger.info("[tritonbench] installing xformers...")
155156
install_xformers()
157+
if args.hstu or args.all:
158+
logger.info("[tritonbench] installing hstu...")
159+
from tools.hstu.install import install_hstu
160+
161+
install_hstu()
156162
logger.info("[tritonbench] installation complete!")
157163
# run tests to check installation
158164
if args.test:

tools/hstu/hstu.patch

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
diff --git a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
2+
index b4e318b..d6bc894 100644
3+
--- a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
4+
+++ b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
5+
@@ -36,7 +36,7 @@ try:
6+
VersionedSpec,
7+
)
8+
except ImportError:
9+
- from hammer.oss.generative_recommenders.ops.triton.utils import (
10+
+ from generative_recommenders.ops.triton.utils import (
11+
_switch_to_contiguous_if_needed,
12+
autotune_max_seq_len,
13+
NamedSpecType,

tools/hstu/install.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import subprocess
3+
import sys
4+
from pathlib import Path
5+
6+
PATCH_DIR = str(
7+
Path(__file__)
8+
.parent.parent.parent.joinpath("submodules", "generative-recommenders")
9+
.absolute()
10+
)
11+
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hstu.patch")
12+
13+
14+
def install_hstu():
15+
try:
16+
subprocess.check_output(
17+
[
18+
"patch",
19+
"-p1",
20+
"--forward",
21+
"-i",
22+
PATCH_FILE,
23+
"-r",
24+
"/tmp/rej",
25+
],
26+
cwd=PATCH_DIR,
27+
)
28+
except subprocess.SubprocessError as e:
29+
output_str = str(e.output)
30+
if "previously applied" in output_str:
31+
return
32+
else:
33+
print(str(output_str))
34+
sys.exit(1)

tritonbench/operators/ragged_attention/hstu.py

+50-46
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,15 @@
1313
)
1414
except ModuleNotFoundError:
1515
# OSS Import
16-
import importlib
16+
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
17+
from generative_recommenders.ops.triton import triton_ragged_hstu_attention
1718

18-
with add_path(str(SUBMODULE_PATH)):
19-
triton_ragged_hstu_attention = importlib.import_module(
20-
"generative-recommenders.ops.triton.triton_ragged_hstu_attention"
21-
)
22-
_ragged_hstu_attn_fwd = triton_ragged_hstu_attention._ragged_hstu_attn_fwd
2319
_ragged_hstu_attn_fwd_persistent = (
2420
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
2521
)
22+
_RaggedAttentionRelativeBiasFunction = (
23+
triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction
24+
)
2625

2726
@torch.fx.wrap
2827
def prev_power_of_2(x: int) -> int:
@@ -47,6 +46,7 @@ def __init__(
4746
num_heads,
4847
max_seq_len,
4948
num_buckets,
49+
requires_grad,
5050
persistent_kernel: bool = False,
5151
) -> None:
5252
super().__init__()
@@ -58,13 +58,17 @@ def __init__(
5858
torch.randn(
5959
(self.num_buckets + 1,),
6060
dtype=torch.bfloat16,
61-
).cuda()
61+
)
62+
.requires_grad_(requires_grad)
63+
.cuda()
6264
)
6365
self.all_pos_weights = torch.nn.Parameter(
6466
torch.randn(
6567
(2 * self.max_seq_len - 1,),
6668
dtype=torch.bfloat16,
67-
).cuda()
69+
)
70+
.requires_grad_(requires_grad)
71+
.cuda()
6872
)
6973
self.persistent_kernel = persistent_kernel
7074

@@ -141,57 +145,57 @@ def forward(
141145
"HAS_SORT_BY_LENGTH_INDICES": False,
142146
"sort_by_length_indices": None,
143147
}
144-
if not IS_FBCODE:
145-
del kwargs["MAX_ATTN_LEN"]
146-
del kwargs["HAS_CONTEXTUAL_SEQ_LEN"]
147-
del kwargs["contextual_seq_len"]
148-
del kwargs["HAS_SORT_BY_LENGTH_INDICES"]
149-
del kwargs["sort_by_length_indices"]
150-
kwargs["HAS_MAX_ATTN_LEN"] = False
151-
kwargs["max_attn_len"] = 0
152148

153149
if self.persistent_kernel:
154150
grid = (1216,)
155151
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
156152
else:
157-
grid = lambda meta: ( # noqa E731
158-
triton.cdiv(N, meta["BLOCK_M"]),
159-
Z * H,
153+
out = _RaggedAttentionRelativeBiasFunction.apply(
154+
self.max_seq_len, # N
155+
kwargs["alpha"],
156+
q,
157+
k,
158+
v,
159+
kwargs["seq_offsets"],
160+
kwargs["INVALID_MASK_TYPE"],
161+
timestamps,
162+
self.all_ts_weights, # ts_weights
163+
self.all_pos_weights, # pos_weights
164+
kwargs["CAUSAL"], # causal,
165+
kwargs["num_buckets"], # num_buckets
166+
"sqrt", # time_bucket_fn
167+
kwargs["time_bucket_incr"], # time_bucket_incr
168+
kwargs["time_bucket_div"], # time_bucket_div
169+
kwargs["time_delta"], # time_delta
170+
kwargs["max_pos_ind"], # max_pos_ind
171+
kwargs["num_targets"],
172+
None, # attn_scale
173+
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
174+
kwargs["MAX_ATTN_LEN"], # max_attn_len
175+
kwargs["contextual_seq_len"], # contextual_seq_len
176+
kwargs["sort_by_length_indices"], # sort_by_length
160177
)
161-
_ragged_hstu_attn_fwd[grid](**kwargs)
162178

163179
return out
164180

165181

166182
def get_test_inputs(
167-
batch_size, num_heads, max_seq_len
183+
batch_size, num_heads, max_seq_len, requires_grad
168184
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169-
timestamp_deltas: torch.Tensor = (
170-
torch.randint(
171-
86400,
172-
size=(batch_size, max_seq_len + 1),
173-
)
174-
.requires_grad_(False)
175-
.cuda()
176-
)
185+
timestamp_deltas: torch.Tensor = torch.randint(
186+
86400,
187+
size=(batch_size, max_seq_len + 1),
188+
).cuda()
177189
timestamps = timestamp_deltas.cumsum(dim=1)
178190

179-
lengths = (
180-
torch.randint(
181-
max_seq_len + 1,
182-
size=(batch_size,),
183-
)
184-
.requires_grad_(False)
185-
.cuda()
186-
)
187-
seq_offsets = (
188-
torch.zeros(
189-
(batch_size + 1,),
190-
dtype=torch.int64,
191-
)
192-
.requires_grad_(False)
193-
.cuda()
194-
)
191+
lengths = torch.randint(
192+
max_seq_len + 1,
193+
size=(batch_size,),
194+
).cuda()
195+
seq_offsets = torch.zeros(
196+
(batch_size + 1,),
197+
dtype=torch.int64,
198+
).cuda()
195199
seq_offsets[1:] = torch.cumsum(
196200
lengths,
197201
dim=0,
@@ -203,7 +207,7 @@ def get_test_inputs(
203207
(L, num_heads, 512),
204208
dtype=torch.bfloat16,
205209
)
206-
.requires_grad_(False)
210+
.requires_grad_(requires_grad)
207211
.cuda()
208212
)
209213
return qkv, seq_offsets, timestamps

tritonbench/operators/ragged_attention/operator.py

+62-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,17 @@
11
import argparse
22

3-
from typing import List, Optional
3+
from typing import Any, Callable, List, Optional
44

5-
from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark
5+
import torch
6+
from tritonbench.utils.input import input_filter
7+
8+
from tritonbench.utils.triton_op import (
9+
BenchmarkOperator,
10+
BenchmarkOperatorMetrics,
11+
Mode,
12+
register_benchmark,
13+
register_metric,
14+
)
615

716
from .hstu import get_test_inputs, RaggedHSTUAttn
817

@@ -30,6 +39,7 @@ def __init__(
3039
self.num_buckets = args.num_buckets
3140
# set a default number of inputs
3241
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
42+
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
3343

3444
@register_benchmark()
3545
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
@@ -38,17 +48,20 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
3848
self.num_heads,
3949
self.max_seq_len,
4050
self.num_buckets,
51+
self.requires_grad,
4152
persistent_kernel=False,
4253
)
4354
return lambda: attn(qkv, seq_offsets, timestamps)
4455

45-
@register_benchmark()
56+
# TODO: enable persistent kernels when the OSS backward is ready
57+
@register_benchmark(enabled=False)
4658
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
4759
attn = RaggedHSTUAttn(
4860
self.batch_size,
4961
self.num_heads,
5062
self.max_seq_len,
5163
self.num_buckets,
64+
self.requires_grad,
5265
persistent_kernel=True,
5366
)
5467
return lambda: attn(qkv, seq_offsets, timestamps)
@@ -58,5 +71,50 @@ def get_x_val(self, example_inputs):
5871

5972
def get_input_iter(self):
6073
for _input_id in range(self._num_inputs):
61-
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
74+
inputs = get_test_inputs(
75+
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
76+
)
6277
yield inputs
78+
79+
def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
80+
o = fwd_fn()
81+
o_tensor = input_filter(
82+
lambda x: isinstance(x, torch.Tensor),
83+
o,
84+
)
85+
do = torch.rand_like(o_tensor)
86+
fn = lambda: o_tensor.backward(do, retain_graph=True)
87+
return fn
88+
89+
@register_metric()
90+
def tflops(
91+
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
92+
) -> float:
93+
ratio = 2.0 # triangular masking
94+
f1 = 0.0
95+
f2 = 0.0
96+
jagged = True
97+
qkv, seq_offsets, timestamps = example_inputs
98+
q = qkv[:, :, :128]
99+
v = qkv[:, :, 256:384]
100+
_, nheads, attn_dim = q.shape
101+
_, _, hidden_dim = v.shape
102+
max_seqlen = timestamps.size(1) - 1
103+
104+
for i in range(self.batch_size):
105+
seq_len = (
106+
int((seq_offsets[i + 1] - seq_offsets[i]).item())
107+
if jagged
108+
else max_seqlen
109+
)
110+
# (QK^T), dQ = d(QK^T)K, dK^T = Q^Td(QK^T)
111+
f1 += 2 * self.num_heads * attn_dim * seq_len**2 // ratio
112+
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
113+
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
114+
if self.mode == Mode.FWD:
115+
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
116+
elif self.mode == Mode.BWD:
117+
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
118+
elif self.mode == Mode.FWD_BWD:
119+
tflops = 4 * f1 + 3 * f2
120+
return tflops / metrics.latency * 1e-9

0 commit comments

Comments
 (0)