Skip to content

Commit 89e9581

Browse files
authored
[circle-mlir/pass] Revise ConvTransposeOp pass (#16457)
This updates the ONNX `ConvTranspose` operation pass to support dynamic-shape inputs. ONE-DCO-1.0-Signed-off-by: Seungho Henry Park <shs.park@samsung.com>
1 parent 540c16c commit 89e9581

File tree

1 file changed

+28
-17
lines changed

1 file changed

+28
-17
lines changed

circle-mlir/circle-mlir/lib/pass/src/ops/ConvTransposeOp.h

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,27 +128,38 @@ class ConvConvTranspose : public mlir::OpConversionPattern<mlir::ONNXConvTranspo
128128
// create output_shape constant
129129
mlir::Value output_shape;
130130
mlir::SmallVector<int32_t, 4> os_i32;
131+
mlir::SmallVector<int64_t, 4> os_i64;
131132
{
132-
int32_t hin = static_cast<int32_t>(inshape[2]);
133-
int32_t win = static_cast<int32_t>(inshape[3]);
134-
int32_t hfs = static_cast<int32_t>(filtershape[2]);
135-
int32_t wfs = static_cast<int32_t>(filtershape[3]);
136-
int32_t hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1;
137-
int32_t wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1;
138-
int32_t nin = static_cast<int32_t>(inshape[0]);
139-
int32_t ofs = static_cast<int32_t>(filtershape[1]);
140-
os_i32.push_back(nin);
141-
os_i32.push_back(hout);
142-
os_i32.push_back(wout);
143-
os_i32.push_back(ofs); // from IOHW
133+
int64_t dyn = mlir::ShapedType::kDynamic;
134+
int64_t hin = inshape[2];
135+
int64_t win = inshape[3];
136+
int64_t hfs = filtershape[2];
137+
int64_t wfs = filtershape[3];
138+
int64_t hout = dyn;
139+
int64_t wout = dyn;
140+
int64_t nin = dyn;
141+
int64_t ofs = filtershape[1];
142+
143+
if (!mlir::ShapedType::isDynamic(inshape[0]))
144+
nin = inshape[0];
145+
if (!mlir::ShapedType::isDynamic(inshape[2]))
146+
hout = (hin - 1) * stride_h + dilation_h * (hfs - 1) + output_padding_h + 1;
147+
if (!mlir::ShapedType::isDynamic(inshape[3]))
148+
wout = (win - 1) * stride_w + dilation_w * (wfs - 1) + output_padding_w + 1;
149+
150+
os_i64 = {nin, hout, wout, ofs};
151+
os_i32.push_back(static_cast<int32_t>(nin));
152+
os_i32.push_back(static_cast<int32_t>(hout));
153+
os_i32.push_back(static_cast<int32_t>(wout));
154+
os_i32.push_back(static_cast<int32_t>(ofs)); // from IOHW
144155

145156
mlir::Location shape_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/shape"));
146157
mlir::Type i32 = rewriter.getI32Type();
147158
mlir::RankedTensorType ostype = RankedTensorType::get({4}, i32);
148159
output_shape = rewriter.create<ConstOp>(shape_loc, DenseIntElementsAttr::get(ostype, os_i32));
149160
}
150161

151-
mlir::SmallVector<int64_t> trconv2d_shape({os_i32[0], os_i32[1], os_i32[2], os_i32[3]});
162+
mlir::SmallVector<int64_t> trconv2d_shape({os_i64[0], os_i64[1], os_i64[2], os_i64[3]});
152163
auto trconv_output_type = mlir::RankedTensorType::get(trconv2d_shape, outtype.getElementType());
153164
mlir::Value trconv2d = rewriter.create<TransposeConvOp>(
154165
opLoc, trconv_output_type, output_shape, filter_tran, pre_tran, bias,
@@ -173,10 +184,10 @@ class ConvConvTranspose : public mlir::OpConversionPattern<mlir::ONNXConvTranspo
173184

174185
mlir::Location ss_loc = mlir::NameLoc::get(rewriter.getStringAttr(op_name + "/slice/size"));
175186
mlir::SmallVector<int32_t, 4> size_i32;
176-
size_i32.push_back(os_i32[0]);
177-
size_i32.push_back(os_i32[1] - 2 * padsValue[0]);
178-
size_i32.push_back(os_i32[2] - 2 * padsValue[1]);
179-
size_i32.push_back(os_i32[3]);
187+
size_i32.push_back(static_cast<int32_t>(os_i64[0]));
188+
size_i32.push_back(static_cast<int32_t>(os_i64[1]) - 2 * padsValue[0]);
189+
size_i32.push_back(static_cast<int32_t>(os_i64[2]) - 2 * padsValue[1]);
190+
size_i32.push_back(static_cast<int32_t>(os_i64[3]));
180191
auto sizeConst =
181192
rewriter.create<ConstOp>(ss_loc, DenseIntElementsAttr::get(bstype, size_i32));
182193

0 commit comments

Comments
 (0)