@@ -58,24 +58,28 @@ def forward(
58
58
float8_recipe_name == Float8LinearRecipeName .ROWWISE
59
59
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
60
60
61
-
62
61
assert 2 <= A .ndim <= 3 , "A must be 2D or 3D"
63
62
assert 2 <= B .ndim <= 3 , "B must be 2D or 3D"
64
63
65
64
# Dim 1 of B must match the final dim of A.
66
- assert A .size (- 1 ) == B .size (- 2 ), f"shape { A .shape } and { B .shape } are not compatible for _scaled_grouped_mm"
65
+ assert A .size (- 1 ) == B .size (
66
+ - 2
67
+ ), f"shape { A .shape } and { B .shape } are not compatible for _scaled_grouped_mm"
67
68
68
69
# offsets are required for 2D A tensor, otherwise it should be None.
69
70
if A .ndim == 2 or B .ndim == 2 :
70
71
assert offs is not None , "offs must be specified for 2D tensor"
71
- else :
72
- assert offs is None , "offs must not be specified for 3D tensor"
73
72
74
73
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
75
74
76
75
# Fetch float8 config from specified recipe name.
77
76
float8_config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
78
77
78
+ # Store what we need for backward.
79
+ ctx .save_for_backward (A , B )
80
+ ctx .float8_config = float8_config
81
+ ctx .offs = offs
82
+
79
83
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
80
84
# A shape: (M, K) or (B, M, K)
81
85
# A_scale shape: (M,1) or (B, M, 1)
@@ -94,7 +98,7 @@ def forward(
94
98
)
95
99
A_scale = A_fp8_row_major ._scale .squeeze ()
96
100
97
- # Convert B to float8, column-major for right operand of grouped GEMM.
101
+ # Convert B to float8, column-major for right operand of grouped GEMM.
98
102
# B shape: (K,N) or (B, K, N)
99
103
# B scales must be computed rowwise keeping the outer/final dim, so:
100
104
# B_scale shape: (1,N) or (B, 1, N)
@@ -113,10 +117,11 @@ def forward(
113
117
)
114
118
B_scale = B_fp8_col_major ._scale .squeeze ()
115
119
116
- # Store what we need for backward.
117
- ctx .save_for_backward (A , B )
118
- ctx .float8_config = float8_config
119
- ctx .offs = offs
120
+ # Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
121
+ # which is the size of the `offs` tensor.
122
+ if A .ndim == 2 and B .ndim == 2 :
123
+ A_scale = A_scale .repeat (offs .numel ())
124
+ B_scale = B_scale .repeat (offs .numel ())
120
125
121
126
# Perform scaled grouped GEMM and return result.
122
127
# output shape: (M, N) or (B, M, N)
@@ -156,7 +161,6 @@ def backward(ctx, grad_output: torch.Tensor):
156
161
)
157
162
grad_output_scale = grad_output_fp8_row_major ._scale .squeeze ()
158
163
159
-
160
164
# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
161
165
# needed for grad_A: grad_output @ B.
162
166
# Since B was transposed before entry to forward, we need to transpose it back here for this.
@@ -167,7 +171,7 @@ def backward(ctx, grad_output: torch.Tensor):
167
171
# - B_scale shape: (1,N) or (B, 1, N)
168
172
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
169
173
# - B scale shape: (N,) or (B, N)
170
- B_fp8_col_major = hp_tensor_to_float8_dynamic (
174
+ B_fp8_col_major = hp_tensor_to_float8_dynamic (
171
175
B_non_transposed ,
172
176
float8_config .cast_config_input .target_dtype ,
173
177
linear_mm_config = LinearMMConfig (),
@@ -180,7 +184,7 @@ def backward(ctx, grad_output: torch.Tensor):
180
184
)
181
185
B_scale = B_fp8_col_major ._scale .squeeze ()
182
186
183
- # Compute grad_A.
187
+ # Compute grad_A.
184
188
#
185
189
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N), output=(B,M,N)
186
190
# grad_A = grad_output @ B
@@ -237,7 +241,9 @@ def backward(ctx, grad_output: torch.Tensor):
237
241
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
238
242
# - A scale shape: (K,) or (B, K)
239
243
A_fp8_col_major = hp_tensor_to_float8_dynamic (
240
- A .transpose (- 2 , - 1 ).contiguous ().tranpose (- 2 ,- 1 ), # Convert to column-major
244
+ A .transpose (- 2 , - 1 )
245
+ .contiguous ()
246
+ .tranpose (- 2 , - 1 ), # Convert to column-major
241
247
float8_config .cast_config_input .target_dtype ,
242
248
linear_mm_config = LinearMMConfig (),
243
249
gemm_input_role = GemmInputRole .INPUT ,
0 commit comments