@@ -32,13 +32,13 @@ pub struct TxNode {
3232 // transaction)
3333 pub results : Vec < Handle > ,
3434 pub transaction_id : Handle ,
35- pub is_error : bool ,
35+ pub is_uncomputable : bool ,
3636 pub intermediate_handles : Vec < Handle > ,
3737}
3838impl TxNode {
3939 pub fn build ( & mut self , mut operations : Vec < DFGOp > , transaction_id : & Handle ) -> Result < ( ) > {
4040 self . transaction_id = transaction_id. clone ( ) ;
41- self . is_error = false ;
41+ self . is_uncomputable = false ;
4242 // Gather all handles produced within the transaction
4343 let mut produced_handles: HashMap < Handle , usize > = HashMap :: new ( ) ;
4444 for ( index, op) in operations. iter ( ) . enumerate ( ) {
@@ -175,31 +175,32 @@ impl DFTxGraph {
175175 result : Result < ( SupportedFheCiphertexts , i16 , Vec < u8 > ) > ,
176176 edges : & Dag < ( ) , TxEdge > ,
177177 ) -> Result < ( ) > {
178- if let Some ( producer) = self . allowed_map . get_mut ( handle) {
179- for edge in edges. edges_directed ( * producer, Direction :: Outgoing ) {
180- let dependent_tx_index = edge. target ( ) ;
181- let dependent_tx = self
182- . graph
183- . node_weight_mut ( dependent_tx_index)
184- . ok_or ( SchedulerError :: DataflowGraphError ) ?;
185- if let Ok ( ref result) = result {
178+ if let Some ( producer) = self . allowed_map . get ( handle) . cloned ( ) {
179+ if let Ok ( ref result) = result {
180+ // Traverse immediate dependents and add this result as an input
181+ for edge in edges. edges_directed ( producer, Direction :: Outgoing ) {
182+ let dependent_tx_index = edge. target ( ) ;
183+ let dependent_tx = self
184+ . graph
185+ . node_weight_mut ( dependent_tx_index)
186+ . ok_or ( SchedulerError :: DataflowGraphError ) ?;
186187 dependent_tx
187188 . inputs
188189 . entry ( handle. to_vec ( ) )
189190 . and_modify ( |v| * v = Some ( DFGTxInput :: Value ( result. 0 . clone ( ) ) ) ) ;
190- } else {
191- // If the output is an error, all dependent
192- // transactions are blocked
193- dependent_tx. is_error = true ;
194191 }
192+ } else {
193+ // If this result was an error, mark this transaction
194+ // and all its dependents as uncomputable, we will
195+ // skip them during scheduling
196+ self . set_uncomputable ( producer, edges) ?;
195197 }
198+ // Finally add the output (either error or compressed
199+ // ciphertext) to the graph's outputs
196200 let producer_tx = self
197201 . graph
198- . node_weight_mut ( * producer)
202+ . node_weight_mut ( producer)
199203 . ok_or ( SchedulerError :: DataflowGraphError ) ?;
200- if result. is_err ( ) {
201- producer_tx. is_error = true ;
202- }
203204 self . results . push ( DFGTxResult {
204205 transaction_id : producer_tx. transaction_id . clone ( ) ,
205206 handle : handle. to_vec ( ) ,
@@ -208,13 +209,31 @@ impl DFTxGraph {
208209 }
209210 Ok ( ( ) )
210211 }
212+ // Set a node as uncomputable and recursively traverse graph to
213+ // set its dependents as uncomputable as well
214+ fn set_uncomputable (
215+ & mut self ,
216+ tx_node_index : NodeIndex ,
217+ edges : & Dag < ( ) , TxEdge > ,
218+ ) -> Result < ( ) > {
219+ let tx_node = self
220+ . graph
221+ . node_weight_mut ( tx_node_index)
222+ . ok_or ( SchedulerError :: DataflowGraphError ) ?;
223+ tx_node. is_uncomputable = true ;
224+ for edge in edges. edges_directed ( tx_node_index, Direction :: Outgoing ) {
225+ let dependent_tx_index = edge. target ( ) ;
226+ self . set_uncomputable ( dependent_tx_index, edges) ?;
227+ }
228+ Ok ( ( ) )
229+ }
211230 pub fn get_results ( & mut self ) -> Vec < DFGTxResult > {
212231 std:: mem:: take ( & mut self . results )
213232 }
214233 pub fn get_intermediate_handles ( & mut self ) -> Vec < ( Handle , Handle ) > {
215234 let mut res = vec ! [ ] ;
216235 for tx in self . graph . node_weights_mut ( ) {
217- if !tx. is_error {
236+ if !tx. is_uncomputable {
218237 res. append (
219238 & mut ( std:: mem:: take ( & mut tx. intermediate_handles ) )
220239 . into_iter ( )
@@ -254,13 +273,11 @@ pub struct DFGResult {
254273pub type OpEdge = u8 ;
255274pub struct OpNode {
256275 opcode : i32 ,
257- result : DFGTaskResult ,
258276 result_handle : Handle ,
259277 inputs : Vec < DFGTaskInput > ,
260278 #[ cfg( feature = "gpu" ) ]
261279 locality : i32 ,
262280 is_allowed : bool ,
263- is_needed : bool ,
264281}
265282impl std:: fmt:: Debug for OpNode {
266283 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
@@ -291,7 +308,6 @@ impl OpNode {
291308pub struct DFGraph {
292309 pub graph : Dag < OpNode , OpEdge > ,
293310}
294-
295311impl DFGraph {
296312 pub fn add_node (
297313 & mut self ,
@@ -302,13 +318,11 @@ impl DFGraph {
302318 ) -> NodeIndex {
303319 self . graph . add_node ( OpNode {
304320 opcode,
305- result : None ,
306321 result_handle : rh,
307322 inputs,
308323 #[ cfg( feature = "gpu" ) ]
309324 locality : -1 ,
310325 is_allowed,
311- is_needed : is_allowed,
312326 } )
313327 }
314328 pub fn add_dependence (
@@ -327,54 +341,4 @@ impl DFGraph {
327341 . map_err ( |_| SchedulerError :: CyclicDependence ) ?;
328342 Ok ( ( ) )
329343 }
330-
331- // fn is_needed(&self, index: usize) -> bool {
332- // let node_index = NodeIndex::new(index);
333- // let node = self.graph.node_weight(node_index).unwrap();
334- // if node.is_allowed || node.is_needed {
335- // true
336- // } else {
337- // for edge in self.graph.edges_directed(node_index, Direction::Outgoing) {
338- // // If any outgoing dependence is needed, so is this node
339- // if self.is_needed(edge.target().index()) {
340- // return true;
341- // }
342- // }
343- // false
344- // }
345- // }
346-
347- // pub fn finalize(&mut self) {
348- // // Traverse in reverse order and mark nodes as needed as the
349- // // graph order is roughly computable, so allowed nodes should
350- // // generally be later in the graph.
351- // for index in (0..self.graph.node_count()).rev() {
352- // if self.is_needed(index) {
353- // let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap();
354- // node.is_needed = true;
355- // }
356- // }
357- // // Prune graph of all unneeded nodes and edges
358- // let mut unneeded_nodes = Vec::new();
359- // for index in 0..self.graph.node_count() {
360- // let node_index = NodeIndex::new(index);
361- // let Some(node) = self.graph.node_weight(node_index) else {
362- // continue;
363- // };
364- // if !node.is_needed {
365- // unneeded_nodes.push(index);
366- // }
367- // }
368- // unneeded_nodes.sort();
369- // // Remove unneeded nodes and their edges
370- // for index in unneeded_nodes.iter().rev() {
371- // let node_index = NodeIndex::new(*index);
372- // let Some(node) = self.graph.node_weight(node_index) else {
373- // continue;
374- // };
375- // if !node.is_needed {
376- // self.graph.remove_node(node_index);
377- // }
378- // }
379- // }
380344}
0 commit comments