Skip to content

Commit 6a8e573

Browse files
Dhruva Kaushalfacebook-github-bot
Dhruva Kaushal
authored andcommitted
Enabling cutlass
Summary: Enabling the cutlass kernel for fp8_grouped_gemm. The large sizes are running into a CUDA issue so this diff temporarily removes them while I figure out a fix with the kernel writers. Reviewed By: htyu Differential Revision: D71371989 fbshipit-source-id: 534efcd2526aa6c901eb329518d519f22c51b493
1 parent 62d4b4f commit 6a8e573

File tree

1 file changed

+56
-27
lines changed

1 file changed

+56
-27
lines changed

tritonbench/operators/fp8_gemm_grouped/operator.py

+56-27
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,15 @@ def parse_args(args: List[str]) -> argparse.Namespace:
198198
# (256, 256, 256),
199199
# (512, 512, 512),
200200
# (2048, 2048, 2048),
201+
(1024, 1024, 1024),
202+
(2048, 1024, 1024),
203+
(2048, 2048, 2048),
201204
(4096, 4096, 4096),
202205
(8192, 4096, 4096),
203-
(16384, 4096, 4096),
204-
(8192, 8192, 8192),
205-
(16384, 8192, 8192),
206-
(16384, 16384, 16384),
206+
# (16384, 4096, 4096),
207+
# (8192, 8192, 8192),
208+
# (16384, 8192, 8192),
209+
# (16384, 16384, 16384),
207210
# (1, 2304, 2048),
208211
# (1, 8192, 16384),
209212
# (4, 4096, 2304),
@@ -285,16 +288,37 @@ def cumulative_sum_with_initial_offset(tensor):
285288
return cumsum
286289

287290

288-
# TODO: remove this.
289-
def reshape_tensor(W, m_sizes):
290-
N = W.shape[0] // torch.sum(m_sizes)
291-
return torch.cat(
292-
[
293-
x.reshape(-1, N, W.shape[-1])
294-
for x in torch.split(W, [size * N for size in m_sizes], dim=0)
295-
],
296-
dim=0,
297-
)
291+
def reshape_tensor(input_tensor, m_sizes):
292+
"""
293+
Reshape the input tensor into a specified grouped format.
294+
This function takes an input tensor and reshapes it into a 3D tensor
295+
with dimensions (G, N, K), where:
296+
- G is the number of groups, determined by the length of m_sizes.
297+
- N is the size of each group, calculated as the integer division of
298+
the first dimension of the input tensor by G.
299+
- K is the size of the second dimension of the input tensor.
300+
Args:
301+
input_tensor (torch.Tensor): The input tensor to be reshaped. It is
302+
expected to have at least two dimensions.
303+
m_sizes (list): A list whose length determines the number of groups (G).
304+
Returns:
305+
torch.Tensor: The reshaped tensor with dimensions (G, N, K).
306+
Raises:
307+
ValueError: If the size of the first dimension of input_tensor is not
308+
divisible by the number of groups (G).
309+
"""
310+
# Calculate the number of groups (G) based on the length of m_sizes
311+
G = len(m_sizes)
312+
313+
# Calculate the size of each group (N) by dividing the first dimension of
314+
# the input tensor by the number of groups (G)
315+
N = input_tensor.size(0) // G
316+
317+
# Get the size of the second dimension (K) of the input tensor
318+
K = input_tensor.size(1)
319+
# Reshape the input tensor to have dimensions (G, N, K)
320+
reshaped_tensor = input_tensor.view(G, N, K)
321+
return reshaped_tensor
298322

299323

300324
class Operator(BenchmarkOperator):
@@ -341,6 +365,11 @@ def __init__(
341365
# Enable CUDA graphs for this operator
342366
self.use_cuda_graphs = True
343367

368+
# Enable fp8_fast_accum by default. The cutlass kernel does not support configuring
369+
# this parameter as of now. By default it is true, but there will be correctness issues
370+
# vs the cutlass kernel, if fp8_fast_accum is turned off.
371+
self.fp8_fast_accum = True
372+
344373
# Parse the additional command-line arguments
345374
addmm_args = parse_args(self.extra_args)
346375

@@ -387,10 +416,17 @@ def _triton(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
387416

388417
# Return a lambda function that calls the grouped_gemm_fp8_rowwise function
389418
return lambda: grouped_gemm_fp8_rowwise(
390-
group_A, group_B, m_sizes, a_scale, b_scale
419+
group_A,
420+
group_B,
421+
m_sizes,
422+
a_scale,
423+
b_scale,
424+
use_fast_accum=self.fp8_fast_accum,
391425
)
392426

393-
@register_benchmark(enabled=False, label="ck" if torch.version.hip else "cutlass")
427+
@register_benchmark(
428+
enabled=HAS_CUTLASS_OR_CK, label="ck" if torch.version.hip else "cutlass"
429+
)
394430
def _cutlass_or_ck(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callable:
395431
"""
396432
Returns a lambda function that performs the Cutlass or CK FP8 GEMM grouped operation.
@@ -405,17 +441,16 @@ def _cutlass_or_ck(self, group_A, group_B, m_sizes, a_scale, b_scale) -> Callabl
405441
Returns:
406442
Callable: A lambda function that performs the Cutlass or CK FP8 GEMM grouped operation.
407443
"""
408-
409-
# Calculate the cumulative sum of the group sizes with an initial offset
410-
cum_sum = cumulative_sum_with_initial_offset(m_sizes).to(torch.int64)
444+
# Reshape group_B to match the format expected by the cutlass implementation (G, N, K)
445+
reshaped_group_B = reshape_tensor(group_B, m_sizes)
411446

412447
# Return a lambda function that calls the cutlass_or_ck_fp8_grouped_mm function
413448
return lambda: cutlass_or_ck_fp8_grouped_mm(
414449
group_A,
415-
group_B,
450+
reshaped_group_B,
416451
a_scale,
417452
b_scale,
418-
cum_sum,
453+
m_sizes.to(torch.int64),
419454
)
420455

421456
@register_x_val(label="(group_size, M, N, K)")
@@ -514,12 +549,6 @@ def get_input_iter(self) -> Generator:
514549
515550
Yields:
516551
tuple: A tuple containing the input tensors and their corresponding scales.
517-
518-
Notes:
519-
The current cutlass imp0lementation of f8f8bf16 grouped gemm has a different
520-
input format than the triton implementation.
521-
D69544396 will update the function signature to match the 2 implementations.
522-
Disabling the cutlass implementation until it lands.
523552
"""
524553

525554
# Iterate over all possible group sizes and shapes

0 commit comments

Comments
 (0)