@@ -76,12 +76,12 @@ def forward(
76
76
# Fetch float8 config from specified recipe name.
77
77
float8_config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
78
78
79
- # Convert high precision input tensor to float8.
79
+ # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM .
80
80
# A shape: (M, K) or (B, M, K)
81
81
# A_scale shape: (M,1) or (B, M, 1)
82
82
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
83
83
# A_scale shape: (M,) or (B, M)
84
- A_fp8 = hp_tensor_to_float8_dynamic (
84
+ A_fp8_row_major = hp_tensor_to_float8_dynamic (
85
85
A ,
86
86
float8_config .cast_config_input .target_dtype ,
87
87
linear_mm_config = LinearMMConfig (),
@@ -92,15 +92,15 @@ def forward(
92
92
),
93
93
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
94
94
)
95
- A_fp8_scale = A_fp8 ._scale .squeeze ()
95
+ A_scale = A_fp8_row_major ._scale .squeeze ()
96
96
97
- # Convert high precision weight tensor to float8 .
97
+ # Convert B to float8, column-major for right operand of grouped GEMM .
98
98
# B shape: (K,N) or (B, K, N)
99
99
# B scales must be computed rowwise keeping the outer/final dim, so:
100
100
# B_scale shape: (1,N) or (B, 1, N)
101
101
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
102
102
# B scale shape: (N,) or (B, N)
103
- B_fp8 = hp_tensor_to_float8_dynamic (
103
+ B_fp8_col_major = hp_tensor_to_float8_dynamic (
104
104
B ,
105
105
float8_config .cast_config_input .target_dtype ,
106
106
linear_mm_config = LinearMMConfig (),
@@ -111,23 +111,129 @@ def forward(
111
111
),
112
112
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
113
113
)
114
+ B_scale = B_fp8_col_major ._scale .squeeze ()
114
115
115
116
# Store what we need for backward.
116
117
ctx .save_for_backward (A , B )
117
- ctx .float_config = float8_config
118
+ ctx .float8_config = float8_config
118
119
ctx .offs = offs
119
120
120
121
# Perform scaled grouped GEMM and return result.
122
+ # output shape: (M, N) or (B, M, N)
121
123
return torch ._scaled_grouped_mm (
122
- A_fp8 ._data ,
123
- B_fp8 ._data ,
124
- A_fp8 . _scale ,
125
- B_fp8 . _scale ,
124
+ A_fp8_row_major ._data ,
125
+ B_fp8_col_major ._data ,
126
+ A_scale ,
127
+ B_scale ,
126
128
offs ,
127
129
out_dtype = out_dtype ,
128
130
use_fast_accum = use_fast_accum ,
129
131
)
130
132
131
133
@staticmethod
132
134
def backward (ctx , grad_output : torch .Tensor ):
133
- return None , None , None , None , None , None
135
+ A , B = ctx .saved_tensors
136
+ offs = ctx .offs
137
+ float8_config = ctx .float8_config
138
+
139
+ # Convert grad_output to float8, row-major for left operand of grouped GEMM.
140
+ # grad_output shape: (M, N) or (B, M, N)
141
+ # grad_output_scale shape: (M, 1) or (B, M, 1)
142
+ # squeeze grad_output_scale to remove empty dim, as required by torch._scaled_grouped_mm.
143
+ # grad_output_scale shape: (M,) or (B, M)
144
+ grad_output_fp8_row_major = hp_tensor_to_float8_dynamic (
145
+ grad_output ,
146
+ float8_config .cast_config_grad_output .target_dtype ,
147
+ linear_mm_config = LinearMMConfig (),
148
+ gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
149
+ scaling_granularity = float8_config .cast_config_grad_output .scaling_granularity ,
150
+ axiswise_dim = get_maybe_axiswise_dim (
151
+ - 1 , float8_config .cast_config_grad_output .scaling_granularity
152
+ ),
153
+ round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
154
+ )
155
+ grad_output_scale = grad_output_fp8_row_major ._scale .squeeze ()
156
+
157
+ # Convert B to float8, column-major for right operand of grouped GEMM.
158
+ # B shape: (K,N) or (B, K, N)
159
+ # B scales must be computed rowwise keeping the outer/final dim, so:
160
+ # B_scale shape: (1,N) or (B, 1, N)
161
+ # torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
162
+ # B scale shape: (N,) or (B, N)
163
+ B_fp8_col_major = hp_tensor_to_float8_dynamic (
164
+ B ,
165
+ float8_config .cast_config_input .target_dtype ,
166
+ linear_mm_config = LinearMMConfig (),
167
+ gemm_input_role = GemmInputRole .WEIGHT ,
168
+ scaling_granularity = float8_config .cast_config_weight .scaling_granularity ,
169
+ axiswise_dim = get_maybe_axiswise_dim (
170
+ 1 , float8_config .cast_config_input .scaling_granularity
171
+ ),
172
+ round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
173
+ )
174
+ B_scale = B_fp8_col_major ._scale .squeeze ()
175
+
176
+ # grad_input = grad_output @ weight
177
+ # grad_A = grad_output @ B
178
+ grad_A = torch ._scaled_grouped_mm (
179
+ grad_output_fp8_row_major ._data ,
180
+ B_fp8_col_major ._data ,
181
+ grad_output_scale ,
182
+ B_scale ,
183
+ offs ,
184
+ out_dtype = grad_output .dtype ,
185
+ use_fast_accum = False ,
186
+ )
187
+
188
+ # Convert tranpose of grad_output to float8, row-major for left operand of grouped GEMM.
189
+ grad_output_t = grad_output .transpose (- 2 , - 1 )
190
+
191
+ # grad_output_t shape: (N, M) or (B, N, M)
192
+ # grad_output_t_scale shape: (N, 1) or (B, N, 1)
193
+ # squeeze grad_output_t_scale to remove empty dim, as required by torch._scaled_grouped_mm.
194
+ # grad_output_t_scale shape: (N,) or (B, N)
195
+ grad_output_t_fp8 = hp_tensor_to_float8_dynamic (
196
+ grad_output_t ,
197
+ float8_config .cast_config_grad_output .target_dtype ,
198
+ linear_mm_config = LinearMMConfig (),
199
+ gemm_input_role = GemmInputRole .GRAD_OUTPUT ,
200
+ scaling_granularity = float8_config .cast_config_grad_output .scaling_granularity ,
201
+ axiswise_dim = get_maybe_axiswise_dim (
202
+ - 1 , float8_config .cast_config_grad_output .scaling_granularity
203
+ ),
204
+ round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
205
+ )
206
+ grad_output_t_scale = grad_output_t_fp8 ._scale .squeeze ()
207
+
208
+ # Convert A to float8, column-major for right operand of grouped GEMM.
209
+ # A shape: (M, K) or (B, M, K)
210
+ # A scales must be computed rowwise keeping the outer/final dim, for right operand in grouped GEMM, so:
211
+ # A_scale shape: (1,K) or (B, 1, K)
212
+ # torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
213
+ # A scale shape: (K,) or (B, K)
214
+ A_fp8 = hp_tensor_to_float8_dynamic (
215
+ A ,
216
+ float8_config .cast_config_input .target_dtype ,
217
+ linear_mm_config = LinearMMConfig (),
218
+ gemm_input_role = GemmInputRole .INPUT ,
219
+ scaling_granularity = float8_config .cast_config_input .scaling_granularity ,
220
+ axiswise_dim = get_maybe_axiswise_dim (
221
+ 1 , float8_config .cast_config_input .scaling_granularity
222
+ ),
223
+ round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
224
+ )
225
+ A_scale = A_fp8 ._scale .squeeze ()
226
+
227
+ # grad_weight = grad_output_t @ input
228
+ # grad_B = grad_output_t @ A
229
+ grad_B = torch ._scaled_grouped_mm (
230
+ grad_output_t_fp8 ._data ,
231
+ A_fp8 ._data ,
232
+ grad_output_t_scale ,
233
+ A_scale ,
234
+ offs ,
235
+ out_dtype = grad_output .dtype ,
236
+ use_fast_accum = False ,
237
+ )
238
+
239
+ return grad_A , grad_B , None , None , None , None
0 commit comments