@@ -58,25 +58,29 @@ def forward(
58
58
float8_recipe_name == Float8LinearRecipeName .ROWWISE
59
59
), "Only rowwise scaling is supported by torch._scaled_grouped_mm."
60
60
61
- # perform dynamic float8 quantization using the given recipe, if specified
61
+
62
62
assert 2 <= A .ndim <= 3 , "A must be 2D or 3D"
63
- assert B .ndim == 3 , "B must be 3D"
63
+ assert 2 <= B .ndim == 3 , "B must be 2D or 3D"
64
64
65
65
# Dim 1 of B must match the final dim of A.
66
- assert B .size (1 ) == A .size (- 1 ), "Dim 1 of B must match the final dim of A "
66
+ assert A .size (- 1 ) == B .size (- 2 ), f"shape { A . shape } and { B . shape } are not compatible for _scaled_grouped_mm "
67
67
68
68
# offsets are required for 2D A tensor, otherwise it should be None.
69
- if A .ndim == 2 :
70
- assert offs is not None , "offs must be specified for 2D A tensor"
69
+ if A .ndim == 2 or B . ndim == 2 :
70
+ assert offs is not None , "offs must be specified for 2D tensor"
71
71
else :
72
- assert offs is None , "offs must not be specified for 3D A tensor"
72
+ assert offs is None , "offs must not be specified for 3D tensor"
73
73
74
74
# TODO: pad dims to be multiples of 16, as required by torch._scaled_grouped_mm.
75
75
76
76
# Fetch float8 config from specified recipe name.
77
77
float8_config = Float8LinearConfig .from_recipe_name (float8_recipe_name )
78
78
79
79
# Convert high precision input tensor to float8.
80
+ # A shape: (M, K) or (B, M, K)
81
+ # A_scale shape: (M,1) or (B, M, 1)
82
+ # torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
83
+ # A_scale shape: (M,) or (B, M)
80
84
A_fp8 = hp_tensor_to_float8_dynamic (
81
85
A ,
82
86
float8_config .cast_config_input .target_dtype ,
@@ -88,12 +92,14 @@ def forward(
88
92
),
89
93
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
90
94
)
91
- # A shape: (M, K)
92
- # A_scale shape: (M,1)
93
- # squeeze A_scale to be 1D for 2D parent tensor, as required in _scaled_grouped_mm
94
- # A_scale shape: (M,)
95
+ A_fp8_scale = A_fp8 ._scale .squeeze ()
95
96
96
97
# Convert high precision weight tensor to float8.
98
+ # B shape: (K,N) or (B, K, N)
99
+ # B scales must be computed rowwise keeping the outer/final dim, so:
100
+ # B_scale shape: (1,N) or (B, 1, N)
101
+ # torch._scaled_grouped_mm requires scales without any empty dims, so squeeze A_scale.
102
+ # B scale shape: (N,) or (B, N)
97
103
B_fp8 = hp_tensor_to_float8_dynamic (
98
104
B ,
99
105
float8_config .cast_config_input .target_dtype ,
@@ -105,20 +111,12 @@ def forward(
105
111
),
106
112
round_scales_to_power_of_2 = float8_config .round_scales_to_power_of_2 ,
107
113
)
108
- # B shape: (B, 1, N)
109
- # B scales must be computed along the outer/final dim, so B_scale shape: (B, 1, N)
110
- # squeeze B_scale to be 2D for parent 3D tensor, as required in _scaled_grouped_mm
111
- # B scale shape: (B, N)
112
114
113
115
# Store what we need for backward.
114
116
ctx .save_for_backward (A , B )
115
117
ctx .float_config = float8_config
116
118
ctx .offs = offs
117
119
118
- # For rowwise scaling, torch._scaled_grouped_mm requires scales without any empty dims.
119
- A_fp8 ._scale = A_fp8 ._scale .squeeze ()
120
- B_fp8 ._scale = B_fp8 ._scale .squeeze ()
121
-
122
120
# Perform scaled grouped GEMM and return result.
123
121
return torch ._scaled_grouped_mm (
124
122
A_fp8 ._data ,
0 commit comments