Skip to content

Commit c19bc88

Browse files
all test cases working
1 parent c9d30b6 commit c19bc88

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

torchao/prototype/grouped_mm/__init__.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,9 @@ def backward(ctx, grad_output: torch.Tensor):
230230
# - A_scale shape: (1,K) or (B, 1, K)
231231
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
232232
# - A scale shape: (K,) or (B, K)
233+
A_col_major = A.transpose(-2, -1).contiguous().transpose(-2, -1)
233234
A_fp8_col_major = hp_tensor_to_float8_dynamic(
234-
A.transpose(-2, -1)
235-
.contiguous()
236-
.transpose(-2, -1), # Convert to column-major
235+
A_col_major,
237236
float8_config.cast_config_input.target_dtype,
238237
linear_mm_config=LinearMMConfig(),
239238
gemm_input_role=GemmInputRole.INPUT,
@@ -243,11 +242,17 @@ def backward(ctx, grad_output: torch.Tensor):
243242
)
244243
A_scale = A_fp8_col_major._scale.squeeze()
245244

245+
# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
246+
# which is the size of the `offs` tensor.
247+
if grad_output_t_fp8_row_major.ndim == 2 and A_fp8_col_major.ndim == 2:
248+
grad_output_t_scale = grad_output_t_scale.repeat(offs.numel())
249+
A_scale = A_scale.repeat(offs.numel())
250+
246251
# Compute grad_B = grad_output_t @ A.
247252
#
248-
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(B,M,N)
253+
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(M,N) <-- special case, B reduced?
249254
# grad_B = grad_output_t @ A
250-
# grad_B = (B,N,M) @ (B,M,K) = (B,N,K)
255+
# grad_B = (N,M) @ (M,K) = (N,K) <-- do we need to repeat along dim0 so it's (B,N,K)?
251256
#
252257
# Case 2: A=3D, B=2D with A=(B,M,K), B^T=(K,N) case, output=(B,M,N)
253258
# grad_B = grad_output_t @ A

torchao/prototype/grouped_mm/test_grouped_mm.py

+17-15
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
2323
device = "cuda"
2424
s_int = int(strided)
2525
m, n, k, n_groups = 16, 32, 16, 4
26-
a = torch.randn(m * n_groups, k * (1 + s_int), device=device, requires_grad=True)[:, :k]
27-
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, requires_grad=True)[
28-
:: (1 + s_int), :, :k
26+
a = torch.randn(m * n_groups, k * (1 + s_int), device=device, requires_grad=True)[
27+
:, :k
2928
]
29+
b = torch.randn(
30+
n_groups * (1 + s_int), n, k * (1 + s_int), device=device, requires_grad=True
31+
)[:: (1 + s_int), :, :k]
3032
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
3133
result = _grouped_scaled_mm(
3234
a,
@@ -62,12 +64,12 @@ def test_grouped_gemm_3d_3d(use_fast_accum, strided):
6264
device = "cuda"
6365
s_int = int(strided)
6466
m, n, k, n_groups = 16, 32, 16, 4
65-
a = torch.randn(n_groups * (1 + s_int), m, k * (1 + s_int), device=device, requires_grad=True)[
66-
:: (1 + s_int), :, :k
67-
]
68-
b = torch.randn(n_groups * (1 + s_int), n, k * (1 + s_int), device=device, requires_grad=True)[
69-
:: (1 + s_int), :, :k
70-
]
67+
a = torch.randn(
68+
n_groups * (1 + s_int), m, k * (1 + s_int), device=device, requires_grad=True
69+
)[:: (1 + s_int), :, :k]
70+
b = torch.randn(
71+
n_groups * (1 + s_int), n, k * (1 + s_int), device=device, requires_grad=True
72+
)[:: (1 + s_int), :, :k]
7173
result = _grouped_scaled_mm(
7274
a,
7375
b.transpose(-2, -1),
@@ -99,12 +101,12 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
99101
out_dtype = torch.bfloat16
100102
device = "cuda"
101103
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
102-
a = torch.randn(m, k * n_groups + k * int(strided), device=device, requires_grad=True)[
103-
:, : k * n_groups
104-
]
105-
b = torch.randn(n, k * n_groups + k * int(strided), device=device, requires_grad=True)[
106-
:, : k * n_groups
107-
]
104+
a = torch.randn(
105+
m, k * n_groups + k * int(strided), device=device, requires_grad=True
106+
)[:, : k * n_groups]
107+
b = torch.randn(
108+
n, k * n_groups + k * int(strided), device=device, requires_grad=True
109+
)[:, : k * n_groups]
108110
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
109111

110112
# Compute result.

0 commit comments

Comments
 (0)