Skip to content

Commit 58076a9

Browse files
committed
Add minor test improvements
1 parent 5b51bac commit 58076a9

1 file changed

Lines changed: 32 additions & 4 deletions

File tree

tests/gemm/test_mm_mxfp8.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,21 @@ def _assert_cosine_similarity(
2626
if use_float:
2727
reference = reference.float()
2828
result = result.float()
29+
30+
# Check cosine similarity between reference and result
2931
cos_sim = F.cosine_similarity(
3032
reference.reshape(-1), result.reshape(-1), dim=0
3133
).item()
34+
3235
if context:
3336
message = (
3437
f"{context} Cosine similarity {cos_sim:.4f} is too low "
35-
f"(expected > {min_cos_sim})."
38+
f"(expected > {min_cos_sim}, {is_sf_swizzled_layout=})."
3639
)
3740
else:
3841
message = (
39-
f"Cosine similarity {cos_sim:.4f} is too low (expected > {min_cos_sim})"
42+
f"Cosine similarity {cos_sim:.4f} is too low "
43+
f"(expected > {min_cos_sim}, {is_sf_swizzled_layout=})."
4044
)
4145
assert cos_sim > min_cos_sim, message
4246
return cos_sim
@@ -113,7 +117,7 @@ def _prepare_mxfp8_tensors(input_bf16, weight_bf16, is_sf_swizzled_layout):
113117

114118
@pytest.mark.parametrize("m", [128, 256, 512, 1024])
115119
@pytest.mark.parametrize("n", [128, 256, 512, 1024])
116-
@pytest.mark.parametrize("k", [128, 256, 512, 1024, 2048, 2560])
120+
@pytest.mark.parametrize("k", [128, 256, 512, 1024, 2048, 2560, 3200])
117121
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
118122
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
119123
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
@@ -136,7 +140,7 @@ def test_mm_mxfp8(
136140

137141

138142
@pytest.mark.parametrize("m", [128, 256, 512, 1024, 2048, 4096])
139-
@pytest.mark.parametrize("n", [4096, 8192, 12288, 16384])
143+
@pytest.mark.parametrize("n", [2688, 4096, 5376, 8192, 12288, 16384])
140144
@pytest.mark.parametrize("k", [4096, 8192])
141145
@pytest.mark.parametrize("is_sf_swizzled_layout", [True, False])
142146
@pytest.mark.parametrize("input_dtype", [torch.bfloat16])
@@ -158,6 +162,30 @@ def test_mm_mxfp8_large_dimensions(
158162
)
159163

160164

165+
@pytest.mark.parametrize(
166+
"m,n,k",
167+
[
168+
(32, 4096, 4096),
169+
(32, 2688, 1856),
170+
(32, 1856, 2688),
171+
(32, 2688, 4096),
172+
(32, 5376, 4096),
173+
],
174+
)
175+
def test_mm_mxfp8_small_m(m, n, k):
176+
_run_mm_mxfp8(
177+
m,
178+
n,
179+
k,
180+
torch.bfloat16,
181+
True, # swizzled scales are the intended fast path
182+
torch.bfloat16,
183+
"cutlass",
184+
auto_tuning=False,
185+
provide_out=True,
186+
)
187+
188+
161189
def _skip_if_unsupported():
162190
compute_capability = get_compute_capability(torch.device("cuda"))
163191
if compute_capability[0] in [11, 12]:

0 commit comments

Comments
 (0)