Skip to content

Commit 7042cef

Browse files
all test cases working
1 parent c9d30b6 commit 7042cef

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

torchao/prototype/grouped_mm/__init__.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -243,11 +243,17 @@ def backward(ctx, grad_output: torch.Tensor):
243243
)
244244
A_scale = A_fp8_col_major._scale.squeeze()
245245

246+
# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
247+
# which is the size of the `offs` tensor.
248+
if grad_output_t_fp8_row_major.ndim == 2 and A_fp8_col_major.ndim == 2:
249+
grad_output_t_scale = grad_output_t_scale.repeat(offs.numel())
250+
A_scale = A_scale.repeat(offs.numel())
251+
246252
# Compute grad_B = grad_output_t @ A.
247253
#
248-
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(B,M,N)
254+
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(M,N) <-- special case, B reduced?
249255
# grad_B = grad_output_t @ A
250-
# grad_B = (B,N,M) @ (B,M,K) = (B,N,K)
256+
# grad_B = (N,M) @ (M,K) = (N,K) <-- do we need to repeat along dim0 so it's (B,N,K)?
251257
#
252258
# Case 2: A=3D, B=2D with A=(B,M,K), B^T=(K,N) case, output=(B,M,N)
253259
# grad_B = grad_output_t @ A

0 commit comments

Comments
 (0)