Skip to content

Commit 492e5df

Browse files
levendleefacebook-github-bot
authored andcommitted
Triton based Gather/Scatter kernels runs on valid tokens.
Summary: Triton based Gather/Scatter kernels runs on valid tokens. Differential Revision: D75320859
1 parent aa2fe3d commit 492e5df

File tree

2 files changed

+121
-12
lines changed

2 files changed

+121
-12
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def gather_scale_dense_tokens(
2121
token_indices: torch.Tensor,
2222
expert_indices: torch.Tensor,
2323
scores: torch.Tensor,
24+
valid_token_count: Optional[torch.Tensor] = None,
2425
) -> torch.Tensor:
2526
T, D = x.shape
2627
E = scores.shape[1]
@@ -58,6 +59,7 @@ def gather_scale_dense_tokens(
5859
scores,
5960
stride_t,
6061
stride_e,
62+
valid_token_count,
6163
D, # pyre-ignore
6264
BLOCK_D_OUTER, # pyre-ignore
6365
BLOCK_D_INNER, # pyre-ignore
@@ -71,6 +73,7 @@ def gather_scale_quant_dense_tokens(
7173
expert_indices: torch.Tensor,
7274
scores: torch.Tensor,
7375
scale_ub: Optional[torch.Tensor] = None,
76+
valid_token_count: Optional[torch.Tensor] = None,
7477
) -> Tuple[torch.Tensor, torch.Tensor]:
7578
T, D = x.shape
7679
E = scores.shape[1]
@@ -104,6 +107,7 @@ def gather_scale_quant_dense_tokens(
104107
scale_ub,
105108
stride_t,
106109
stride_e,
110+
valid_token_count,
107111
D,
108112
TL_FP8_DTYPE=tl_dtype,
109113
MAX_FP8=max_fp8,
@@ -117,6 +121,7 @@ def scatter_add_dense_tokens(
117121
out_tokens: torch.Tensor, # [T, D]
118122
in_tokens: torch.Tensor, # [a, D]
119123
token_indices: torch.Tensor, # [a]
124+
valid_token_count: Optional[torch.Tensor] = None,
120125
) -> None:
121126
assert torch.version.hip is not None or (
122127
torch.version.cuda is not None and torch.version.cuda >= "12.4"
@@ -144,6 +149,7 @@ def scatter_add_dense_tokens(
144149
out_tokens,
145150
in_tokens,
146151
token_indices,
152+
valid_token_count,
147153
D, # pyre-ignore
148154
BLOCK_D_OUTER, # pyre-ignore
149155
BLOCK_D_INNER, # pyre-ignore
@@ -206,6 +212,7 @@ def gather_scale_dense_tokens_meta(
206212
token_indices,
207213
expert_indices,
208214
scores,
215+
valid_token_count=None,
209216
):
210217
D = x.shape[1]
211218
a = token_indices.shape[0]
@@ -218,12 +225,14 @@ def gather_scale_dense_tokens_cuda(
218225
token_indices,
219226
expert_indices,
220227
scores,
228+
valid_token_count=None,
221229
):
222230
return gather_scale_dense_tokens(
223231
x,
224232
token_indices,
225233
expert_indices,
226234
scores,
235+
valid_token_count,
227236
)
228237

229238

@@ -241,7 +250,8 @@ def gather_scale_quant_dense_tokens_meta(
241250
token_indices,
242251
expert_indices,
243252
scores,
244-
scale_ub,
253+
scale_ub=None,
254+
valid_token_count=None,
245255
):
246256
D = x.shape[1]
247257
a = token_indices.shape[0]
@@ -258,13 +268,15 @@ def gather_scale_quant_dense_tokens_cuda(
258268
expert_indices,
259269
scores,
260270
scale_ub=None,
271+
valid_token_count=None,
261272
):
262273
return gather_scale_quant_dense_tokens(
263274
x,
264275
token_indices,
265276
expert_indices,
266277
scores,
267278
scale_ub,
279+
valid_token_count,
268280
)
269281

270282

@@ -281,6 +293,7 @@ def scatter_add_dense_tokens_meta(
281293
out_tokens,
282294
in_tokens,
283295
token_indices,
296+
valid_token_count=None,
284297
):
285298
return None
286299

@@ -290,8 +303,11 @@ def scatter_add_dense_tokens_cuda(
290303
out_tokens,
291304
in_tokens,
292305
token_indices,
306+
valid_token_count=None,
293307
):
294-
return scatter_add_dense_tokens(out_tokens, in_tokens, token_indices)
308+
return scatter_add_dense_tokens(
309+
out_tokens, in_tokens, token_indices, valid_token_count
310+
)
295311

296312

297313
_SCATTER_ADD_PADDED_TOKENS_OP_NAME = "fbgemm::scatter_add_padded_tokens"
@@ -337,13 +353,21 @@ def _fbgemm_gather_scale_dense_tokens(
337353
scores,
338354
stride_t,
339355
stride_e,
356+
valid_token_count,
340357
D: tl.constexpr,
341358
BLOCK_D_OUTER: tl.constexpr,
342359
BLOCK_D_INNER: tl.constexpr,
343360
):
344361
output_token_index = tl.program_id(0)
345362
feature_offset = tl.program_id(1) * BLOCK_D_OUTER
346363

364+
if valid_token_count is not None:
365+
valid_token_count = tl.load(
366+
valid_token_count, None, eviction_policy="evict_last"
367+
)
368+
if output_token_index >= valid_token_count:
369+
return
370+
347371
input_token_index = tl.load(
348372
token_indices + output_token_index, None, eviction_policy="evict_last"
349373
)
@@ -383,13 +407,21 @@ def _fbgemm_scatter_add_dense_tokens(
383407
out_tokens,
384408
in_tokens,
385409
token_indices,
410+
valid_token_count,
386411
D: tl.constexpr,
387412
BLOCK_D_OUTER: tl.constexpr,
388413
BLOCK_D_INNER: tl.constexpr,
389414
):
390415
input_token_index = tl.program_id(0).to(tl.int64)
391416
feature_offset = tl.program_id(1) * BLOCK_D_OUTER + tl.arange(0, BLOCK_D_INNER)[:]
392417

418+
if valid_token_count is not None:
419+
valid_token_count = tl.load(
420+
valid_token_count, None, eviction_policy="evict_last"
421+
)
422+
if input_token_index >= valid_token_count:
423+
return
424+
393425
output_token_index = tl.load(
394426
token_indices + input_token_index, None, eviction_policy="evict_last"
395427
).to(tl.int64)
@@ -429,6 +461,7 @@ def _fbgemm_gather_scale_fp8_rowwise_quant_dense_tokens(
429461
scale_ub_ptr,
430462
stride_t,
431463
stride_e,
464+
valid_token_count,
432465
D: tl.constexpr,
433466
TL_FP8_DTYPE: tl.constexpr,
434467
MAX_FP8: tl.constexpr,
@@ -440,6 +473,13 @@ def _fbgemm_gather_scale_fp8_rowwise_quant_dense_tokens(
440473

441474
output_token_index = tl.program_id(0)
442475

476+
if valid_token_count is not None:
477+
valid_token_count = tl.load(
478+
valid_token_count, None, eviction_policy="evict_last"
479+
)
480+
if output_token_index >= valid_token_count:
481+
return
482+
443483
input_token_index = tl.load(
444484
token_indices_ptr + output_token_index, None, eviction_policy="evict_first"
445485
)

fbgemm_gpu/experimental/gen_ai/test/moe/gather_scatter_test.py

Lines changed: 79 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import logging
1111
import unittest
12-
from typing import Tuple
12+
from typing import Optional, Tuple
1313

1414
import torch
1515
import triton # noqa: F401
@@ -28,11 +28,9 @@
2828
from hypothesis import given, settings, strategies as st, Verbosity
2929

3030
try:
31-
# pyre-ignore[21]
3231
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
3332
from fbgemm_gpu import open_source
3433

35-
# pyre-ignore[21]
3634
# @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
3735
from fbgemm_gpu.docs.version import __version__ # noqa: F401
3836
except Exception:
@@ -58,6 +56,7 @@ class GatherScatterTests(unittest.TestCase):
5856
E=st.sampled_from([2, 4, 8]),
5957
T=st.sampled_from([1, 128, 2048, 4096, 16384]),
6058
D=st.sampled_from([5120, 7168]),
59+
partial=st.sampled_from([True, False]),
6160
rowmajor=st.sampled_from([True, False]),
6261
compiled=st.sampled_from([True, False]),
6362
)
@@ -67,6 +66,7 @@ def test_gather_scale_dense_tokens(
6766
E: int,
6867
T: int,
6968
D: int,
69+
partial: bool,
7070
rowmajor: bool,
7171
compiled: bool,
7272
) -> None:
@@ -78,6 +78,22 @@ def test_gather_scale_dense_tokens(
7878
token_indices: torch.Tensor = torch.randperm(T, device="cuda").to(torch.int32)
7979
scores: torch.Tensor = torch.rand((E, T), dtype=torch.bfloat16, device="cuda")
8080

81+
num_valid_tokens: int = T
82+
valid_token_count: Optional[torch.Tensor] = None
83+
partial_expert_indices: torch.Tensor = expert_indices
84+
partial_token_indices: torch.Tensor = token_indices
85+
if partial:
86+
num_valid_tokens = T // 2
87+
valid_token_count = torch.tensor(
88+
[num_valid_tokens], dtype=torch.int32, device="cuda"
89+
)
90+
partial_expert_indices = torch.where(
91+
torch.arange(T).cuda() < num_valid_tokens, expert_indices, -1
92+
)
93+
partial_token_indices = torch.where(
94+
torch.arange(T).cuda() < num_valid_tokens, token_indices, -1
95+
)
96+
8197
def torch_fn() -> torch.Tensor:
8298
shuffled_x = torch.index_select(x, dim=0, index=token_indices)
8399
shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
@@ -96,17 +112,26 @@ def triton_fn() -> torch.Tensor:
96112
op = gather_scale_dense_tokens
97113
if compiled:
98114
op = torch.compile(op)
99-
test_output = op(x, token_indices, expert_indices, scores_)
115+
test_output = op(
116+
x,
117+
partial_token_indices,
118+
partial_expert_indices,
119+
scores_,
120+
valid_token_count,
121+
)
100122
return test_output
101123

102124
test_output = triton_fn()
103125

104-
torch.testing.assert_close(torch_output, test_output)
126+
torch.testing.assert_close(
127+
torch_output[:num_valid_tokens], test_output[:num_valid_tokens]
128+
)
105129

106130
@given(
107131
E=st.sampled_from([2, 4, 8]),
108132
T=st.sampled_from([1, 128, 2048, 4096, 16384]),
109133
D=st.sampled_from([5120, 7168]),
134+
partial=st.sampled_from([True, False]),
110135
rowmajor=st.sampled_from([True, False]),
111136
compiled=st.sampled_from([True, False]),
112137
)
@@ -116,6 +141,7 @@ def test_gather_scale_quant_dense_tokens(
116141
E: int,
117142
T: int,
118143
D: int,
144+
partial: bool,
119145
rowmajor: bool,
120146
compiled: bool,
121147
) -> None:
@@ -126,9 +152,24 @@ def test_gather_scale_quant_dense_tokens(
126152
expert_indices: torch.Tensor = torch.randint(0, E, (T,), device="cuda")
127153
token_indices: torch.Tensor = torch.randperm(T, device="cuda").to(torch.int32)
128154
scores: torch.Tensor = torch.randn((E, T), dtype=torch.bfloat16, device="cuda")
129-
130155
scale_ub = torch.tensor([1200], dtype=torch.float, device="cuda")
131156

157+
num_valid_tokens: int = T
158+
valid_token_count: Optional[torch.Tensor] = None
159+
partial_expert_indices: torch.Tensor = expert_indices
160+
partial_token_indices: torch.Tensor = token_indices
161+
if partial:
162+
num_valid_tokens = T // 2
163+
valid_token_count = torch.tensor(
164+
[num_valid_tokens], dtype=torch.int32, device="cuda"
165+
)
166+
partial_expert_indices = torch.where(
167+
torch.arange(T).cuda() < num_valid_tokens, expert_indices, -1
168+
)
169+
partial_token_indices = torch.where(
170+
torch.arange(T).cuda() < num_valid_tokens, token_indices, -1
171+
)
172+
132173
def torch_fn() -> Tuple[torch.Tensor, torch.Tensor]:
133174
shuffled_x = torch.index_select(x, dim=0, index=token_indices)
134175
shuffled_scores = torch.index_select(scores, dim=1, index=token_indices)
@@ -156,25 +197,37 @@ def triton_fn() -> Tuple[torch.Tensor, torch.Tensor]:
156197
if compiled:
157198
op = torch.compile(op)
158199
test_output_q, test_output_scales = op(
159-
x, token_indices, expert_indices, scores_, scale_ub
200+
x,
201+
partial_token_indices,
202+
partial_expert_indices,
203+
scores_,
204+
scale_ub,
205+
valid_token_count,
160206
)
161207
return test_output_q, test_output_scales
162208

163209
test_output_q, test_output_scales = triton_fn()
164210
test_output = test_output_q.to(torch.float32) * test_output_scales.view(-1, 1)
165211

166-
torch.testing.assert_close(torch_output, test_output, atol=1e-3, rtol=1.6e-2)
212+
torch.testing.assert_close(
213+
torch_output[:num_valid_tokens],
214+
test_output[:num_valid_tokens],
215+
atol=1e-3,
216+
rtol=1.6e-2,
217+
)
167218

168219
@given(
169220
num_tokens=st.sampled_from([1, 128, 2048, 4096, 16384]),
170221
dim=st.sampled_from([5120]),
222+
partial=st.sampled_from([True, False]),
171223
compiled=st.sampled_from([True, False]),
172224
)
173225
@settings(verbosity=Verbosity.verbose, max_examples=_MAX_SAMPLES, deadline=None)
174226
def test_scatter_add_dense_tokens(
175227
self,
176228
num_tokens: int,
177229
dim: int,
230+
partial: bool,
178231
compiled: bool,
179232
) -> None:
180233
torch.manual_seed(0)
@@ -190,6 +243,18 @@ def test_scatter_add_dense_tokens(
190243
torch.int32
191244
)
192245

246+
num_valid_tokens: int = num_tokens
247+
valid_token_count: Optional[torch.Tensor] = None
248+
partial_token_indices: torch.Tensor = token_indices
249+
if partial:
250+
num_valid_tokens = num_tokens // 2
251+
valid_token_count = torch.tensor(
252+
[num_valid_tokens], dtype=torch.int32, device="cuda"
253+
)
254+
partial_token_indices = torch.where(
255+
torch.arange(num_tokens).cuda() < num_valid_tokens, token_indices, -1
256+
)
257+
193258
test_out_tokens: torch.Tensor = out_tokens.clone()
194259
ref_out_tokens: torch.Tensor = out_tokens.clone()
195260

@@ -201,11 +266,12 @@ def fn() -> None:
201266
test_out_tokens,
202267
in_tokens,
203268
token_indices,
269+
valid_token_count,
204270
)
205271

206272
fn()
207273

208-
token_indices: torch.Tensor = token_indices.to(torch.int64)
274+
token_indices: torch.Tensor = token_indices[:num_valid_tokens].to(torch.int64)
209275

210276
def ref_fn() -> None:
211277
ref_out_tokens.scatter_add_(
@@ -217,7 +283,10 @@ def ref_fn() -> None:
217283
ref_fn()
218284

219285
torch.testing.assert_close(
220-
test_out_tokens, ref_out_tokens, atol=1e-3, rtol=1.6e-2
286+
test_out_tokens[:num_valid_tokens],
287+
ref_out_tokens[:num_valid_tokens],
288+
atol=1e-3,
289+
rtol=1.6e-2,
221290
)
222291

223292
@given(

0 commit comments

Comments
 (0)