Skip to content

Commit f6ff84e

Browse files
committed
Multi-batch splitting
1 parent 813a86c commit f6ff84e

File tree

10 files changed

+929
-369
lines changed

10 files changed

+929
-369
lines changed

examples/main_pod_points.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
8888
game_pk = game_pk,
8989
);
9090
println!("# custom predicate batch:{}", input);
91-
let batch = parse(&input, &params, &[])?.custom_batch;
91+
let batch = parse(&input, &params, &[])?
92+
.first_batch()
93+
.expect("Expected batch")
94+
.clone();
9295
let points_pred = batch.predicate_ref_by_name("points").unwrap();
9396
let over_9000_pred = batch.predicate_ref_by_name("over_9000").unwrap();
9497

src/backends/plonky2/mainpod/mod.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1179,7 +1179,9 @@ pub mod tests {
11791179
&[],
11801180
)
11811181
.unwrap()
1182-
.custom_batch;
1182+
.first_batch()
1183+
.unwrap()
1184+
.clone();
11831185
let mut builder = MainPodBuilder::new(&params, &DEFAULT_VD_SET);
11841186
let cpr = CustomPredicateRef { batch, index: 0 };
11851187
let eq_st = builder.priv_op(frontend::Operation::eq(1, 1)).unwrap();

src/examples/custom.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ pub fn eth_dos_batch(params: &Params) -> Result<Arc<CustomPredicateBatch>> {
3232
eth_dos_ind(src, dst, distance)
3333
)
3434
"#;
35-
let batch = parse(input, params, &[]).expect("lang parse").custom_batch;
35+
let batch = parse(input, params, &[])
36+
.expect("lang parse")
37+
.first_batch()
38+
.expect("Expected batch")
39+
.clone();
3640
println!("a.0. {}", batch.predicates[0]);
3741
println!("a.1. {}", batch.predicates[1]);
3842
println!("a.2. {}", batch.predicates[2]);

src/frontend/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,11 @@ pub mod tests {
13781378
Equal(b, 5)
13791379
)
13801380
"#;
1381-
let batch = parse(input, &params, &[]).unwrap().custom_batch;
1381+
let batch = parse(input, &params, &[])
1382+
.unwrap()
1383+
.first_batch()
1384+
.unwrap()
1385+
.clone();
13821386
let pred_test = batch.predicate_ref_by_name("Test").unwrap();
13831387

13841388
// Try to build with wrong type in 1st arg

src/lang/error.rs

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ pub enum LangError {
2222

2323
#[error("Lowering error: {0}")]
2424
Lowering(Box<LoweringError>),
25+
26+
#[error("Batching error: {0}")]
27+
Batching(Box<BatchingError>),
2528
}
2629

2730
/// Validation errors from frontend AST validation
@@ -90,14 +93,6 @@ pub enum ValidationError {
9093
/// Lowering errors from frontend AST lowering to middleware
9194
#[derive(Debug, thiserror::Error)]
9295
pub enum LoweringError {
93-
#[error("Too many custom predicates in batch '{batch_name}': {count} exceeds limit of {max}{}", if *.original_count != *.count { format!(" (started with {} predicates before automatic splitting)", original_count) } else { String::new() })]
94-
TooManyPredicates {
95-
batch_name: String,
96-
count: usize,
97-
max: usize,
98-
original_count: usize,
99-
},
100-
10196
#[error("Too many statements in predicate '{predicate}': {count} exceeds limit of {max}")]
10297
TooManyStatements {
10398
predicate: String,
@@ -127,6 +122,9 @@ pub enum LoweringError {
127122
#[error("Splitting error: {0}")]
128123
Splitting(#[from] SplittingError),
129124

125+
#[error("Batching error: {0}")]
126+
Batching(#[from] BatchingError),
127+
130128
#[error("Cannot lower document with validation errors")]
131129
ValidationErrors,
132130
}
@@ -235,6 +233,21 @@ fn format_public_args_at_split_error(
235233
msg
236234
}
237235

236+
/// Batching errors from multi-batch packing
237+
#[derive(Debug, thiserror::Error)]
238+
pub enum BatchingError {
239+
#[error("Forward cross-batch reference: predicate '{caller}' (batch {caller_batch}) calls '{callee}' (batch {callee_batch}). Move '{callee}' earlier or '{caller}' later.")]
240+
ForwardCrossBatchReference {
241+
caller: String,
242+
caller_batch: usize,
243+
callee: String,
244+
callee_batch: usize,
245+
},
246+
247+
#[error("Internal batching error: {message}")]
248+
Internal { message: String },
249+
}
250+
238251
/// Splitting errors from predicate splitting
239252
#[derive(Debug, thiserror::Error)]
240253
pub enum SplittingError {
@@ -271,13 +284,6 @@ pub enum SplittingError {
271284
max_allowed: usize,
272285
suggestion: Option<Box<RefactorSuggestion>>,
273286
},
274-
275-
#[error("Too many predicates in chain for '{predicate}': {count} exceeds batch limit of {max_allowed}")]
276-
TooManyPredicatesInChain {
277-
predicate: String,
278-
count: usize,
279-
max_allowed: usize,
280-
},
281287
}
282288

283289
impl From<ParseError> for LangError {
@@ -303,3 +309,9 @@ impl From<LoweringError> for LangError {
303309
LangError::Lowering(Box::new(err))
304310
}
305311
}
312+
313+
impl From<BatchingError> for LangError {
314+
fn from(err: BatchingError) -> Self {
315+
LangError::Batching(Box::new(err))
316+
}
317+
}

0 commit comments

Comments
 (0)