Skip to content

Commit c2d2f13

Browse files
committed
Account for automatically-inserted Contains statements
1 parent 316df30 commit c2d2f13

File tree

4 files changed

+209
-65
lines changed

4 files changed

+209
-65
lines changed

src/frontend/multi_pod/cost.rs

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,11 @@
66
use std::collections::BTreeSet;
77

88
use crate::{
9-
frontend::Operation,
10-
middleware::{CustomPredicateBatch, Hash, NativeOperation, OperationType, Params},
9+
frontend::{Operation, OperationArg},
10+
middleware::{
11+
CustomPredicateBatch, Hash, NativeOperation, OperationType, Params, RawValue, Statement,
12+
ValueRef,
13+
},
1114
};
1215

1316
/// Unique identifier for a custom predicate batch.
@@ -23,6 +26,44 @@ impl From<&CustomPredicateBatch> for CustomBatchId {
2326
}
2427
}
2528

29+
/// Unique identifier for an anchored key (dict, key) pair.
30+
///
31+
/// When a Contains statement is used as an argument to operations like gt(), eq(), etc.,
32+
/// the value is accessed via an "anchored key" - a reference to a specific key in a
33+
/// specific dictionary. Each unique anchored key used in a POD requires a Contains
34+
/// statement to be present in that POD (auto-inserted by MainPodBuilder if needed).
35+
///
36+
/// We use the raw values of the dict and key for comparison, as they uniquely identify
37+
/// the anchored key regardless of the specific Value types involved.
38+
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
39+
pub struct AnchoredKeyId {
40+
/// The dictionary root value (raw representation for Ord).
41+
pub dict: RawValue,
42+
/// The key within the dictionary (raw representation for Ord).
43+
pub key: RawValue,
44+
}
45+
46+
impl AnchoredKeyId {
47+
/// Create a new anchored key ID from raw values.
48+
pub fn new(dict: RawValue, key: RawValue) -> Self {
49+
Self { dict, key }
50+
}
51+
52+
/// Try to extract an anchored key ID from a Contains statement with all literal values.
53+
pub fn from_contains_statement(stmt: &Statement) -> Option<Self> {
54+
if let Statement::Contains(
55+
ValueRef::Literal(dict),
56+
ValueRef::Literal(key),
57+
ValueRef::Literal(_value),
58+
) = stmt
59+
{
60+
Some(Self::new(dict.raw(), key.raw()))
61+
} else {
62+
None
63+
}
64+
}
65+
}
66+
2667
/// Resource costs for a single statement/operation.
2768
///
2869
/// Each field corresponds to a resource with a per-POD limit in `Params`.
@@ -51,6 +92,14 @@ pub struct StatementCost {
5192
/// Custom predicate batches used (for batch cardinality constraint).
5293
/// Limit: `params.max_custom_predicate_batches` distinct batches per POD.
5394
pub custom_batch_ids: BTreeSet<CustomBatchId>,
95+
96+
/// Anchored keys referenced by this operation.
97+
///
98+
/// When a Contains statement with all literal values is used as an argument,
99+
/// the operation references an "anchored key" (dict, key pair). Each unique
100+
/// anchored key used in a POD incurs an additional Contains statement cost,
101+
/// as MainPodBuilder::add_entries_contains will auto-insert it if not already present.
102+
pub anchored_keys: BTreeSet<AnchoredKeyId>,
54103
}
55104

56105
impl StatementCost {
@@ -121,6 +170,18 @@ impl StatementCost {
121170
}
122171
}
123172

173+
// Extract anchored keys from operation arguments.
174+
// Any argument that is a Contains statement with all literal values
175+
// represents an anchored key reference that will require a Contains
176+
// statement in the POD (auto-inserted by MainPodBuilder if needed).
177+
for arg in &op.1 {
178+
if let OperationArg::Statement(stmt) = arg {
179+
if let Some(anchored_key) = AnchoredKeyId::from_contains_statement(stmt) {
180+
cost.anchored_keys.insert(anchored_key);
181+
}
182+
}
183+
}
184+
124185
cost
125186
}
126187
}

src/frontend/multi_pod/mod.rs

Lines changed: 97 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ mod cost;
1717
mod deps;
1818
mod solver;
1919

20-
use cost::{estimate_pod_count, StatementCost};
20+
use cost::{estimate_pod_count, AnchoredKeyId, StatementCost};
2121
use deps::{DependencyGraph, StatementSource};
2222
pub use solver::MultiPodSolution;
2323

@@ -255,6 +255,14 @@ impl MultiPodBuilder {
255255
.map(StatementCost::from_operation)
256256
.collect();
257257

258+
// Collect all unique anchored keys from the costs
259+
let all_anchored_keys: Vec<AnchoredKeyId> = costs
260+
.iter()
261+
.flat_map(|c| c.anchored_keys.iter().cloned())
262+
.collect::<std::collections::BTreeSet<_>>()
263+
.into_iter()
264+
.collect();
265+
258266
// Build external POD statement mapping (cache for reuse in build_single_pod)
259267
let external_pod_statements = self.build_external_statement_map();
260268
self.cached_external_map = Some(external_pod_statements);
@@ -274,6 +282,7 @@ impl MultiPodBuilder {
274282
output_public_indices: &self.output_public_indices,
275283
params: &self.params,
276284
max_pods: self.options.max_pods,
285+
all_anchored_keys: &all_anchored_keys,
277286
};
278287

279288
let solution = solver::solve(&input)?;
@@ -665,48 +674,72 @@ mod tests {
665674
fn test_cross_pod_dependencies() -> Result<()> {
666675
// Verifies that dependencies work correctly when statements span POD boundaries.
667676
//
668-
// Each pair forms a dependency: lt(a, b) proves a < b, then lt_to_ne derives a ≠ b.
677+
// Scenario: Verify properties of a user profile credential.
678+
// The profile contains multiple attributes, and we verify each meets a threshold.
679+
// Each verification creates a dependency chain:
680+
// dict_contains(profile, key, value) -> gt(value, threshold)
681+
//
669682
// When statements are split across PODs, the solver must:
670683
// 1. Ensure dependencies are available (either proved locally or public in earlier POD)
671684
// 2. Insert CopyStatements to bring dependencies into the POD that needs them
672685
//
673-
// Setup: 8 statements with max_priv=4 forces splitting across 2+ PODs.
686+
// Statement count: 12 user operations (11 private + 1 public), plus 6 anchored keys.
687+
// To force multiple PODs, we set max_priv_statements = 10 (< 17 effective statements).
674688
let params = Params {
675-
max_statements: 6,
689+
max_statements: 12,
676690
max_public_statements: 2,
677-
// Derived: max_priv_statements = 6 - 2 = 4
678-
// With 8 statements, need ceil(8/4) = 2 PODs minimum
691+
// Derived: max_priv_statements = 12 - 2 = 10
679692
max_input_pods: 2,
680-
max_input_pods_public_statements: 4,
693+
max_input_pods_public_statements: 14,
681694
..Params::default()
682695
};
683696
let vd_set = &*MOCK_VD_SET;
684697

685698
let mut builder = MultiPodBuilder::new(&params, vd_set);
686699

687-
// Create 4 dependency pairs - enough to force cross-POD dependencies
688-
// Pair 1: prove balance < limit, derive balance ≠ limit
689-
let balance_under_limit = builder.priv_op(FrontendOp::lt(1, 100))?;
690-
let _balance_not_at_limit = builder.priv_op(FrontendOp::lt_to_ne(balance_under_limit))?;
700+
// User profile credential with multiple attributes
701+
let profile = dict!({
702+
"age" => 25,
703+
"balance" => 1000,
704+
"reputation" => 85,
705+
"level" => 5,
706+
"credits" => 150,
707+
"score" => 72
708+
});
709+
710+
// Verify each attribute meets its threshold requirement
711+
// Each creates a dependency: dict_contains -> gt
712+
713+
// Verify age >= 18 (adult)
714+
let age = builder.priv_op(FrontendOp::dict_contains(profile.clone(), "age", 25))?;
715+
let _age_ok = builder.priv_op(FrontendOp::gt_eq(age, 18))?;
716+
717+
// Verify balance >= 100 (minimum balance)
718+
let balance = builder.priv_op(FrontendOp::dict_contains(profile.clone(), "balance", 1000))?;
719+
let _balance_ok = builder.priv_op(FrontendOp::gt_eq(balance, 100))?;
691720

692-
// Pair 2: prove age < max, derive age ≠ max
693-
let age_under_max = builder.priv_op(FrontendOp::lt(2, 200))?;
694-
let _age_not_at_max = builder.priv_op(FrontendOp::lt_to_ne(age_under_max))?;
721+
// Verify reputation >= 50 (trusted user)
722+
let reputation =
723+
builder.priv_op(FrontendOp::dict_contains(profile.clone(), "reputation", 85))?;
724+
let _reputation_ok = builder.priv_op(FrontendOp::gt_eq(reputation, 50))?;
695725

696-
// Pair 3: prove score < threshold, derive score ≠ threshold
697-
let score_under_threshold = builder.priv_op(FrontendOp::lt(3, 300))?;
698-
let _score_not_at_threshold =
699-
builder.priv_op(FrontendOp::lt_to_ne(score_under_threshold))?;
726+
// Verify level >= 3 (experienced user)
727+
let level = builder.priv_op(FrontendOp::dict_contains(profile.clone(), "level", 5))?;
728+
let _level_ok = builder.priv_op(FrontendOp::gt_eq(level, 3))?;
700729

701-
// Pair 4: prove level < cap, derive level ≠ cap (public output)
702-
let level_under_cap = builder.priv_op(FrontendOp::lt(4, 400))?;
703-
let _level_not_at_cap = builder.pub_op(FrontendOp::lt_to_ne(level_under_cap))?;
730+
// Verify credits >= 100 (has credits)
731+
let credits = builder.priv_op(FrontendOp::dict_contains(profile.clone(), "credits", 150))?;
732+
let _credits_ok = builder.priv_op(FrontendOp::gt_eq(credits, 100))?;
733+
734+
// Verify score >= 60 (passing score) - make this one public
735+
let score = builder.priv_op(FrontendOp::dict_contains(profile, "score", 72))?;
736+
let _score_ok = builder.pub_op(FrontendOp::gt_eq(score, 60))?;
704737

705738
let pod_count = {
706739
let solution = builder.solve()?;
707740
assert!(
708741
solution.pod_count >= 2,
709-
"Expected at least 2 PODs for 8 statements with max_priv=4, got {}",
742+
"Expected at least 2 PODs for 17 effective private statements with max_priv=10, got {}",
710743
solution.pod_count
711744
);
712745
solution.pod_count
@@ -1107,43 +1140,45 @@ mod tests {
11071140
}
11081141

11091142
#[test]
1110-
fn test_copy_statements_counted_in_statement_limit() -> Result<()> {
1111-
// Verifies that CopyStatements for cross-POD dependencies are counted
1112-
// toward the statement limit.
1143+
fn test_anchored_key_overhead_counted_in_statement_limit() -> Result<()> {
1144+
// Verifies that anchored key overhead is correctly counted toward statement limits.
1145+
//
1146+
// When a Contains statement is used as an argument to operations like gt(),
1147+
// it creates an "anchored key" reference. If the gt() is proved in a different
1148+
// POD than the original Contains, MainPodBuilder auto-inserts a local Contains
1149+
// statement for that anchored key. The solver must account for this overhead.
11131150
//
11141151
// Setup:
1115-
// - max_priv_statements = 2 (small limit)
1116-
// - Statement A with no deps (public, goes to POD 0)
1117-
// - Statements B, C, D all depend on A (private)
1152+
// - max_priv_statements = 4 (small limit)
1153+
// - Statement A: dict_contains (public, in POD 0)
1154+
// - Statement B: eq (public, in POD 0)
1155+
// - Statements C, D, E: gt(A, val) - each uses A as an anchored key
11181156
//
1119-
// Expected:
1120-
// - Solver should recognize that if B, C, D go to POD 1, it needs a CopyStatement for A
1121-
// - So POD 1 would have: CopyStatement(A) + B + C + D = 4 private statements
1122-
// - This exceeds max_priv_statements = 2, so solver should create more PODs
1157+
// The solver must account for the anchored key Contains statements that will
1158+
// be auto-inserted when gt operations are proved in PODs other than POD 0.
11231159

11241160
let params = Params {
1125-
max_statements: 4,
1126-
max_public_statements: 2, // max_priv_statements = 4 - 2 = 2
1161+
max_statements: 6,
1162+
max_public_statements: 2, // max_priv_statements = 6 - 2 = 4
11271163
..Params::default()
11281164
};
11291165
let vd_set = &*MOCK_VD_SET;
11301166

11311167
let mut builder = MultiPodBuilder::new(&params, vd_set);
11321168

1133-
// Statement 0: public, no deps - will be in POD 0
1134-
let stmt_a = builder.pub_op(FrontendOp::lt(1, 100))?;
1169+
// Statement A: public Contains - proved in POD 0
1170+
let dict = dict!({"x" => 100});
1171+
let stmt_a = builder.pub_op(FrontendOp::dict_contains(dict, "x", 100))?;
11351172

1136-
// Statements 1, 2, 3: private, all depend on statement 0
1137-
// With max_priv_statements = 2, these can't all fit in POD 0
1138-
// Solver must account for CopyStatement when distributing these
1139-
builder.priv_op(FrontendOp::lt_to_ne(stmt_a.clone()))?;
1140-
builder.priv_op(FrontendOp::lt_to_ne(stmt_a.clone()))?;
1141-
builder.priv_op(FrontendOp::lt_to_ne(stmt_a))?;
1142-
1143-
// Add another public statement for the output POD
1173+
// Statement B: another public statement in POD 0
11441174
builder.pub_op(FrontendOp::eq(200, 200))?;
11451175

1146-
// Solver should correctly account for CopyStatements and create enough PODs
1176+
// Statements C, D, E: each uses stmt_a as an anchored key
1177+
// When proved in a different POD, each needs a local Contains for the anchored key
1178+
builder.priv_op(FrontendOp::gt(stmt_a.clone(), 0))?;
1179+
builder.priv_op(FrontendOp::gt(stmt_a.clone(), 1))?;
1180+
builder.priv_op(FrontendOp::gt(stmt_a, 2))?;
1181+
11471182
let prover = MockProver {};
11481183
let result = builder.prove(&prover)?;
11491184

@@ -1162,12 +1197,18 @@ mod tests {
11621197
// Verifies that scenarios with both internal and external dependencies work
11631198
// when the total input count stays within max_input_pods.
11641199
//
1165-
// This is a sanity check that mixing internal and external POD dependencies
1166-
// works correctly when limits are respected.
1200+
// Setup:
1201+
// - 1 external POD with a public statement
1202+
// - 2 public dict_contains statements (uses anchored keys)
1203+
// - 2 private gt statements that reference the dict_contains via anchored keys
1204+
// - 1 private copy of the external POD's statement
1205+
//
1206+
// This tests that mixing internal POD dependencies (from earlier generated PODs)
1207+
// and external POD dependencies (from user-provided input PODs) works correctly.
11671208

11681209
let params = Params {
1169-
max_statements: 6,
1170-
max_public_statements: 3, // max_priv_statements = 3
1210+
max_statements: 10,
1211+
max_public_statements: 3, // max_priv_statements = 7
11711212
max_input_pods: 3, // Allow up to 3 inputs per POD
11721213
max_input_pods_public_statements: 10,
11731214
..Params::default()
@@ -1190,13 +1231,15 @@ mod tests {
11901231
let mut builder = MultiPodBuilder::new(&params, vd_set);
11911232
builder.add_pod(ext_pod);
11921233

1193-
// Output POD: public statements
1194-
let lt_0 = builder.pub_op(FrontendOp::lt(1, 100))?;
1195-
let lt_1 = builder.pub_op(FrontendOp::lt(2, 200))?;
1234+
// Output POD: public Contains statements
1235+
let dict0 = dict!({"x" => 100});
1236+
let dict1 = dict!({"y" => 200});
1237+
let contains_0 = builder.pub_op(FrontendOp::dict_contains(dict0, "x", 100))?;
1238+
let contains_1 = builder.pub_op(FrontendOp::dict_contains(dict1, "y", 200))?;
11961239

11971240
// Statements that depend on output POD
1198-
builder.priv_op(FrontendOp::lt_to_ne(lt_0))?;
1199-
builder.priv_op(FrontendOp::lt_to_ne(lt_1))?;
1241+
builder.priv_op(FrontendOp::gt(contains_0, 0))?;
1242+
builder.priv_op(FrontendOp::gt(contains_1, 0))?;
12001243

12011244
// Depend on external POD
12021245
builder.priv_op(FrontendOp::copy(stmt_ext))?;

0 commit comments

Comments
 (0)