@@ -90,7 +90,7 @@ def forward(
90
90
)
91
91
# A shape: (M, K)
92
92
# A_scale shape: (M,1)
93
- # squeeze A_scale to be 1D for 2D tensor _scaled_grouped_mm
93
+ # squeeze A_scale to be 1D for 2D parent tensor, as required in _scaled_grouped_mm
94
94
# A_scale shape: (M,)
95
95
96
96
# Convert high precision weight tensor to float8.
@@ -105,23 +105,10 @@ def forward(
105
105
),
106
106
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
107
107
)
108
- # B shape: (B,K,N) => this is compatible for matmul with A shape: (M,K) @ (B,K,N) = (B,M,N)
109
- # B_scale shape: (B,K,1) => (using axiswise_dim=-1)
110
- # squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
111
- # B_scale shape: (B,K)
112
-
113
- # This fails the check in _scaled_grouped_mm here: "scale.size(1) == mat.size(1 + dim)" where dim=1 for matrix B, because K != N
114
- # check scale call: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1530
115
- # failure: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1458-L1461
116
-
117
- # To solve this, I changed axiswise_dim to 1, so scale shape becomes:
118
- # B_scale shape: (B,1,N)
119
- # squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
120
- # B_scale shape: (B,N)
121
- # This passes the check in _scaled_grouped_mm
122
-
123
- # TODO: allowing axiswise_dim to be 1 breaks assumptions in torchao,
124
- # so we need to either design long term support for this, or change the requirement in torch._scaled_grouped_mm
108
+ # B shape: (B, 1, N)
109
+ # B scales must be computed along the outer/final dim, so B_scale shape: (B, 1, N)
110
+ # squeeze B_scale to be 2D for parent 3D tensor, as required in _scaled_grouped_mm
111
+ # B scale shape: (B, N)
125
112
126
113
# Store what we need for backward.
127
114
ctx .save_for_backward (A , B )
0 commit comments