Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void DropQDQNodesRules(SelectorActionRegistry& qdq_selector_action_registry) {
std::unique_ptr<NodeSelector> selector_no_16bit_and_positive_scale =
std::make_unique<QDQ::DropQDQNodesSelector>(false, true, false, providers);
qdq_selector_action_registry.RegisterSelectorAndAction(drop_action_no_int16_and_positive_scale_name,
{{"MaxPool", {12}},
{{"MaxPool", {12, 22}},
{"ReduceMax", {}},
{"ReduceMin", {}}},
std::move(selector_no_16bit_and_positive_scale),
Expand Down
51 changes: 51 additions & 0 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,57 @@ TEST(QDQTransformerTests, ReshapeDropQDQ) {
RunReshapeDropQDQTestCase<uint16_t>({1, 3, 2, 2}, {1, 12}, false, 21); // Use int16 ONNX QDQ ops
}

// Runs a test case that checks if Q/DQ nodes are dropped from DQ -> MaxPool -> Q.
template <typename QuantType>
static void RunMaxPoolDropQDQTestCase(bool use_contrib_qdq = false,
int opset = 12) {
auto build_test_case = [use_contrib_qdq](ModelTestBuilder& builder) {
constexpr QuantType qmin = std::numeric_limits<QuantType>::min();
constexpr QuantType qmax = std::numeric_limits<QuantType>::max();

const std::vector<int64_t> input_shape = {1, 17, 17, 3};
auto* input_arg = builder.MakeInput<QuantType>(input_shape, qmin, qmax);
auto* output_arg = builder.MakeOutput();
QuantType zero_point = 1 + (qmax + qmin) / 2;

// add DequantizeLinear
auto* input_arg_dq = builder.MakeIntermediate();
builder.AddDequantizeLinearNode<QuantType>(input_arg, .003f, zero_point, input_arg_dq, use_contrib_qdq);

// add MaxPool
auto* maxpool_output = builder.MakeIntermediate();
Node& maxpool_node = builder.AddNode("MaxPool", {input_arg_dq}, {maxpool_output});
maxpool_node.AddAttribute("auto_pad", "VALID");
maxpool_node.AddAttribute("kernel_shape", std::vector<int64_t>({2, 2}));

// add QuantizeLinear
builder.AddQuantizeLinearNode<QuantType>(maxpool_output, .003f, zero_point, output_arg, use_contrib_qdq);
};

auto check_graph = [use_contrib_qdq](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
EXPECT_EQ(op_to_count["MaxPool"], 1);
EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0);
EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0);
};

TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, opset);
}

// Checks that Q/DQ nodes are dropped from DQ -> MaxPool -> Q. Uses 8-bit Q/DQ ops.
TEST(QDQTransformerTests, MaxPoolDropQDQ) {
// Opset 12
RunMaxPoolDropQDQTestCase<int8_t>();
RunMaxPoolDropQDQTestCase<int8_t>(true); // Use com.microsoft QDQ ops
RunMaxPoolDropQDQTestCase<uint8_t>();
RunMaxPoolDropQDQTestCase<uint8_t>(true); // Use com.microsoft QDQ ops

// Opset 22
RunMaxPoolDropQDQTestCase<int8_t>(false, 22);
RunMaxPoolDropQDQTestCase<uint8_t>(false, 22);
}

// Runs a test case that checks if Q/DQ nodes are *not* dropped from DQ -> MaxPool -> Q if the quantization scale is
// negative.
template <typename QuantType>
Expand Down