Skip to content

Commit fd339d5

Browse files
authored
[onert] Support dynamic shapes in DepthwiseConv2D (#15338)
This commit extends `DepthwiseConv2D` operator to support dynamic input and output shapes. ONE-DCO-1.0-Signed-off-by: ragmani <ragmani0216@gmail.com>
1 parent 4db9d9e commit fd339d5

2 files changed

Lines changed: 26 additions & 0 deletions

File tree

runtime/onert/core/include/exec/DynamicShapeInferer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class DynamicShapeInferer : public ir::OperationVisitor
5252
void visit(const ir::operation::Comparison &op) override;
5353
void visit(const ir::operation::Concat &op) override;
5454
void visit(const ir::operation::Conv2D &op) override;
55+
void visit(const ir::operation::DepthwiseConv2D &op) override;
5556
void visit(const ir::operation::ElementwiseActivation &op) override;
5657
void visit(const ir::operation::ElementwiseBinary &op) override;
5758
void visit(const ir::operation::ElementwiseUnary &op) override;

runtime/onert/core/src/exec/DynamicShapeInferer.cc

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,31 @@ void DynamicShapeInferer::visit(const ir::operation::Conv2D &op)
337337
assert(output->buffer() != nullptr);
338338
}
339339

340+
void DynamicShapeInferer::visit(const ir::operation::DepthwiseConv2D &op)
341+
{
342+
// check if input is not dynamic
343+
auto input_ind = op.getInputs().at(ir::operation::DepthwiseConv2D::INPUT);
344+
auto input = _tensor_registry->getITensor(input_ind);
345+
346+
auto ker_ind = op.getInputs().at(ir::operation::DepthwiseConv2D::KERNEL);
347+
auto ker = _tensor_registry->getITensor(ker_ind);
348+
349+
if ((!input->is_dynamic()) && (!ker->is_dynamic()))
350+
return;
351+
352+
ir::Shape input_shape = input->getShape();
353+
ir::Shape ker_shape = ker->getShape();
354+
355+
auto output_ind = op.getOutputs().at(0);
356+
auto output = _tensor_registry->getITensor(output_ind);
357+
358+
ir::Shape output_shape =
359+
shape_inference::inferDepthwiseConv2DShape(input_shape, ker_shape, op.param());
360+
361+
output->applyShape(output_shape);
362+
assert(output->buffer() != nullptr);
363+
}
364+
340365
void DynamicShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
341366
{
342367
handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::INPUT));

0 commit comments

Comments
 (0)