Skip to content

Commit 5ebd10d

Browse files
danielvegamyhregithub-actions[bot]drisspg
authored
[moe training] apply per group padding to fp8 grouped mm (#4045)
* [moe training] apply per group padding to fp8 grouped mm * Add module docstring to moe_training/utils.py Co-authored-by: Driss Guessous <drisspg@users.noreply.github.com> --------- Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Driss Guessous <drisspg@users.noreply.github.com>
1 parent 605a22e commit 5ebd10d

File tree

7 files changed

+223
-102
lines changed

7 files changed

+223
-102
lines changed

test/prototype/moe_training/reference_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,9 @@ def generate_permute_indices(
136136
"""
137137
Prepare permutation indices and the number of tokens for each expert.
138138
"""
139+
# if using generate_permute_indices, capture scalar outputs to avoid graph break
140+
torch._dynamo.config.capture_scalar_outputs = True
141+
139142
start_index_values = (
140143
torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group
141144
)

test/prototype/moe_training/test_training.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,11 @@ def test_moe_training(
9797

9898
# FP8_ROWWISE hardware path requires SM90 (CUDA) or MI300/MI350 (ROCm)
9999
if recipe == Float8TrainingRecipe.FP8_ROWWISE:
100+
if compile:
101+
pytest.skip(
102+
"https://github.com/pytorch/ao/issues/4048: 'FakeTensor' object has no attribute '__tensor_flatten__'"
103+
)
104+
100105
if is_ROCM():
101106
if not (is_MI300() or is_MI350()):
102107
pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm")

torchao/prototype/moe_training/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class Float8TrainingOpConfig(TrainingOpBaseConfig):
5757
# Output dtype for the FP8 grouped GEMMs.
5858
out_dtype: Optional[torch.dtype] = torch.bfloat16
5959

60+
# TODO: support pad_token_groups_for_grouped_mm field like MXFP8TrainingOpConfig
61+
6062
@classmethod
6163
def from_recipe(
6264
cls,

torchao/prototype/moe_training/fp8_grouped_mm.py

Lines changed: 97 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
triton_fp8_per_group_colwise_scales,
1515
triton_fp8_rowwise_3d_transpose_rhs,
1616
)
17-
from torchao.prototype.moe_training.utils import _is_column_major
17+
from torchao.prototype.moe_training.utils import (
18+
_is_column_major,
19+
pad_token_groups,
20+
unpad_token_groups,
21+
)
1822

1923

2024
def _to_fp8_rowwise_then_scaled_grouped_mm(
@@ -23,6 +27,7 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
2327
offs: torch.Tensor,
2428
out_dtype: Optional[torch.dtype] = torch.bfloat16,
2529
float8_dtype: torch.dtype = torch.float8_e4m3fn,
30+
pad_token_groups_for_grouped_mm: bool = True,
2631
) -> torch.Tensor:
2732
"""
2833
Differentiable FP8 grouped matrix multiplication with dynamic FP8 rowwise quantization.
@@ -39,6 +44,9 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
3944
offs: Offset tensor of shape (num_groups + 1,) with dtype int32, defining
4045
group boundaries for the grouped GEMM operation. Group sizes must be divisible by 16.
4146
out_dtype: Output dtype for the result. Defaults to torch.bfloat16.
47+
float8_dtype: Float8 dtype for quantization. Defaults to torch.float8_e4m3fn.
48+
pad_token_groups_for_grouped_mm: Whether to pad token groups to the next multiple of 16
49+
(requirement for FP8 grouped GEMM). If your tokens are already padded, set to False.
4250
4351
Returns:
4452
torch.Tensor: Result of grouped matrix multiplication with shape (M, N).
@@ -49,7 +57,9 @@ def _to_fp8_rowwise_then_scaled_grouped_mm(
4957
- Scales are computed per-row and rounded to powers of 2 for efficiency
5058
- This function is fully differentiable via custom autograd implementation
5159
"""
52-
return _Float8GroupedMM.apply(A, B_t, offs, out_dtype, float8_dtype)
60+
return _Float8GroupedMM.apply(
61+
A, B_t, offs, out_dtype, float8_dtype, pad_token_groups_for_grouped_mm
62+
)
5363

5464

5565
class _Float8GroupedMM(torch.autograd.Function):
@@ -63,6 +73,7 @@ def forward(
6373
offs: Optional[torch.Tensor] = None,
6474
out_dtype: Optional[torch.dtype] = torch.bfloat16,
6575
float8_dtype: torch.dtype = torch.float8_e4m3fn,
76+
pad_token_groups_for_grouped_mm: bool = True,
6677
) -> torch.Tensor:
6778
# torchao _quantize_then_scaled_grouped_mm only supports A=2D|3D and B=3D.
6879
assert A.ndim == 2 or A.ndim == 3, "A must be 2D or 3D"
@@ -97,17 +108,33 @@ def forward(
97108
# Due to hardware requirements, the right operand in a scaled grouped GEMM must be column-major.
98109
assert _is_column_major(B_t), "B must be column-major"
99110

111+
# Save original group_end_offsets and num_tokens before padding
112+
num_tokens = A.shape[0]
113+
padded_group_start_offsets = None
114+
padded_group_end_offsets = None
115+
116+
# Conditionally pad token groups if not aligned to 16
117+
if pad_token_groups_for_grouped_mm:
118+
padded_A, padded_group_start_offsets, padded_group_end_offsets = (
119+
pad_token_groups(
120+
A, offs, alignment_size=16
121+
) # TODO: support emulated mode
122+
)
123+
else:
124+
padded_A = A
125+
padded_group_end_offsets = offs
126+
100127
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
101-
# A shape: (M, K) or (B, M, K)
102-
# A_scales shape: (M,1) or (B, M, 1)
128+
# padded_A shape: (M, K) or (padded_M, K) if padding was used
129+
# A_scales shape: (M,1) or (padded_M, 1) if padding was used
103130
A_scales = tensor_to_scale(
104-
A,
131+
padded_A,
105132
float8_dtype,
106133
scaling_granularity=ScalingGranularity.AXISWISE,
107134
axiswise_dim=-1,
108135
round_scales_to_power_of_2=True,
109136
)
110-
A_scaled = A.to(torch.float32) * A_scales
137+
A_scaled = padded_A.to(torch.float32) * A_scales
111138
A_data_row_major = to_fp8_saturated(A_scaled, float8_dtype)
112139

113140
# Convert B to float8, column-major for right operand of grouped GEMM.
@@ -125,9 +152,13 @@ def forward(
125152
B_t_data_col_major = to_fp8_saturated(B_t_scaled, float8_dtype)
126153

127154
# Store what we need for backward.
128-
ctx.save_for_backward(A, B_t, offs)
155+
ctx.save_for_backward(
156+
padded_A, B_t, offs, padded_group_start_offsets, padded_group_end_offsets
157+
)
129158
ctx.out_dtype = out_dtype
130159
ctx.float8_dtype = float8_dtype
160+
ctx.pad_token_groups_for_grouped_mm = pad_token_groups_for_grouped_mm
161+
ctx.num_tokens = num_tokens
131162

132163
# Perform scaled grouped GEMM and return result.
133164
# output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N)
@@ -139,45 +170,77 @@ def forward(
139170
)
140171

141172
# Squeeze empty dims out of scales, to comply with grouped mm API.
142-
# A_scales shape: (M,1) or (B, M, 1)
173+
# A_scales shape: (M,1) or (padded_M, 1)
143174
# B_t_scales shape: (E, 1, N)
144175
A_scales = A_scales.squeeze(-1)
145176
B_t_scales = B_t_scales.squeeze(1)
146-
return torch._scaled_grouped_mm(
177+
output = torch._scaled_grouped_mm(
147178
A_data_row_major,
148179
B_t_data_col_major,
149180
A_scales.reciprocal(), # Reciprocals are needed for rescaling the output.
150181
B_t_scales.reciprocal(),
151-
offs,
182+
padded_group_end_offsets,
152183
out_dtype=out_dtype,
153184
use_fast_accum=True,
154185
)
155186

187+
# Unpad output if padding was used
188+
if pad_token_groups_for_grouped_mm:
189+
output = unpad_token_groups(
190+
output,
191+
offs,
192+
padded_group_start_offsets,
193+
num_tokens,
194+
alignment_size=16,
195+
)
196+
197+
assert output.shape[0] == num_tokens
198+
199+
return output
200+
156201
@staticmethod
157202
def backward(ctx, grad_output: torch.Tensor):
158-
A, B_t, offs = ctx.saved_tensors
203+
(
204+
padded_A,
205+
B_t,
206+
original_group_end_offsets,
207+
padded_group_start_offsets,
208+
padded_group_end_offsets,
209+
) = ctx.saved_tensors
159210
out_dtype = ctx.out_dtype
160211
float8_dtype = ctx.float8_dtype
212+
pad_token_groups_for_grouped_mm = ctx.pad_token_groups_for_grouped_mm
213+
num_tokens = ctx.num_tokens
214+
215+
# Pad grad_output if padding was used in forward (needed for both dgrad and wgrad)
216+
if pad_token_groups_for_grouped_mm:
217+
padded_grad_output, _, _ = pad_token_groups(
218+
grad_output,
219+
original_group_end_offsets,
220+
alignment_size=16,
221+
)
222+
else:
223+
padded_grad_output = grad_output
161224

162225
# Convert grad_output to float8, row-major for left operand of grouped GEMM
163226
# needed for grad_A: grad_output @ B
164227
#
165-
# grad_output shape: (Mg, N)
166-
# grad_output_scale shape: (Mg, 1)
228+
# padded_grad_output shape: (Mg, N) or (padded_Mg, N) if padding was used
229+
# grad_output_scale shape: (Mg, 1) or (padded_Mg, 1) if padding was used
167230
grad_output_scales = tensor_to_scale(
168-
grad_output,
231+
padded_grad_output,
169232
float8_dtype,
170233
scaling_granularity=ScalingGranularity.AXISWISE,
171234
axiswise_dim=-1,
172235
round_scales_to_power_of_2=True,
173236
)
174-
grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales
237+
grad_output_scaled = padded_grad_output.to(torch.float32) * grad_output_scales
175238
grad_output_data_row_major = to_fp8_saturated(grad_output_scaled, float8_dtype)
176239

177240
# Compute B fp8 column-major for right operand of grouped GEMM:
178241
# grad_A = grad_output @ B.
179242
B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs(
180-
B_t._data if hasattr(B_t, "_data") else B_t,
243+
B_t,
181244
output_dtype=float8_dtype,
182245
round_scales_to_power_of_2=True,
183246
)
@@ -193,7 +256,7 @@ def backward(ctx, grad_output: torch.Tensor):
193256
)
194257

195258
# Squeeze empty dims out of scales, to comply with grouped mm API.
196-
# grad_output_scales shape: (M,1) or (B, M, 1)
259+
# grad_output_scales shape: (M,1) or (padded_M, 1)
197260
# B_scales shape: (E, 1, N)
198261
grad_output_scales = grad_output_scales.squeeze(-1)
199262
B_scales = B_scales.squeeze(1)
@@ -202,29 +265,39 @@ def backward(ctx, grad_output: torch.Tensor):
202265
B_data_col_major,
203266
grad_output_scales.reciprocal(),
204267
B_scales.reciprocal(),
205-
offs,
268+
padded_group_end_offsets,
206269
out_dtype=out_dtype,
207270
use_fast_accum=True,
208271
)
209272

273+
# Unpad grad_A if padding was used
274+
if pad_token_groups_for_grouped_mm:
275+
grad_A = unpad_token_groups(
276+
grad_A,
277+
original_group_end_offsets,
278+
padded_group_start_offsets,
279+
num_tokens,
280+
alignment_size=16,
281+
)
282+
210283
# grad_B is a special case. both operands of the grouped gemm will be 2D with offsets determing the "groups."
211284
# Compute scales for grad_output_t and A, which are both 2D tensors with offsets which define the "jagged" groups.
212285

213286
# Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM
214287
# needed for grad_B: grad_output_t @ A
215288
# Use transpose method to avoid uncoalesced memory accesses.
216289
grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales(
217-
grad_output,
218-
offs,
290+
padded_grad_output,
291+
padded_group_end_offsets,
219292
float8_dtype,
220293
round_scales_to_power_of_2=True,
221294
)
222295
grad_output_t_data_row_major = grad_out_data_colwise.t()
223296
grad_output_t_scales = grad_out_scales.t()
224297

225298
A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales(
226-
A,
227-
offs,
299+
padded_A,
300+
padded_group_end_offsets,
228301
float8_dtype,
229302
round_scales_to_power_of_2=True,
230303
)
@@ -246,8 +319,8 @@ def backward(ctx, grad_output: torch.Tensor):
246319
A_data_col_major,
247320
grad_output_t_scales.reciprocal(),
248321
A_scales.reciprocal(),
249-
offs,
322+
padded_group_end_offsets,
250323
out_dtype=out_dtype,
251324
use_fast_accum=True,
252325
)
253-
return grad_A, grad_B.transpose(-2, -1), None, None, None
326+
return grad_A, grad_B.transpose(-2, -1), None, None, None, None

0 commit comments

Comments
 (0)