Skip to content

Commit b4da310

Browse files
committed
[luci] Use static dimensions from new_shape attribute
This PR adds use static dimension from new_shape attribute even if only some of them are static and the rest dynamic. ONE-DCO-1.0-Signed-off-by: Mateusz Bencer m.bencer@partner.samsung.com
1 parent 0dfc2a5 commit b4da310

2 files changed

Lines changed: 46 additions & 3 deletions

File tree

compiler/luci/service/src/Nodes/CircleReshape.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,7 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
120120
// for non-existing `shape`, we can use `newShape` if it's valid
121121
auto new_shape = node->newShape();
122122
auto rank = new_shape->rank();
123-
auto shape_dummy = dynamic_cast<luci::CircleOutputDummy *>(node->shape());
124-
if (shape_dummy && rank > 0)
123+
if (rank > 0)
125124
{
126125
is_static_shape = true;
127126
shape_by_input.rank(rank);
@@ -154,7 +153,14 @@ loco::TensorShape Algorithm::visit(const luci::CircleReshape *node)
154153

155154
for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
156155
{
157-
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
156+
if (node->newShape()->dim(axis) > 0)
157+
{
158+
shape_by_attr.dim(axis) = node->newShape()->dim(axis);
159+
}
160+
else
161+
{
162+
shape_by_attr.dim(axis).unset(); // unset means unknown dimension
163+
}
158164
}
159165
}
160166

compiler/luci/service/src/Nodes/CircleReshape.test.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,3 +197,40 @@ TEST(ShapeRuleTest, reshape_by_newShape)
197197
ASSERT_EQ(2, output_shape.dim(0).value());
198198
ASSERT_EQ(12, output_shape.dim(1).value());
199199
}
200+
201+
TEST(ShapeRuleTest, reshape_by_newShape_dynamic)
202+
{
203+
auto g = loco::make_graph();
204+
auto node_reshape = g->nodes()->create<luci::CircleReshape>();
205+
auto tensor_input = g->nodes()->create<luci::CircleInput>();
206+
auto target_shape = g->nodes()->create<luci::CircleInput>();
207+
;
208+
209+
tensor_input->dtype(loco::DataType::S32);
210+
tensor_input->shape({2, 3, 4});
211+
tensor_input->shape_status(luci::ShapeStatus::VALID);
212+
213+
target_shape->dtype(loco::DataType::S32);
214+
target_shape->rank(1);
215+
target_shape->shape_status(luci::ShapeStatus::VALID);
216+
217+
node_reshape->tensor(tensor_input);
218+
node_reshape->shape(target_shape);
219+
220+
// reshape to {dynamic, 4, dynamic}
221+
node_reshape->newShape()->rank(3);
222+
node_reshape->newShape()->dim(0) = -1;
223+
node_reshape->newShape()->dim(1) = 4;
224+
node_reshape->newShape()->dim(2) = -1;
225+
226+
loco::TensorShape output_shape;
227+
luci::sinf::Rule shape_inf_rule;
228+
229+
ASSERT_TRUE(shape_inf_rule.infer(node_reshape, output_shape));
230+
231+
ASSERT_EQ(3, output_shape.rank());
232+
ASSERT_FALSE(output_shape.dim(0).known());
233+
ASSERT_TRUE(output_shape.dim(1).known());
234+
ASSERT_EQ(4, output_shape.dim(1).value());
235+
ASSERT_FALSE(output_shape.dim(2).known());
236+
}

0 commit comments

Comments
 (0)