@@ -197,3 +197,79 @@ TEST(ShapeRuleTest, reshape_by_newShape)
197197 ASSERT_EQ (2 , output_shape.dim (0 ).value ());
198198 ASSERT_EQ (12 , output_shape.dim (1 ).value ());
199199}
200+
201+ TEST (ShapeRuleTest, reshape_by_newShape_dynamic)
202+ {
203+ auto g = loco::make_graph ();
204+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
205+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
206+ auto shape_dummy = g->nodes ()->create <luci::CircleOutputDummy>();
207+
208+ tensor_input->dtype (loco::DataType::S32);
209+ tensor_input->shape ({2 , 3 , 4 });
210+ tensor_input->shape_status (luci::ShapeStatus::VALID);
211+
212+ shape_dummy->dtype (loco::DataType::S32);
213+ shape_dummy->shape ({});
214+ shape_dummy->shape_status (luci::ShapeStatus::VALID);
215+
216+ node_reshape->tensor (tensor_input);
217+ node_reshape->shape (shape_dummy);
218+
219+ // reshape to {-1, 12}
220+ node_reshape->newShape ()->rank (2 );
221+ node_reshape->newShape ()->dim (0 ) = -1 ;
222+ node_reshape->newShape ()->dim (1 ) = 12 ;
223+
224+ loco::TensorShape output_shape;
225+ luci::sinf::Rule shape_inf_rule;
226+
227+ ASSERT_TRUE (shape_inf_rule.infer (node_reshape, output_shape));
228+
229+ ASSERT_EQ (2 , output_shape.rank ());
230+ ASSERT_FALSE (output_shape.dim (0 ).known ());
231+ ASSERT_TRUE (output_shape.dim (1 ).known ());
232+ ASSERT_EQ (12 , output_shape.dim (1 ).value ());
233+ }
234+
235+ TEST (ShapeRuleTest, merge_shape_from_newShape_and_input_node)
236+ {
237+ auto g = loco::make_graph ();
238+ auto node_reshape = g->nodes ()->create <luci::CircleReshape>();
239+ auto tensor_input = g->nodes ()->create <luci::CircleInput>();
240+ auto shape_by_input = g->nodes ()->create <luci::CircleConst>();
241+
242+ node_reshape->tensor (tensor_input);
243+
244+ tensor_input->dtype (loco::DataType::S32);
245+ tensor_input->shape ({2 , 3 , 4 , 5 });
246+ tensor_input->shape_status (luci::ShapeStatus::VALID);
247+
248+ shape_by_input->dtype (loco::DataType::S32);
249+ shape_by_input->size <loco::DataType::S32>(3 );
250+ shape_by_input->at <loco::DataType::S32>(0 ) = 2 ;
251+ shape_by_input->at <loco::DataType::S32>(1 ) = -1 ;
252+ shape_by_input->at <loco::DataType::S32>(2 ) = -1 ;
253+ shape_by_input->shape_status (luci::ShapeStatus::VALID);
254+
255+ node_reshape->tensor (tensor_input);
256+ node_reshape->shape (shape_by_input);
257+
258+ node_reshape->newShape ()->rank (3 );
259+ node_reshape->newShape ()->dim (0 ) = -1 ; // unknow here but pass by shape_by_input
260+ node_reshape->newShape ()->dim (1 ) = 12 ;
261+ node_reshape->newShape ()->dim (2 ) = 5 ;
262+
263+ loco::TensorShape output_shape;
264+ luci::sinf::Rule shape_inf_rule;
265+
266+ ASSERT_TRUE (shape_inf_rule.infer (node_reshape, output_shape));
267+
268+ ASSERT_EQ (3 , output_shape.rank ());
269+ ASSERT_TRUE (output_shape.dim (0 ).known ());
270+ EXPECT_EQ (2 , output_shape.dim (0 ).value ());
271+ ASSERT_TRUE (output_shape.dim (1 ).known ());
272+ EXPECT_EQ (12 , output_shape.dim (1 ).value ());
273+ ASSERT_TRUE (output_shape.dim (2 ).known ());
274+ EXPECT_EQ (5 , output_shape.dim (2 ).value ());
275+ }
0 commit comments