diff --git a/src/lang/error.rs b/src/lang/error.rs index bdd1c453..944988c7 100644 --- a/src/lang/error.rs +++ b/src/lang/error.rs @@ -287,7 +287,8 @@ fn format_public_args_at_split_error( msg.push_str(&format!( " Statements {}-{} in this segment\n", - context.statement_range.0, context.statement_range.1 + context.statement_range.0, + context.statement_range.1 - 1 )); if !context.incoming_public.is_empty() { diff --git a/src/lang/frontend_ast_split.rs b/src/lang/frontend_ast_split.rs index 7e7279cb..0d17217f 100644 --- a/src/lang/frontend_ast_split.rs +++ b/src/lang/frontend_ast_split.rs @@ -15,10 +15,7 @@ //! We use a greedy algorithm to order the statements in a predicate to minimize //! the number of live wildcards at split boundaries. -use std::{ - cmp::Reverse, - collections::{HashMap, HashSet}, -}; +use std::{cmp::Reverse, collections::HashSet}; // SplittingError is now defined in error.rs pub use crate::lang::error::SplittingError; @@ -33,8 +30,8 @@ pub struct ChainLink { pub public_args_in: Vec, /// Private arguments used only in this link pub private_args: Vec, - /// Public arguments promoted to pass to next link (empty if last link) - pub public_args_out: Vec, + /// Private wildcards promoted to public for the next link (empty if last link) + pub promoted_wildcards: Vec, } /// Information about a single piece of a split predicate chain @@ -71,13 +68,6 @@ pub struct SplitResult { pub chain_info: Option, } -/// Wildcard usage information -#[derive(Debug, Clone)] -struct WildcardUsage { - /// Indices of statements using this wildcard - used_in_statements: HashSet, -} - /// Early validation: Check if predicate is fundamentally splittable pub fn validate_predicate_is_splittable(pred: &CustomPredicateDef) -> Result<(), SplittingError> { let public_args = pred.args.public_args.len(); @@ -121,26 +111,6 @@ pub fn split_predicate_if_needed( }) } -fn analyze_wildcards(statements: &[StatementTmpl]) -> HashMap { - let mut usage: HashMap = HashMap::new(); - - for (idx, stmt) in statements.iter().enumerate() { - let wildcards = collect_wildcards_from_statement(stmt); - - for wildcard in wildcards { - usage - .entry(wildcard.clone()) - .or_insert_with(|| WildcardUsage { - used_in_statements: HashSet::new(), - }) - .used_in_statements - .insert(idx); - } - } - - usage -} - /// Collect all wildcard names from a statement fn collect_wildcards_from_statement(stmt: &StatementTmpl) -> HashSet { let mut wildcards = HashSet::new(); @@ -172,7 +142,7 @@ struct OrderingResult { fn order_constraints_optimally( statements: Vec, - _usage: &HashMap, + public_args: &HashSet, ) -> OrderingResult { let n = statements.len(); @@ -190,8 +160,13 @@ fn order_constraints_optimally( let mut active_wildcards: HashSet = HashSet::new(); while !remaining.is_empty() { - let best_idx = - find_best_next_statement(&statements, &remaining, &active_wildcards, ordered.len()); + let best_idx = find_best_next_statement( + &statements, + &remaining, + &active_wildcards, + ordered.len(), + public_args, + ); remaining.remove(&best_idx); let stmt = &statements[best_idx]; @@ -200,14 +175,20 @@ fn order_constraints_optimally( reorder_map[best_idx] = ordered.len(); ordered.push(stmt.clone()); - // Update active wildcards + // Only track private wildcards in the active set — public args are always + // available at every boundary so their liveness is irrelevant to split cost. let stmt_wildcards = collect_wildcards_from_statement(stmt); - active_wildcards.extend(stmt_wildcards); + active_wildcards.extend( + stmt_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)), + ); - // Remove wildcards no longer needed by remaining statements + // Remove private wildcards no longer needed by remaining statements let needed_later: HashSet<_> = remaining .iter() .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .filter(|w| !public_args.contains(w)) .collect(); active_wildcards.retain(|w| needed_later.contains(w)); } @@ -225,30 +206,35 @@ fn compute_tie_breakers( active_wildcards: &HashSet, statements: &[StatementTmpl], remaining: &HashSet, + needed_later: &HashSet, + public_args: &HashSet, ) -> (usize, usize, i32) { - let stmt_wildcards = collect_wildcards_from_statement(stmt); + let all_wildcards = collect_wildcards_from_statement(stmt); + // Only consider private wildcards for tie-breaking metrics + let stmt_wildcards: HashSet<_> = all_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); - // Metric 1: Simplicity - prefer statements with fewer wildcards + // Metric 1: Simplicity - prefer statements with fewer private wildcards let simplicity = usize::MAX - stmt_wildcards.len(); - // Metric 2: Public closure - prefer statements that close active wildcards + // Metric 2: Closure - prefer statements that close active private wildcards // (wildcards that won't be needed by any remaining statements) - let needed_later: HashSet = remaining - .iter() - .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) - .collect(); - let closes_count = stmt_wildcards .intersection(active_wildcards) .filter(|w| !needed_later.contains(*w)) .count(); // Metric 3: Fanout - prefer statements with lower future usage - // (number of remaining statements that use any wildcard from this statement) + // (number of remaining statements sharing private wildcards with this statement) let fanout = remaining .iter() .filter(|&&i| { - let other_wildcards = collect_wildcards_from_statement(&statements[i]); + let other_wildcards: HashSet<_> = collect_wildcards_from_statement(&statements[i]) + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); !stmt_wildcards.is_disjoint(&other_wildcards) }) .count(); @@ -262,16 +248,33 @@ fn statement_selection_key( active_wildcards: &HashSet, remaining: &HashSet, approaching_split: bool, + public_args: &HashSet, ) -> (i32, (usize, usize, i32), Reverse) { + // Pre-compute needed_later once and share between primary score and tie-breakers. + // Exclude the candidate itself: we want to know what the *other* remaining statements + // need, so that wildcards used only by this candidate correctly appear as closeable. + let needed_later: HashSet = remaining + .iter() + .filter(|&&i| i != idx) + .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) + .filter(|w| !public_args.contains(w)) + .collect(); + let primary_score = score_statement( + &statements[idx], + active_wildcards, + approaching_split, + public_args, + &needed_later, + ); + let tie_breakers = compute_tie_breakers( &statements[idx], active_wildcards, statements, remaining, - approaching_split, + &needed_later, + public_args, ); - let tie_breakers = - compute_tie_breakers(&statements[idx], active_wildcards, statements, remaining); // Final deterministic tie-breaker: prefer smaller original indices. // This avoids hash-iteration-dependent selection when scores are equal. @@ -284,6 +287,7 @@ fn find_best_next_statement( remaining: &HashSet, active_wildcards: &HashSet, ordered_count: usize, + public_args: &HashSet, ) -> usize { // Calculate distance to next split point let bucket_size = Params::max_custom_predicate_arity() - 1; // Reserve slot for chain call @@ -299,51 +303,66 @@ fn find_best_next_statement( active_wildcards, remaining, approaching_split, + public_args, ) }) .copied() .unwrap() } -/// Score a statement based on how well it minimizes liveness +/// Score a statement based on how well it minimizes private-wildcard liveness at boundaries. +/// `needed_later` is the set of private wildcards used by any remaining statement. fn score_statement( stmt: &StatementTmpl, active_wildcards: &HashSet, - statements: &[StatementTmpl], - remaining: &HashSet, approaching_split: bool, + public_args: &HashSet, + needed_later: &HashSet, ) -> i32 { - let stmt_wildcards = collect_wildcards_from_statement(stmt); + let all_wildcards = collect_wildcards_from_statement(stmt); + + // Only score based on private wildcards. Public args are always available at every + // split boundary — they never consume a promotion slot, so their liveness is free. + let stmt_wildcards: HashSet<_> = all_wildcards + .into_iter() + .filter(|w| !public_args.contains(w)) + .collect(); - // How many active wildcards does this reuse? + // Statements that touch only public args ("cheap" statements) waste a bucket slot + // that could be used to cluster private wildcards. Strongly defer them while any + // private-wildcard statements remain, so they fill leftover space at the end. + // `needed_later` is non-empty iff some remaining statement has a private wildcard. + if stmt_wildcards.is_empty() { + return if needed_later.is_empty() { + 0 + } else { + i32::MIN / 2 + }; + } + + // How many active private wildcards does this reuse? let reuse_count = stmt_wildcards.intersection(active_wildcards).count(); - // How many new wildcards does this introduce? + // How many new private wildcards does this introduce? let new_wildcard_count = stmt_wildcards.difference(active_wildcards).count(); - // After adding this statement, what would be active? + // Which of the projected-active wildcards are still needed after this statement? let mut projected_active = active_wildcards.clone(); - projected_active.extend(stmt_wildcards.clone()); - - // Which wildcards are still needed by other remaining statements? - let needed_later: HashSet = remaining - .iter() - .flat_map(|&i| collect_wildcards_from_statement(&statements[i])) - .collect(); - - // Wildcards we can close = active now but not needed later + projected_active.extend(stmt_wildcards); projected_active.retain(|w| needed_later.contains(w)); let still_active_count = projected_active.len(); - // Base score calculation - // - Prefer statements that reuse active wildcards (don't introduce new liveness) - // - Penalize introducing new wildcards (increases liveness) - // - Penalize keeping many wildcards active (higher liveness) + // Base score: + // +3 per reused wildcard — rewards clustering (wildcard already open, no new cost) + // -4 per new wildcard — penalises opening new live ranges + // -2 per still-live — penalises carrying many wildcards toward the boundary let base_score = (reuse_count * 3) as i32 - (new_wildcard_count * 4) as i32 - (still_active_count * 2) as i32; - // Look-ahead bonus: when approaching split, heavily favor closing wildcards + // When close to a split boundary, strongly reward statements that close wildcards + // (active.len() + new - still_active = number of wildcards resolved by this statement). + // Weight 10 >> max base-score magnitude to make closing the dominant factor. if approaching_split { let closes_count = active_wildcards.len() + new_wildcard_count - still_active_count; base_score + (closes_count * 10) as i32 @@ -375,8 +394,6 @@ fn calculate_live_wildcards( fn generate_refactor_suggestion( crossing_wildcards: &[String], ordered_statements: &[StatementTmpl], - _pos: usize, - _end: usize, ) -> Option { use crate::lang::error::RefactorSuggestion; @@ -445,13 +462,6 @@ fn split_into_chain( let original_name = pred.name.name.clone(); let conjunction = pred.conjunction_type; - let usage = analyze_wildcards(&pred.statements); - let real_statement_count = pred.statements.len(); - - let ordering_result = order_constraints_optimally(pred.statements, &usage); - let ordered_statements = ordering_result.statements; - let reorder_map = ordering_result.reorder_map; - let original_public_args: Vec = pred .args .public_args @@ -459,6 +469,14 @@ fn split_into_chain( .map(|id| id.name.clone()) .collect(); + let public_args_set: HashSet = original_public_args.iter().cloned().collect(); + + let real_statement_count = pred.statements.len(); + + let ordering_result = order_constraints_optimally(pred.statements, &public_args_set); + let ordered_statements = ordering_result.statements; + let reorder_map = ordering_result.reorder_map; + let mut chain_links = Vec::new(); let mut pos = 0; let mut incoming_public = original_public_args.clone(); @@ -501,8 +519,7 @@ fn split_into_chain( total_public, }; - let suggestion = - generate_refactor_suggestion(&new_promotions, &ordered_statements, pos, end); + let suggestion = generate_refactor_suggestion(&new_promotions, &ordered_statements); return Err(SplittingError::TooManyPublicArgsAtSplit { predicate: original_name.clone(), @@ -540,23 +557,60 @@ fn split_into_chain( }); } - let mut public_args_out: Vec = live_at_boundary.iter().cloned().collect(); - public_args_out.sort(); // Deterministic ordering - chain_links.push(ChainLink { statements: ordered_statements[pos..end].to_vec(), public_args_in: incoming_public.clone(), private_args, - public_args_out: public_args_out.clone(), + // new_promotions are already sorted and already filtered to exclude incoming_public + promoted_wildcards: new_promotions.clone(), }); pos = end; - // Next link's incoming public args = current incoming + newly promoted live wildcards - // Only add wildcards that aren't already in incoming_public to avoid duplicates - for wildcard in public_args_out { - if !incoming_set.contains(&wildcard) { - incoming_public.push(wildcard); + // Extend incoming_public for the next link with the newly promoted wildcards. + // new_promotions is already filtered to exclude incoming_set, so no dedup needed. + incoming_public.extend(new_promotions); + } + + // Backward pass: prune each continuation's public args to the minimal set needed. + // + // The forward pass accumulates incoming_public monotonically, so a continuation may + // inherit original public args that none of its statements (or downstream continuations) + // ever reference. A continuation must declare every public arg it receives, and the + // proof system constrains each declared arg - an arg that goes unused has no constraints + // and will not match the value the caller passes. + // + // Propagating from the last link backward ensures each continuation declares exactly the + // args it uses directly, plus any args its successor still needs. Link 0 (the original + // predicate) is left untouched - its public-arg signature is user-declared. + { + let num_links = chain_links.len(); + if num_links > 1 { + // Collect wildcards referenced by each link's statements once. + let link_wildcards: Vec> = chain_links + .iter() + .map(|link| { + link.statements + .iter() + .flat_map(collect_wildcards_from_statement) + .collect() + }) + .collect(); + + let last = num_links - 1; + + // Seed: last link retains only args it directly references. + chain_links[last] + .public_args_in + .retain(|a| link_wildcards[last].contains(a)); + + // Propagate backward through intermediate continuation links (skip link 0). + for i in (1..last).rev() { + let needed_downstream: HashSet = + chain_links[i + 1].public_args_in.iter().cloned().collect(); + chain_links[i] + .public_args_in + .retain(|a| link_wildcards[i].contains(a) || needed_downstream.contains(a)); } } } @@ -590,7 +644,7 @@ fn split_into_chain( let mut chain_predicates = generate_chain_predicates(&original_name, chain_links, conjunction, params)?; - validate_chain(&chain_predicates, params)?; + validate_chain(&chain_predicates, params); // Reverse so continuations come before callers in declaration order. // This ensures that when batched, continuations are in earlier batches @@ -633,7 +687,7 @@ fn generate_chain_predicates( }; // Create arguments for chain call: use next link's public_args_in - // which is the deduplicated union of current public_args_in and public_args_out + // which is current public_args_in extended with current promoted_wildcards let next_link = &chain_links[i + 1]; let chain_call_args: Vec = next_link .public_args_in @@ -665,10 +719,12 @@ fn generate_chain_predicates( }) .collect(); - // Build private args (private + promoted for next) + // Build private args: segment-local private wildcards, plus any wildcards being + // promoted to public for the next link (they must be declared here so the solver + // can bind them before passing them as public args to the continuation). let mut private_arg_names = link.private_args.clone(); if !is_last { - private_arg_names.extend(link.public_args_out.clone()); + private_arg_names.extend(link.promoted_wildcards.iter().cloned()); } let private_args = if private_arg_names.is_empty() { @@ -698,25 +754,34 @@ fn generate_chain_predicates( Ok(predicates) } -/// Phase 5: Validate the generated chain -/// -/// Note: We no longer check chain length against max_custom_batch_size since -/// chains can now span multiple batches thanks to multi-batch support. -fn validate_chain(chain: &[CustomPredicateDef], params: &Params) -> Result<(), SplittingError> { +/// Sanity-check the generated chain. All three constraints are enforced as proper errors +/// earlier in `split_into_chain`, so violations here indicate a bug in the algorithm. +fn validate_chain(chain: &[CustomPredicateDef], params: &Params) { for pred in chain { - // Each predicate should have ≤ max_statements - assert!(pred.statements.len() <= Params::max_custom_predicate_arity()); - - // Public args should fit - assert!(pred.args.public_args.len() <= Params::max_statement_args()); - - // Total args should fit + assert!( + pred.statements.len() <= Params::max_custom_predicate_arity(), + "chain link '{}' has {} statements, exceeds max {}", + pred.name.name, + pred.statements.len(), + Params::max_custom_predicate_arity(), + ); + assert!( + pred.args.public_args.len() <= Params::max_statement_args(), + "chain link '{}' has {} public args, exceeds max {}", + pred.name.name, + pred.args.public_args.len(), + Params::max_statement_args(), + ); let total = pred.args.public_args.len() + pred.args.private_args.as_ref().map_or(0, |v| v.len()); - assert!(total <= params.max_custom_predicate_wildcards); + assert!( + total <= params.max_custom_predicate_wildcards, + "chain link '{}' has {} total args, exceeds max {}", + pred.name.name, + total, + params.max_custom_predicate_wildcards, + ); } - - Ok(()) } #[cfg(test)] @@ -976,8 +1041,24 @@ mod tests { let remaining: HashSet = [0, 1].into_iter().collect(); let active_wildcards = HashSet::new(); - let key0 = statement_selection_key(0, &statements, &active_wildcards, &remaining, false); - let key1 = statement_selection_key(1, &statements, &active_wildcards, &remaining, false); + // A and B are the public args of tie_break(A, B) + let public_args: HashSet = ["A".to_string(), "B".to_string()].into_iter().collect(); + let key0 = statement_selection_key( + 0, + &statements, + &active_wildcards, + &remaining, + false, + &public_args, + ); + let key1 = statement_selection_key( + 1, + &statements, + &active_wildcards, + &remaining, + false, + &public_args, + ); assert_eq!(key0.0, key1.0, "Primary heuristic score should tie"); assert_eq!(key0.1, key1.1, "Secondary tie-breaker metrics should tie"); @@ -986,7 +1067,8 @@ mod tests { "Lower original index should win deterministic final tie-breaker" ); - let selected = find_best_next_statement(&statements, &remaining, &active_wildcards, 0); + let selected = + find_best_next_statement(&statements, &remaining, &active_wildcards, 0, &public_args); assert_eq!(selected, 0); } @@ -1084,7 +1166,7 @@ mod tests { assert!(error_msg.contains("3 crossing wildcards")); assert!(error_msg.contains("= 6 total")); assert!(error_msg.contains("exceeds max of 5")); - assert!(error_msg.contains("Statements 0-4")); + assert!(error_msg.contains("Statements 0-3")); assert!(error_msg.contains("Incoming public args: A, B, C")); assert!(error_msg.contains("Wildcards crossing this boundary: T1, T2, T3")); assert!(error_msg.contains("Suggestion:")); @@ -1144,25 +1226,82 @@ mod tests { ); } + // --- Regression tests --- + + /// Statements that reference only public args should be deferred until private-wildcard + /// statements have been clustered, so they don't consume bucket slots that would reduce + /// liveness at split boundaries. + /// + /// 4 public args, 7 statements: W1 used in stmts 0,1,4; W2 used in stmts 1,2,3; + /// stmts 5,6 reference only public args. The scoring correctly defers stmts 5,6, + /// yielding bucket0={0,1,2,3}, bucket1={4,5,6} with only W1 crossing (4+1=5 <= max). #[test] - fn test_refactor_suggestion_group_wildcards() { - // Test the "group wildcard usages" suggestion formatting - use crate::lang::error::RefactorSuggestion; + fn test_split_succeeds_with_four_public_args_and_public_only_statements() { + // Optimal split: bucket0={0,1,2,3}, bucket1={4,5,6} + // Only W1 crosses (used in 0,1 and 4), total = 4 public + 1 crossing = 5 ✓ + let input = r#" + pred(A, B, C, D, private: W1, W2) = AND( + Equal(W1["x"], A["v"]) + Equal(W2["y"], W1["x"]) + Equal(W2["z"], B["v"]) + Equal(C["r"], W2["y"]) + Equal(D["s"], W1["x"]) + Equal(A["out"], C["out"]) + Equal(B["out"], D["out"]) + ) + "#; - let suggestion = RefactorSuggestion::GroupWildcardUsages { - wildcards: vec!["T1".to_string(), "T2".to_string(), "T3".to_string()], - }; + let pred = parse_predicate(input); + let params = Params::default(); - let suggestion_text = suggestion.format(); + let result = split_predicate_if_needed(pred, ¶ms); + assert!( + result.is_ok(), + "Should find a valid split with ≤1 crossing wildcard, got: {:?}", + result.err() + ); + } - // Verify the suggestion formats correctly - assert!(suggestion_text.contains("Group operations for wildcards")); - assert!(suggestion_text.contains("T1, T2, T3")); - assert!(suggestion_text.contains("used across multiple segments")); + /// Continuation predicates should only declare the public args they actually use - + /// original public args that are not referenced in a continuation's statements or + /// any of its downstream continuations must be omitted from its signature. + #[test] + fn test_continuation_excludes_public_args_unused_after_split() { + // A is used only in the first segment; B is used only in the second segment. + // The continuation predicate (pred_1) must include B but not A. + let input = r#" + pred(A, B, private: T) = AND( + Equal(T["x"], A["val"]) + Equal(T["y"], 1) + Equal(T["z"], 2) + Equal(T["w"], 3) + Equal(B["r"], T["x"]) + Equal(B["s"], T["y"]) + ) + "#; - eprintln!( - "\n=== Example GroupWildcardUsages Suggestion ===\n{}\n", - suggestion_text + let pred = parse_predicate(input); + let params = Params::default(); + + let result = split_predicate_if_needed(pred, ¶ms).unwrap(); + // chain[0] is the continuation (_1 suffix), chain[1] is the original + let continuation = result + .predicates + .iter() + .find(|p| p.name.name == "pred_1") + .expect("Expected a pred_1 continuation predicate"); + + let cont_public: Vec<&str> = continuation + .args + .public_args + .iter() + .map(|id| id.name.as_str()) + .collect(); + + assert!( + !cont_public.contains(&"A"), + "Continuation should drop unused public arg 'A', got: {:?}", + cont_public ); } } diff --git a/src/middleware/statement.rs b/src/middleware/statement.rs index 4ed1a8d5..d3e05347 100644 --- a/src/middleware/statement.rs +++ b/src/middleware/statement.rs @@ -77,14 +77,14 @@ impl NativePredicate { | NativePredicate::MaxOf | NativePredicate::HashOf | NativePredicate::SetInsert - | NativePredicate::SetDelete => 3, + | NativePredicate::SetDelete + | NativePredicate::DictDelete + | NativePredicate::ContainerDelete => 3, NativePredicate::DictInsert | NativePredicate::DictUpdate - | NativePredicate::DictDelete | NativePredicate::ArrayUpdate | NativePredicate::ContainerInsert - | NativePredicate::ContainerUpdate - | NativePredicate::ContainerDelete => 4, + | NativePredicate::ContainerUpdate => 4, } } }