Skip to content

Commit 38d2144

Browse files
committed
optimize WHIR recursion program for memory usage
1 parent 44c3ced commit 38d2144

File tree

2 files changed

+79
-42
lines changed

2 files changed

+79
-42
lines changed

crates/rec_aggregation/recursion_program.lean_lang

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -377,18 +377,21 @@ fn merkle_verif_batch_dynamic(n_paths, leaves_digests, leave_positions, root, he
377377
if height == MERKLE_HEIGHT_0 {
378378
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_0);
379379
return;
380-
}
381-
if height == MERKLE_HEIGHT_1 {
382-
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1);
383-
return;
384-
}
385-
if height == MERKLE_HEIGHT_2 {
386-
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2);
387-
return;
388-
}
389-
if height == MERKLE_HEIGHT_3 {
390-
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3);
391-
return;
380+
} else {
381+
if height == MERKLE_HEIGHT_1 {
382+
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_1);
383+
return;
384+
} else {
385+
if height == MERKLE_HEIGHT_2 {
386+
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_2);
387+
return;
388+
} else {
389+
if height == MERKLE_HEIGHT_3 {
390+
merkle_verif_batch_const(n_paths, leaves_digests, leave_positions, root, MERKLE_HEIGHT_3);
391+
return;
392+
}
393+
}
394+
}
392395
}
393396

394397
print(12345555);
@@ -502,10 +505,22 @@ fn dot_product_ee_dynamic(a, b, res, n) {
502505
dot_product_ee(a, b, res, 16);
503506
return;
504507
}
508+
509+
dot_product_ee_dynamic_helper_1(a, b, res, n);
510+
return;
511+
}
512+
513+
fn dot_product_ee_dynamic_helper_1(a, b, res, n) {
505514
if n == NUM_QUERIES_0 {
506515
dot_product_ee(a, b, res, NUM_QUERIES_0);
507516
return;
508517
}
518+
519+
dot_product_ee_dynamic_helper_2(a, b, res, n);
520+
return;
521+
}
522+
523+
fn dot_product_ee_dynamic_helper_2(a, b, res, n) {
509524
if n == NUM_QUERIES_1 {
510525
dot_product_ee(a, b, res, NUM_QUERIES_1);
511526
return;
@@ -514,6 +529,12 @@ fn dot_product_ee_dynamic(a, b, res, n) {
514529
dot_product_ee(a, b, res, NUM_QUERIES_2);
515530
return;
516531
}
532+
533+
dot_product_ee_dynamic_helper_3(a, b, res, n);
534+
return;
535+
}
536+
537+
fn dot_product_ee_dynamic_helper_3(a, b, res, n) {
517538
if n == NUM_QUERIES_3 {
518539
dot_product_ee(a, b, res, NUM_QUERIES_3);
519540
return;
@@ -522,6 +543,12 @@ fn dot_product_ee_dynamic(a, b, res, n) {
522543
dot_product_ee(a, b, res, NUM_QUERIES_0 + 1);
523544
return;
524545
}
546+
547+
dot_product_ee_dynamic_helper_4(a, b, res, n);
548+
return;
549+
}
550+
551+
fn dot_product_ee_dynamic_helper_4(a, b, res, n) {
525552
if n == NUM_QUERIES_1 + 1 {
526553
dot_product_ee(a, b, res, NUM_QUERIES_1 + 1);
527554
return;
@@ -530,6 +557,12 @@ fn dot_product_ee_dynamic(a, b, res, n) {
530557
dot_product_ee(a, b, res, NUM_QUERIES_2 + 1);
531558
return;
532559
}
560+
561+
dot_product_ee_dynamic_helper_5(a, b, res, n);
562+
return;
563+
}
564+
565+
fn dot_product_ee_dynamic_helper_5(a, b, res, n) {
533566
if n == NUM_QUERIES_3 + 1 {
534567
dot_product_ee(a, b, res, NUM_QUERIES_3 + 1);
535568
return;
@@ -687,21 +720,26 @@ fn pow(a, b) -> 1 {
687720
}
688721

689722
fn sample_bits_dynamic(fs_state, n_samples, K) -> 2 {
723+
var new_fs_state;
724+
var sampled_bits;
690725
if n_samples == NUM_QUERIES_0 {
691726
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_0, K);
692727
return new_fs_state, sampled_bits;
693-
}
694-
if n_samples == NUM_QUERIES_1 {
695-
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K);
696-
return new_fs_state, sampled_bits;
697-
}
698-
if n_samples == NUM_QUERIES_2 {
699-
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K);
700-
return new_fs_state, sampled_bits;
701-
}
702-
if n_samples == NUM_QUERIES_3 {
703-
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K);
704-
return new_fs_state, sampled_bits;
728+
} else {
729+
if n_samples == NUM_QUERIES_1 {
730+
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_1, K);
731+
return new_fs_state, sampled_bits;
732+
} else {
733+
if n_samples == NUM_QUERIES_2 {
734+
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_2, K);
735+
return new_fs_state, sampled_bits;
736+
} else {
737+
if n_samples == NUM_QUERIES_3 {
738+
new_fs_state, sampled_bits = sample_bits_const(fs_state, NUM_QUERIES_3, K);
739+
return new_fs_state, sampled_bits;
740+
}
741+
}
742+
}
705743
}
706744
print(n_samples);
707745
print(999333);
@@ -861,26 +899,25 @@ fn fs_sample_helper(fs_state, n, res) -> 1 {
861899
new_fs_state[2] = fs_state[2];
862900
new_fs_state[3] = output_buffer_size - n;
863901
return new_fs_state;
864-
}
902+
} else {
903+
// duplexing
904+
l_r = malloc_vec(2);
905+
poseidon16(fs_state[1], fs_state[2], l_r, PERMUTATION);
906+
new_fs_state = malloc(4);
907+
new_fs_state[0] = fs_state[0];
908+
new_fs_state[1] = l_r;
909+
new_fs_state[2] = l_r + 1;
910+
new_fs_state[3] = 8; // output_buffer_size
865911

866-
// duplexing
867-
l_r = malloc_vec(2);
868-
poseidon16(fs_state[1], fs_state[2], l_r, PERMUTATION);
869-
new_fs_state = malloc(4);
870-
new_fs_state[0] = fs_state[0];
871-
new_fs_state[1] = l_r;
872-
new_fs_state[2] = l_r + 1;
873-
new_fs_state[3] = 8; // output_buffer_size
912+
remaining = n - output_buffer_size;
913+
if remaining == 0 {
914+
return new_fs_state;
915+
}
874916

875-
remaining = n - output_buffer_size;
876-
if remaining == 0 {
877-
return new_fs_state;
917+
shifted_res = res + output_buffer_size;
918+
final_res = fs_sample_helper(new_fs_state, remaining, shifted_res);
919+
return final_res;
878920
}
879-
880-
shifted_res = res + output_buffer_size;
881-
final_res = fs_sample_helper(new_fs_state, remaining, shifted_res);
882-
return final_res;
883-
884921
}
885922

886923
fn fs_hint(fs_state, n) -> 2 {

crates/rec_aggregation/src/recursion.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ pub fn run_whir_recursion_benchmark(tracing: bool, n_recursions: usize) {
155155
(&public_input, &[]),
156156
whir_config_builder(),
157157
no_vec_runtime_memory,
158-
false,
158+
true,
159159
(&vec![], &vec![]), // TODO precompute poseidons
160160
merkle_path_hints,
161161
);

0 commit comments

Comments
 (0)