Skip to content

Commit fdb808b

Browse files
fix: protect against weak Fiat-Shamir (#54)
* fix: protect against weak Fiat-Shamir * add log * feat: switch to row hash to avoid DFT overhead * feat: switch to bincode for stability * use inner pattern for code quality * Revert "feat: switch to bincode for stability" vkey size increases a lot This reverts commit 1e1f7ad.
1 parent 3bf6092 commit fdb808b

File tree

10 files changed

+70
-8
lines changed

10 files changed

+70
-8
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ async-trait = "0.1.83"
108108
getset = "0.1.3"
109109
rand = { version = "0.8.5", default-features = false }
110110
hex = { version = "0.4.3", default-features = false }
111+
bitcode = "0.6.5"
111112

112113
# default-features = false for no_std
113114
itertools = { version = "0.14.0", default-features = false }

crates/stark-backend/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ derive-new.workspace = true
2828
metrics = { workspace = true, optional = true }
2929
cfg-if.workspace = true
3030
thiserror.workspace = true
31-
async-trait.workspace = true
3231
rustc-hash.workspace = true
32+
bitcode = { workspace = true, features = ["serde"] }
3333

3434
[target.'cfg(unix)'.dependencies]
3535
tikv-jemallocator = { version = "0.6", optional = true }

crates/stark-backend/src/keygen/mod.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@ use std::{collections::HashMap, iter::zip, sync::Arc};
22

33
use itertools::Itertools;
44
use p3_commit::Pcs;
5-
use p3_field::{Field, FieldExtensionAlgebra};
6-
use p3_matrix::Matrix;
5+
use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
6+
use p3_matrix::{dense::RowMajorMatrix, Matrix};
77
use tracing::instrument;
8+
use types::MultiStarkVerifyingKey0;
89

910
use crate::{
1011
air_builders::symbolic::{get_symbolic_builder, SymbolicRapBuilder},
@@ -195,11 +196,33 @@ impl<'a, SC: StarkGenericConfig> MultiStarkKeygenBuilder<'a, SC> {
195196
threshold: log_up_security_params.max_interaction_count,
196197
});
197198

199+
let pre_vk: MultiStarkVerifyingKey0<SC> = MultiStarkVerifyingKey0 {
200+
per_air: pk_per_air.iter().map(|pk| pk.vk.clone()).collect(),
201+
trace_height_constraints: trace_height_constraints.clone(),
202+
log_up_pow_bits: log_up_security_params.log_up_pow_bits,
203+
};
204+
// To protect against weak Fiat-Shamir, we hash the "pre"-verifying key and include it in the
205+
// final verifying key. This just needs to commit to the verifying key and does not need to be
206+
// verified by the verifier, so we just use bincode to serialize it.
207+
let vk_bytes = bitcode::serialize(&pre_vk).unwrap();
208+
tracing::info!("pre-vkey: {} bytes", vk_bytes.len());
209+
// Purely to get type compatibility and convenience, we hash using pcs.commit as a single row
210+
let vk_as_row = RowMajorMatrix::new_row(
211+
vk_bytes
212+
.into_iter()
213+
.map(Val::<SC>::from_canonical_u8)
214+
.collect(),
215+
);
216+
let pcs = self.config.pcs();
217+
let deg_1_domain = pcs.natural_domain_for_degree(1);
218+
let (vk_pre_hash, _) = pcs.commit(vec![(deg_1_domain, vk_as_row)]);
219+
198220
MultiStarkProvingKey {
199221
per_air: pk_per_air,
200222
trace_height_constraints,
201223
max_constraint_degree: self.max_constraint_degree,
202224
log_up_pow_bits: log_up_security_params.log_up_pow_bits,
225+
vk_pre_hash,
203226
}
204227
}
205228
}

crates/stark-backend/src/keygen/types.rs

+25
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ pub struct StarkVerifyingKey<Val, Com> {
8484
deserialize = "Com<SC>: Deserialize<'de>"
8585
))]
8686
pub struct MultiStarkVerifyingKey<SC: StarkGenericConfig> {
87+
/// All parts of the verifying key needed by the verifier, except
88+
/// the `pre_hash` used to initialize the Fiat-Shamir transcript.
89+
pub inner: MultiStarkVerifyingKey0<SC>,
90+
/// The hash of all other parts of the verifying key. The Fiat-Shamir hasher will
91+
/// initialize by observing this hash.
92+
pub pre_hash: Com<SC>,
93+
}
94+
95+
/// Everything in [MultiStarkVerifyingKey] except the `pre_hash` used to initialize the Fiat-Shamir transcript.
96+
#[derive(Derivative, Serialize, Deserialize)]
97+
#[derivative(Clone(bound = "Com<SC>: Clone"))]
98+
#[serde(bound(
99+
serialize = "Com<SC>: Serialize",
100+
deserialize = "Com<SC>: Deserialize<'de>"
101+
))]
102+
pub struct MultiStarkVerifyingKey0<SC: StarkGenericConfig> {
87103
pub per_air: Vec<StarkVerifyingKey<Val<SC>, Com<SC>>>,
88104
pub trace_height_constraints: Vec<LinearConstraint>,
89105
pub log_up_pow_bits: usize,
@@ -129,6 +145,8 @@ pub struct MultiStarkProvingKey<SC: StarkGenericConfig> {
129145
/// Maximum degree of constraints across all AIRs
130146
pub max_constraint_degree: usize,
131147
pub log_up_pow_bits: usize,
148+
/// See [MultiStarkVerifyingKey]
149+
pub vk_pre_hash: Com<SC>,
132150
}
133151

134152
impl<Val, Com> StarkVerifyingKey<Val, Com> {
@@ -148,6 +166,13 @@ impl<Val, Com> StarkVerifyingKey<Val, Com> {
148166
impl<SC: StarkGenericConfig> MultiStarkProvingKey<SC> {
149167
pub fn get_vk(&self) -> MultiStarkVerifyingKey<SC> {
150168
MultiStarkVerifyingKey {
169+
inner: self.get_vk0(),
170+
pre_hash: self.vk_pre_hash.clone(),
171+
}
172+
}
173+
174+
fn get_vk0(&self) -> MultiStarkVerifyingKey0<SC> {
175+
MultiStarkVerifyingKey0 {
151176
per_air: self.per_air.iter().map(|pk| pk.vk.clone()).collect(),
152177
trace_height_constraints: self.trace_height_constraints.clone(),
153178
log_up_pow_bits: self.log_up_pow_bits,

crates/stark-backend/src/keygen/view.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ pub(crate) struct MultiStarkVerifyingKeyView<'a, Val, Com> {
1111
/// Trace height constraints are *not* filtered by AIR. When computing the dot product, this
1212
/// will be indexed into by air_id.
1313
pub trace_height_constraints: &'a [LinearConstraint],
14+
pub pre_hash: Com,
1415
}
1516

1617
impl<SC: StarkGenericConfig> MultiStarkVerifyingKey<SC> {
1718
/// Returns a view with all airs.
1819
pub(crate) fn full_view(&self) -> MultiStarkVerifyingKeyView<Val<SC>, Com<SC>> {
19-
self.view(&(0..self.per_air.len()).collect_vec())
20+
self.view(&(0..self.inner.per_air.len()).collect_vec())
2021
}
2122
pub(crate) fn view(&self, air_ids: &[usize]) -> MultiStarkVerifyingKeyView<Val<SC>, Com<SC>> {
2223
MultiStarkVerifyingKeyView {
23-
per_air: air_ids.iter().map(|&id| &self.per_air[id]).collect(),
24-
trace_height_constraints: &self.trace_height_constraints,
24+
per_air: air_ids.iter().map(|&id| &self.inner.per_air[id]).collect(),
25+
trace_height_constraints: &self.inner.trace_height_constraints,
26+
pre_hash: self.pre_hash.clone(),
2527
}
2628
}
2729
}

crates/stark-backend/src/prover/coordinator.rs

+2
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ where
8484
#[cfg(feature = "bench-metrics")]
8585
let start = std::time::Instant::now();
8686
assert!(mpk.validate(&ctx), "Invalid proof input");
87+
self.challenger.observe(mpk.vk_pre_hash.clone());
8788

8889
let num_air = ctx.per_air.len();
8990
info!(num_air);
@@ -307,6 +308,7 @@ impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKey<'a, PB> {
307308
MultiStarkVerifyingKeyView::new(
308309
self.per_air.iter().map(|pk| pk.vk).collect(),
309310
&self.trace_height_constraints,
311+
self.vk_pre_hash.clone(),
310312
)
311313
}
312314
}

crates/stark-backend/src/prover/cpu/mod.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,12 @@ where
443443
}
444444
})
445445
.collect();
446-
DeviceMultiStarkProvingKey::new(air_ids, per_air, mpk.trace_height_constraints.clone())
446+
DeviceMultiStarkProvingKey::new(
447+
air_ids,
448+
per_air,
449+
mpk.trace_height_constraints.clone(),
450+
mpk.vk_pre_hash.clone(),
451+
)
447452
}
448453
fn transport_matrix_to_device(
449454
&self,

crates/stark-backend/src/prover/types.rs

+3
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@ pub struct DeviceMultiStarkProvingKey<'a, PB: ProverBackend> {
1919
/// Each [LinearConstraint] is indexed by AIR ID.
2020
/// **Caution**: the linear constraints are **not** filtered for only the AIRs appearing in `per_air`.
2121
pub trace_height_constraints: Vec<LinearConstraint>,
22+
pub vk_pre_hash: PB::Commitment,
2223
}
2324

2425
impl<'a, PB: ProverBackend> DeviceMultiStarkProvingKey<'a, PB> {
2526
pub fn new(
2627
air_ids: Vec<usize>,
2728
per_air: Vec<DeviceStarkProvingKey<'a, PB>>,
2829
trace_height_constraints: Vec<LinearConstraint>,
30+
vk_pre_hash: PB::Commitment,
2931
) -> Self {
3032
assert_eq!(air_ids.len(), per_air.len());
3133
Self {
3234
air_ids,
3335
per_air,
3436
trace_height_constraints,
37+
vk_pre_hash,
3538
}
3639
}
3740
}

crates/stark-backend/src/verifier/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ impl<'c, SC: StarkGenericConfig> MultiTraceStarkVerifier<'c, SC> {
5757
mvk: &MultiStarkVerifyingKeyView<Val<SC>, Com<SC>>,
5858
proof: &Proof<SC>,
5959
) -> Result<(), VerificationError> {
60+
challenger.observe(mvk.pre_hash.clone());
6061
// Enforce trace height linear inequalities
6162
for constraint in mvk.trace_height_constraints {
6263
let sum = proof

crates/stark-backend/tests/interaction/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ fn test_interaction_trace_height_constraints() {
4646
keygen_builder.add_air(Arc::new(sender_air_2));
4747
keygen_builder.add_air(Arc::new(sender_air_3));
4848
let pk = keygen_builder.generate_pk();
49-
let vk = pk.get_vk();
49+
let vk = pk.get_vk().inner;
5050

5151
assert_eq!(vk.trace_height_constraints.len(), 3);
5252

0 commit comments

Comments
 (0)