Skip to content

Commit f6531f0

Browse files
authored
Fix an OpGraph bug (#201)
Previously, when we merge `A` and `B`, the new node's users were set to the users of `B` (i.e., `C`), which is wrong. The new node's users should be `C` and `D`. This PR is a fix. ``` A -> B -> C -> D | ^ | | +--------------+ ```
1 parent 22410dc commit f6531f0

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

ark/sched/sched_opgraph.cc

+2-5
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@
88
#include "logging.h"
99
#include "model.h"
1010

11-
using namespace std;
12-
1311
#define DEBUG_OPGRAPH 0
1412
#define OPGRAPH_DEBUG(...) \
1513
do { \
@@ -293,8 +291,7 @@ void OpGraph::recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
293291
continue;
294292
}
295293
}
296-
// The candidate has only one user. Merge the two nodes.
297-
294+
// We can merge the two nodes.
298295
// Merge `boundary_node` into `merge_candidate`.
299296
OPGRAPH_DEBUG(" merge: ", merge_candidate->get_name(), " -> ",
300297
boundary_node->get_name());
@@ -314,7 +311,7 @@ void OpGraph::recursive_merge(std::list<std::unique_ptr<OpNode>> &nodes,
314311
producer->users.insert(merge_candidate);
315312
merge_candidate->producers.insert(producer);
316313
}
317-
merge_candidate->users = boundary_node->users;
314+
merge_candidate->users.erase(boundary_node);
318315

319316
// Remove `boundary_node` from `nodes`.
320317
auto it =

ark/sched/sched_opgraph_test.cc

+50
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,55 @@ ark::unittest::State test_sched_opgraph() {
319319
return ark::unittest::SUCCESS;
320320
}
321321

322+
ark::unittest::State test_sched_opgraph_dependent_inputs() {
323+
ark::Model m;
324+
325+
ark::Tensor *ones = m.tensor({256, 256}, ark::FP16);
326+
ark::Tensor *x0 = m.scale(m.scale(ones, 2), 2);
327+
ark::Tensor *x1 = m.scale(m.scale(x0, 2), 2);
328+
329+
ark::Tensor *x2 = m.mul(ones, x1);
330+
ark::Tensor *x3 = m.mul(ones, x1);
331+
ark::Tensor *x4 = m.mul(x2, x3);
332+
ark::Tensor *y = m.add(x0, x4);
333+
334+
ark::OpGraph graph(m);
335+
UNITTEST_EQ(graph.get_nodes().size(), 4);
336+
auto nodes_iter = graph.get_nodes().begin();
337+
auto node = (nodes_iter++)->get();
338+
UNITTEST_EQ(node->ops.size(), 4);
339+
UNITTEST_EQ(node->ops[1]->outputs[0], x0);
340+
UNITTEST_EQ(node->ops[3]->outputs[0], x1);
341+
UNITTEST_EQ(node->users.size(), 3);
342+
UNITTEST_EQ(node->producers.size(), 0);
343+
node = (nodes_iter++)->get();
344+
UNITTEST_EQ(node->ops.size(), 1);
345+
UNITTEST_EQ(node->ops[0]->outputs[0], x2);
346+
UNITTEST_EQ(node->ops[0]->inputs[0], ones);
347+
UNITTEST_EQ(node->ops[0]->inputs[1], x1);
348+
UNITTEST_EQ(node->users.size(), 1);
349+
UNITTEST_EQ(node->producers.size(), 1);
350+
node = (nodes_iter++)->get();
351+
UNITTEST_EQ(node->ops.size(), 1);
352+
UNITTEST_EQ(node->ops[0]->outputs[0], x3);
353+
UNITTEST_EQ(node->ops[0]->inputs[0], ones);
354+
UNITTEST_EQ(node->ops[0]->inputs[1], x1);
355+
UNITTEST_EQ(node->users.size(), 1);
356+
UNITTEST_EQ(node->producers.size(), 1);
357+
node = (nodes_iter++)->get();
358+
UNITTEST_EQ(node->ops.size(), 2);
359+
UNITTEST_EQ(node->ops[0]->outputs[0], x4);
360+
UNITTEST_EQ(node->ops[0]->inputs[0], x2);
361+
UNITTEST_EQ(node->ops[0]->inputs[1], x3);
362+
UNITTEST_EQ(node->ops[1]->outputs[0], y);
363+
UNITTEST_EQ(node->ops[1]->inputs[0], x0);
364+
UNITTEST_EQ(node->ops[1]->inputs[1], x4);
365+
UNITTEST_EQ(node->users.size(), 0);
366+
UNITTEST_EQ(node->producers.size(), 3);
367+
368+
return ark::unittest::SUCCESS;
369+
}
370+
322371
ark::unittest::State test_sched_opgraph_noop() {
323372
ark::Model model;
324373
model.tensor({1}, ark::FP32);
@@ -564,6 +613,7 @@ ark::unittest::State test_sched_opgraph_all_reduce() {
564613
int main() {
565614
ark::init();
566615
UNITTEST(test_sched_opgraph);
616+
UNITTEST(test_sched_opgraph_dependent_inputs);
567617
UNITTEST(test_sched_opgraph_noop);
568618
UNITTEST(test_sched_opgraph_identity);
569619
UNITTEST(test_sched_opgraph_sharding);

0 commit comments

Comments
 (0)