Skip to content

Commit c3c48b2

Browse files
committed
[luci-interpreter] Merge target shape from input node and attribute
This commit improve mechanism of shape inference for Reshape operator. If some dimension from input node is unknown we are trying to find such information in attribute. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer <m.bencer@partner.samsung.com>
1 parent 0dfc2a5 commit c3c48b2

2 files changed

Lines changed: 104 additions & 2 deletions

File tree

compiler/luci/service/src/Nodes/CircleReshape.cpp

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ luci::CircleNode *CloneNodeLet<CN::OPQR>::visit(const luci::CircleReshape *node)
6565

6666
namespace sinf
6767
{
68+
namespace
69+
{
70+
loco::TensorShape merge_shapes(const loco::TensorShape &base_shape,
71+
const loco::TensorShape &merged_shape)
72+
{
73+
loco::TensorShape result_shape = base_shape;
74+
if (base_shape.rank() == merged_shape.rank())
75+
{
76+
for (int axis = 0; axis < base_shape.rank(); ++axis)
77+
{
78+
if (!base_shape.dim(axis).known() && merged_shape.dim(axis).known())
79+
{
80+
result_shape.dim(axis) = merged_shape.dim(axis);
81+
}
82+
}
83+
}
84+
return result_shape;
85+
}
86+
} // namespace
6887

6988
loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
7089
{
@@ -154,7 +173,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
154173

155174
for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
156175
{
157-
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
176+
if (node->newShape()->dim(axis) > 0)
177+
{
178+
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
179+
}
180+
else
181+
{
182+
shape_by_attr.dim(axis).unset(); // unset means unknown dimension
183+
}
158184
}
159185
}
160186

@@ -165,7 +191,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
165191
INFO(l) << " shape_by_attr : " << shape_by_attr << std::endl;
166192
}
167193

168-
loco::TensorShape output_shape = shape_by_input;
194+
loco::TensorShape output_shape = merge_shapes(shape_by_input, shape_by_attr);
169195

170196
// One of the dimensions can have special value -1, meaning its actual value should be inferred.
171197
const auto input = loco::must_cast<luci::CircleNode *>(node->tensor());

compiler/luci/service/src/Nodes/CircleReshape.test.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,79 @@ TEST(ShapeRuleTest, reshape_by_newShape)
197197
ASSERT_EQ(2, output_shape.dim(0).value());
198198
ASSERT_EQ(12, output_shape.dim(1).value());
199199
}
200+
201+
TEST(ShapeRuleTest, reshape_by_newShape_dynamic)
202+
{
203+
auto g = loco::make_graph();
204+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
205+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
206+
auto shape_dummy = g->nodes()->create<luci::CircleOutputDummy>();
207+
208+
tensor_input->dtype(loco::DataType::S32);
209+
tensor_input->shape({2, 3, 4});
210+
tensor_input->shape_status(luci::ShapeStatus::VALID);
211+
212+
shape_dummy->dtype(loco::DataType::S32);
213+
shape_dummy->shape({});
214+
shape_dummy->shape_status(luci::ShapeStatus::VALID);
215+
216+
node_reshape->tensor(tensor_input);
217+
node_reshape->shape(shape_dummy);
218+
219+
// reshape to {-1, 12}
220+
node_reshape->newShape()->rank(2);
221+
node_reshape->newShape()->dim(0) = -1;
222+
node_reshape->newShape()->dim(1) = 12;
223+
224+
loco::TensorShape output_shape;
225+
luci::sinf::Rule shape_inf_rule;
226+
227+
ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));
228+
229+
ASSERT_EQ(2, output_shape.rank());
230+
ASSERT_FALSE(output_shape.dim(0).known());
231+
ASSERT_TRUE(output_shape.dim(1).known());
232+
ASSERT_EQ(12, output_shape.dim(1).value());
233+
}
234+
235+
TEST(ShapeRuleTest, merge_shape_from_newShape_and_input_node)
236+
{
237+
auto g = loco::make_graph();
238+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
239+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
240+
auto shape_by_input = g->nodes()->create<luci::CircleConst>();
241+
242+
node_reshape->tensor(tensor_input);
243+
244+
tensor_input->dtype(loco::DataType::S32);
245+
tensor_input->shape({2, 3, 4, 5});
246+
tensor_input->shape_status(luci::ShapeStatus::VALID);
247+
248+
shape_by_input->dtype(loco::DataType::S32);
249+
shape_by_input->size<loco::DataType::S32>(3);
250+
shape_by_input->at<loco::DataType::S32>(0) = 2;
251+
shape_by_input->at<loco::DataType::S32>(1) = -1;
252+
shape_by_input->at<loco::DataType::S32>(2) = -1;
253+
shape_by_input->shape_status(luci::ShapeStatus::VALID);
254+
255+
node_reshape->tensor(tensor_input);
256+
node_reshape->shape(shape_by_input);
257+
258+
node_reshape->newShape()->rank(3);
259+
node_reshape->newShape()->dim(0) = -1; // unknow here but pass by shape_by_input
260+
node_reshape->newShape()->dim(1) = 12;
261+
node_reshape->newShape()->dim(2) = 5;
262+
263+
loco::TensorShape output_shape;
264+
luci::sinf::Rule shape_inf_rule;
265+
266+
ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));
267+
268+
ASSERT_EQ(3, output_shape.rank());
269+
ASSERT_TRUE(output_shape.dim(0).known());
270+
EXPECT_EQ(2, output_shape.dim(0).value());
271+
ASSERT_TRUE(output_shape.dim(1).known());
272+
EXPECT_EQ(12, output_shape.dim(1).value());
273+
ASSERT_TRUE(output_shape.dim(2).known());
274+
EXPECT_EQ(5, output_shape.dim(2).value());
275+
}

0 commit comments

Comments
 (0)