Skip to content

Commit 5b51bac

Browse files
committed
Add shape validation for mat1Scale and mat2Scale before kernel launch
1 parent 2dd82ea commit 5b51bac

1 file changed

Lines changed: 84 additions & 0 deletions

File tree

csrc/mxfp8_gemm_cutlass.cu

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ void mxfp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tens
100100
CHECK_INPUT_AND_TYPE(mat2Scale, SF_DTYPE);
101101

102102
int64_t m, n, k, b;
103+
// Scale validation for swizzled (1D) and non-swizzled (2D) layouts.
103104
if (mat1.ndim() == 2) {
104105
TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix";
105106
// mat2 is passed as b.T, but TensorView reads underlying storage as [N, K]
@@ -130,6 +131,89 @@ void mxfp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tens
130131
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices";
131132
}
132133

134+
constexpr int64_t sfVecSize = 32; // MXFP8 block size
135+
auto scale_len = [&](int64_t dim) { return (dim + sfVecSize - 1) / sfVecSize; };
136+
auto swizzled_len = [&](int64_t rows, int64_t cols) {
137+
auto pad_up = [](int64_t value, int64_t multiple) {
138+
return (value + multiple - 1) / multiple * multiple;
139+
};
140+
int64_t padded_rows = pad_up(rows, 128);
141+
int64_t padded_cols = pad_up(cols, 4);
142+
return padded_rows * padded_cols;
143+
};
144+
145+
if (mat1.ndim() == 2) {
146+
const int64_t k_scales = scale_len(k);
147+
if (mat1Scale.ndim() == 1) {
148+
int64_t expected = swizzled_len(m, k_scales);
149+
TVM_FFI_ICHECK_EQ(mat1Scale.size(0), expected)
150+
<< "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << ", got "
151+
<< mat1Scale.size(0);
152+
} else {
153+
TVM_FFI_ICHECK_EQ(mat1Scale.ndim(), 2)
154+
<< "mxfp8_bmm_impl: mat1Scale must be 1D (swizzled) or 2D (non-swizzled), got "
155+
<< mat1Scale.ndim();
156+
TVM_FFI_ICHECK_EQ(mat1Scale.size(0), m)
157+
<< "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << m << ", got "
158+
<< mat1Scale.size(0);
159+
TVM_FFI_ICHECK_EQ(mat1Scale.size(1), k_scales)
160+
<< "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << k_scales << ", got "
161+
<< mat1Scale.size(1);
162+
}
163+
164+
if (mat2Scale.ndim() == 1) {
165+
int64_t expected = swizzled_len(n, k_scales);
166+
TVM_FFI_ICHECK_EQ(mat2Scale.size(0), expected)
167+
<< "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << ", got "
168+
<< mat2Scale.size(0);
169+
} else {
170+
TVM_FFI_ICHECK_EQ(mat2Scale.ndim(), 2)
171+
<< "mxfp8_bmm_impl: mat2Scale must be 1D (swizzled) or 2D (non-swizzled), got "
172+
<< mat2Scale.ndim();
173+
TVM_FFI_ICHECK_EQ(mat2Scale.size(0), n)
174+
<< "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << n << ", got "
175+
<< mat2Scale.size(0);
176+
TVM_FFI_ICHECK_EQ(mat2Scale.size(1), k_scales)
177+
<< "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << k_scales << ", got "
178+
<< mat2Scale.size(1);
179+
}
180+
} else {
181+
const int64_t k_scales = scale_len(k);
182+
if (mat1Scale.ndim() == 1) {
183+
int64_t expected = swizzled_len(b * m, k_scales);
184+
TVM_FFI_ICHECK_EQ(mat1Scale.size(0), expected)
185+
<< "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << expected << ", got "
186+
<< mat1Scale.size(0);
187+
} else {
188+
TVM_FFI_ICHECK_EQ(mat1Scale.ndim(), 2)
189+
<< "mxfp8_bmm_impl: mat1Scale must be 1D (swizzled) or 2D (non-swizzled), got "
190+
<< mat1Scale.ndim();
191+
TVM_FFI_ICHECK_EQ(mat1Scale.size(0), b)
192+
<< "mxfp8_bmm_impl: mat1Scale batch size mismatch, expected " << b << ", got "
193+
<< mat1Scale.size(0);
194+
TVM_FFI_ICHECK_EQ(mat1Scale.size(1), scale_len(m))
195+
<< "mxfp8_bmm_impl: mat1Scale size mismatch, expected " << scale_len(m) << ", got "
196+
<< mat1Scale.size(1);
197+
}
198+
199+
if (mat2Scale.ndim() == 1) {
200+
int64_t expected = swizzled_len(b * n, k_scales);
201+
TVM_FFI_ICHECK_EQ(mat2Scale.size(0), expected)
202+
<< "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << expected << ", got "
203+
<< mat2Scale.size(0);
204+
} else {
205+
TVM_FFI_ICHECK_EQ(mat2Scale.ndim(), 2)
206+
<< "mxfp8_bmm_impl: mat2Scale must be 1D (swizzled) or 2D (non-swizzled), got "
207+
<< mat2Scale.ndim();
208+
TVM_FFI_ICHECK_EQ(mat2Scale.size(0), b)
209+
<< "mxfp8_bmm_impl: mat2Scale batch size mismatch, expected " << b << ", got "
210+
<< mat2Scale.size(0);
211+
TVM_FFI_ICHECK_EQ(mat2Scale.size(1), scale_len(n))
212+
<< "mxfp8_bmm_impl: mat2Scale size mismatch, expected " << scale_len(n) << ", got "
213+
<< mat2Scale.size(1);
214+
}
215+
}
216+
133217
// No heuristic for now, we rely on the autotuner to select the best tactic.
134218
if (tactic == -1) {
135219
tactic = 0;

0 commit comments

Comments
 (0)