|
47 | 47 | //! |
48 | 48 | //! [`MainPodBuilder`]: crate::frontend::MainPodBuilder |
49 | 49 |
|
50 | | -use std::collections::{BTreeSet, HashMap}; |
| 50 | +use std::{ |
| 51 | + collections::{BTreeSet, HashMap}, |
| 52 | + fmt, |
| 53 | +}; |
51 | 54 |
|
52 | 55 | use crate::{ |
53 | 56 | frontend::{MainPod, MainPodBuilder, Operation, OperationArg}, |
@@ -237,61 +240,16 @@ impl SolvedMultiPod { |
237 | 240 | ) -> Result<MainPod> { |
238 | 241 | let mut builder = MainPodBuilder::new(&self.params, &self.vd_set); |
239 | 242 | let solution = &self.solution; |
240 | | - |
241 | | - let statements_in_this_pod: &Vec<usize> = &solution.pod_statements[pod_idx]; |
242 | | - let mut needed_external_pods: BTreeSet<usize> = BTreeSet::new(); |
243 | | - let mut needed_earlier_pods: BTreeSet<usize> = BTreeSet::new(); |
| 243 | + let statements_in_this_pod = &solution.pod_statements[pod_idx]; |
244 | 244 |
|
245 | 245 | // Step 1: Find which external and earlier PODs we need based on dependencies |
246 | | - for &stmt_idx in statements_in_this_pod { |
247 | | - for dep in &self.deps.statement_deps[stmt_idx] { |
248 | | - match dep { |
249 | | - StatementSource::Internal(dep_idx) => { |
250 | | - // Check if dependency is in an earlier generated POD |
251 | | - let mut found = false; |
252 | | - for earlier_pod_idx in 0..pod_idx { |
253 | | - if solution.pod_public_statements[earlier_pod_idx].contains(dep_idx) { |
254 | | - needed_earlier_pods.insert(earlier_pod_idx); |
255 | | - found = true; |
256 | | - break; |
257 | | - } |
258 | | - } |
259 | | - // If not found in earlier PODs, it must be local to this POD |
260 | | - if !found && !statements_in_this_pod.contains(dep_idx) { |
261 | | - unreachable!( |
262 | | - "Internal dependency {} for statement {} is neither local \ |
263 | | - nor public in any earlier POD (solver bug)", |
264 | | - dep_idx, stmt_idx |
265 | | - ); |
266 | | - } |
267 | | - } |
268 | | - StatementSource::External(pod_hash) => { |
269 | | - // Find which external POD has this hash |
270 | | - let ext_idx = self |
271 | | - .input_pods |
272 | | - .iter() |
273 | | - .position(|p| p.statements_hash() == *pod_hash); |
274 | | - match ext_idx { |
275 | | - Some(idx) => { |
276 | | - needed_external_pods.insert(idx); |
277 | | - } |
278 | | - None => { |
279 | | - unreachable!( |
280 | | - "External dependency with hash {:?} not found in input PODs", |
281 | | - pod_hash |
282 | | - ); |
283 | | - } |
284 | | - } |
285 | | - } |
286 | | - } |
287 | | - } |
288 | | - } |
| 246 | + let (needed_earlier_pods, needed_external_pods) = self.compute_pod_inputs(pod_idx); |
289 | 247 |
|
290 | 248 | // Step 2: Add input PODs to the builder |
291 | | - for &ext_idx in &needed_external_pods { |
| 249 | + for ext_idx in needed_external_pods { |
292 | 250 | builder.add_pod(self.input_pods[ext_idx].clone())?; |
293 | 251 | } |
294 | | - for &earlier_idx in &needed_earlier_pods { |
| 252 | + for earlier_idx in needed_earlier_pods { |
295 | 253 | builder.add_pod(earlier_pods[earlier_idx].clone())?; |
296 | 254 | } |
297 | 255 |
|
@@ -338,6 +296,132 @@ impl SolvedMultiPod { |
338 | 296 |
|
339 | 297 | Ok(pod) |
340 | 298 | } |
| 299 | + |
| 300 | + /// Compute which input PODs (internal and external) are needed for a given POD. |
| 301 | + /// |
| 302 | + /// Returns (internal_pod_indices, external_pod_indices). |
| 303 | + fn compute_pod_inputs(&self, pod_idx: usize) -> (BTreeSet<usize>, BTreeSet<usize>) { |
| 304 | + let solution = &self.solution; |
| 305 | + let statements_in_pod = &solution.pod_statements[pod_idx]; |
| 306 | + |
| 307 | + let mut internal_pods: BTreeSet<usize> = BTreeSet::new(); |
| 308 | + let mut external_pods: BTreeSet<usize> = BTreeSet::new(); |
| 309 | + |
| 310 | + for &stmt_idx in statements_in_pod { |
| 311 | + for dep in &self.deps.statement_deps[stmt_idx] { |
| 312 | + match dep { |
| 313 | + StatementSource::Internal(dep_idx) => { |
| 314 | + // Check if dependency is in an earlier POD (not local) |
| 315 | + if !statements_in_pod.contains(dep_idx) { |
| 316 | + let earlier_pod_idx = (0..pod_idx) |
| 317 | + .find(|earlier_pod_idx| { |
| 318 | + solution.pod_public_statements[*earlier_pod_idx] |
| 319 | + .contains(dep_idx) |
| 320 | + }) |
| 321 | + .expect("internal pod with dependency statement"); |
| 322 | + internal_pods.insert(earlier_pod_idx); |
| 323 | + } |
| 324 | + } |
| 325 | + StatementSource::External(pod_hash) => { |
| 326 | + let idx = self |
| 327 | + .input_pods |
| 328 | + .iter() |
| 329 | + .position(|p| p.statements_hash() == *pod_hash) |
| 330 | + .expect("external pod with dependency statement"); |
| 331 | + external_pods.insert(idx); |
| 332 | + } |
| 333 | + } |
| 334 | + } |
| 335 | + } |
| 336 | + |
| 337 | + assert!(internal_pods.len() + external_pods.len() <= self.params.max_input_pods); |
| 338 | + |
| 339 | + (internal_pods, external_pods) |
| 340 | + } |
| 341 | +} |
| 342 | + |
| 343 | +impl fmt::Display for SolvedMultiPod { |
| 344 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 345 | + let solution = &self.solution; |
| 346 | + let output_pod_idx = solution.pod_count.saturating_sub(1); |
| 347 | + |
| 348 | + // Header |
| 349 | + writeln!( |
| 350 | + f, |
| 351 | + "SolvedMultiPod: {} statements → {} PODs", |
| 352 | + self.statements.len(), |
| 353 | + solution.pod_count |
| 354 | + )?; |
| 355 | + |
| 356 | + if !self.input_pods.is_empty() { |
| 357 | + writeln!(f, " External input PODs: {}", self.input_pods.len())?; |
| 358 | + } |
| 359 | + |
| 360 | + writeln!(f)?; |
| 361 | + |
| 362 | + // Per-POD breakdown |
| 363 | + for pod_idx in 0..solution.pod_count { |
| 364 | + let is_output = pod_idx == output_pod_idx; |
| 365 | + let role = if is_output { "output" } else { "intermediate" }; |
| 366 | + |
| 367 | + writeln!(f, " POD {} ({}):", pod_idx, role)?; |
| 368 | + |
| 369 | + // Show input PODs |
| 370 | + let (internal_inputs, external_inputs) = self.compute_pod_inputs(pod_idx); |
| 371 | + if !internal_inputs.is_empty() || !external_inputs.is_empty() { |
| 372 | + let internal_str: Vec<String> = internal_inputs |
| 373 | + .iter() |
| 374 | + .map(|i| format!("POD {}", i)) |
| 375 | + .collect(); |
| 376 | + let external_str: Vec<String> = external_inputs |
| 377 | + .iter() |
| 378 | + .map(|i| format!("ext[{}]", i)) |
| 379 | + .collect(); |
| 380 | + let all_inputs: Vec<&str> = internal_str |
| 381 | + .iter() |
| 382 | + .map(|s| s.as_str()) |
| 383 | + .chain(external_str.iter().map(|s| s.as_str())) |
| 384 | + .collect(); |
| 385 | + writeln!( |
| 386 | + f, |
| 387 | + " inputs: {} (total: {})", |
| 388 | + all_inputs.join(", "), |
| 389 | + all_inputs.len() |
| 390 | + )?; |
| 391 | + } |
| 392 | + |
| 393 | + // Show statements |
| 394 | + let stmts = &solution.pod_statements[pod_idx]; |
| 395 | + let public_stmts = &solution.pod_public_statements[pod_idx]; |
| 396 | + |
| 397 | + for &stmt_idx in stmts { |
| 398 | + let stmt = &self.statements[stmt_idx]; |
| 399 | + let is_public = public_stmts.contains(&stmt_idx); |
| 400 | + let visibility = if is_public { "public" } else { "private" }; |
| 401 | + |
| 402 | + // Show dependencies for this statement |
| 403 | + let deps = &self.deps.statement_deps[stmt_idx]; |
| 404 | + let dep_str = if deps.is_empty() { |
| 405 | + String::new() |
| 406 | + } else { |
| 407 | + let dep_parts: Vec<String> = deps |
| 408 | + .iter() |
| 409 | + .map(|d| match d { |
| 410 | + StatementSource::Internal(i) => format!("stmt[{}]", i), |
| 411 | + StatementSource::External(_) => "ext".to_string(), |
| 412 | + }) |
| 413 | + .collect(); |
| 414 | + format!(" ← {}", dep_parts.join(", ")) |
| 415 | + }; |
| 416 | + |
| 417 | + writeln!(f, " [{}] {} [{}]{}", stmt_idx, stmt, visibility, dep_str)?; |
| 418 | + } |
| 419 | + |
| 420 | + writeln!(f)?; |
| 421 | + } |
| 422 | + |
| 423 | + Ok(()) |
| 424 | + } |
341 | 425 | } |
342 | 426 |
|
343 | 427 | impl MultiPodBuilder { |
|
0 commit comments