99
1010import logging
1111import unittest
12- from typing import Tuple
12+ from typing import Optional , Tuple
1313
1414import torch
1515import triton # noqa: F401
2828from hypothesis import given , settings , strategies as st , Verbosity
2929
3030try :
31- # pyre-ignore[21]
3231 # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
3332 from fbgemm_gpu import open_source
3433
35- # pyre-ignore[21]
3634 # @manual=//deeplearning/fbgemm/fbgemm_gpu:test_utils
3735 from fbgemm_gpu .docs .version import __version__ # noqa: F401
3836except Exception :
@@ -58,6 +56,7 @@ class GatherScatterTests(unittest.TestCase):
5856 E = st .sampled_from ([2 , 4 , 8 ]),
5957 T = st .sampled_from ([1 , 128 , 2048 , 4096 , 16384 ]),
6058 D = st .sampled_from ([5120 , 7168 ]),
59+ partial = st .sampled_from ([True , False ]),
6160 rowmajor = st .sampled_from ([True , False ]),
6261 compiled = st .sampled_from ([True , False ]),
6362 )
@@ -67,6 +66,7 @@ def test_gather_scale_dense_tokens(
6766 E : int ,
6867 T : int ,
6968 D : int ,
69+ partial : bool ,
7070 rowmajor : bool ,
7171 compiled : bool ,
7272 ) -> None :
@@ -78,6 +78,22 @@ def test_gather_scale_dense_tokens(
7878 token_indices : torch .Tensor = torch .randperm (T , device = "cuda" ).to (torch .int32 )
7979 scores : torch .Tensor = torch .rand ((E , T ), dtype = torch .bfloat16 , device = "cuda" )
8080
81+ num_valid_tokens : int = T
82+ valid_token_count : Optional [torch .Tensor ] = None
83+ partial_expert_indices : torch .Tensor = expert_indices
84+ partial_token_indices : torch .Tensor = token_indices
85+ if partial :
86+ num_valid_tokens = T // 2
87+ valid_token_count = torch .tensor (
88+ [num_valid_tokens ], dtype = torch .int32 , device = "cuda"
89+ )
90+ partial_expert_indices = torch .where (
91+ torch .arange (T ).cuda () < num_valid_tokens , expert_indices , - 1
92+ )
93+ partial_token_indices = torch .where (
94+ torch .arange (T ).cuda () < num_valid_tokens , token_indices , - 1
95+ )
96+
8197 def torch_fn () -> torch .Tensor :
8298 shuffled_x = torch .index_select (x , dim = 0 , index = token_indices )
8399 shuffled_scores = torch .index_select (scores , dim = 1 , index = token_indices )
@@ -96,17 +112,26 @@ def triton_fn() -> torch.Tensor:
96112 op = gather_scale_dense_tokens
97113 if compiled :
98114 op = torch .compile (op )
99- test_output = op (x , token_indices , expert_indices , scores_ )
115+ test_output = op (
116+ x ,
117+ partial_token_indices ,
118+ partial_expert_indices ,
119+ scores_ ,
120+ valid_token_count ,
121+ )
100122 return test_output
101123
102124 test_output = triton_fn ()
103125
104- torch .testing .assert_close (torch_output , test_output )
126+ torch .testing .assert_close (
127+ torch_output [:num_valid_tokens ], test_output [:num_valid_tokens ]
128+ )
105129
106130 @given (
107131 E = st .sampled_from ([2 , 4 , 8 ]),
108132 T = st .sampled_from ([1 , 128 , 2048 , 4096 , 16384 ]),
109133 D = st .sampled_from ([5120 , 7168 ]),
134+ partial = st .sampled_from ([True , False ]),
110135 rowmajor = st .sampled_from ([True , False ]),
111136 compiled = st .sampled_from ([True , False ]),
112137 )
@@ -116,6 +141,7 @@ def test_gather_scale_quant_dense_tokens(
116141 E : int ,
117142 T : int ,
118143 D : int ,
144+ partial : bool ,
119145 rowmajor : bool ,
120146 compiled : bool ,
121147 ) -> None :
@@ -126,9 +152,24 @@ def test_gather_scale_quant_dense_tokens(
126152 expert_indices : torch .Tensor = torch .randint (0 , E , (T ,), device = "cuda" )
127153 token_indices : torch .Tensor = torch .randperm (T , device = "cuda" ).to (torch .int32 )
128154 scores : torch .Tensor = torch .randn ((E , T ), dtype = torch .bfloat16 , device = "cuda" )
129-
130155 scale_ub = torch .tensor ([1200 ], dtype = torch .float , device = "cuda" )
131156
157+ num_valid_tokens : int = T
158+ valid_token_count : Optional [torch .Tensor ] = None
159+ partial_expert_indices : torch .Tensor = expert_indices
160+ partial_token_indices : torch .Tensor = token_indices
161+ if partial :
162+ num_valid_tokens = T // 2
163+ valid_token_count = torch .tensor (
164+ [num_valid_tokens ], dtype = torch .int32 , device = "cuda"
165+ )
166+ partial_expert_indices = torch .where (
167+ torch .arange (T ).cuda () < num_valid_tokens , expert_indices , - 1
168+ )
169+ partial_token_indices = torch .where (
170+ torch .arange (T ).cuda () < num_valid_tokens , token_indices , - 1
171+ )
172+
132173 def torch_fn () -> Tuple [torch .Tensor , torch .Tensor ]:
133174 shuffled_x = torch .index_select (x , dim = 0 , index = token_indices )
134175 shuffled_scores = torch .index_select (scores , dim = 1 , index = token_indices )
@@ -156,25 +197,37 @@ def triton_fn() -> Tuple[torch.Tensor, torch.Tensor]:
156197 if compiled :
157198 op = torch .compile (op )
158199 test_output_q , test_output_scales = op (
159- x , token_indices , expert_indices , scores_ , scale_ub
200+ x ,
201+ partial_token_indices ,
202+ partial_expert_indices ,
203+ scores_ ,
204+ scale_ub ,
205+ valid_token_count ,
160206 )
161207 return test_output_q , test_output_scales
162208
163209 test_output_q , test_output_scales = triton_fn ()
164210 test_output = test_output_q .to (torch .float32 ) * test_output_scales .view (- 1 , 1 )
165211
166- torch .testing .assert_close (torch_output , test_output , atol = 1e-3 , rtol = 1.6e-2 )
212+ torch .testing .assert_close (
213+ torch_output [:num_valid_tokens ],
214+ test_output [:num_valid_tokens ],
215+ atol = 1e-3 ,
216+ rtol = 1.6e-2 ,
217+ )
167218
168219 @given (
169220 num_tokens = st .sampled_from ([1 , 128 , 2048 , 4096 , 16384 ]),
170221 dim = st .sampled_from ([5120 ]),
222+ partial = st .sampled_from ([True , False ]),
171223 compiled = st .sampled_from ([True , False ]),
172224 )
173225 @settings (verbosity = Verbosity .verbose , max_examples = _MAX_SAMPLES , deadline = None )
174226 def test_scatter_add_dense_tokens (
175227 self ,
176228 num_tokens : int ,
177229 dim : int ,
230+ partial : bool ,
178231 compiled : bool ,
179232 ) -> None :
180233 torch .manual_seed (0 )
@@ -190,6 +243,18 @@ def test_scatter_add_dense_tokens(
190243 torch .int32
191244 )
192245
246+ num_valid_tokens : int = num_tokens
247+ valid_token_count : Optional [torch .Tensor ] = None
248+ partial_token_indices : torch .Tensor = token_indices
249+ if partial :
250+ num_valid_tokens = num_tokens // 2
251+ valid_token_count = torch .tensor (
252+ [num_valid_tokens ], dtype = torch .int32 , device = "cuda"
253+ )
254+ partial_token_indices = torch .where (
255+ torch .arange (num_tokens ).cuda () < num_valid_tokens , token_indices , - 1
256+ )
257+
193258 test_out_tokens : torch .Tensor = out_tokens .clone ()
194259 ref_out_tokens : torch .Tensor = out_tokens .clone ()
195260
@@ -201,11 +266,12 @@ def fn() -> None:
201266 test_out_tokens ,
202267 in_tokens ,
203268 token_indices ,
269+ valid_token_count ,
204270 )
205271
206272 fn ()
207273
208- token_indices : torch .Tensor = token_indices .to (torch .int64 )
274+ token_indices : torch .Tensor = token_indices [: num_valid_tokens ] .to (torch .int64 )
209275
210276 def ref_fn () -> None :
211277 ref_out_tokens .scatter_add_ (
@@ -217,7 +283,10 @@ def ref_fn() -> None:
217283 ref_fn ()
218284
219285 torch .testing .assert_close (
220- test_out_tokens , ref_out_tokens , atol = 1e-3 , rtol = 1.6e-2
286+ test_out_tokens [:num_valid_tokens ],
287+ ref_out_tokens [:num_valid_tokens ],
288+ atol = 1e-3 ,
289+ rtol = 1.6e-2 ,
221290 )
222291
223292 @given (
0 commit comments