Skip to content

Commit 4f385e5

Browse files
add detailed comments
1 parent dc013a3 commit 4f385e5

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

torchao/prototype/grouped_mm/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ def forward(
8888
),
8989
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
9090
)
91+
# A shape: (M, K)
92+
# A_scale shape: (M,1)
93+
# squeeze A_scale to be 1D for 2D tensor _scaled_grouped_mm
94+
# A_scale shape: (M,)
9195

9296
# Convert high precision weight tensor to float8.
9397
B_fp8 = hp_tensor_to_float8_dynamic(
@@ -101,6 +105,23 @@ def forward(
101105
),
102106
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
103107
)
108+
# B shape: (B,K,N) => this is compatible for matmul with A shape: (M,K) @ (B,K,N) = (B,M,N)
109+
# B_scale shape: (B,K,1) => (using axiswise_dim=-1)
110+
# squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
111+
# B_scale shape: (B,K)
112+
113+
# This fails the check in _scaled_grouped_mm here: "scale.size(1) == mat.size(1 + dim)" where dim=1 for matrix B, because K != N
114+
# check scale call: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1530
115+
# failure: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1458-L1461
116+
117+
# To solve this, I changed axiswise_dim to 1, so scale shape becomes:
118+
# B_scale shape: (B,1,N)
119+
# squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
120+
# B_scale shape: (B,N)
121+
# This passes the check in _scaled_grouped_mm
122+
123+
# TODO: allowing axiswise_dim to be 1 breaks assumptions in torchao,
124+
# so we need to either design long term support for this, or change the requirement in torch._scaled_grouped_mm
104125

105126
# Store what we need for backward.
106127
ctx.save_for_backward(A, B)

0 commit comments

Comments
 (0)