File tree 1 file changed +4
-1
lines changed
torchao/prototype/grouped_mm
1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -24,7 +24,7 @@ def _grouped_scaled_mm(
24
24
25
25
Args:
26
26
A (torch.Tensor): The first input tensor, which can be 2D or 3D.
27
- B (torch.Tensor): The second input tensor which must be 3D.
27
+ B (torch.Tensor): The second input tensor which must be 3D. Dim 1 of B must match the final dim of A.
28
28
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
29
29
offs (Optional[torch.Tensor]): The offsets to use to mark the starting index of each group. This
30
30
is required when 2D A tensor is used, otherwise it should be None.
@@ -62,6 +62,9 @@ def forward(
62
62
assert 2 <= A .ndim <= 3 , "A must be 2D or 3D"
63
63
assert B .ndim == 3 , "B must be 3D"
64
64
65
+ # Dim 1 of B must match the final dim of A.
66
+ assert B .size (1 ) == A .size (- 1 ), "Dim 1 of B must match the final dim of A"
67
+
65
68
# offsets are required for 2D A tensor, otherwise it should be None.
66
69
if A .ndim == 2 :
67
70
assert offs is not None , "offs must be specified for 2D A tensor"
You can’t perform that action at this time.
0 commit comments