@@ -243,11 +243,17 @@ def backward(ctx, grad_output: torch.Tensor):
243
243
)
244
244
A_scale = A_fp8_col_major ._scale .squeeze ()
245
245
246
+ # Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
247
+ # which is the size of the `offs` tensor.
248
+ if grad_output_t_fp8_row_major .ndim == 2 and A_fp8_col_major .ndim == 2 :
249
+ grad_output_t_scale = grad_output_t_scale .repeat (offs .numel ())
250
+ A_scale = A_scale .repeat (offs .numel ())
251
+
246
252
# Compute grad_B = grad_output_t @ A.
247
253
#
248
- # Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(B, M,N)
254
+ # Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N) case, output=(M,N) <-- special case, B reduced?
249
255
# grad_B = grad_output_t @ A
250
- # grad_B = (B, N,M) @ (B, M,K) = (B,N,K)
256
+ # grad_B = (N,M) @ (M,K) = (N,K) <-- do we need to repeat along dim0 so it's ( B,N,K)?
251
257
#
252
258
# Case 2: A=3D, B=2D with A=(B,M,K), B^T=(K,N) case, output=(B,M,N)
253
259
# grad_B = grad_output_t @ A
0 commit comments