@@ -58,24 +58,28 @@ def forward(
58
58
float8_recipe_name == Float8LinearRecipeName .ROWWISE
59
59
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
60
60
61
-
62
61
assert 2 <= A .ndim <= 3 , "A must be 2D or 3D"
63
62
assert 2 <= B .ndim <= 3 , "B must be 2D or 3D"
64
63
65
64
# Dim 1 of B must match the final dim of A.
66
- assert A .size (- 1 ) == B .size (- 2 ), f"shape { A .shape } and { B .shape } are not compatible for _scaled_grouped_mm"
65
+ assert A .size (- 1 ) == B .size (
66
+ - 2
67
+ ), f"shape { A .shape } and { B .shape } are not compatible for _scaled_grouped_mm"
67
68
68
69
# offsets are required for 2D A tensor, otherwise it should be None.
69
70
if A .ndim == 2 or B .ndim == 2 :
70
71
assert offs is not None , "offs must be specified for 2D tensor"
71
- else :
72
- assert offs is None , "offs must not be specified for 3D tensor"
73
72
74
73
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
75
74
76
75
# Fetch float8 config from specified recipe name.
77
76
float8_config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
78
77
78
+ # Store what we need for backward.
79
+ ctx .save_for_backward (A , B )
80
+ ctx .float8_config = float8_config
81
+ ctx .offs = offs
82
+
79
83
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
80
84
# A shape: (M, K) or (B, M, K)
81
85
# A_scale shape: (M,1) or (B, M, 1)
@@ -87,14 +91,12 @@ def forward(
87
91
linear_mm_config = LinearMMConfig (),
88
92
gemm_input_role = GemmInputRole .INPUT ,
89
93
scaling_granularity = float8_config .cast_config_input .scaling_granularity ,
90
- axiswise_dim = get_maybe_axiswise_dim (
91
- - 1 , float8_config .cast_config_input .scaling_granularity
92
- ),
94
+ axiswise_dim = - 1 ,
93
95
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
94
96
)
95
97
A_scale = A_fp8_row_major ._scale .squeeze ()
96
98
97
- # Convert B to float8, column-major for right operand of grouped GEMM.
99
+ # Convert B to float8, column-major for right operand of grouped GEMM.
98
100
# B shape: (K,N) or (B, K, N)
99
101
# B scales must be computed rowwise keeping the outer/final dim, so:
100
102
# B_scale shape: (1,N) or (B, 1, N)
@@ -106,17 +108,16 @@ def forward(
106
108
linear_mm_config = LinearMMConfig (),
107
109
gemm_input_role = GemmInputRole .WEIGHT ,
108
110
scaling_granularity = float8_config .cast_config_weight .scaling_granularity ,
109
- axiswise_dim = get_maybe_axiswise_dim (
110
- - 2 , float8_config .cast_config_input .scaling_granularity
111
- ),
111
+ axiswise_dim = - 2 ,
112
112
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
113
113
)
114
114
B_scale = B_fp8_col_major ._scale .squeeze ()
115
115
116
- # Store what we need for backward.
117
- ctx .save_for_backward (A , B )
118
- ctx .float8_config = float8_config
119
- ctx .offs = offs
116
+ # Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
117
+ # which is the size of the `offs` tensor.
118
+ if A .ndim == 2 and B .ndim == 2 :
119
+ A_scale = A_scale .repeat (offs .numel ())
120
+ B_scale = B_scale .repeat (offs .numel ())
120
121
121
122
# Perform scaled grouped GEMM and return result.
122
123
# output shape: (M, N) or (B, M, N)
@@ -149,14 +150,11 @@ def backward(ctx, grad_output: torch.Tensor):
149
150
linear_mm_config = LinearMMConfig (),
150
151
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
151
152
scaling_granularity = float8_config .cast_config_grad_output .scaling_granularity ,
152
- axiswise_dim = get_maybe_axiswise_dim (
153
- - 1 , float8_config .cast_config_grad_output .scaling_granularity
154
- ),
153
+ axiswise_dim = - 1 ,
155
154
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
156
155
)
157
156
grad_output_scale = grad_output_fp8_row_major ._scale .squeeze ()
158
157
159
-
160
158
# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
161
159
# needed for grad_A: grad_output @ B.
162
160
# Since B was transposed before entry to forward, we need to transpose it back here for this.
@@ -167,20 +165,18 @@ def backward(ctx, grad_output: torch.Tensor):
167
165
# - B_scale shape: (1,N) or (B, 1, N)
168
166
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
169
167
# - B scale shape: (N,) or (B, N)
170
- B_fp8_col_major = hp_tensor_to_float8_dynamic (
168
+ B_fp8_col_major = hp_tensor_to_float8_dynamic (
171
169
B_non_transposed ,
172
170
float8_config .cast_config_input .target_dtype ,
173
171
linear_mm_config = LinearMMConfig (),
174
172
gemm_input_role = GemmInputRole .WEIGHT ,
175
173
scaling_granularity = float8_config .cast_config_weight .scaling_granularity ,
176
- axiswise_dim = get_maybe_axiswise_dim (
177
- - 2 , float8_config .cast_config_input .scaling_granularity
178
- ),
174
+ axiswise_dim = - 2 ,
179
175
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
180
176
)
181
177
B_scale = B_fp8_col_major ._scale .squeeze ()
182
178
183
- # Compute grad_A.
179
+ # Compute grad_A.
184
180
#
185
181
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N), output=(B,M,N)
186
182
# grad_A = grad_output @ B
@@ -221,9 +217,7 @@ def backward(ctx, grad_output: torch.Tensor):
221
217
linear_mm_config = LinearMMConfig (),
222
218
gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
223
219
scaling_granularity = float8_config .cast_config_grad_output .scaling_granularity ,
224
- axiswise_dim = get_maybe_axiswise_dim (
225
- - 1 , float8_config .cast_config_grad_output .scaling_granularity
226
- ),
220
+ axiswise_dim = - 1 ,
227
221
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
228
222
)
229
223
grad_output_t_scale = grad_output_t_fp8_row_major ._scale .squeeze ()
@@ -237,14 +231,14 @@ def backward(ctx, grad_output: torch.Tensor):
237
231
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
238
232
# - A scale shape: (K,) or (B, K)
239
233
A_fp8_col_major = hp_tensor_to_float8_dynamic (
240
- A .transpose (- 2 , - 1 ).contiguous ().tranpose (- 2 ,- 1 ), # Convert to column-major
234
+ A .transpose (- 2 , - 1 )
235
+ .contiguous ()
236
+ .tranpose (- 2 , - 1 ), # Convert to column-major
241
237
float8_config .cast_config_input .target_dtype ,
242
238
linear_mm_config = LinearMMConfig (),
243
239
gemm_input_role = GemmInputRole .INPUT ,
244
240
scaling_granularity = float8_config .cast_config_input .scaling_granularity ,
245
- axiswise_dim = get_maybe_axiswise_dim (
246
- - 2 , float8_config .cast_config_input .scaling_granularity
247
- ),
241
+ axiswise_dim = - 2 ,
248
242
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
249
243
)
250
244
A_scale = A_fp8_col_major ._scale .squeeze ()
0 commit comments