Skip to content

Commit 72a9b9f

Browse files
update comments
1 parent 4f385e5 commit 72a9b9f

File tree

1 file changed

+5
-18
lines changed

1 file changed

+5
-18
lines changed

torchao/prototype/grouped_mm/__init__.py

+5-18
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def forward(
9090
)
9191
# A shape: (M, K)
9292
# A_scale shape: (M,1)
93-
# squeeze A_scale to be 1D for 2D tensor _scaled_grouped_mm
93+
# squeeze A_scale to be 1D for 2D parent tensor, as required in _scaled_grouped_mm
9494
# A_scale shape: (M,)
9595

9696
# Convert high precision weight tensor to float8.
@@ -105,23 +105,10 @@ def forward(
105105
),
106106
round_scales_to_power_of_2=float8_config.round_scales_to_power_of_2,
107107
)
108-
# B shape: (B,K,N) => this is compatible for matmul with A shape: (M,K) @ (B,K,N) = (B,M,N)
109-
# B_scale shape: (B,K,1) => (using axiswise_dim=-1)
110-
# squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
111-
# B_scale shape: (B,K)
112-
113-
# This fails the check in _scaled_grouped_mm here: "scale.size(1) == mat.size(1 + dim)" where dim=1 for matrix B, because K != N
114-
# check scale call: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1530
115-
# failure: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1458-L1461
116-
117-
# To solve this, I changed axiswise_dim to 1, so scale shape becomes:
118-
# B_scale shape: (B,1,N)
119-
# squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
120-
# B_scale shape: (B,N)
121-
# This passes the check in _scaled_grouped_mm
122-
123-
# TODO: allowing axiswise_dim to be 1 breaks assumptions in torchao,
124-
# so we need to either design long term support for this, or change the requirement in torch._scaled_grouped_mm
108+
# B shape: (B, 1, N)
109+
# B scales must be computed along the outer/final dim, so B_scale shape: (B, 1, N)
110+
# squeeze B_scale to be 2D for parent 3D tensor, as required in _scaled_grouped_mm
111+
# B scale shape: (B, N)
125112

126113
# Store what we need for backward.
127114
ctx.save_for_backward(A, B)

0 commit comments

Comments
 (0)