Skip to content

Commit 4af1315

Browse files
committed
refactor CheckInputs
1 parent dec8867 commit 4af1315

File tree

3 files changed

+166
-50
lines changed

3 files changed

+166
-50
lines changed

onnxruntime/contrib_ops/cpu/moe/moe_helper.h

Lines changed: 106 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,68 @@ struct MoEParameters {
3535
};
3636
namespace moe_helper {
3737

38+
// Helper to check shape dimensions
39+
#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \
40+
if (shape_ptr != nullptr) { \
41+
if (shape_ptr->NumDimensions() != dim) { \
42+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \
43+
"' is expected to have ", dim, " dimensions, got ", \
44+
shape_ptr->NumDimensions()); \
45+
} \
46+
}
47+
48+
#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name)
49+
50+
#define CHECK_SHAPE(shape_ptr, name, ...) \
51+
if (shape_ptr != nullptr) { \
52+
const TensorShape& expected_shape = make_shape(__VA_ARGS__); \
53+
if (*shape_ptr != expected_shape) { \
54+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \
55+
"' is expected to have shape ", expected_shape, \
56+
", got ", *shape_ptr); \
57+
} \
58+
}
59+
3860
template <typename Tensor>
3961
Status CheckInputs(MoEParameters& parameters,
40-
const Tensor* input, // required
41-
const Tensor* router_probs, // required
42-
const Tensor* fc1_experts_weights, // required
43-
const Tensor* fc1_experts_bias, // optional
44-
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
45-
const Tensor* fc1_zero_points, // optional, for qMoE
46-
const Tensor* fc2_experts_weights, // required
47-
const Tensor* fc2_experts_bias, // optional
48-
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
49-
const Tensor* fc2_zero_points, // optional, for qMoE
50-
const Tensor* fc3_experts_weights, // optional
51-
const Tensor* fc3_experts_bias, // optional
52-
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
53-
const Tensor* fc3_zero_points, // optional, for qMoE
54-
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
62+
const Tensor* input, // required
63+
const Tensor* router_probs, // required
64+
const TensorShape* fc1_experts_weights_shape,
65+
const Tensor* fc1_experts_bias, // optional
66+
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
67+
const Tensor* fc1_zero_points, // optional, for qMoE
68+
const TensorShape* fc2_experts_weights_shape,
69+
const Tensor* fc2_experts_bias, // optional
70+
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
71+
const Tensor* fc2_zero_points, // optional, for qMoE
72+
const TensorShape* fc3_experts_weights_shape,
73+
const Tensor* fc3_experts_bias, // optional
74+
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
75+
const Tensor* fc3_zero_points, // optional, for qMoE
76+
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
5577
const bool is_fused_swiglu,
5678
const int64_t block_size = 0) { // block size for block-wise quantization
5779
// Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later.
80+
if (input == nullptr) {
81+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required.");
82+
}
5883
ASSERT_TENSOR_2D_OR_3D(input);
59-
if (fc1_experts_weights) ASSERT_TENSOR_3D(fc1_experts_weights);
60-
if (fc2_experts_weights) ASSERT_TENSOR_3D(fc2_experts_weights);
84+
85+
if (router_probs == nullptr) {
86+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required.");
87+
}
6188
ASSERT_TENSOR_2D(router_probs);
6289

90+
if (fc1_experts_weights_shape == nullptr) {
91+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required.");
92+
}
93+
ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights");
94+
95+
if (fc2_experts_weights_shape == nullptr) {
96+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required.");
97+
}
98+
ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights");
99+
63100
const auto& input_dims = input->Shape().GetDims();
64101
const auto& router_probs_dims = router_probs->Shape().GetDims();
65102

@@ -68,34 +105,35 @@ Status CheckInputs(MoEParameters& parameters,
68105
int64_t num_experts = router_probs_dims[1];
69106

70107
int64_t local_num_experts;
71-
if (fc1_experts_weights != nullptr) {
72-
local_num_experts = fc1_experts_weights->Shape().GetDims()[0];
108+
if (fc1_experts_weights_shape != nullptr) {
109+
local_num_experts = fc1_experts_weights_shape->GetDims()[0];
73110
} else if (fc1_experts_scales != nullptr) {
74111
local_num_experts = fc1_experts_scales->Shape().GetDims()[0];
75112
} else {
76-
// Fallback for non-quantized MoE without weights (should not happen in current code paths)
77-
// or if only bias is provided?
78-
local_num_experts = num_experts;
113+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
114+
"Invalid MoE configuration: both fc1_experts_weights and fc1_experts_scales are null. "
115+
"At least one must be provided.");
79116
}
80117

81118
int64_t inter_size;
82-
if (fc2_experts_weights != nullptr) {
83-
const auto& dims = fc2_experts_weights->Shape().GetDims();
119+
if (fc2_experts_weights_shape != nullptr) {
120+
const auto& dims = fc2_experts_weights_shape->GetDims();
84121
inter_size = (dims[1] * dims[2] * pack_size) / hidden_size;
85122
} else if (fc3_experts_scales != nullptr) {
86123
inter_size = fc3_experts_scales->Shape().GetDims()[1];
87124
} else if (fc1_experts_scales != nullptr) {
88125
int64_t fc1_inter_size = fc1_experts_scales->Shape().GetDims()[1];
89126
inter_size = is_fused_swiglu ? fc1_inter_size / 2 : fc1_inter_size;
90127
} else {
91-
// Should not happen for valid QMoE calls
92-
inter_size = 0;
128+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
129+
"Invalid MoE configuration: unable to infer inter_size because "
130+
"fc2_experts_weights, fc3_experts_scales, and fc1_experts_scales are all null.");
93131
}
94132

95133
bool legacy_shape = false;
96-
if (fc2_experts_weights != nullptr && fc1_experts_weights != nullptr) {
97-
const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims();
98-
const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims();
134+
if (fc2_experts_weights_shape != nullptr && fc1_experts_weights_shape != nullptr) {
135+
const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims();
136+
const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims();
99137
legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) ||
100138
(hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size);
101139
}
@@ -106,13 +144,13 @@ Status CheckInputs(MoEParameters& parameters,
106144

107145
if (legacy_shape) {
108146
// legacy shape does not match column major memory layout. This is for backward compatibility.
109-
if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size);
110-
if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size);
111-
if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size);
147+
CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size);
148+
CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size);
149+
CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size);
112150
} else {
113-
if (fc1_experts_weights) CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size);
114-
if (fc2_experts_weights) CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size);
115-
if (fc3_experts_weights) CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size);
151+
CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size);
152+
CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size);
153+
CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size);
116154
}
117155

118156
CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts);
@@ -194,9 +232,11 @@ Status CheckInputs(MoEParameters& parameters,
194232
}
195233
}
196234

197-
if (fc3_experts_weights == nullptr) {
235+
if (fc3_experts_weights_shape == nullptr) {
236+
// If fc3 weights are not provided, ensure no other fc3 parameters are provided
198237
ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr);
199238
} else {
239+
// If fc3 weights are provided, ensure scales logic is consistent
200240
ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales
201241
}
202242

@@ -226,6 +266,36 @@ Status CheckInputs(MoEParameters& parameters,
226266
return Status::OK();
227267
}
228268

269+
template <typename Tensor>
270+
Status CheckInputs(MoEParameters& parameters,
271+
const Tensor* input, // required
272+
const Tensor* router_probs, // required
273+
const Tensor* fc1_experts_weights, // required
274+
const Tensor* fc1_experts_bias, // optional
275+
const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE
276+
const Tensor* fc1_zero_points, // optional, for qMoE
277+
const Tensor* fc2_experts_weights, // required
278+
const Tensor* fc2_experts_bias, // optional
279+
const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE
280+
const Tensor* fc2_zero_points, // optional, for qMoE
281+
const Tensor* fc3_experts_weights, // optional
282+
const Tensor* fc3_experts_bias, // optional
283+
const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE
284+
const Tensor* fc3_zero_points, // optional, for qMoE
285+
const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8)
286+
const bool is_fused_swiglu,
287+
const int64_t block_size = 0) { // block size for block-wise quantization
288+
289+
const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr;
290+
const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr;
291+
const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr;
292+
293+
return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points,
294+
fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points,
295+
fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points,
296+
pack_size, is_fused_swiglu, block_size);
297+
}
298+
229299
} // namespace moe_helper
230300
} // namespace contrib
231301
} // namespace onnxruntime

onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,9 @@ Status QMoECPU<T>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all
471471

472472
// If scales are prepacked, they are constant initializers.
473473
if (input_idx == 3) {
474-
has_prepacked_fc1_scales_ = true;
475474
return Status::OK();
476475
}
477476
if (input_idx == 6) {
478-
has_prepacked_fc2_scales_ = true;
479477
return Status::OK();
480478
}
481479

@@ -511,11 +509,33 @@ Status QMoECPU<T>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all
511509
}
512510
}
513511

512+
if (input_idx == 2) {
513+
fc1_shape_ = shape;
514+
} else if (input_idx == 5) {
515+
fc2_shape_ = shape;
516+
} else if (input_idx == 8) {
517+
fc3_shape_ = shape;
518+
}
519+
514520
if (prepacked_weights) {
515521
prepacked_weights->buffers_.push_back(std::move(packed_buffer));
516522
prepacked_weights->buffer_sizes_.push_back(packed_size);
517523
is_packed = true;
518524

525+
// Pack Shape (Buffer 1)
526+
auto dims = shape.GetDims();
527+
size_t rank_bytes = sizeof(int64_t);
528+
size_t dims_bytes = dims.size() * sizeof(int64_t);
529+
size_t shape_size = rank_bytes + dims_bytes;
530+
531+
auto shape_buffer = IAllocator::MakeUniquePtr<void>(alloc, shape_size);
532+
int64_t* buffer_data = static_cast<int64_t*>(shape_buffer.get());
533+
*buffer_data = static_cast<int64_t>(dims.size());
534+
memcpy(buffer_data + 1, dims.data(), dims_bytes);
535+
536+
prepacked_weights->buffers_.push_back(std::move(shape_buffer));
537+
prepacked_weights->buffer_sizes_.push_back(shape_size);
538+
519539
// Try build MLAS Q4 cache if scales are available
520540
if (use_mlas_q4_gemm_) {
521541
const Tensor* scales_tensor = nullptr;
@@ -550,7 +570,7 @@ Status QMoECPU<T>::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr all
550570
alloc, cache_buffer)
551571
.IsOK()) {
552572
// Store the size so we can verify later? Container holds size.
553-
// We push it as a SECOND buffer.
573+
// We push it as a THIRD buffer (Buffer 2) now.
554574
size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast<size_t>(rows), static_cast<size_t>(cols)) * static_cast<size_t>(num_experts);
555575
prepacked_weights->buffers_.push_back(std::move(cache_buffer));
556576
prepacked_weights->buffer_sizes_.push_back(cache_size);
@@ -576,17 +596,38 @@ Status QMoECPU<T>::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepa
576596
if (input_idx == 2 && !prepacked_buffers.empty()) {
577597
packed_fc1_ = std::move(prepacked_buffers[0]);
578598
if (prepacked_buffers.size() > 1) {
579-
packed_fc1_mlas_cache_ = std::move(prepacked_buffers[1]);
599+
int64_t* buffer_data = static_cast<int64_t*>(prepacked_buffers[1].get());
600+
int64_t rank = buffer_data[0];
601+
std::vector<int64_t> dims(rank);
602+
memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t));
603+
fc1_shape_ = TensorShape(dims);
604+
}
605+
if (prepacked_buffers.size() > 2) {
606+
packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]);
580607
}
581608
used_shared_buffers = true;
582609
} else if (input_idx == 5 && !prepacked_buffers.empty()) {
583610
packed_fc2_ = std::move(prepacked_buffers[0]);
584611
if (prepacked_buffers.size() > 1) {
585-
packed_fc2_mlas_cache_ = std::move(prepacked_buffers[1]);
612+
int64_t* buffer_data = static_cast<int64_t*>(prepacked_buffers[1].get());
613+
int64_t rank = buffer_data[0];
614+
std::vector<int64_t> dims(rank);
615+
memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t));
616+
fc2_shape_ = TensorShape(dims);
617+
}
618+
if (prepacked_buffers.size() > 2) {
619+
packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]);
586620
}
587621
used_shared_buffers = true;
588622
} else if (input_idx == 8 && !prepacked_buffers.empty()) {
589623
packed_fc3_ = std::move(prepacked_buffers[0]);
624+
if (prepacked_buffers.size() > 1) {
625+
int64_t* buffer_data = static_cast<int64_t*>(prepacked_buffers[1].get());
626+
int64_t rank = buffer_data[0];
627+
std::vector<int64_t> dims(rank);
628+
memcpy(dims.data(), buffer_data + 1, rank * sizeof(int64_t));
629+
fc3_shape_ = TensorShape(dims);
630+
}
590631
used_shared_buffers = true;
591632
}
592633

@@ -635,17 +676,21 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
635676
const auto* fc2_zero_points = context->Input<Tensor>(12);
636677
const auto* fc3_zero_points = context->Input<Tensor>(13);
637678

679+
const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr);
680+
const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr);
681+
const TensorShape* fc3_shape_ptr = packed_fc3_ ? &fc3_shape_ : (fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr);
682+
638683
MoEParameters moe_params;
639684
ORT_RETURN_IF_ERROR(moe_helper::CheckInputs<Tensor>(
640685
moe_params, input, router_probs,
641-
fc1_experts_weights, fc1_experts_bias, fc1_scales, fc1_zero_points,
642-
fc2_experts_weights, fc2_experts_bias, fc2_scales, fc2_zero_points,
643-
fc3_experts_weights, fc3_experts_bias, fc3_scales, fc3_zero_points,
686+
fc1_shape_ptr, fc1_experts_bias, fc1_scales, fc1_zero_points,
687+
fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points,
688+
fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points,
644689
expert_weight_bits_ == 4 ? 2 : 1,
645690
true,
646691
block_size_));
647692

648-
if (fc3_experts_weights || fc3_experts_bias || fc3_scales || fc3_zero_points) {
693+
if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) {
649694
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE");
650695
}
651696

@@ -808,8 +853,8 @@ Status QMoECPU<T>::Compute(OpKernelContext* context) const {
808853
const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1);
809854
const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1);
810855

811-
const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->Data<uint8_t>();
812-
const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->Data<uint8_t>();
856+
const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->template Data<uint8_t>();
857+
const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->template Data<uint8_t>();
813858
const T* fc1_scales_data = fc1_scales->Data<T>();
814859
const T* fc2_scales_data = fc2_scales->Data<T>();
815860
const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data<T>() : nullptr;

onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include "core/framework/op_kernel.h"
88
#include "core/mlas/inc/mlas_q4.h"
99
#include "contrib_ops/cpu/moe/moe_base_cpu.h"
10-
#include <mutex>
1110
#include <vector>
1211

1312
namespace onnxruntime {
@@ -42,13 +41,15 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU {
4241
int64_t block_size_;
4342
bool use_mlas_q4_gemm_{false};
4443
bool use_mlas_q4_gemm_overridden_{false};
45-
bool has_prepacked_fc1_scales_{false};
46-
bool has_prepacked_fc2_scales_{false};
4744

4845
IAllocatorUniquePtr<void> packed_fc1_;
4946
IAllocatorUniquePtr<void> packed_fc2_;
5047
IAllocatorUniquePtr<void> packed_fc3_;
5148

49+
TensorShape fc1_shape_;
50+
TensorShape fc2_shape_;
51+
TensorShape fc3_shape_;
52+
5253
IAllocatorUniquePtr<void> packed_fc1_mlas_cache_;
5354
IAllocatorUniquePtr<void> packed_fc2_mlas_cache_;
5455
};

0 commit comments

Comments
 (0)