@@ -62,6 +62,84 @@ static void dumpCmd(const Command* cmd) {
6262    MNN_PRINT (" }\n "  );
6363}
6464
65+ void  mergeConvolutionAndPrelu (Node* root, MNNForwardType forwardType){
66+     if  (root->cmd ->op  != nullptr  && root->cmd ->op ->type () == OpType_Convolution && root->succ .size () == 1 ) {
67+         auto  child = root->succ [0 ];
68+         if (child->cmd ->op ->type () == OpType_PReLU){
69+             if (root->cmd ->op ->externalPath () != nullptr ){
70+                 return ;
71+             }
72+             std::shared_ptr<Command> cmdPlugin;
73+             auto  inputs = root->cmd ->inputs ;
74+             auto  outputs = root->cmd ->outputs ;
75+             auto  convOp = root->cmd ->op ->main_as_Convolution2D ();
76+             if (convOp->quanParameter () != nullptr  || convOp->symmetricQuan () != nullptr  || convOp->sparseParameter () != nullptr  || convOp->external () != nullptr  || convOp->common ()->outputCount () != child->cmd ->op ->main_as_PRelu ()->slopeCount ()){
77+                 return ;
78+             }
79+             std::unique_ptr<OpT> fuseOp (new  OpT);
80+             fuseOp->type  = OpType_Extra;
81+             fuseOp->name  = root->cmd ->op ->name ()->str ();
82+             ExtraT* extra_param = new  ExtraT;
83+             extra_param->type  = " ExtraConvolution2DPrelu"  ;
84+             extra_param->attr .resize (2 );
85+             //  copy convolution2D param
86+             AttributeT* convAtr = new  AttributeT;
87+             BlobT* convParamBlob = new  BlobT;
88+             {
89+                 std::unique_ptr<Convolution2DT> convolutionParam (convOp->UnPack ());
90+                 flatbuffers::FlatBufferBuilder builder;
91+                 auto  lastOffset = Convolution2D::Pack (builder, convolutionParam.get ());
92+                 builder.Finish (lastOffset);
93+                 
94+                 const  uint8_t * buffer_ptr = builder.GetBufferPointer ();
95+                 const  size_t  size = builder.GetSize ();
96+                 convParamBlob->uint8s .resize (size);
97+                 ::memcpy (convParamBlob->uint8s.data(), buffer_ptr, size);
98+             }
99+             convAtr->tensor .reset (convParamBlob);
100+             extra_param->attr [0 ].reset (convAtr);
101+             
102+             //  copy prelu param
103+             AttributeT* preluAtr = new  AttributeT;
104+             BlobT* preluParamBlob = new  BlobT;
105+             {
106+                 std::unique_ptr<PReluT> preluParam (child->cmd ->op ->main_as_PRelu ()->UnPack ());
107+                 flatbuffers::FlatBufferBuilder builder;
108+                 auto  lastOffset = PRelu::Pack (builder, preluParam.get ());
109+                 builder.Finish (lastOffset);
110+                 const  uint8_t * buffer_ptr = builder.GetBufferPointer ();
111+                 const  size_t  size = builder.GetSize ();
112+                 preluParamBlob->uint8s .resize (size);
113+                 ::memcpy (preluParamBlob->uint8s.data(), buffer_ptr, size);
114+             }
115+             preluAtr->tensor .reset (preluParamBlob);
116+             extra_param->attr [1 ].reset (preluAtr);
117+             
118+             fuseOp->main .type   = OpParameter_Extra;
119+             fuseOp->main .value  = extra_param;
120+             flatbuffers::FlatBufferBuilder builder;
121+             auto  lastOffset = Op::Pack (builder, fuseOp.get ());
122+             builder.Finish (lastOffset);
123+             cmdPlugin = GeometryComputerUtils::makeCommand (builder, inputs, outputs);
124+             
125+             root->cmd ->op  = cmdPlugin->op ;
126+             root->cmd ->inputs  = cmdPlugin->inputs ;
127+             root->cmd ->outputs  = cmdPlugin->outputs ;
128+             root->cmd ->buffer  = cmdPlugin->buffer ;
129+             child->cmd ->op  = nullptr ;
130+             child->cmd ->buffer .reset ();
131+             for (auto  &childNode : child->succ ){
132+                 for (auto  &input : childNode->cmd ->inputs ){
133+                     if (input == child->cmd ->outputs [0 ]){
134+                         input = root->cmd ->outputs [0 ];
135+                     }
136+                 }
137+             }
138+             root->succ  = child->succ ;
139+         }
140+     }
141+ }
142+ 
65143//  is legal fused type
66144bool  isLegal (Command* cmd, MNNForwardType forwardType) {
67145    auto  type = cmd->op ->type ();
@@ -369,6 +447,20 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, Back
369447            graph.push_back (std::move (node));
370448        }
371449    }
450+     
451+     if (type == MNN_FORWARD_OPENCL){
452+         for (int  i = 0 ; i < graph.size (); ++i){
453+             mergeConvolutionAndPrelu (graph[i].get (), type);
454+         }
455+         for (auto  iter = graph.begin (); iter != graph.end ();){
456+             if (iter->get ()->cmd ->op  == nullptr ){
457+                 iter = graph.erase (iter);
458+             }else {
459+                 ++iter;
460+             }
461+         }
462+     }
463+     
372464    std::queue<Node*> postDominateNodeQueue;
373465    //  build dominate tree
374466    for  (int  i = static_cast <int >(graph.size ()) - 1 ; i >= 0 ; i--) {
0 commit comments