Skip to content

Commit 4c5e9db

Browse files
2d-2d working
1 parent dc6bcf3 commit 4c5e9db

File tree

3 files changed

+33
-34
lines changed

3 files changed

+33
-34
lines changed

torchao/float8/float8_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def __new__(
316316
assert axiswise_dim in (
317317
None,
318318
0,
319-
1,
320319
-1,
320+
-2,
321321
), f"unsupported axiswise_dim {axiswise_dim}"
322322
self._axiswise_dim = axiswise_dim
323323

torchao/prototype/grouped_mm/__init__.py

+25-31
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,28 @@ def forward(
5858
float8_recipe_name == Float8LinearRecipeName.ROWWISE
5959
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
6060

61-
6261
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
6362
assert 2 <= B.ndim <= 3, "B must be 2D or 3D"
6463

6564
# Dim 1 of B must match the final dim of A.
66-
assert A.size(-1) == B.size(-2), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"
65+
assert A.size(-1) == B.size(
66+
-2
67+
), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"
6768

6869
# offsets are required for 2D A tensor, otherwise it should be None.
6970
if A.ndim == 2 or B.ndim == 2:
7071
assert offs is not None, "offs must be specified for 2D tensor"
71-
else:
72-
assert offs is None, "offs must not be specified for 3D tensor"
7372

7473
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
7574

7675
# Fetch float8 config from specified recipe name.
7776
float8_config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
7877

78+
# Store what we need for backward.
79+
ctx.save_for_backward(A, B)
80+
ctx.float8_config = float8_config
81+
ctx.offs = offs
82+
7983
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8084
# A shape: (M, K) or (B, M, K)
8185
# A_scale shape: (M,1) or (B, M, 1)
@@ -87,14 +91,12 @@ def forward(
8791
linear_mm_config=LinearMMConfig(),
8892
gemm_input_role=GemmInputRole.INPUT,
8993
scaling_granularity=float8_config.cast_config_input.scaling_granularity,
90-
axiswise_dim=get_maybe_axiswise_dim(
91-
-1, float8_config.cast_config_input.scaling_granularity
92-
),
94+
axiswise_dim=-1,
9395
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
9496
)
9597
A_scale = A_fp8_row_major._scale.squeeze()
9698

97-
# Convert B to float8, column-major for right operand of grouped GEMM.
99+
# Convert B to float8, column-major for right operand of grouped GEMM.
98100
# B shape: (K,N) or (B, K, N)
99101
# B scales must be computed rowwise keeping the outer/final dim, so:
100102
# B_scale shape: (1,N) or (B, 1, N)
@@ -106,17 +108,16 @@ def forward(
106108
linear_mm_config=LinearMMConfig(),
107109
gemm_input_role=GemmInputRole.WEIGHT,
108110
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
109-
axiswise_dim=get_maybe_axiswise_dim(
110-
-2, float8_config.cast_config_input.scaling_granularity
111-
),
111+
axiswise_dim=-2,
112112
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
113113
)
114114
B_scale = B_fp8_col_major._scale.squeeze()
115115

116-
# Store what we need for backward.
117-
ctx.save_for_backward(A, B)
118-
ctx.float8_config = float8_config
119-
ctx.offs = offs
116+
# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
117+
# which is the size of the `offs` tensor.
118+
if A.ndim == 2 and B.ndim == 2:
119+
A_scale = A_scale.repeat(offs.numel())
120+
B_scale = B_scale.repeat(offs.numel())
120121

121122
# Perform scaled grouped GEMM and return result.
122123
# output shape: (M, N) or (B, M, N)
@@ -149,14 +150,11 @@ def backward(ctx, grad_output: torch.Tensor):
149150
linear_mm_config=LinearMMConfig(),
150151
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
151152
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
152-
axiswise_dim=get_maybe_axiswise_dim(
153-
-1, float8_config.cast_config_grad_output.scaling_granularity
154-
),
153+
axiswise_dim=-1,
155154
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
156155
)
157156
grad_output_scale = grad_output_fp8_row_major._scale.squeeze()
158157

159-
160158
# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
161159
# needed for grad_A: grad_output @ B.
162160
# Since B was transposed before entry to forward, we need to transpose it back here for this.
@@ -167,20 +165,18 @@ def backward(ctx, grad_output: torch.Tensor):
167165
# - B_scale shape: (1,N) or (B, 1, N)
168166
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
169167
# - B scale shape: (N,) or (B, N)
170-
B_fp8_col_major = hp_tensor_to_float8_dynamic(
168+
B_fp8_col_major = hp_tensor_to_float8_dynamic(
171169
B_non_transposed,
172170
float8_config.cast_config_input.target_dtype,
173171
linear_mm_config=LinearMMConfig(),
174172
gemm_input_role=GemmInputRole.WEIGHT,
175173
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
176-
axiswise_dim=get_maybe_axiswise_dim(
177-
-2, float8_config.cast_config_input.scaling_granularity
178-
),
174+
axiswise_dim=-2,
179175
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
180176
)
181177
B_scale = B_fp8_col_major._scale.squeeze()
182178

183-
# Compute grad_A.
179+
# Compute grad_A.
184180
#
185181
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N), output=(B,M,N)
186182
# grad_A = grad_output @ B
@@ -221,9 +217,7 @@ def backward(ctx, grad_output: torch.Tensor):
221217
linear_mm_config=LinearMMConfig(),
222218
gemm_input_role=GemmInputRole.GRAD_OUTPUT,
223219
scaling_granularity=float8_config.cast_config_grad_output.scaling_granularity,
224-
axiswise_dim=get_maybe_axiswise_dim(
225-
-1, float8_config.cast_config_grad_output.scaling_granularity
226-
),
220+
axiswise_dim=-1,
227221
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
228222
)
229223
grad_output_t_scale = grad_output_t_fp8_row_major._scale.squeeze()
@@ -237,14 +231,14 @@ def backward(ctx, grad_output: torch.Tensor):
237231
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
238232
# - A scale shape: (K,) or (B, K)
239233
A_fp8_col_major = hp_tensor_to_float8_dynamic(
240-
A.transpose(-2, -1).contiguous().tranpose(-2,-1), # Convert to column-major
234+
A.transpose(-2, -1)
235+
.contiguous()
236+
.tranpose(-2, -1), # Convert to column-major
241237
float8_config.cast_config_input.target_dtype,
242238
linear_mm_config=LinearMMConfig(),
243239
gemm_input_role=GemmInputRole.INPUT,
244240
scaling_granularity=float8_config.cast_config_input.scaling_granularity,
245-
axiswise_dim=get_maybe_axiswise_dim(
246-
-2, float8_config.cast_config_input.scaling_granularity
247-
),
241+
axiswise_dim=-2,
248242
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
249243
)
250244
A_scale = A_fp8_col_major._scale.squeeze()

torchao/prototype/grouped_mm/test_grouped_mm.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
9393
out_dtype = torch.bfloat16
9494
device = "cuda"
9595
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
96-
a = torch.randn(m, k * n_groups + k * int(strided), device=device)[:, :k * n_groups]
97-
b = torch.randn(n, k * n_groups + k * int(strided), device=device)[:, :k * n_groups]
96+
a = torch.randn(m, k * n_groups + k * int(strided), device=device)[
97+
:, : k * n_groups
98+
]
99+
b = torch.randn(n, k * n_groups + k * int(strided), device=device)[
100+
:, : k * n_groups
101+
]
98102
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
99103

100104
# Compute result.
@@ -119,6 +123,7 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
119123
offs=offs,
120124
)
121125

126+
122127
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
123128
def test_tensorwise_scaling_not_supported():
124129
device = "cuda"

0 commit comments

Comments
 (0)