Skip to content

Commit 39b76af

Browse files
committed
fix(coprocessor): update error and uncomputable handling
1 parent 9aec40d commit 39b76af

File tree

5 files changed

+173
-196
lines changed

5 files changed

+173
-196
lines changed

coprocessor/fhevm-engine/scheduler/Cargo.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,3 @@ fhevm-engine-common = { path = "../fhevm-engine-common" }
2323
[features]
2424
nightly-avx512 = ["tfhe/nightly-avx512"]
2525
gpu = ["tfhe/gpu"]
26-
rerandomise = []

coprocessor/fhevm-engine/scheduler/src/dfg.rs

Lines changed: 38 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ pub struct TxNode {
3232
// transaction)
3333
pub results: Vec<Handle>,
3434
pub transaction_id: Handle,
35-
pub is_error: bool,
35+
pub is_uncomputable: bool,
3636
pub intermediate_handles: Vec<Handle>,
3737
}
3838
impl TxNode {
3939
pub fn build(&mut self, mut operations: Vec<DFGOp>, transaction_id: &Handle) -> Result<()> {
4040
self.transaction_id = transaction_id.clone();
41-
self.is_error = false;
41+
self.is_uncomputable = false;
4242
// Gather all handles produced within the transaction
4343
let mut produced_handles: HashMap<Handle, usize> = HashMap::new();
4444
for (index, op) in operations.iter().enumerate() {
@@ -175,31 +175,32 @@ impl DFTxGraph {
175175
result: Result<(SupportedFheCiphertexts, i16, Vec<u8>)>,
176176
edges: &Dag<(), TxEdge>,
177177
) -> Result<()> {
178-
if let Some(producer) = self.allowed_map.get_mut(handle) {
179-
for edge in edges.edges_directed(*producer, Direction::Outgoing) {
180-
let dependent_tx_index = edge.target();
181-
let dependent_tx = self
182-
.graph
183-
.node_weight_mut(dependent_tx_index)
184-
.ok_or(SchedulerError::DataflowGraphError)?;
185-
if let Ok(ref result) = result {
178+
if let Some(producer) = self.allowed_map.get(handle).cloned() {
179+
if let Ok(ref result) = result {
180+
// Traverse immediate dependents and add this result as an input
181+
for edge in edges.edges_directed(producer, Direction::Outgoing) {
182+
let dependent_tx_index = edge.target();
183+
let dependent_tx = self
184+
.graph
185+
.node_weight_mut(dependent_tx_index)
186+
.ok_or(SchedulerError::DataflowGraphError)?;
186187
dependent_tx
187188
.inputs
188189
.entry(handle.to_vec())
189190
.and_modify(|v| *v = Some(DFGTxInput::Value(result.0.clone())));
190-
} else {
191-
// If the output is an error, all dependent
192-
// transactions are blocked
193-
dependent_tx.is_error = true;
194191
}
192+
} else {
193+
// If this result was an error, mark this transaction
194+
// and all its dependents as uncomputable, we will
195+
// skip them during scheduling
196+
self.set_uncomputable(producer, edges)?;
195197
}
198+
// Finally add the output (either error or compressed
199+
// ciphertext) to the graph's outputs
196200
let producer_tx = self
197201
.graph
198-
.node_weight_mut(*producer)
202+
.node_weight_mut(producer)
199203
.ok_or(SchedulerError::DataflowGraphError)?;
200-
if result.is_err() {
201-
producer_tx.is_error = true;
202-
}
203204
self.results.push(DFGTxResult {
204205
transaction_id: producer_tx.transaction_id.clone(),
205206
handle: handle.to_vec(),
@@ -208,13 +209,31 @@ impl DFTxGraph {
208209
}
209210
Ok(())
210211
}
212+
// Set a node as uncomputable and recursively traverse graph to
213+
// set its dependents as uncomputable as well
214+
fn set_uncomputable(
215+
&mut self,
216+
tx_node_index: NodeIndex,
217+
edges: &Dag<(), TxEdge>,
218+
) -> Result<()> {
219+
let tx_node = self
220+
.graph
221+
.node_weight_mut(tx_node_index)
222+
.ok_or(SchedulerError::DataflowGraphError)?;
223+
tx_node.is_uncomputable = true;
224+
for edge in edges.edges_directed(tx_node_index, Direction::Outgoing) {
225+
let dependent_tx_index = edge.target();
226+
self.set_uncomputable(dependent_tx_index, edges)?;
227+
}
228+
Ok(())
229+
}
211230
pub fn get_results(&mut self) -> Vec<DFGTxResult> {
212231
std::mem::take(&mut self.results)
213232
}
214233
pub fn get_intermediate_handles(&mut self) -> Vec<(Handle, Handle)> {
215234
let mut res = vec![];
216235
for tx in self.graph.node_weights_mut() {
217-
if !tx.is_error {
236+
if !tx.is_uncomputable {
218237
res.append(
219238
&mut (std::mem::take(&mut tx.intermediate_handles))
220239
.into_iter()
@@ -254,13 +273,11 @@ pub struct DFGResult {
254273
pub type OpEdge = u8;
255274
pub struct OpNode {
256275
opcode: i32,
257-
result: DFGTaskResult,
258276
result_handle: Handle,
259277
inputs: Vec<DFGTaskInput>,
260278
#[cfg(feature = "gpu")]
261279
locality: i32,
262280
is_allowed: bool,
263-
is_needed: bool,
264281
}
265282
impl std::fmt::Debug for OpNode {
266283
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -291,7 +308,6 @@ impl OpNode {
291308
pub struct DFGraph {
292309
pub graph: Dag<OpNode, OpEdge>,
293310
}
294-
295311
impl DFGraph {
296312
pub fn add_node(
297313
&mut self,
@@ -302,13 +318,11 @@ impl DFGraph {
302318
) -> NodeIndex {
303319
self.graph.add_node(OpNode {
304320
opcode,
305-
result: None,
306321
result_handle: rh,
307322
inputs,
308323
#[cfg(feature = "gpu")]
309324
locality: -1,
310325
is_allowed,
311-
is_needed: is_allowed,
312326
})
313327
}
314328
pub fn add_dependence(
@@ -327,54 +341,4 @@ impl DFGraph {
327341
.map_err(|_| SchedulerError::CyclicDependence)?;
328342
Ok(())
329343
}
330-
331-
// fn is_needed(&self, index: usize) -> bool {
332-
// let node_index = NodeIndex::new(index);
333-
// let node = self.graph.node_weight(node_index).unwrap();
334-
// if node.is_allowed || node.is_needed {
335-
// true
336-
// } else {
337-
// for edge in self.graph.edges_directed(node_index, Direction::Outgoing) {
338-
// // If any outgoing dependence is needed, so is this node
339-
// if self.is_needed(edge.target().index()) {
340-
// return true;
341-
// }
342-
// }
343-
// false
344-
// }
345-
// }
346-
347-
// pub fn finalize(&mut self) {
348-
// // Traverse in reverse order and mark nodes as needed as the
349-
// // graph order is roughly computable, so allowed nodes should
350-
// // generally be later in the graph.
351-
// for index in (0..self.graph.node_count()).rev() {
352-
// if self.is_needed(index) {
353-
// let node = self.graph.node_weight_mut(NodeIndex::new(index)).unwrap();
354-
// node.is_needed = true;
355-
// }
356-
// }
357-
// // Prune graph of all unneeded nodes and edges
358-
// let mut unneeded_nodes = Vec::new();
359-
// for index in 0..self.graph.node_count() {
360-
// let node_index = NodeIndex::new(index);
361-
// let Some(node) = self.graph.node_weight(node_index) else {
362-
// continue;
363-
// };
364-
// if !node.is_needed {
365-
// unneeded_nodes.push(index);
366-
// }
367-
// }
368-
// unneeded_nodes.sort();
369-
// // Remove unneeded nodes and their edges
370-
// for index in unneeded_nodes.iter().rev() {
371-
// let node_index = NodeIndex::new(*index);
372-
// let Some(node) = self.graph.node_weight(node_index) else {
373-
// continue;
374-
// };
375-
// if !node.is_needed {
376-
// self.graph.remove_node(node_index);
377-
// }
378-
// }
379-
// }
380344
}

0 commit comments

Comments
 (0)