@@ -35,31 +35,68 @@ struct MoEParameters {
3535};
3636namespace moe_helper {
3737
38+ // Helper to check shape dimensions
39+ #define ASSERT_SHAPE_DIMENSION (shape_ptr, dim, name ) \
40+ if (shape_ptr != nullptr ) { \
41+ if (shape_ptr->NumDimensions () != dim) { \
42+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input '" , name, \
43+ " ' is expected to have " , dim, " dimensions, got " , \
44+ shape_ptr->NumDimensions ()); \
45+ } \
46+ }
47+
48+ #define ASSERT_SHAPE_3D (shape_ptr, name ) ASSERT_SHAPE_DIMENSION(shape_ptr, 3 , name)
49+
50+ #define CHECK_SHAPE (shape_ptr, name, ...) \
51+ if (shape_ptr != nullptr ) { \
52+ const TensorShape& expected_shape = make_shape (__VA_ARGS__); \
53+ if (*shape_ptr != expected_shape) { \
54+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input '" , name, \
55+ " ' is expected to have shape " , expected_shape, \
56+ " , got " , *shape_ptr); \
57+ } \
58+ }
59+
3860template <typename Tensor>
3961Status CheckInputs (MoEParameters& parameters,
40- const Tensor* input, // required
41- const Tensor* router_probs, // required
42- const Tensor* fc1_experts_weights, // required
43- const Tensor* fc1_experts_bias, // optional
44- const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
45- const Tensor* fc1_zero_points, // optional, for qMoE
46- const Tensor* fc2_experts_weights, // required
47- const Tensor* fc2_experts_bias, // optional
48- const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
49- const Tensor* fc2_zero_points, // optional, for qMoE
50- const Tensor* fc3_experts_weights, // optional
51- const Tensor* fc3_experts_bias, // optional
52- const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
53- const Tensor* fc3_zero_points, // optional, for qMoE
54- const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
62+ const Tensor* input, // required
63+ const Tensor* router_probs, // required
64+ const TensorShape* fc1_experts_weights_shape,
65+ const Tensor* fc1_experts_bias, // optional
66+ const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
67+ const Tensor* fc1_zero_points, // optional, for qMoE
68+ const TensorShape* fc2_experts_weights_shape,
69+ const Tensor* fc2_experts_bias, // optional
70+ const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
71+ const Tensor* fc2_zero_points, // optional, for qMoE
72+ const TensorShape* fc3_experts_weights_shape,
73+ const Tensor* fc3_experts_bias, // optional
74+ const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
75+ const Tensor* fc3_zero_points, // optional, for qMoE
76+ const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
5577 const bool is_fused_swiglu,
5678 const int64_t block_size = 0 ) { // block size for block-wise quantization
5779 // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
80+ if (input == nullptr ) {
81+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'input' is required." );
82+ }
5883 ASSERT_TENSOR_2D_OR_3D (input);
59- if (fc1_experts_weights) ASSERT_TENSOR_3D (fc1_experts_weights);
60- if (fc2_experts_weights) ASSERT_TENSOR_3D (fc2_experts_weights);
84+
85+ if (router_probs == nullptr ) {
86+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'router_probs' is required." );
87+ }
6188 ASSERT_TENSOR_2D (router_probs);
6289
90+ if (fc1_experts_weights_shape == nullptr ) {
91+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'fc1_experts_weights' is required." );
92+ }
93+ ASSERT_SHAPE_3D (fc1_experts_weights_shape, " fc1_experts_weights" );
94+
95+ if (fc2_experts_weights_shape == nullptr ) {
96+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Input 'fc2_experts_weights' is required." );
97+ }
98+ ASSERT_SHAPE_3D (fc2_experts_weights_shape, " fc2_experts_weights" );
99+
63100 const auto & input_dims = input->Shape ().GetDims ();
64101 const auto & router_probs_dims = router_probs->Shape ().GetDims ();
65102
@@ -68,34 +105,35 @@ Status CheckInputs(MoEParameters& parameters,
68105 int64_t num_experts = router_probs_dims[1 ];
69106
70107 int64_t local_num_experts;
71- if (fc1_experts_weights != nullptr ) {
72- local_num_experts = fc1_experts_weights-> Shape (). GetDims ()[0 ];
108+ if (fc1_experts_weights_shape != nullptr ) {
109+ local_num_experts = fc1_experts_weights_shape-> GetDims ()[0 ];
73110 } else if (fc1_experts_scales != nullptr ) {
74111 local_num_experts = fc1_experts_scales->Shape ().GetDims ()[0 ];
75112 } else {
76- // Fallback for non-quantized MoE without weights (should not happen in current code paths)
77- // or if only bias is provided?
78- local_num_experts = num_experts ;
113+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
114+ " Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. "
115+ " At least one must be provided. " ) ;
79116 }
80117
81118 int64_t inter_size;
82- if (fc2_experts_weights != nullptr ) {
83- const auto & dims = fc2_experts_weights-> Shape (). GetDims ();
119+ if (fc2_experts_weights_shape != nullptr ) {
120+ const auto & dims = fc2_experts_weights_shape-> GetDims ();
84121 inter_size = (dims[1 ] * dims[2 ] * pack_size) / hidden_size;
85122 } else if (fc3_experts_scales != nullptr ) {
86123 inter_size = fc3_experts_scales->Shape ().GetDims ()[1 ];
87124 } else if (fc1_experts_scales != nullptr ) {
88125 int64_t fc1_inter_size = fc1_experts_scales->Shape ().GetDims ()[1 ];
89126 inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size;
90127 } else {
91- // Should not happen for valid QMoE calls
92- inter_size = 0 ;
128+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
129+ " Invalid MoE configuration: unable to infer inter_size because "
130+ " fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null." );
93131 }
94132
95133 bool legacy_shape = false ;
96- if (fc2_experts_weights != nullptr && fc1_experts_weights != nullptr ) {
97- const auto & fc2_experts_weights_dims = fc2_experts_weights-> Shape (). GetDims ();
98- const auto & fc1_experts_weights_dims = fc1_experts_weights-> Shape (). GetDims ();
134+ if (fc2_experts_weights_shape != nullptr && fc1_experts_weights_shape != nullptr ) {
135+ const auto & fc2_experts_weights_dims = fc2_experts_weights_shape-> GetDims ();
136+ const auto & fc1_experts_weights_dims = fc1_experts_weights_shape-> GetDims ();
99137 legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1 ] == inter_size) ||
100138 (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1 ] == hidden_size);
101139 }
@@ -106,13 +144,13 @@ Status CheckInputs(MoEParameters& parameters,
106144
107145 if (legacy_shape) {
108146 // legacy shape does not match column major memory layout. This is for backward compatibility.
109- if (fc1_experts_weights) CHECK_TENSOR_SHAPE ( fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
110- if (fc2_experts_weights) CHECK_TENSOR_SHAPE ( fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
111- if (fc3_experts_weights) CHECK_TENSOR_SHAPE ( fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
147+ CHECK_SHAPE (fc1_experts_weights_shape, " fc1_experts_weights" , num_experts, hidden_size, fc1_inter_size / pack_size);
148+ CHECK_SHAPE (fc2_experts_weights_shape, " fc2_experts_weights" , num_experts, inter_size, hidden_size / pack_size);
149+ CHECK_SHAPE (fc3_experts_weights_shape, " fc3_experts_weights" , num_experts, hidden_size, inter_size / pack_size);
112150 } else {
113- if (fc1_experts_weights) CHECK_TENSOR_SHAPE ( fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
114- if (fc2_experts_weights) CHECK_TENSOR_SHAPE ( fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
115- if (fc3_experts_weights) CHECK_TENSOR_SHAPE ( fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
151+ CHECK_SHAPE (fc1_experts_weights_shape, " fc1_experts_weights" , num_experts, fc1_inter_size, hidden_size / pack_size);
152+ CHECK_SHAPE (fc2_experts_weights_shape, " fc2_experts_weights" , num_experts, hidden_size, inter_size / pack_size);
153+ CHECK_SHAPE (fc3_experts_weights_shape, " fc3_experts_weights" , num_experts, inter_size, hidden_size / pack_size);
116154 }
117155
118156 CHECK_TENSOR_SHAPE (router_probs, num_rows, num_experts);
@@ -194,9 +232,11 @@ Status CheckInputs(MoEParameters& parameters,
194232 }
195233 }
196234
197- if (fc3_experts_weights == nullptr ) {
235+ if (fc3_experts_weights_shape == nullptr ) {
236+ // If fc3 weights are not provided, ensure no other fc3 parameters are provided
198237 ORT_ENFORCE (fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr );
199238 } else {
239+ // If fc3 weights are provided, ensure scales logic is consistent
200240 ORT_ENFORCE (fc1_experts_scales == nullptr || fc3_experts_scales != nullptr ); // MOE no scale, or qMOE need scales
201241 }
202242
@@ -226,6 +266,36 @@ Status CheckInputs(MoEParameters& parameters,
226266 return Status::OK ();
227267}
228268
269+ template <typename Tensor>
270+ Status CheckInputs (MoEParameters& parameters,
271+ const Tensor* input, // required
272+ const Tensor* router_probs, // required
273+ const Tensor* fc1_experts_weights, // required
274+ const Tensor* fc1_experts_bias, // optional
275+ const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
276+ const Tensor* fc1_zero_points, // optional, for qMoE
277+ const Tensor* fc2_experts_weights, // required
278+ const Tensor* fc2_experts_bias, // optional
279+ const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
280+ const Tensor* fc2_zero_points, // optional, for qMoE
281+ const Tensor* fc3_experts_weights, // optional
282+ const Tensor* fc3_experts_bias, // optional
283+ const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
284+ const Tensor* fc3_zero_points, // optional, for qMoE
285+ const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
286+ const bool is_fused_swiglu,
287+ const int64_t block_size = 0 ) { // block size for block-wise quantization
288+
289+ const TensorShape* fc1_shape = (fc1_experts_weights != nullptr ) ? &fc1_experts_weights->Shape () : nullptr ;
290+ const TensorShape* fc2_shape = (fc2_experts_weights != nullptr ) ? &fc2_experts_weights->Shape () : nullptr ;
291+ const TensorShape* fc3_shape = (fc3_experts_weights != nullptr ) ? &fc3_experts_weights->Shape () : nullptr ;
292+
293+ return CheckInputs (parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points,
294+ fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points,
295+ fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points,
296+ pack_size, is_fused_swiglu, block_size);
297+ }
298+
229299} // namespace moe_helper
230300} // namespace contrib
231301} // namespace onnxruntime
0 commit comments