@@ -88,6 +88,10 @@ def forward(
88
88
),
89
89
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
90
90
)
91
+ # A shape: (M, K)
92
+ # A_scale shape: (M,1)
93
+ # squeeze A_scale to be 1D for 2D tensor _scaled_grouped_mm
94
+ # A_scale shape: (M,)
91
95
92
96
# Convert high precision weight tensor to float8.
93
97
B_fp8 = hp_tensor_to_float8_dynamic (
@@ -101,6 +105,23 @@ def forward(
101
105
),
102
106
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
103
107
)
108
+ # B shape: (B,K,N) => this is compatible for matmul with A shape: (M,K) @ (B,K,N) = (B,M,N)
109
+ # B_scale shape: (B,K,1) => (using axiswise_dim=-1)
110
+ # squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
111
+ # B_scale shape: (B,K)
112
+
113
+ # This fails the check in _scaled_grouped_mm here: "scale.size(1) == mat.size(1 + dim)" where dim=1 for matrix B, because K != N
114
+ # check scale call: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1530
115
+ # failure: https://github.com/pytorch/pytorch/blob/d25acac357ff8663a7787e57e6bc5e69987a8f9a/aten/src/ATen/native/cuda/Blas.cpp#L1458-L1461
116
+
117
+ # To solve this, I changed axiswise_dim to 1, so scale shape becomes:
118
+ # B_scale shape: (B,1,N)
119
+ # squeeze A_scale to be 2D for 3D tensor in _scaled_grouped_mm
120
+ # B_scale shape: (B,N)
121
+ # This passes the check in _scaled_grouped_mm
122
+
123
+ # TODO: allowing axiswise_dim to be 1 breaks assumptions in torchao,
124
+ # so we need to either design long term support for this, or change the requirement in torch._scaled_grouped_mm
104
125
105
126
# Store what we need for backward.
106
127
ctx .save_for_backward (A , B )
0 commit comments