Skip to content

Commit 681e6fd

Browse files
authored
Poseidon GKR: Move the matrix multiplications out of the sumchecks (#84)
* wip * full rounds OK * fix weird memory slowdown * display proof size GKR / PCS * partial rounds OK * allow N_COMMITED_CUBES = 0 * fix * graphs * branch * clippy * typo --------- Co-authored-by: Tom Wambsgans <[email protected]>
1 parent de09723 commit 681e6fd

File tree

17 files changed

+654
-552
lines changed

17 files changed

+654
-552
lines changed

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/lean_prover/src/common.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ use poseidon_circuit::{PoseidonGKRLayers, default_cube_layers};
99
use crate::*;
1010
use lean_vm::*;
1111

12-
pub(crate) const N_COMMITED_CUBES_P16: usize = KOALABEAR_RC16_INTERNAL.len() - 1;
13-
pub(crate) const N_COMMITED_CUBES_P24: usize = KOALABEAR_RC24_INTERNAL.len() - 1;
12+
pub(crate) const N_COMMITED_CUBES_P16: usize = KOALABEAR_RC16_INTERNAL.len() - 2;
13+
pub(crate) const N_COMMITED_CUBES_P24: usize = KOALABEAR_RC24_INTERNAL.len() - 2;
1414

1515
pub fn get_base_dims(
1616
n_cycles: usize,

crates/lean_prover/witness_generation/src/poseidon_tables.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use p3_field::PrimeCharacteristicRing;
66
use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters};
77
use p3_monty_31::InternalLayerBaseParameters;
88
use poseidon_circuit::{PoseidonGKRLayers, PoseidonWitness, generate_poseidon_witness};
9+
use tracing::instrument;
910
use utils::{padd_with_zero_to_next_power_of_two, transposed_par_iter_mut};
1011

1112
pub fn all_poseidon_16_indexes(poseidons_16: &[WitnessPoseidon16]) -> [Vec<F>; 3] {
@@ -84,6 +85,7 @@ pub fn full_poseidon_indexes_poly(
8485
all_poseidon_indexes
8586
}
8687

88+
#[instrument(skip_all)]
8789
pub fn generate_poseidon_witness_helper<
8890
const WIDTH: usize,
8991
const N_COMMITED_CUBES: usize,
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
use multilinear_toolkit::prelude::*;
2+
use p3_field::ExtensionField;
3+
4+
use crate::{EF, F};
5+
6+
#[derive(Debug)]
7+
pub struct CompressionComputation<const WIDTH: usize> {
8+
pub compressed_output: usize,
9+
}
10+
11+
impl<NF: ExtensionField<F>, const WIDTH: usize> SumcheckComputation<NF, EF>
12+
for CompressionComputation<WIDTH>
13+
where
14+
EF: ExtensionField<NF>,
15+
{
16+
fn degree(&self) -> usize {
17+
2
18+
}
19+
20+
fn eval(&self, point: &[NF], alpha_powers: &[EF]) -> EF {
21+
debug_assert_eq!(point.len(), WIDTH + 1);
22+
let mut res = EF::ZERO;
23+
let compressed = point[WIDTH];
24+
for i in 0..self.compressed_output {
25+
res += alpha_powers[i] * point[i];
26+
}
27+
for i in self.compressed_output..WIDTH {
28+
res += alpha_powers[i] * point[i] * (EF::ONE - compressed);
29+
}
30+
31+
res
32+
}
33+
}
34+
35+
impl<const WIDTH: usize> SumcheckComputationPacked<EF> for CompressionComputation<WIDTH> {
36+
fn degree(&self) -> usize {
37+
2
38+
}
39+
40+
fn eval_packed_base(&self, point: &[PFPacking<EF>], alpha_powers: &[EF]) -> EFPacking<EF> {
41+
debug_assert_eq!(point.len(), WIDTH + 1);
42+
let mut res = EFPacking::<EF>::ZERO;
43+
let compressed = point[WIDTH];
44+
for i in 0..self.compressed_output {
45+
res += EFPacking::<EF>::from(alpha_powers[i]) * point[i];
46+
}
47+
for i in self.compressed_output..WIDTH {
48+
res += EFPacking::<EF>::from(alpha_powers[i])
49+
* point[i]
50+
* (PFPacking::<EF>::ONE - compressed);
51+
}
52+
53+
res
54+
}
55+
56+
fn eval_packed_extension(&self, point: &[EFPacking<EF>], alpha_powers: &[EF]) -> EFPacking<EF> {
57+
debug_assert_eq!(point.len(), WIDTH + 1);
58+
let mut res = EFPacking::<EF>::ZERO;
59+
let compressed = point[WIDTH];
60+
for i in 0..self.compressed_output {
61+
res += point[i] * alpha_powers[i];
62+
}
63+
for i in self.compressed_output..WIDTH {
64+
res += point[i] * (EFPacking::<EF>::ONE - compressed) * alpha_powers[i];
65+
}
66+
67+
res
68+
}
69+
}

crates/poseidon_circuit/src/gkr_layers/full_round.rs

Lines changed: 16 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,144 +1,55 @@
1-
use std::array;
2-
31
use multilinear_toolkit::prelude::*;
42
use p3_field::ExtensionField;
5-
use p3_koala_bear::{
6-
GenericPoseidon2LinearLayersKoalaBear, KoalaBearInternalLayerParameters, KoalaBearParameters,
7-
};
3+
use p3_koala_bear::{KoalaBearInternalLayerParameters, KoalaBearParameters};
84
use p3_monty_31::InternalLayerBaseParameters;
9-
use p3_poseidon2::GenericPoseidon2LinearLayers;
105

116
use crate::{EF, F};
127

138
#[derive(Debug)]
14-
pub struct FullRoundComputation<const WIDTH: usize, const FIRST: bool> {
15-
pub constants: [F; WIDTH],
16-
pub compressed_output: Option<usize>,
17-
}
9+
pub struct FullRoundComputation<const WIDTH: usize> {}
1810

19-
impl<NF: ExtensionField<F>, const WIDTH: usize, const FIRST: bool> SumcheckComputation<NF, EF>
20-
for FullRoundComputation<WIDTH, FIRST>
11+
impl<NF: ExtensionField<F>, const WIDTH: usize> SumcheckComputation<NF, EF>
12+
for FullRoundComputation<WIDTH>
2113
where
2214
KoalaBearInternalLayerParameters: InternalLayerBaseParameters<KoalaBearParameters, WIDTH>,
2315
EF: ExtensionField<NF>,
2416
{
2517
fn degree(&self) -> usize {
26-
3 + self.compressed_output.is_some() as usize
18+
3
2719
}
2820

2921
fn eval(&self, point: &[NF], alpha_powers: &[EF]) -> EF {
30-
debug_assert_eq!(
31-
point.len(),
32-
WIDTH
33-
+ if self.compressed_output.is_some() {
34-
1
35-
} else {
36-
0
37-
}
38-
);
39-
let mut buff: [NF; WIDTH] = array::from_fn(|j| point[j]);
40-
if FIRST {
41-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
42-
}
43-
buff.iter_mut().enumerate().for_each(|(j, val)| {
44-
*val = (*val + self.constants[j]).cube();
45-
});
46-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
22+
debug_assert_eq!(point.len(), WIDTH);
4723
let mut res = EF::ZERO;
48-
if let Some(compression_output_width) = self.compressed_output {
49-
let compressed = point[WIDTH];
50-
for i in 0..compression_output_width {
51-
res += alpha_powers[i] * buff[i];
52-
}
53-
for i in compression_output_width..WIDTH {
54-
res += alpha_powers[i] * buff[i] * (EF::ONE - compressed);
55-
}
56-
} else {
57-
for i in 0..WIDTH {
58-
res += alpha_powers[i] * buff[i];
59-
}
24+
for i in 0..WIDTH {
25+
res += alpha_powers[i] * point[i].cube();
6026
}
6127
res
6228
}
6329
}
6430

65-
impl<const WIDTH: usize, const FIRST: bool> SumcheckComputationPacked<EF>
66-
for FullRoundComputation<WIDTH, FIRST>
31+
impl<const WIDTH: usize> SumcheckComputationPacked<EF> for FullRoundComputation<WIDTH>
6732
where
6833
KoalaBearInternalLayerParameters: InternalLayerBaseParameters<KoalaBearParameters, WIDTH>,
6934
{
7035
fn degree(&self) -> usize {
71-
3 + self.compressed_output.is_some() as usize
36+
3
7237
}
7338

7439
fn eval_packed_base(&self, point: &[PFPacking<EF>], alpha_powers: &[EF]) -> EFPacking<EF> {
75-
debug_assert_eq!(
76-
point.len(),
77-
WIDTH
78-
+ if self.compressed_output.is_some() {
79-
1
80-
} else {
81-
0
82-
}
83-
);
84-
let mut buff: [PFPacking<EF>; WIDTH] = array::from_fn(|j| point[j]);
85-
if FIRST {
86-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
87-
}
88-
buff.iter_mut().enumerate().for_each(|(j, val)| {
89-
*val = (*val + self.constants[j]).cube();
90-
});
91-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
40+
debug_assert_eq!(point.len(), WIDTH);
9241
let mut res = EFPacking::<EF>::ZERO;
93-
if let Some(compression_output_width) = self.compressed_output {
94-
let compressed = point[WIDTH];
95-
for i in 0..compression_output_width {
96-
res += EFPacking::<EF>::from(alpha_powers[i]) * buff[i];
97-
}
98-
for i in compression_output_width..WIDTH {
99-
res += EFPacking::<EF>::from(alpha_powers[i])
100-
* buff[i]
101-
* (PFPacking::<EF>::ONE - compressed);
102-
}
103-
} else {
104-
for j in 0..WIDTH {
105-
res += EFPacking::<EF>::from(alpha_powers[j]) * buff[j];
106-
}
42+
for i in 0..WIDTH {
43+
res += EFPacking::<EF>::from(alpha_powers[i]) * point[i].cube();
10744
}
10845
res
10946
}
11047

11148
fn eval_packed_extension(&self, point: &[EFPacking<EF>], alpha_powers: &[EF]) -> EFPacking<EF> {
112-
debug_assert_eq!(
113-
point.len(),
114-
WIDTH
115-
+ if self.compressed_output.is_some() {
116-
1
117-
} else {
118-
0
119-
}
120-
);
121-
let mut buff: [EFPacking<EF>; WIDTH] = array::from_fn(|j| point[j]);
122-
if FIRST {
123-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
124-
}
125-
buff.iter_mut().enumerate().for_each(|(j, val)| {
126-
*val = (*val + PFPacking::<EF>::from(self.constants[j])).cube();
127-
});
128-
GenericPoseidon2LinearLayersKoalaBear::external_linear_layer(&mut buff);
49+
debug_assert_eq!(point.len(), WIDTH);
12950
let mut res = EFPacking::<EF>::ZERO;
130-
if let Some(compression_output_width) = self.compressed_output {
131-
let compressed = point[WIDTH];
132-
for i in 0..compression_output_width {
133-
res += buff[i] * alpha_powers[i];
134-
}
135-
for i in compression_output_width..WIDTH {
136-
res += buff[i] * (EFPacking::<EF>::ONE - compressed) * alpha_powers[i];
137-
}
138-
} else {
139-
for j in 0..WIDTH {
140-
res += buff[j] * alpha_powers[j];
141-
}
51+
for i in 0..WIDTH {
52+
res += point[i].cube() * alpha_powers[i];
14253
}
14354
res
14455
}

crates/poseidon_circuit/src/gkr_layers/mod.rs

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ pub use partial_round::*;
77
mod batch_partial_rounds;
88
pub use batch_partial_rounds::*;
99

10+
mod compression;
11+
pub use compression::*;
12+
1013
use p3_koala_bear::{
1114
KOALABEAR_RC16_EXTERNAL_FINAL, KOALABEAR_RC16_EXTERNAL_INITIAL, KOALABEAR_RC16_INTERNAL,
1215
KOALABEAR_RC24_EXTERNAL_FINAL, KOALABEAR_RC24_EXTERNAL_INITIAL, KOALABEAR_RC24_INTERNAL,
@@ -16,11 +19,11 @@ use crate::F;
1619

1720
#[derive(Debug)]
1821
pub struct PoseidonGKRLayers<const WIDTH: usize, const N_COMMITED_CUBES: usize> {
19-
pub initial_full_round: FullRoundComputation<WIDTH, true>,
20-
pub initial_full_rounds_remaining: Vec<FullRoundComputation<WIDTH, false>>,
21-
pub batch_partial_rounds: BatchPartialRounds<WIDTH, N_COMMITED_CUBES>,
22-
pub partial_rounds_remaining: Vec<PartialRoundComputation<WIDTH>>,
23-
pub final_full_rounds: Vec<FullRoundComputation<WIDTH, false>>,
22+
pub initial_full_rounds: Vec<[F; WIDTH]>,
23+
pub batch_partial_rounds: Option<BatchPartialRounds<WIDTH, N_COMMITED_CUBES>>,
24+
pub partial_rounds_remaining: Vec<F>,
25+
pub final_full_rounds: Vec<[F; WIDTH]>,
26+
pub compressed_output: Option<usize>,
2427
}
2528

2629
impl<const WIDTH: usize, const N_COMMITED_CUBES: usize> PoseidonGKRLayers<WIDTH, N_COMMITED_CUBES> {
@@ -54,43 +57,26 @@ impl<const WIDTH: usize, const N_COMMITED_CUBES: usize> PoseidonGKRLayers<WIDTH,
5457
final_constants: &[[F; WIDTH]],
5558
compressed_output: Option<usize>,
5659
) -> Self {
57-
let initial_full_round = FullRoundComputation {
58-
constants: initial_constants[0],
59-
compressed_output: None,
60-
};
61-
let initial_full_rounds_remaining = initial_constants[1..]
62-
.iter()
63-
.map(|&constants| FullRoundComputation {
64-
constants,
65-
compressed_output: None,
66-
})
67-
.collect::<Vec<_>>();
68-
let batch_partial_rounds = BatchPartialRounds {
69-
constants: internal_constants[..N_COMMITED_CUBES].try_into().unwrap(),
70-
last_constant: internal_constants[N_COMMITED_CUBES],
60+
assert!(N_COMMITED_CUBES < internal_constants.len() - 1); // TODO we could go up to internal_constants.len() in theory
61+
let initial_full_rounds = initial_constants.to_vec();
62+
let (batch_partial_rounds, partial_rounds_remaining) = if N_COMMITED_CUBES == 0 {
63+
(None, internal_constants.to_vec())
64+
} else {
65+
(
66+
Some(BatchPartialRounds {
67+
constants: internal_constants[..N_COMMITED_CUBES].try_into().unwrap(),
68+
last_constant: internal_constants[N_COMMITED_CUBES],
69+
}),
70+
internal_constants[N_COMMITED_CUBES + 1..].to_vec(),
71+
)
7172
};
72-
let partial_rounds_remaining = internal_constants[N_COMMITED_CUBES + 1..]
73-
.iter()
74-
.map(|&constant| PartialRoundComputation { constant })
75-
.collect::<Vec<_>>();
76-
let final_full_rounds = final_constants
77-
.iter()
78-
.enumerate()
79-
.map(|(i, &constants)| FullRoundComputation {
80-
constants,
81-
compressed_output: if i == final_constants.len() - 1 {
82-
compressed_output
83-
} else {
84-
None
85-
},
86-
})
87-
.collect::<Vec<_>>();
73+
let final_full_rounds = final_constants.to_vec();
8874
Self {
89-
initial_full_round,
90-
initial_full_rounds_remaining,
75+
initial_full_rounds,
9176
batch_partial_rounds,
9277
partial_rounds_remaining,
9378
final_full_rounds,
79+
compressed_output,
9480
}
9581
}
9682
}

0 commit comments

Comments
 (0)