@@ -319,6 +319,55 @@ ark::unittest::State test_sched_opgraph() {
319
319
return ark::unittest::SUCCESS;
320
320
}
321
321
322
+ ark::unittest::State test_sched_opgraph_dependent_inputs () {
323
+ ark::Model m;
324
+
325
+ ark::Tensor *ones = m.tensor ({256 , 256 }, ark::FP16);
326
+ ark::Tensor *x0 = m.scale (m.scale (ones, 2 ), 2 );
327
+ ark::Tensor *x1 = m.scale (m.scale (x0, 2 ), 2 );
328
+
329
+ ark::Tensor *x2 = m.mul (ones, x1);
330
+ ark::Tensor *x3 = m.mul (ones, x1);
331
+ ark::Tensor *x4 = m.mul (x2, x3);
332
+ ark::Tensor *y = m.add (x0, x4);
333
+
334
+ ark::OpGraph graph (m);
335
+ UNITTEST_EQ (graph.get_nodes ().size (), 4 );
336
+ auto nodes_iter = graph.get_nodes ().begin ();
337
+ auto node = (nodes_iter++)->get ();
338
+ UNITTEST_EQ (node->ops .size (), 4 );
339
+ UNITTEST_EQ (node->ops [1 ]->outputs [0 ], x0);
340
+ UNITTEST_EQ (node->ops [3 ]->outputs [0 ], x1);
341
+ UNITTEST_EQ (node->users .size (), 3 );
342
+ UNITTEST_EQ (node->producers .size (), 0 );
343
+ node = (nodes_iter++)->get ();
344
+ UNITTEST_EQ (node->ops .size (), 1 );
345
+ UNITTEST_EQ (node->ops [0 ]->outputs [0 ], x2);
346
+ UNITTEST_EQ (node->ops [0 ]->inputs [0 ], ones);
347
+ UNITTEST_EQ (node->ops [0 ]->inputs [1 ], x1);
348
+ UNITTEST_EQ (node->users .size (), 1 );
349
+ UNITTEST_EQ (node->producers .size (), 1 );
350
+ node = (nodes_iter++)->get ();
351
+ UNITTEST_EQ (node->ops .size (), 1 );
352
+ UNITTEST_EQ (node->ops [0 ]->outputs [0 ], x3);
353
+ UNITTEST_EQ (node->ops [0 ]->inputs [0 ], ones);
354
+ UNITTEST_EQ (node->ops [0 ]->inputs [1 ], x1);
355
+ UNITTEST_EQ (node->users .size (), 1 );
356
+ UNITTEST_EQ (node->producers .size (), 1 );
357
+ node = (nodes_iter++)->get ();
358
+ UNITTEST_EQ (node->ops .size (), 2 );
359
+ UNITTEST_EQ (node->ops [0 ]->outputs [0 ], x4);
360
+ UNITTEST_EQ (node->ops [0 ]->inputs [0 ], x2);
361
+ UNITTEST_EQ (node->ops [0 ]->inputs [1 ], x3);
362
+ UNITTEST_EQ (node->ops [1 ]->outputs [0 ], y);
363
+ UNITTEST_EQ (node->ops [1 ]->inputs [0 ], x0);
364
+ UNITTEST_EQ (node->ops [1 ]->inputs [1 ], x4);
365
+ UNITTEST_EQ (node->users .size (), 0 );
366
+ UNITTEST_EQ (node->producers .size (), 3 );
367
+
368
+ return ark::unittest::SUCCESS;
369
+ }
370
+
322
371
ark::unittest::State test_sched_opgraph_noop () {
323
372
ark::Model model;
324
373
model.tensor ({1 }, ark::FP32);
@@ -564,6 +613,7 @@ ark::unittest::State test_sched_opgraph_all_reduce() {
564
613
int main () {
565
614
ark::init ();
566
615
UNITTEST (test_sched_opgraph);
616
+ UNITTEST (test_sched_opgraph_dependent_inputs);
567
617
UNITTEST (test_sched_opgraph_noop);
568
618
UNITTEST (test_sched_opgraph_identity);
569
619
UNITTEST (test_sched_opgraph_sharding);
0 commit comments