@@ -128,27 +128,38 @@ class ConvConvTranspose : public mlir::OpConversionPattern<mlir::ONNXConvTranspo
128128 // create output_shape constant
129129 mlir::Value output_shape;
130130 mlir::SmallVector<int32_t , 4 > os_i32;
131+ mlir::SmallVector<int64_t , 4 > os_i64;
131132 {
132- int32_t hin = static_cast <int32_t >(inshape[2 ]);
133- int32_t win = static_cast <int32_t >(inshape[3 ]);
134- int32_t hfs = static_cast <int32_t >(filtershape[2 ]);
135- int32_t wfs = static_cast <int32_t >(filtershape[3 ]);
136- int32_t hout = (hin - 1 ) * stride_h + dilation_h * (hfs - 1 ) + output_padding_h + 1 ;
137- int32_t wout = (win - 1 ) * stride_w + dilation_w * (wfs - 1 ) + output_padding_w + 1 ;
138- int32_t nin = static_cast <int32_t >(inshape[0 ]);
139- int32_t ofs = static_cast <int32_t >(filtershape[1 ]);
140- os_i32.push_back (nin);
141- os_i32.push_back (hout);
142- os_i32.push_back (wout);
143- os_i32.push_back (ofs); // from IOHW
133+ int64_t dyn = mlir::ShapedType::kDynamic ;
134+ int64_t hin = inshape[2 ];
135+ int64_t win = inshape[3 ];
136+ int64_t hfs = filtershape[2 ];
137+ int64_t wfs = filtershape[3 ];
138+ int64_t hout = dyn;
139+ int64_t wout = dyn;
140+ int64_t nin = dyn;
141+ int64_t ofs = filtershape[1 ];
142+
143+ if (!mlir::ShapedType::isDynamic (inshape[0 ]))
144+ nin = inshape[0 ];
145+ if (!mlir::ShapedType::isDynamic (inshape[2 ]))
146+ hout = (hin - 1 ) * stride_h + dilation_h * (hfs - 1 ) + output_padding_h + 1 ;
147+ if (!mlir::ShapedType::isDynamic (inshape[3 ]))
148+ wout = (win - 1 ) * stride_w + dilation_w * (wfs - 1 ) + output_padding_w + 1 ;
149+
150+ os_i64 = {nin, hout, wout, ofs};
151+ os_i32.push_back (static_cast <int32_t >(nin));
152+ os_i32.push_back (static_cast <int32_t >(hout));
153+ os_i32.push_back (static_cast <int32_t >(wout));
154+ os_i32.push_back (static_cast <int32_t >(ofs)); // from IOHW
144155
145156 mlir::Location shape_loc = mlir::NameLoc::get (rewriter.getStringAttr (op_name + " /shape" ));
146157 mlir::Type i32 = rewriter.getI32Type ();
147158 mlir::RankedTensorType ostype = RankedTensorType::get ({4 }, i32 );
148159 output_shape = rewriter.create <ConstOp>(shape_loc, DenseIntElementsAttr::get (ostype, os_i32));
149160 }
150161
151- mlir::SmallVector<int64_t > trconv2d_shape ({os_i32 [0 ], os_i32 [1 ], os_i32 [2 ], os_i32 [3 ]});
162+ mlir::SmallVector<int64_t > trconv2d_shape ({os_i64 [0 ], os_i64 [1 ], os_i64 [2 ], os_i64 [3 ]});
152163 auto trconv_output_type = mlir::RankedTensorType::get (trconv2d_shape, outtype.getElementType ());
153164 mlir::Value trconv2d = rewriter.create <TransposeConvOp>(
154165 opLoc, trconv_output_type, output_shape, filter_tran, pre_tran, bias,
@@ -173,10 +184,10 @@ class ConvConvTranspose : public mlir::OpConversionPattern<mlir::ONNXConvTranspo
173184
174185 mlir::Location ss_loc = mlir::NameLoc::get (rewriter.getStringAttr (op_name + " /slice/size" ));
175186 mlir::SmallVector<int32_t , 4 > size_i32;
176- size_i32.push_back (os_i32 [0 ]);
177- size_i32.push_back (os_i32 [1 ] - 2 * padsValue[0 ]);
178- size_i32.push_back (os_i32 [2 ] - 2 * padsValue[1 ]);
179- size_i32.push_back (os_i32 [3 ]);
187+ size_i32.push_back (static_cast < int32_t >(os_i64 [0 ]) );
188+ size_i32.push_back (static_cast < int32_t >(os_i64 [1 ]) - 2 * padsValue[0 ]);
189+ size_i32.push_back (static_cast < int32_t >(os_i64 [2 ]) - 2 * padsValue[1 ]);
190+ size_i32.push_back (static_cast < int32_t >(os_i64 [3 ]) );
180191 auto sizeConst =
181192 rewriter.create <ConstOp>(ss_loc, DenseIntElementsAttr::get (bstype, size_i32));
182193
0 commit comments