@@ -40,7 +40,7 @@ int detail::InTracing::tracing_counter{0};
4040int detail::RetainGraph::tracing_counter{0 };
4141
4242array eval_impl (std::vector<array> outputs, bool async) {
43- std::queue <array> tape;
43+ std::vector <array> tape;
4444
4545 // stream events to use for synchronization
4646 std::unordered_map<uint32_t , Event> events;
@@ -64,7 +64,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
6464 events.emplace (stream.index , Event{stream});
6565
6666 {
67- std::unordered_set<std::uintptr_t > cache;
67+ // Record the degree of each input
68+ std::unordered_map<std::uintptr_t , int > cache;
69+
6870 std::stack<std::pair<std::reference_wrapper<array>, int >> dfs;
6971 dfs.emplace (synchronizer, 0 );
7072 while (!dfs.empty ()) {
@@ -104,50 +106,82 @@ array eval_impl(std::vector<array> outputs, bool async) {
104106 }
105107 }
106108
107- if (cache.find (in.id ()) == cache.end ()) {
109+ // All siblings have the same degree
110+ auto cache_it = cache.find (in.id ());
111+ if (cache_it == cache.end ()) {
108112 dfs.emplace (in, 0 );
109- cache.insert (in.id ());
113+ cache.insert ({in.id (), 1 });
114+ for (auto & s : in.siblings ()) {
115+ cache.insert ({s.id (), 1 });
116+ }
117+ } else {
118+ cache_it->second ++;
110119 for (auto & s : in.siblings ()) {
111- cache. insert ( s.id ()) ;
120+ cache[ s.id ()]++ ;
112121 }
113122 }
114123 continue ;
115124 }
116-
117- // All inputs are done being processed, process this array
118125 if ((a.status () != array::Status::unscheduled) && !a.is_tracer () &&
119126 a.has_primitive ()) {
120127 // If the array is evaluated and is no longer a tracer, detach it
121128 a.detach ();
122- } else if (a.status () == array::Status::unscheduled) {
123- tape.push (a);
124- // Lookup corresponding event and increment counter
125- auto & stream = a.primitive ().stream ();
126- auto e = events.find (stream.index );
127- if (e == events.end ()) {
128- e = events.emplace (stream.index , Event{stream}).first ;
129+ }
130+ dfs.pop ();
131+ }
132+
133+ // Build the tape in BFS order
134+ tape.push_back (synchronizer);
135+ for (int i = 0 ; !cache.empty () && i < tape.size (); ++i) {
136+ auto & a = tape[i];
137+ for (auto & in : a.inputs ()) {
138+ if (in.status () != array::Status::unscheduled) {
139+ continue ;
129140 }
130- e->second .set_value (e->second .value () + 1 );
131- a.attach_event (e->second );
132- for (auto & s : a.siblings ()) {
133- s.attach_event (e->second );
141+ auto it = cache.find (in.id ());
142+ it->second -= 1 ;
143+
144+ if (it->second != 0 ) {
145+ for (auto & s : in.siblings ()) {
146+ cache[s.id ()] -= 1 ;
147+ }
148+ continue ;
149+ }
150+
151+ // Remove input and siblings from cache
152+ cache.erase (it);
153+ for (auto & s : in.siblings ()) {
154+ cache.erase (s.id ());
134155 }
156+
157+ tape.push_back (in);
135158 }
136- dfs.pop ();
137159 }
138160 }
139161
140162 while (!tape.empty ()) {
141- auto arr = std::move (tape.front ());
142- tape.pop ();
163+ auto arr = std::move (tape.back ());
164+ tape.pop_back ();
165+
166+ auto stream = arr.primitive ().stream ();
167+
168+ // Lookup corresponding event and increment counter
169+ auto e = events.find (stream.index );
170+ if (e == events.end ()) {
171+ e = events.emplace (stream.index , Event{stream}).first ;
172+ }
173+ e->second .set_value (e->second .value () + 1 );
174+ arr.attach_event (e->second );
175+ for (auto & s : arr.siblings ()) {
176+ s.attach_event (e->second );
177+ }
143178
144179 // Set the status of the array and siblings.
145180 arr.set_status (array::Status::scheduled);
146181 for (auto & s : arr.siblings ()) {
147182 s.set_status (array::Status::scheduled);
148183 }
149184
150- auto stream = arr.primitive ().stream ();
151185 std::vector<std::shared_future<void >> arr_deps;
152186 bool signal = needs_signal.find (arr.id ()) != needs_signal.end ();
153187
0 commit comments