Skip to content

Commit 8e88e30

Browse files
authored
BFS graph evaluation order (#1525)
* bfs order * try fix event issue
1 parent 0eb56d5 commit 8e88e30

File tree

1 file changed

+56
-22
lines changed

1 file changed

+56
-22
lines changed

mlx/transforms.cpp

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0};
4040
int detail::RetainGraph::tracing_counter{0};
4141

4242
array eval_impl(std::vector<array> outputs, bool async) {
43-
std::queue<array> tape;
43+
std::vector<array> tape;
4444

4545
// stream events to use for synchronization
4646
std::unordered_map<uint32_t, Event> events;
@@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
6464
events.emplace(stream.index, Event{stream});
6565

6666
{
67-
std::unordered_set<std::uintptr_t> cache;
67+
// Record the degree of each input
68+
std::unordered_map<std::uintptr_t, int> cache;
69+
6870
std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
6971
dfs.emplace(synchronizer, 0);
7072
while (!dfs.empty()) {
@@ -104,50 +106,82 @@ array eval_impl(std::vector<array> outputs, bool async) {
104106
}
105107
}
106108

107-
if (cache.find(in.id()) == cache.end()) {
109+
// All siblings have the same degree
110+
auto cache_it = cache.find(in.id());
111+
if (cache_it == cache.end()) {
108112
dfs.emplace(in, 0);
109-
cache.insert(in.id());
113+
cache.insert({in.id(), 1});
114+
for (auto& s : in.siblings()) {
115+
cache.insert({s.id(), 1});
116+
}
117+
} else {
118+
cache_it->second++;
110119
for (auto& s : in.siblings()) {
111-
cache.insert(s.id());
120+
cache[s.id()]++;
112121
}
113122
}
114123
continue;
115124
}
116-
117-
// All inputs are done being processed, process this array
118125
if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
119126
a.has_primitive()) {
120127
// If the array is evaluated and is no longer a tracer, detach it
121128
a.detach();
122-
} else if (a.status() == array::Status::unscheduled) {
123-
tape.push(a);
124-
// Lookup corresponding event and increment counter
125-
auto& stream = a.primitive().stream();
126-
auto e = events.find(stream.index);
127-
if (e == events.end()) {
128-
e = events.emplace(stream.index, Event{stream}).first;
129+
}
130+
dfs.pop();
131+
}
132+
133+
// Build the tape in BFS order
134+
tape.push_back(synchronizer);
135+
for (int i = 0; !cache.empty() && i < tape.size(); ++i) {
136+
auto& a = tape[i];
137+
for (auto& in : a.inputs()) {
138+
if (in.status() != array::Status::unscheduled) {
139+
continue;
129140
}
130-
e->second.set_value(e->second.value() + 1);
131-
a.attach_event(e->second);
132-
for (auto& s : a.siblings()) {
133-
s.attach_event(e->second);
141+
auto it = cache.find(in.id());
142+
it->second -= 1;
143+
144+
if (it->second != 0) {
145+
for (auto& s : in.siblings()) {
146+
cache[s.id()] -= 1;
147+
}
148+
continue;
149+
}
150+
151+
// Remove input and siblings from cache
152+
cache.erase(it);
153+
for (auto& s : in.siblings()) {
154+
cache.erase(s.id());
134155
}
156+
157+
tape.push_back(in);
135158
}
136-
dfs.pop();
137159
}
138160
}
139161

140162
while (!tape.empty()) {
141-
auto arr = std::move(tape.front());
142-
tape.pop();
163+
auto arr = std::move(tape.back());
164+
tape.pop_back();
165+
166+
auto stream = arr.primitive().stream();
167+
168+
// Lookup corresponding event and increment counter
169+
auto e = events.find(stream.index);
170+
if (e == events.end()) {
171+
e = events.emplace(stream.index, Event{stream}).first;
172+
}
173+
e->second.set_value(e->second.value() + 1);
174+
arr.attach_event(e->second);
175+
for (auto& s : arr.siblings()) {
176+
s.attach_event(e->second);
177+
}
143178

144179
// Set the status of the array and siblings.
145180
arr.set_status(array::Status::scheduled);
146181
for (auto& s : arr.siblings()) {
147182
s.set_status(array::Status::scheduled);
148183
}
149184

150-
auto stream = arr.primitive().stream();
151185
std::vector<std::shared_future<void>> arr_deps;
152186
bool signal = needs_signal.find(arr.id()) != needs_signal.end();
153187

0 commit comments

Comments
 (0)