Skip to content

Commit 4117a9e

Browse files
clean up
1 parent 4e04022 commit 4117a9e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchao/prototype/grouped_mm/test_grouped_mm.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def validate_grouped_mm(
112112
float8_recipe_name: Float8LinearRecipeName,
113113
offs: Optional[torch.Tensor] = None,
114114
):
115-
assert isinstance(result, torch.Tensor)
116115
assert result.dtype == out_dtype
117116

118117
# Validate output by comparing the partition of the grouped scaled mm output
@@ -150,6 +149,8 @@ def validate_grouped_mm(
150149
A_list, B_list, A_scale_list, B_scale_list, result_list = [], [], [], [], []
151150
start = 0
152151

152+
# If A is 2D, we need to split it into parts based on offs, so we can perform
153+
# separate _scaled_mm calls for each part.
153154
if A.ndim == 2 and offs is not None:
154155
offs_cpu = offs.cpu()
155156
for i in range(n_groups):

0 commit comments

Comments
 (0)