Skip to content

Commit c4c6c99

Browse files
update comments
1 parent 72a9b9f commit c4c6c99

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

torchao/prototype/grouped_mm/__init__.py

+16-18
Original file line numberDiff line numberDiff line change
@@ -58,25 +58,29 @@ def forward(
5858
float8_recipe_name == Float8LinearRecipeName.ROWWISE
5959
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
6060

61-
# perform dynamic float8 quantization using the given recipe, if specified
61+
6262
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
63-
assert B.ndim == 3, "B must be 3D"
63+
assert 2 <= B.ndim == 3, "B must be 2D or 3D"
6464

6565
# Dim 1 of B must match the final dim of A.
66-
assert B.size(1) == A.size(-1), "Dim 1 of B must match the final dim of A"
66+
assert A.size(-1) == B.size(-2), f"shape {A.shape} and {B.shape} are not compatible for _scaled_grouped_mm"
6767

6868
# offsets are required for 2D A tensor, otherwise it should be None.
69-
if A.ndim == 2:
70-
assert offs is not None, "offs must be specified for 2D A tensor"
69+
if A.ndim == 2 or B.ndim == 2:
70+
assert offs is not None, "offs must be specified for 2D tensor"
7171
else:
72-
assert offs is None, "offs must not be specified for 3D A tensor"
72+
assert offs is None, "offs must not be specified for 3D tensor"
7373

7474
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
7575

7676
# Fetch float8 config from specified recipe name.
7777
float8_config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
7878

7979
# Convert high precision input tensor to float8.
80+
# A shape: (M, K) or (B, M, K)
81+
# A_scale shape: (M,1) or (B, M, 1)
82+
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
83+
# A_scale shape: (M,) or (B, M)
8084
A_fp8 = hp_tensor_to_float8_dynamic(
8185
A,
8286
float8_config.cast_config_input.target_dtype,
@@ -88,12 +92,14 @@ def forward(
8892
),
8993
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
9094
)
91-
# A shape: (M, K)
92-
# A_scale shape: (M,1)
93-
# squeeze A_scale to be 1D for 2D parent tensor, as required in _scaled_grouped_mm
94-
# A_scale shape: (M,)
95+
A_fp8_scale = A_fp8._scale.squeeze()
9596

9697
# Convert high precision weight tensor to float8.
98+
# B shape: (K,N) or (B, K, N)
99+
# B scales must be computed rowwise keeping the outer/final dim, so:
100+
# B_scale shape: (1,N) or (B, 1, N)
101+
# torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
102+
# B scale shape: (N,) or (B, N)
97103
B_fp8 = hp_tensor_to_float8_dynamic(
98104
B,
99105
float8_config.cast_config_input.target_dtype,
@@ -105,20 +111,12 @@ def forward(
105111
),
106112
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
107113
)
108-
# B shape: (B, 1, N)
109-
# B scales must be computed along the outer/final dim, so B_scale shape: (B, 1, N)
110-
# squeeze B_scale to be 2D for parent 3D tensor, as required in _scaled_grouped_mm
111-
# B scale shape: (B, N)
112114

113115
# Store what we need for backward.
114116
ctx.save_for_backward(A, B)
115117
ctx.float_config = float8_config
116118
ctx.offs = offs
117119

118-
# For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
119-
A_fp8._scale = A_fp8._scale.squeeze()
120-
B_fp8._scale = B_fp8._scale.squeeze()
121-
122120
# Perform scaled grouped GEMM and return result.
123121
return torch._scaled_grouped_mm(
124122
A_fp8._data,

0 commit comments

Comments
 (0)