@@ -198,12 +198,15 @@ def parse_args(args: List[str]) -> argparse.Namespace:
198
198
# (256, 256, 256),
199
199
# (512, 512, 512),
200
200
# (2048, 2048, 2048),
201
+ (1024 , 1024 , 1024 ),
202
+ (2048 , 1024 , 1024 ),
203
+ (2048 , 2048 , 2048 ),
201
204
(4096 , 4096 , 4096 ),
202
205
(8192 , 4096 , 4096 ),
203
- (16384 , 4096 , 4096 ),
204
- (8192 , 8192 , 8192 ),
205
- (16384 , 8192 , 8192 ),
206
- (16384 , 16384 , 16384 ),
206
+ # (16384, 4096, 4096),
207
+ # (8192, 8192, 8192),
208
+ # (16384, 8192, 8192),
209
+ # (16384, 16384, 16384),
207
210
# (1, 2304, 2048),
208
211
# (1, 8192, 16384),
209
212
# (4, 4096, 2304),
@@ -285,16 +288,37 @@ def cumulative_sum_with_initial_offset(tensor):
285
288
return cumsum
286
289
287
290
288
- # TODO: remove this.
289
- def reshape_tensor (W , m_sizes ):
290
- N = W .shape [0 ] // torch .sum (m_sizes )
291
- return torch .cat (
292
- [
293
- x .reshape (- 1 , N , W .shape [- 1 ])
294
- for x in torch .split (W , [size * N for size in m_sizes ], dim = 0 )
295
- ],
296
- dim = 0 ,
297
- )
291
+ def reshape_tensor (input_tensor , m_sizes ):
292
+ """
293
+ Reshape the input tensor into a specified grouped format.
294
+ This function takes an input tensor and reshapes it into a 3D tensor
295
+ with dimensions (G, N, K), where:
296
+ - G is the number of groups, determined by the length of m_sizes.
297
+ - N is the size of each group, calculated as the integer division of
298
+ the first dimension of the input tensor by G.
299
+ - K is the size of the second dimension of the input tensor.
300
+ Args:
301
+ input_tensor (torch.Tensor): The input tensor to be reshaped. It is
302
+ expected to have at least two dimensions.
303
+ m_sizes (list): A list whose length determines the number of groups (G).
304
+ Returns:
305
+ torch.Tensor: The reshaped tensor with dimensions (G, N, K).
306
+ Raises:
307
+ ValueError: If the size of the first dimension of input_tensor is not
308
+ divisible by the number of groups (G).
309
+ """
310
+ # Calculate the number of groups (G) based on the length of m_sizes
311
+ G = len (m_sizes )
312
+
313
+ # Calculate the size of each group (N) by dividing the first dimension of
314
+ # the input tensor by the number of groups (G)
315
+ N = input_tensor .size (0 ) // G
316
+
317
+ # Get the size of the second dimension (K) of the input tensor
318
+ K = input_tensor .size (1 )
319
+ # Reshape the input tensor to have dimensions (G, N, K)
320
+ reshaped_tensor = input_tensor .view (G , N , K )
321
+ return reshaped_tensor
298
322
299
323
300
324
class Operator (BenchmarkOperator ):
@@ -341,6 +365,11 @@ def __init__(
341
365
# Enable CUDA graphs for this operator
342
366
self .use_cuda_graphs = True
343
367
368
+ # Enable fp8_fast_accum by default. The cutlass kernel does not support configuring
369
+ # this parameter as of now. By default it is true, but there will be correctness issues
370
+ # vs the cutlass kernel, if fp8_fast_accum is turned off.
371
+ self .fp8_fast_accum = True
372
+
344
373
# Parse the additional command-line arguments
345
374
addmm_args = parse_args (self .extra_args )
346
375
@@ -387,10 +416,17 @@ def _triton(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
387
416
388
417
# Return a lambda function that calls the grouped_gemm_fp8_rowwise function
389
418
return lambda : grouped_gemm_fp8_rowwise (
390
- group_A , group_B , m_sizes , a_scale , b_scale
419
+ group_A ,
420
+ group_B ,
421
+ m_sizes ,
422
+ a_scale ,
423
+ b_scale ,
424
+ use_fast_accum = self .fp8_fast_accum ,
391
425
)
392
426
393
- @register_benchmark (enabled = False , label = "ck" if torch .version .hip else "cutlass" )
427
+ @register_benchmark (
428
+ enabled = HAS_CUTLASS_OR_CK , label = "ck" if torch .version .hip else "cutlass"
429
+ )
394
430
def _cutlass_or_ck (self , group_A , group_B , m_sizes , a_scale , b_scale ) -> Callable :
395
431
"""
396
432
Returns a lambda function that performs the Cutlass or CK FP8 GEMM grouped operation.
@@ -405,17 +441,16 @@ def _cutlass_or_ck(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callabl
405
441
Returns:
406
442
Callable: A lambda function that performs the Cutlass or CK FP8 GEMM grouped operation.
407
443
"""
408
-
409
- # Calculate the cumulative sum of the group sizes with an initial offset
410
- cum_sum = cumulative_sum_with_initial_offset (m_sizes ).to (torch .int64 )
444
+ # Reshape group_B to match the format expected by the cutlass implementation (G, N, K)
445
+ reshaped_group_B = reshape_tensor (group_B , m_sizes )
411
446
412
447
# Return a lambda function that calls the cutlass_or_ck_fp8_grouped_mm function
413
448
return lambda : cutlass_or_ck_fp8_grouped_mm (
414
449
group_A ,
415
- group_B ,
450
+ reshaped_group_B ,
416
451
a_scale ,
417
452
b_scale ,
418
- cum_sum ,
453
+ m_sizes . to ( torch . int64 ) ,
419
454
)
420
455
421
456
@register_x_val (label = "(group_size, M, N, K)" )
@@ -514,12 +549,6 @@ def get_input_iter(self) -> Generator:
514
549
515
550
Yields:
516
551
tuple: A tuple containing the input tensors and their corresponding scales.
517
-
518
- Notes:
519
- The current cutlass imp0lementation of f8f8bf16 grouped gemm has a different
520
- input format than the triton implementation.
521
- D69544396 will update the function signature to match the 2 implementations.
522
- Disabling the cutlass implementation until it lands.
523
552
"""
524
553
525
554
# Iterate over all possible group sizes and shapes
0 commit comments