1414 triton_fp8_per_group_colwise_scales ,
1515 triton_fp8_rowwise_3d_transpose_rhs ,
1616)
17- from torchao .prototype .moe_training .utils import _is_column_major
17+ from torchao .prototype .moe_training .utils import (
18+ _is_column_major ,
19+ pad_token_groups ,
20+ unpad_token_groups ,
21+ )
1822
1923
2024def _to_fp8_rowwise_then_scaled_grouped_mm (
@@ -23,6 +27,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
2327 offs : torch .Tensor ,
2428 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
2529 float8_dtype : torch .dtype = torch .float8_e4m3fn ,
30+ pad_token_groups_for_grouped_mm : bool = True ,
2631) -> torch .Tensor :
2732 """
2833 Differentiable FP8 grouped matrix multiplication with dynamic FP8 rowwise quantization.
@@ -39,6 +44,9 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
3944 offs: Offset tensor of shape (num_groups + 1,) with dtype int32, defining
4045 group boundaries for the grouped GEMM operation. Group sizes must be divisible by 16.
4146 out_dtype: Output dtype for the result. Defaults to torch.bfloat16.
47+ float8_dtype: Float8 dtype for quantization. Defaults to torch.float8_e4m3fn.
48+ pad_token_groups_for_grouped_mm: Whether to pad token groups to the next multiple of 16
49+ (requirement for FP8 grouped GEMM). If your tokens are already padded, set to False.
4250
4351 Returns:
4452 torch.Tensor: Result of grouped matrix multiplication with shape (M, N).
@@ -49,7 +57,9 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
4957 - Scales are computed per-row and rounded to powers of 2 for efficiency
5058 - This function is fully differentiable via custom autograd implementation
5159 """
52- return _Float8GroupedMM .apply (A , B_t , offs , out_dtype , float8_dtype )
60+ return _Float8GroupedMM .apply (
61+ A , B_t , offs , out_dtype , float8_dtype , pad_token_groups_for_grouped_mm
62+ )
5363
5464
5565class _Float8GroupedMM (torch .autograd .Function ):
@@ -63,6 +73,7 @@ def forward(
6373 offs : Optional [torch .Tensor ] = None ,
6474 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
6575 float8_dtype : torch .dtype = torch .float8_e4m3fn ,
76+ pad_token_groups_for_grouped_mm : bool = True ,
6677 ) -> torch .Tensor :
6778 # torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
6879 assert A .ndim == 2 or A .ndim == 3 , "A must be 2D or 3D"
@@ -97,17 +108,33 @@ def forward(
97108 # Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
98109 assert _is_column_major (B_t ), "B must be column-major"
99110
111+ # Save original group_end_offsets and num_tokens before padding
112+ num_tokens = A .shape [0 ]
113+ padded_group_start_offsets = None
114+ padded_group_end_offsets = None
115+
116+ # Conditionally pad token groups if not aligned to 16
117+ if pad_token_groups_for_grouped_mm :
118+ padded_A , padded_group_start_offsets , padded_group_end_offsets = (
119+ pad_token_groups (
120+ A , offs , alignment_size = 16
121+ ) # TODO: support emulated mode
122+ )
123+ else :
124+ padded_A = A
125+ padded_group_end_offsets = offs
126+
100127 # Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
101- # A shape: (M, K) or (B, M, K)
102- # A_scales shape: (M,1) or (B, M, 1)
128+ # padded_A shape: (M, K) or (padded_M, K) if padding was used
129+ # A_scales shape: (M,1) or (padded_M, 1) if padding was used
103130 A_scales = tensor_to_scale (
104- A ,
131+ padded_A ,
105132 float8_dtype ,
106133 scaling_granularity = ScalingGranularity .AXISWISE ,
107134 axiswise_dim = - 1 ,
108135 round_scales_to_power_of_2 = True ,
109136 )
110- A_scaled = A .to (torch .float32 ) * A_scales
137+ A_scaled = padded_A .to (torch .float32 ) * A_scales
111138 A_data_row_major = to_fp8_saturated (A_scaled , float8_dtype )
112139
113140 # Convert B to float8, column-major for right operand of grouped GEMM.
@@ -125,9 +152,13 @@ def forward(
125152 B_t_data_col_major = to_fp8_saturated (B_t_scaled , float8_dtype )
126153
127154 # Store what we need for backward.
128- ctx .save_for_backward (A , B_t , offs )
155+ ctx .save_for_backward (
156+ padded_A , B_t , offs , padded_group_start_offsets , padded_group_end_offsets
157+ )
129158 ctx .out_dtype = out_dtype
130159 ctx .float8_dtype = float8_dtype
160+ ctx .pad_token_groups_for_grouped_mm = pad_token_groups_for_grouped_mm
161+ ctx .num_tokens = num_tokens
131162
132163 # Perform scaled grouped GEMM and return result.
133164 # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
@@ -139,45 +170,77 @@ def forward(
139170 )
140171
141172 # Squeeze empty dims out of scales, to comply with grouped mm API.
142- # A_scales shape: (M,1) or (B, M , 1)
173+ # A_scales shape: (M,1) or (padded_M , 1)
143174 # B_t_scales shape: (E, 1, N)
144175 A_scales = A_scales .squeeze (- 1 )
145176 B_t_scales = B_t_scales .squeeze (1 )
146- return torch ._scaled_grouped_mm (
177+ output = torch ._scaled_grouped_mm (
147178 A_data_row_major ,
148179 B_t_data_col_major ,
149180 A_scales .reciprocal (), # Reciprocals are needed for rescaling the output.
150181 B_t_scales .reciprocal (),
151- offs ,
182+ padded_group_end_offsets ,
152183 out_dtype = out_dtype ,
153184 use_fast_accum = True ,
154185 )
155186
187+ # Unpad output if padding was used
188+ if pad_token_groups_for_grouped_mm :
189+ output = unpad_token_groups (
190+ output ,
191+ offs ,
192+ padded_group_start_offsets ,
193+ num_tokens ,
194+ alignment_size = 16 ,
195+ )
196+
197+ assert output .shape [0 ] == num_tokens
198+
199+ return output
200+
156201 @staticmethod
157202 def backward (ctx , grad_output : torch .Tensor ):
158- A , B_t , offs = ctx .saved_tensors
203+ (
204+ padded_A ,
205+ B_t ,
206+ original_group_end_offsets ,
207+ padded_group_start_offsets ,
208+ padded_group_end_offsets ,
209+ ) = ctx .saved_tensors
159210 out_dtype = ctx .out_dtype
160211 float8_dtype = ctx .float8_dtype
212+ pad_token_groups_for_grouped_mm = ctx .pad_token_groups_for_grouped_mm
213+ num_tokens = ctx .num_tokens
214+
215+ # Pad grad_output if padding was used in forward (needed for both dgrad and wgrad)
216+ if pad_token_groups_for_grouped_mm :
217+ padded_grad_output , _ , _ = pad_token_groups (
218+ grad_output ,
219+ original_group_end_offsets ,
220+ alignment_size = 16 ,
221+ )
222+ else :
223+ padded_grad_output = grad_output
161224
162225 # Convert grad_output to float8, row-major for left operand of grouped GEMM
163226 # needed for grad_A: grad_output @ B
164227 #
165- # grad_output shape: (Mg, N)
166- # grad_output_scale shape: (Mg, 1)
228+ # padded_grad_output shape: (Mg, N) or (padded_Mg, N) if padding was used
229+ # grad_output_scale shape: (Mg, 1) or (padded_Mg, 1) if padding was used
167230 grad_output_scales = tensor_to_scale (
168- grad_output ,
231+ padded_grad_output ,
169232 float8_dtype ,
170233 scaling_granularity = ScalingGranularity .AXISWISE ,
171234 axiswise_dim = - 1 ,
172235 round_scales_to_power_of_2 = True ,
173236 )
174- grad_output_scaled = grad_output .to (torch .float32 ) * grad_output_scales
237+ grad_output_scaled = padded_grad_output .to (torch .float32 ) * grad_output_scales
175238 grad_output_data_row_major = to_fp8_saturated (grad_output_scaled , float8_dtype )
176239
177240 # Compute B fp8 column-major for right operand of grouped GEMM:
178241 # grad_A = grad_output @ B.
179242 B_data_col_major , B_scales = triton_fp8_rowwise_3d_transpose_rhs (
180- B_t . _data if hasattr ( B_t , "_data" ) else B_t ,
243+ B_t ,
181244 output_dtype = float8_dtype ,
182245 round_scales_to_power_of_2 = True ,
183246 )
@@ -193,7 +256,7 @@ def backward(ctx, grad_output: torch.Tensor):
193256 )
194257
195258 # Squeeze empty dims out of scales, to comply with grouped mm API.
196- # grad_output_scales shape: (M,1) or (B, M , 1)
259+ # grad_output_scales shape: (M,1) or (padded_M , 1)
197260 # B_scales shape: (E, 1, N)
198261 grad_output_scales = grad_output_scales .squeeze (- 1 )
199262 B_scales = B_scales .squeeze (1 )
@@ -202,29 +265,39 @@ def backward(ctx, grad_output: torch.Tensor):
202265 B_data_col_major ,
203266 grad_output_scales .reciprocal (),
204267 B_scales .reciprocal (),
205- offs ,
268+ padded_group_end_offsets ,
206269 out_dtype = out_dtype ,
207270 use_fast_accum = True ,
208271 )
209272
273+ # Unpad grad_A if padding was used
274+ if pad_token_groups_for_grouped_mm :
275+ grad_A = unpad_token_groups (
276+ grad_A ,
277+ original_group_end_offsets ,
278+ padded_group_start_offsets ,
279+ num_tokens ,
280+ alignment_size = 16 ,
281+ )
282+
210283 # grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
211284 # Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
212285
213286 # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
214287 # needed for grad_B: grad_output_t @ A
215288 # Use transpose method to avoid uncoalesced memory accesses.
216289 grad_out_data_colwise , grad_out_scales = triton_fp8_per_group_colwise_scales (
217- grad_output ,
218- offs ,
290+ padded_grad_output ,
291+ padded_group_end_offsets ,
219292 float8_dtype ,
220293 round_scales_to_power_of_2 = True ,
221294 )
222295 grad_output_t_data_row_major = grad_out_data_colwise .t ()
223296 grad_output_t_scales = grad_out_scales .t ()
224297
225298 A_data_col_major , A_scales = triton_fp8_per_group_colwise_scales (
226- A ,
227- offs ,
299+ padded_A ,
300+ padded_group_end_offsets ,
228301 float8_dtype ,
229302 round_scales_to_power_of_2 = True ,
230303 )
@@ -246,8 +319,8 @@ def backward(ctx, grad_output: torch.Tensor):
246319 A_data_col_major ,
247320 grad_output_t_scales .reciprocal (),
248321 A_scales .reciprocal (),
249- offs ,
322+ padded_group_end_offsets ,
250323 out_dtype = out_dtype ,
251324 use_fast_accum = True ,
252325 )
253- return grad_A , grad_B .transpose (- 2 , - 1 ), None , None , None
326+ return grad_A , grad_B .transpose (- 2 , - 1 ), None , None , None , None
0 commit comments