Skip to content

Commit 5dab819

Browse files
authored
Fix accidental inclusion of extra input PODs (#476) (#478)
* Fix issue with adding extra input PODs * Panic if input PODs are missing
1 parent 2bd99ef commit 5dab819

File tree

1 file changed

+134
-50
lines changed

1 file changed

+134
-50
lines changed

src/frontend/multi_pod/mod.rs

Lines changed: 134 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
//!
4848
//! [`MainPodBuilder`]: crate::frontend::MainPodBuilder
4949
50-
use std::collections::{BTreeSet, HashMap};
50+
use std::{
51+
collections::{BTreeSet, HashMap},
52+
fmt,
53+
};
5154

5255
use crate::{
5356
frontend::{MainPod, MainPodBuilder, Operation, OperationArg},
@@ -237,61 +240,16 @@ impl SolvedMultiPod {
237240
) -> Result<MainPod> {
238241
let mut builder = MainPodBuilder::new(&self.params, &self.vd_set);
239242
let solution = &self.solution;
240-
241-
let statements_in_this_pod: &Vec<usize> = &solution.pod_statements[pod_idx];
242-
let mut needed_external_pods: BTreeSet<usize> = BTreeSet::new();
243-
let mut needed_earlier_pods: BTreeSet<usize> = BTreeSet::new();
243+
let statements_in_this_pod = &solution.pod_statements[pod_idx];
244244

245245
// Step 1: Find which external and earlier PODs we need based on dependencies
246-
for &stmt_idx in statements_in_this_pod {
247-
for dep in &self.deps.statement_deps[stmt_idx] {
248-
match dep {
249-
StatementSource::Internal(dep_idx) => {
250-
// Check if dependency is in an earlier generated POD
251-
let mut found = false;
252-
for earlier_pod_idx in 0..pod_idx {
253-
if solution.pod_public_statements[earlier_pod_idx].contains(dep_idx) {
254-
needed_earlier_pods.insert(earlier_pod_idx);
255-
found = true;
256-
break;
257-
}
258-
}
259-
// If not found in earlier PODs, it must be local to this POD
260-
if !found && !statements_in_this_pod.contains(dep_idx) {
261-
unreachable!(
262-
"Internal dependency {} for statement {} is neither local \
263-
nor public in any earlier POD (solver bug)",
264-
dep_idx, stmt_idx
265-
);
266-
}
267-
}
268-
StatementSource::External(pod_hash) => {
269-
// Find which external POD has this hash
270-
let ext_idx = self
271-
.input_pods
272-
.iter()
273-
.position(|p| p.statements_hash() == *pod_hash);
274-
match ext_idx {
275-
Some(idx) => {
276-
needed_external_pods.insert(idx);
277-
}
278-
None => {
279-
unreachable!(
280-
"External dependency with hash {:?} not found in input PODs",
281-
pod_hash
282-
);
283-
}
284-
}
285-
}
286-
}
287-
}
288-
}
246+
let (needed_earlier_pods, needed_external_pods) = self.compute_pod_inputs(pod_idx);
289247

290248
// Step 2: Add input PODs to the builder
291-
for &ext_idx in &needed_external_pods {
249+
for ext_idx in needed_external_pods {
292250
builder.add_pod(self.input_pods[ext_idx].clone())?;
293251
}
294-
for &earlier_idx in &needed_earlier_pods {
252+
for earlier_idx in needed_earlier_pods {
295253
builder.add_pod(earlier_pods[earlier_idx].clone())?;
296254
}
297255

@@ -338,6 +296,132 @@ impl SolvedMultiPod {
338296

339297
Ok(pod)
340298
}
299+
300+
/// Compute which input PODs (internal and external) are needed for a given POD.
301+
///
302+
/// Returns (internal_pod_indices, external_pod_indices).
303+
fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet<usize>, BTreeSet<usize>) {
304+
let solution = &self.solution;
305+
let statements_in_pod = &solution.pod_statements[pod_idx];
306+
307+
let mut internal_pods: BTreeSet<usize> = BTreeSet::new();
308+
let mut external_pods: BTreeSet<usize> = BTreeSet::new();
309+
310+
for &stmt_idx in statements_in_pod {
311+
for dep in &self.deps.statement_deps[stmt_idx] {
312+
match dep {
313+
StatementSource::Internal(dep_idx) => {
314+
// Check if dependency is in an earlier POD (not local)
315+
if !statements_in_pod.contains(dep_idx) {
316+
let earlier_pod_idx = (0..pod_idx)
317+
.find(|earlier_pod_idx| {
318+
solution.pod_public_statements[*earlier_pod_idx]
319+
.contains(dep_idx)
320+
})
321+
.expect("internal pod with dependency statement");
322+
internal_pods.insert(earlier_pod_idx);
323+
}
324+
}
325+
StatementSource::External(pod_hash) => {
326+
let idx = self
327+
.input_pods
328+
.iter()
329+
.position(|p| p.statements_hash() == *pod_hash)
330+
.expect("external pod with dependency statement");
331+
external_pods.insert(idx);
332+
}
333+
}
334+
}
335+
}
336+
337+
assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods);
338+
339+
(internal_pods, external_pods)
340+
}
341+
}
342+
343+
impl fmt::Display for SolvedMultiPod {
344+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345+
let solution = &self.solution;
346+
let output_pod_idx = solution.pod_count.saturating_sub(1);
347+
348+
// Header
349+
writeln!(
350+
f,
351+
"SolvedMultiPod: {} statements → {} PODs",
352+
self.statements.len(),
353+
solution.pod_count
354+
)?;
355+
356+
if !self.input_pods.is_empty() {
357+
writeln!(f, " External input PODs: {}", self.input_pods.len())?;
358+
}
359+
360+
writeln!(f)?;
361+
362+
// Per-POD breakdown
363+
for pod_idx in 0..solution.pod_count {
364+
let is_output = pod_idx == output_pod_idx;
365+
let role = if is_output { "output" } else { "intermediate" };
366+
367+
writeln!(f, " POD {} ({}):", pod_idx, role)?;
368+
369+
// Show input PODs
370+
let (internal_inputs, external_inputs) = self.compute_pod_inputs(pod_idx);
371+
if !internal_inputs.is_empty() || !external_inputs.is_empty() {
372+
let internal_str: Vec<String> = internal_inputs
373+
.iter()
374+
.map(|i| format!("POD {}", i))
375+
.collect();
376+
let external_str: Vec<String> = external_inputs
377+
.iter()
378+
.map(|i| format!("ext[{}]", i))
379+
.collect();
380+
let all_inputs: Vec<&str> = internal_str
381+
.iter()
382+
.map(|s| s.as_str())
383+
.chain(external_str.iter().map(|s| s.as_str()))
384+
.collect();
385+
writeln!(
386+
f,
387+
" inputs: {} (total: {})",
388+
all_inputs.join(", "),
389+
all_inputs.len()
390+
)?;
391+
}
392+
393+
// Show statements
394+
let stmts = &solution.pod_statements[pod_idx];
395+
let public_stmts = &solution.pod_public_statements[pod_idx];
396+
397+
for &stmt_idx in stmts {
398+
let stmt = &self.statements[stmt_idx];
399+
let is_public = public_stmts.contains(&stmt_idx);
400+
let visibility = if is_public { "public" } else { "private" };
401+
402+
// Show dependencies for this statement
403+
let deps = &self.deps.statement_deps[stmt_idx];
404+
let dep_str = if deps.is_empty() {
405+
String::new()
406+
} else {
407+
let dep_parts: Vec<String> = deps
408+
.iter()
409+
.map(|d| match d {
410+
StatementSource::Internal(i) => format!("stmt[{}]", i),
411+
StatementSource::External(_) => "ext".to_string(),
412+
})
413+
.collect();
414+
format!(" ← {}", dep_parts.join(", "))
415+
};
416+
417+
writeln!(f, " [{}] {} [{}]{}", stmt_idx, stmt, visibility, dep_str)?;
418+
}
419+
420+
writeln!(f)?;
421+
}
422+
423+
Ok(())
424+
}
341425
}
342426

343427
impl MultiPodBuilder {

0 commit comments

Comments
 (0)