@@ -100,6 +100,7 @@ void mxfp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tens
100100 CHECK_INPUT_AND_TYPE (mat2Scale, SF_DTYPE);
101101
102102 int64_t m, n, k, b;
103+ // Scale validation for swizzled (1D) and non-swizzled (2D) layouts.
103104 if (mat1.ndim () == 2 ) {
104105 TVM_FFI_ICHECK_EQ (mat2.ndim (), 2 ) << " mat2 must be a matrix" ;
105106 // mat2 is passed as b.T, but TensorView reads underlying storage as [N, K]
@@ -130,6 +131,89 @@ void mxfp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tens
130131 TVM_FFI_LOG_AND_THROW (NotImplementedError) << " mat1 must be a matrix or a batch of matrices" ;
131132 }
132133
134+ constexpr int64_t sfVecSize = 32 ; // MXFP8 block size
135+ auto scale_len = [&](int64_t dim) { return (dim + sfVecSize - 1 ) / sfVecSize; };
136+ auto swizzled_len = [&](int64_t rows, int64_t cols) {
137+ auto pad_up = [](int64_t value, int64_t multiple) {
138+ return (value + multiple - 1 ) / multiple * multiple;
139+ };
140+ int64_t padded_rows = pad_up (rows, 128 );
141+ int64_t padded_cols = pad_up (cols, 4 );
142+ return padded_rows * padded_cols;
143+ };
144+
145+ if (mat1.ndim () == 2 ) {
146+ const int64_t k_scales = scale_len (k);
147+ if (mat1Scale.ndim () == 1 ) {
148+ int64_t expected = swizzled_len (m, k_scales);
149+ TVM_FFI_ICHECK_EQ (mat1Scale.size (0 ), expected)
150+ << " mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << " , got "
151+ << mat1Scale.size (0 );
152+ } else {
153+ TVM_FFI_ICHECK_EQ (mat1Scale.ndim (), 2 )
154+ << " mxfp8_bmm_impl: mat1Scale must be 1D (swizzled) or 2D (non-swizzled), got "
155+ << mat1Scale.ndim ();
156+ TVM_FFI_ICHECK_EQ (mat1Scale.size (0 ), m)
157+ << " mxfp8_bmm_impl: mat1Scale size mismatch, expected " << m << " , got "
158+ << mat1Scale.size (0 );
159+ TVM_FFI_ICHECK_EQ (mat1Scale.size (1 ), k_scales)
160+ << " mxfp8_bmm_impl: mat1Scale size mismatch, expected " << k_scales << " , got "
161+ << mat1Scale.size (1 );
162+ }
163+
164+ if (mat2Scale.ndim () == 1 ) {
165+ int64_t expected = swizzled_len (n, k_scales);
166+ TVM_FFI_ICHECK_EQ (mat2Scale.size (0 ), expected)
167+ << " mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << " , got "
168+ << mat2Scale.size (0 );
169+ } else {
170+ TVM_FFI_ICHECK_EQ (mat2Scale.ndim (), 2 )
171+ << " mxfp8_bmm_impl: mat2Scale must be 1D (swizzled) or 2D (non-swizzled), got "
172+ << mat2Scale.ndim ();
173+ TVM_FFI_ICHECK_EQ (mat2Scale.size (0 ), n)
174+ << " mxfp8_bmm_impl: mat2Scale size mismatch, expected " << n << " , got "
175+ << mat2Scale.size (0 );
176+ TVM_FFI_ICHECK_EQ (mat2Scale.size (1 ), k_scales)
177+ << " mxfp8_bmm_impl: mat2Scale size mismatch, expected " << k_scales << " , got "
178+ << mat2Scale.size (1 );
179+ }
180+ } else {
181+ const int64_t k_scales = scale_len (k);
182+ if (mat1Scale.ndim () == 1 ) {
183+ int64_t expected = swizzled_len (b * m, k_scales);
184+ TVM_FFI_ICHECK_EQ (mat1Scale.size (0 ), expected)
185+ << " mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << " , got "
186+ << mat1Scale.size (0 );
187+ } else {
188+ TVM_FFI_ICHECK_EQ (mat1Scale.ndim (), 2 )
189+ << " mxfp8_bmm_impl: mat1Scale must be 1D (swizzled) or 2D (non-swizzled), got "
190+ << mat1Scale.ndim ();
191+ TVM_FFI_ICHECK_EQ (mat1Scale.size (0 ), b)
192+ << " mxfp8_bmm_impl: mat1Scale batch size mismatch, expected " << b << " , got "
193+ << mat1Scale.size (0 );
194+ TVM_FFI_ICHECK_EQ (mat1Scale.size (1 ), scale_len (m))
195+ << " mxfp8_bmm_impl: mat1Scale size mismatch, expected " << scale_len (m) << " , got "
196+ << mat1Scale.size (1 );
197+ }
198+
199+ if (mat2Scale.ndim () == 1 ) {
200+ int64_t expected = swizzled_len (b * n, k_scales);
201+ TVM_FFI_ICHECK_EQ (mat2Scale.size (0 ), expected)
202+ << " mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << " , got "
203+ << mat2Scale.size (0 );
204+ } else {
205+ TVM_FFI_ICHECK_EQ (mat2Scale.ndim (), 2 )
206+ << " mxfp8_bmm_impl: mat2Scale must be 1D (swizzled) or 2D (non-swizzled), got "
207+ << mat2Scale.ndim ();
208+ TVM_FFI_ICHECK_EQ (mat2Scale.size (0 ), b)
209+ << " mxfp8_bmm_impl: mat2Scale batch size mismatch, expected " << b << " , got "
210+ << mat2Scale.size (0 );
211+ TVM_FFI_ICHECK_EQ (mat2Scale.size (1 ), scale_len (n))
212+ << " mxfp8_bmm_impl: mat2Scale size mismatch, expected " << scale_len (n) << " , got "
213+ << mat2Scale.size (1 );
214+ }
215+ }
216+
133217 // No heuristic for now, we rely on the autotuner to select the best tactic.
134218 if (tactic == -1 ) {
135219 tactic = 0 ;
0 commit comments