Skip to content

Commit 0dc4c7f

Browse files
2d-2d working
1 parent dc6bcf3 commit 0dc4c7f

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

torchao/float8/float8_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -316,8 +316,8 @@ def __new__(
316316
assert axiswise_dim in (
317317
None,
318318
0,
319-
1,
320319
-1,
320+
-2,
321321
), f"unsupported axiswise_dim {axiswise_dim}"
322322
self._axiswise_dim = axiswise_dim
323323

torchao/prototype/grouped_mm/__init__.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,28 @@ def forward(
5858
float8_recipe_name == Float8LinearRecipeName.ROWWISE
5959
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
6060

61-
6261
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
6362
assert 2 <= B.ndim <= 3, "B must be 2D or 3D"
6463

6564
# Dim 1 of B must match the final dim of A.
66-
assert A.size(-1) == B.size(-2), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"
65+
assert A.size(-1) == B.size(
66+
-2
67+
), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"
6768

6869
# offsets are required for 2D A tensor, otherwise it should be None.
6970
if A.ndim == 2 or B.ndim == 2:
7071
assert offs is not None, "offs must be specified for 2D tensor"
71-
else:
72-
assert offs is None, "offs must not be specified for 3D tensor"
7372

7473
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
7574

7675
# Fetch float8 config from specified recipe name.
7776
float8_config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
7877

78+
# Store what we need for backward.
79+
ctx.save_for_backward(A, B)
80+
ctx.float8_config = float8_config
81+
ctx.offs = offs
82+
7983
# Convert high precision input tensor to float8, row-major for left operand of grouped GEMM.
8084
# A shape: (M, K) or (B, M, K)
8185
# A_scale shape: (M,1) or (B, M, 1)
@@ -94,7 +98,7 @@ def forward(
9498
)
9599
A_scale = A_fp8_row_major._scale.squeeze()
96100

97-
# Convert B to float8, column-major for right operand of grouped GEMM.
101+
# Convert B to float8, column-major for right operand of grouped GEMM.
98102
# B shape: (K,N) or (B, K, N)
99103
# B scales must be computed rowwise keeping the outer/final dim, so:
100104
# B_scale shape: (1,N) or (B, 1, N)
@@ -113,10 +117,11 @@ def forward(
113117
)
114118
B_scale = B_fp8_col_major._scale.squeeze()
115119

116-
# Store what we need for backward.
117-
ctx.save_for_backward(A, B)
118-
ctx.float8_config = float8_config
119-
ctx.offs = offs
120+
# Special case: 2D-2D grouped GEMM, the scales must be multiplied by the number of groups,
121+
# which is the size of the `offs` tensor.
122+
if A.ndim == 2 and B.ndim == 2:
123+
A_scale = A_scale.repeat(offs.numel())
124+
B_scale = B_scale.repeat(offs.numel())
120125

121126
# Perform scaled grouped GEMM and return result.
122127
# output shape: (M, N) or (B, M, N)
@@ -156,7 +161,6 @@ def backward(ctx, grad_output: torch.Tensor):
156161
)
157162
grad_output_scale = grad_output_fp8_row_major._scale.squeeze()
158163

159-
160164
# Convert B to non-transposed, float8, column-major for right operand of grouped GEMM
161165
# needed for grad_A: grad_output @ B.
162166
# Since B was transposed before entry to forward, we need to transpose it back here for this.
@@ -167,7 +171,7 @@ def backward(ctx, grad_output: torch.Tensor):
167171
# - B_scale shape: (1,N) or (B, 1, N)
168172
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
169173
# - B scale shape: (N,) or (B, N)
170-
B_fp8_col_major = hp_tensor_to_float8_dynamic(
174+
B_fp8_col_major = hp_tensor_to_float8_dynamic(
171175
B_non_transposed,
172176
float8_config.cast_config_input.target_dtype,
173177
linear_mm_config=LinearMMConfig(),
@@ -180,7 +184,7 @@ def backward(ctx, grad_output: torch.Tensor):
180184
)
181185
B_scale = B_fp8_col_major._scale.squeeze()
182186

183-
# Compute grad_A.
187+
# Compute grad_A.
184188
#
185189
# Case 1: A=2D, B=3D with A=(M,K), B^T=(B,K,N), output=(B,M,N)
186190
# grad_A = grad_output @ B
@@ -237,7 +241,9 @@ def backward(ctx, grad_output: torch.Tensor):
237241
# - torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
238242
# - A scale shape: (K,) or (B, K)
239243
A_fp8_col_major = hp_tensor_to_float8_dynamic(
240-
A.transpose(-2, -1).contiguous().tranpose(-2,-1), # Convert to column-major
244+
A.transpose(-2, -1)
245+
.contiguous()
246+
.tranpose(-2, -1), # Convert to column-major
241247
float8_config.cast_config_input.target_dtype,
242248
linear_mm_config=LinearMMConfig(),
243249
gemm_input_role=GemmInputRole.INPUT,

torchao/prototype/grouped_mm/test_grouped_mm.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,12 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
9393
out_dtype = torch.bfloat16
9494
device = "cuda"
9595
m, n, k, n_groups = 16, 16, 16, 4 # all sizes have to be divisible by 16
96-
a = torch.randn(m, k * n_groups + k * int(strided), device=device)[:, :k * n_groups]
97-
b = torch.randn(n, k * n_groups + k * int(strided), device=device)[:, :k * n_groups]
96+
a = torch.randn(m, k * n_groups + k * int(strided), device=device)[
97+
:, : k * n_groups
98+
]
99+
b = torch.randn(n, k * n_groups + k * int(strided), device=device)[
100+
:, : k * n_groups
101+
]
98102
offs = torch.arange(k, n_groups * k + 1, k, device=device, dtype=torch.int32)
99103

100104
# Compute result.
@@ -119,6 +123,7 @@ def test_grouped_gemm_2d_2d(use_fast_accum, strided):
119123
offs=offs,
120124
)
121125

126+
122127
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
123128
def test_tensorwise_scaling_not_supported():
124129
device = "cuda"

0 commit comments

Comments
 (0)