@@ -1622,6 +1622,128 @@ TEST_F(GraphTransformationTests, FusePadWithMaxPoolOpsetLessThan11) {
1622
1622
}
1623
1623
}
1624
1624
1625
+ TEST_F (GraphTransformationTests, FusePadWithAvgPool) {
1626
+ constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER " fusion/fuse-pad-avgpool.onnx" ;
1627
+
1628
+ std::shared_ptr<Model> p_model;
1629
+ ASSERT_STATUS_OK (Model::Load (model_uri, p_model, nullptr , *logger_));
1630
+ Graph& graph = p_model->MainGraph ();
1631
+
1632
+ std::vector<int64_t > expected_pads;
1633
+ GraphViewer graphViewer (graph);
1634
+ for (auto & node_index : graphViewer.GetNodesInTopologicalOrder ()) {
1635
+ auto & node = *graph.GetNode (node_index);
1636
+ if (node.OpType () == " Pad" ) {
1637
+ auto const & pads_proto = node.GetAttributes ().at (" pads" ).ints ();
1638
+ gsl::span<const int64_t > pads_values = gsl::make_span (pads_proto.data (), pads_proto.size ());
1639
+ expected_pads.resize (pads_values.size () - 4 );
1640
+ for (uint32_t pads_index = 2 , index = 0 ; pads_index < pads_values.size () / 2 ; pads_index++, index ++) {
1641
+ expected_pads[index ] = pads_values[pads_index];
1642
+ expected_pads[index + (expected_pads.size () / 2 )] = pads_values[pads_index + (pads_values.size () / 2 )];
1643
+ }
1644
+ }
1645
+ }
1646
+
1647
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5 };
1648
+ auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>(" RuleTransformerL1" );
1649
+ ASSERT_STATUS_OK (rule_transformer_L1->Register (std::make_unique<PadFusion>()));
1650
+ ASSERT_STATUS_OK (graph_transformation_mgr.Register (std::move (rule_transformer_L1), TransformerLevel::Level1));
1651
+
1652
+ ASSERT_STATUS_OK (graph_transformation_mgr.ApplyTransformers (graph, TransformerLevel::Level1, *logger_));
1653
+
1654
+ std::map<std::string, int > op_to_count = CountOpsInGraph (graph);
1655
+ ASSERT_EQ (op_to_count[" Pad" ], 0 );
1656
+ ASSERT_EQ (op_to_count[" AveragePool" ], 1 );
1657
+
1658
+ for (auto & node : graph.Nodes ()) {
1659
+ if (node.OpType () == " AveragePool" ) {
1660
+ auto const & child_pads = node.GetAttributes ().at (" pads" ).ints ();
1661
+ auto const & count_include_pad = node.GetAttributes ().at (" count_include_pad" );
1662
+ ASSERT_NE (count_include_pad.i (), 0 ) << " fusion should ensure count_include_pad!=0" ;
1663
+ ASSERT_EQ (child_pads.size (), static_cast <int32_t >(expected_pads.size ()))
1664
+ << " fusion should produce the same size of pads integer as the AvgPool node" ;
1665
+ for (uint32_t index = 0 ; index < expected_pads.size (); index ++) {
1666
+ ASSERT_EQ (expected_pads[index ], child_pads.Get (index ))
1667
+ << " fusion does not produce correct padding value" ;
1668
+ }
1669
+ }
1670
+ }
1671
+ }
1672
+
1673
+ TEST_F (GraphTransformationTests, FusePadWithAvgPoolWithPad) {
1674
+ constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER " fusion/fuse-pad-avgpool_with_pad.onnx" ;
1675
+
1676
+ std::shared_ptr<Model> p_model;
1677
+ ASSERT_STATUS_OK (Model::Load (model_uri, p_model, nullptr , *logger_));
1678
+ Graph& graph = p_model->MainGraph ();
1679
+
1680
+ std::vector<int64_t > expected_pads;
1681
+ GraphViewer graphViewer (graph);
1682
+ for (auto & node_index : graphViewer.GetNodesInTopologicalOrder ()) {
1683
+ auto & node = *graph.GetNode (node_index);
1684
+ if (node.OpType () == " Pad" ) {
1685
+ auto const & pads_proto = node.GetAttributes ().at (" pads" ).ints ();
1686
+ gsl::span<const int64_t > pads_values = gsl::make_span (pads_proto.data (), pads_proto.size ());
1687
+ expected_pads.resize (pads_values.size () - 4 );
1688
+
1689
+ for (uint32_t pads_index = 2 , index = 0 ; pads_index < pads_values.size () / 2 ; pads_index++, index ++) {
1690
+ expected_pads[index ] = pads_values[pads_index];
1691
+ expected_pads[index + (expected_pads.size () / 2 )] = pads_values[pads_index + (pads_values.size () / 2 )];
1692
+ }
1693
+ } else if (node.OpType () == " AveragePool" ) {
1694
+ auto const & child_pads = node.GetAttributes ().at (" pads" ).ints ();
1695
+ for (uint32_t index = 0 ; index < expected_pads.size (); index ++) {
1696
+ expected_pads[index ] += child_pads.Get (index );
1697
+ }
1698
+ }
1699
+ }
1700
+
1701
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5 };
1702
+ auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>(" RuleTransformerL1" );
1703
+ ASSERT_STATUS_OK (rule_transformer_L1->Register (std::make_unique<PadFusion>()));
1704
+ ASSERT_STATUS_OK (graph_transformation_mgr.Register (std::move (rule_transformer_L1), TransformerLevel::Level1));
1705
+
1706
+ ASSERT_STATUS_OK (graph_transformation_mgr.ApplyTransformers (graph, TransformerLevel::Level1, *logger_));
1707
+
1708
+ std::map<std::string, int > op_to_count = CountOpsInGraph (graph);
1709
+ ASSERT_EQ (op_to_count[" Pad" ], 0 );
1710
+ ASSERT_EQ (op_to_count[" AveragePool" ], 1 );
1711
+
1712
+ for (auto & node : graph.Nodes ()) {
1713
+ if (node.OpType () == " AveragePool" ) {
1714
+ auto const & child_pads = node.GetAttributes ().at (" pads" ).ints ();
1715
+ auto const & count_include_pad = node.GetAttributes ().at (" count_include_pad" );
1716
+ ASSERT_NE (count_include_pad.i (), 0 ) << " fusion should ensure count_include_pad!=0" ;
1717
+ ASSERT_EQ (child_pads.size (), static_cast <int32_t >(expected_pads.size ()))
1718
+ << " fusion should produce the same size of pads integer as the AvgPool node" ;
1719
+ for (uint32_t index = 0 ; index < expected_pads.size (); index ++) {
1720
+ ASSERT_EQ (expected_pads[index ], child_pads.Get (index ))
1721
+ << " fusion does not produce correct padding value" ;
1722
+ }
1723
+ }
1724
+ }
1725
+ }
1726
+
1727
+ // should not fuse
1728
+ TEST_F (GraphTransformationTests, FusePadWithAvgPoolWithPadNoInclude) {
1729
+ constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER " fusion/fuse-pad-avgpool_with_pad-nofuse.onnx" ;
1730
+
1731
+ std::shared_ptr<Model> p_model;
1732
+ ASSERT_STATUS_OK (Model::Load (model_uri, p_model, nullptr , *logger_));
1733
+ Graph& graph = p_model->MainGraph ();
1734
+
1735
+ onnxruntime::GraphTransformerManager graph_transformation_mgr{5 };
1736
+ auto rule_transformer_L1 = std::make_unique<RuleBasedGraphTransformer>(" RuleTransformerL1" );
1737
+ ASSERT_STATUS_OK (rule_transformer_L1->Register (std::make_unique<PadFusion>()));
1738
+ ASSERT_STATUS_OK (graph_transformation_mgr.Register (std::move (rule_transformer_L1), TransformerLevel::Level1));
1739
+
1740
+ ASSERT_STATUS_OK (graph_transformation_mgr.ApplyTransformers (graph, TransformerLevel::Level1, *logger_));
1741
+
1742
+ std::map<std::string, int > op_to_count = CountOpsInGraph (graph);
1743
+ ASSERT_EQ (op_to_count[" Pad" ], 1 );
1744
+ ASSERT_EQ (op_to_count[" AveragePool" ], 1 );
1745
+ }
1746
+
1625
1747
TEST_F (GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) {
1626
1748
constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER " fusion/fuse-matmul-bn-with-reshape.onnx" ;
1627
1749
0 commit comments