Skip to content

Commit 4e04022

Browse files
allow other axiswise dims so we can pass in 3D B tensor tranposed
1 parent 5099838 commit 4e04022

File tree

4 files changed

+22
-21
lines changed

4 files changed

+22
-21
lines changed

torchao/float8/float8_ops.py

+3
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def float8_transpose(aten_op, args, kwargs=None):
151151
else:
152152
new_scale = args[0]._scale
153153

154+
if aten_op == aten.transpose.int:
155+
_assert_tensorwise_scale(aten_op, args[0]._scale)
156+
154157
old_axiswise_dim = args[0]._axiswise_dim
155158
new_axiswise_dim = old_axiswise_dim
156159
if old_axiswise_dim is not None:

torchao/float8/float8_tensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def __new__(
313313
linear_mm_config if linear_mm_config is not None else LinearMMConfig()
314314
)
315315
self._gemm_input_role = gemm_input_role
316-
assert axiswise_dim in (None, 0, -1), f"unsupported axiswise_dim {axiswise_dim}"
316+
assert axiswise_dim in (None, 0, 1, -1), f"unsupported axiswise_dim {axiswise_dim}"
317317
self._axiswise_dim = axiswise_dim
318318

319319
return self

torchao/prototype/grouped_mm/__init__.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class _Float8GroupedMM(torch.autograd.Function):
4747
def forward(
4848
ctx,
4949
A: torch.Tensor,
50-
B: torch.Tensor,
50+
B_t: torch.Tensor,
5151
float8_recipe_name: Float8LinearRecipeName,
5252
offs: Optional[torch.Tensor] = None,
5353
out_dtype: Optional[torch.dtype] = None,
@@ -60,7 +60,7 @@ def forward(
6060

6161
# perform dynamic float8 quantization using the given recipe, if specified
6262
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
63-
assert B.ndim == 3, "B must be 3D"
63+
assert B_t.ndim == 3, "B must be 3D"
6464
if A.ndim == 2:
6565
assert offs is not None, "offs must be specified for 2D A tensor"
6666
else:
@@ -83,34 +83,33 @@ def forward(
8383
)
8484

8585
# Convert high precision weight tensor to float8.
86-
B_fp8 = hp_tensor_to_float8_dynamic(
87-
B,
86+
B_t_fp8 = hp_tensor_to_float8_dynamic(
87+
B_t,
8888
float8_config.cast_config_input.target_dtype,
8989
linear_mm_config=LinearMMConfig(),
9090
gemm_input_role=GemmInputRole.WEIGHT,
9191
scaling_granularity=float8_config.cast_config_weight.scaling_granularity,
9292
axiswise_dim=get_maybe_axiswise_dim(
93-
-1, float8_config.cast_config_input.scaling_granularity
93+
1, float8_config.cast_config_input.scaling_granularity
9494
),
9595
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
9696
)
97-
B_fp8_t = B_fp8.transpose(-2, -1)
9897

9998
# Store what we need for backward.
100-
ctx.save_for_backward(A, B)
99+
ctx.save_for_backward(A, B_t)
101100
ctx.float_config = float8_config
102101
ctx.offs = offs
103102

104103
# For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
105104
A_fp8._scale = A_fp8._scale.squeeze()
106-
B_fp8_t._scale = B_fp8_t._scale.squeeze()
105+
B_t_fp8._scale = B_t_fp8._scale.squeeze()
107106

108107
# Perform scaled grouped GEMM and return result.
109108
return torch._scaled_grouped_mm(
110109
A_fp8._data,
111-
B_fp8_t._data,
110+
B_t_fp8._data,
112111
A_fp8._scale,
113-
B_fp8_t._scale,
112+
B_t_fp8._scale,
114113
offs,
115114
out_dtype=out_dtype,
116115
use_fast_accum=use_fast_accum,

torchao/prototype/grouped_mm/test_grouped_mm.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
3030
offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32)
3131
result = _grouped_scaled_mm(
3232
a,
33-
b,
33+
b.transpose(-2, -1),
3434
offs=offs,
3535
float8_recipe=float8_recipe_name,
3636
out_dtype=out_dtype,
@@ -41,7 +41,7 @@ def test_grouped_gemm_2d_3d(use_fast_accum, strided):
4141
validate_grouped_mm(
4242
result,
4343
a,
44-
b,
44+
b.transpose(-2, -1),
4545
n_groups,
4646
out_dtype,
4747
use_fast_accum,
@@ -67,7 +67,7 @@ def test_grouped_gemm_3d_3d(use_fast_accum, strided):
6767
]
6868
result = _grouped_scaled_mm(
6969
a,
70-
b,
70+
b.transpose(-2, -1),
7171
float8_recipe=float8_recipe_name,
7272
out_dtype=out_dtype,
7373
use_fast_accum=use_fast_accum,
@@ -77,7 +77,7 @@ def test_grouped_gemm_3d_3d(use_fast_accum, strided):
7777
validate_grouped_mm(
7878
result,
7979
a,
80-
b,
80+
b.transpose(-2, -1),
8181
n_groups,
8282
out_dtype,
8383
use_fast_accum,
@@ -95,7 +95,7 @@ def test_tensorwise_scaling_not_supported():
9595
with pytest.raises(AssertionError):
9696
_grouped_scaled_mm(
9797
a,
98-
b,
98+
b.transpose(-2, -1),
9999
offs=offs,
100100
float8_recipe=Float8LinearRecipeName.TENSORWISE,
101101
out_dtype=torch.bfloat16,
@@ -131,8 +131,8 @@ def validate_grouped_mm(
131131
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
132132
)
133133

134-
B_fp8 = hp_tensor_to_float8_dynamic(
135-
B,
134+
B_t_fp8 = hp_tensor_to_float8_dynamic(
135+
B.transpose(-2, -1),
136136
float8_config.cast_config_input.target_dtype,
137137
linear_mm_config=LinearMMConfig(),
138138
gemm_input_role=GemmInputRole.WEIGHT,
@@ -142,11 +142,10 @@ def validate_grouped_mm(
142142
),
143143
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
144144
)
145-
B_fp8_t = B_fp8.transpose(-2, -1)
146145

147146
# grouped_scaled_mm doesn't support empty dims
148147
scale_A = A_fp8._scale.squeeze()
149-
scale_B = B_fp8_t._scale.squeeze()
148+
scale_B = B_t_fp8._scale.squeeze()
150149

151150
A_list, B_list, A_scale_list, B_scale_list, result_list = [], [], [], [], []
152151
start = 0
@@ -160,7 +159,7 @@ def validate_grouped_mm(
160159
start = offs_cpu[i]
161160
else:
162161
A_list = A_fp8._data
163-
B_list = B_fp8_t._data
162+
B_list = B_t_fp8._data
164163

165164
A_scale_list = scale_A
166165
B_scale_list = scale_B

0 commit comments

Comments
 (0)