Skip to content

Commit dc013a3

Browse files
check input dims are compatible
1 parent 80b7630 commit dc013a3

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchao/prototype/grouped_mm/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _grouped_scaled_mm(
2424
2525
Args:
2626
A (torch.Tensor): The first input tensor, which can be 2D or 3D.
27-
B (torch.Tensor): The second input tensor which must be 3D.
27+
B (torch.Tensor): The second input tensor which must be 3D. Dim 1 of B must match the final dim of A.
2828
float8_recipe (Float8LinearRecipeName): The recipe to use for dynamic float8 quantization.
2929
offs (Optional[torch.Tensor]): The offsets to use to mark the starting index of each group. This
3030
is required when 2D A tensor is used, otherwise it should be None.
@@ -62,6 +62,9 @@ def forward(
6262
assert 2 <= A.ndim <= 3, "A must be 2D or 3D"
6363
assert B.ndim == 3, "B must be 3D"
6464

65+
# Dim 1 of B must match the final dim of A.
66+
assert B.size(1) == A.size(-1), "Dim 1 of B must match the final dim of A"
67+
6568
# offsets are required for 2D A tensor, otherwise it should be None.
6669
if A.ndim == 2:
6770
assert offs is not None, "offs must be specified for 2D A tensor"

0 commit comments

Comments
 (0)