@@ -26,17 +26,21 @@ def _assert_cosine_similarity(
2626 if use_float :
2727 reference = reference .float ()
2828 result = result .float ()
29+
30+ # Check cosine similarity between reference and result
2931 cos_sim = F .cosine_similarity (
3032 reference .reshape (- 1 ), result .reshape (- 1 ), dim = 0
3133 ).item ()
34+
3235 if context :
3336 message = (
3437 f"{ context } Cosine similarity { cos_sim :.4f} is too low "
35- f"(expected > { min_cos_sim } )."
38+ f"(expected > { min_cos_sim } , { is_sf_swizzled_layout = } )."
3639 )
3740 else :
3841 message = (
39- f"Cosine similarity { cos_sim :.4f} is too low (expected > { min_cos_sim } )"
42+ f"Cosine similarity { cos_sim :.4f} is too low "
43+ f"(expected > { min_cos_sim } , { is_sf_swizzled_layout = } )."
4044 )
4145 assert cos_sim > min_cos_sim , message
4246 return cos_sim
@@ -113,7 +117,7 @@ def _prepare_mxfp8_tensors(input_bf16, weight_bf16, is_sf_swizzled_layout):
113117
114118@pytest .mark .parametrize ("m" , [128 , 256 , 512 , 1024 ])
115119@pytest .mark .parametrize ("n" , [128 , 256 , 512 , 1024 ])
116- @pytest .mark .parametrize ("k" , [128 , 256 , 512 , 1024 , 2048 , 2560 ])
120+ @pytest .mark .parametrize ("k" , [128 , 256 , 512 , 1024 , 2048 , 2560 , 3200 ])
117121@pytest .mark .parametrize ("is_sf_swizzled_layout" , [True , False ])
118122@pytest .mark .parametrize ("input_dtype" , [torch .bfloat16 ])
119123@pytest .mark .parametrize ("out_dtype" , [torch .bfloat16 , torch .float16 ])
@@ -136,7 +140,7 @@ def test_mm_mxfp8(
136140
137141
138142@pytest .mark .parametrize ("m" , [128 , 256 , 512 , 1024 , 2048 , 4096 ])
139- @pytest .mark .parametrize ("n" , [4096 , 8192 , 12288 , 16384 ])
143+ @pytest .mark .parametrize ("n" , [2688 , 4096 , 5376 , 8192 , 12288 , 16384 ])
140144@pytest .mark .parametrize ("k" , [4096 , 8192 ])
141145@pytest .mark .parametrize ("is_sf_swizzled_layout" , [True , False ])
142146@pytest .mark .parametrize ("input_dtype" , [torch .bfloat16 ])
@@ -158,6 +162,30 @@ def test_mm_mxfp8_large_dimensions(
158162 )
159163
160164
165+ @pytest .mark .parametrize (
166+ "m,n,k" ,
167+ [
168+ (32 , 4096 , 4096 ),
169+ (32 , 2688 , 1856 ),
170+ (32 , 1856 , 2688 ),
171+ (32 , 2688 , 4096 ),
172+ (32 , 5376 , 4096 ),
173+ ],
174+ )
175+ def test_mm_mxfp8_small_m (m , n , k ):
176+ _run_mm_mxfp8 (
177+ m ,
178+ n ,
179+ k ,
180+ torch .bfloat16 ,
181+ True , # swizzled scales are the intended fast path
182+ torch .bfloat16 ,
183+ "cutlass" ,
184+ auto_tuning = False ,
185+ provide_out = True ,
186+ )
187+
188+
161189def _skip_if_unsupported ():
162190 compute_capability = get_compute_capability (torch .device ("cuda" ))
163191 if compute_capability [0 ] in [11 , 12 ]:
0 commit comments