Skip to content

Commit bb8b2f2

Browse files
authored
[onert] Ease restriction on shape validation (#14950)
- Fix too strong shape restriction - Add test case for rank > 4 ONE-DCO-1.0-Signed-off-by: Chunseok Lee <chunseok.lee@samsung.com>
1 parent 21effac commit bb8b2f2

2 files changed

Lines changed: 28 additions & 3 deletions

File tree

runtime/onert/core/src/compiler/ShapeValidator.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -563,9 +563,8 @@ void ShapeValidator::visit(const ir::operation::Gather &node)
563563
const auto &indices_shape = operands.at(indices_index).shape();
564564
const auto &ofm_shape = operands.at(ofm_index).shape();
565565

566-
OP_REQUIRES(ifm_shape.rank() <= 4);
567-
OP_REQUIRES(indices_shape.rank() <= 3);
568-
OP_REQUIRES(ofm_shape.rank() <= 4);
566+
// Since gather implementation is general enough, we do not restrict max rank
567+
OP_REQUIRES(ifm_shape.rank() + indices_shape.rank() - 1 == ofm_shape.rank());
569568
}
570569

571570
void ShapeValidator::visit(const ir::operation::DepthToSpace &node)

runtime/tests/nnfw_api/src/GenModelTests/one_op_tests/Gather.test.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,32 @@
1818

1919
#include "common.h"
2020

21+
TEST_F(GenModelTest, OneOp_Gather_rank5)
22+
{
23+
CircleGen cgen;
24+
25+
std::vector<int32_t> index_data{1};
26+
27+
auto index_buf = cgen.addBuffer(index_data);
28+
29+
int input = cgen.addTensor({{3, 1, 1, 2, 2}, circle::TensorType::TensorType_FLOAT32});
30+
int indice = cgen.addTensor({{1}, circle::TensorType::TensorType_INT32, index_buf});
31+
int output = cgen.addTensor({{1, 1, 1, 2, 2}, circle::TensorType::TensorType_FLOAT32});
32+
33+
cgen.addOperatorGather({{input, indice}, {output}}, 0 /*axis*/);
34+
cgen.setInputsAndOutputs({input}, {output});
35+
36+
_context = std::make_unique<GenModelTestContext>(cgen.finish());
37+
38+
TestCaseData tc;
39+
tc.addInput<float>({1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3});
40+
tc.addOutput<float>({2, 2, 2, 2});
41+
_context->addTestCase(tc);
42+
_context->setBackends({"cpu"});
43+
44+
SUCCEED();
45+
}
46+
2147
TEST_F(GenModelTest, OneOp_Gather_Q4_0)
2248
{
2349
CircleGen cgen;

0 commit comments

Comments
 (0)