Skip to content

Commit 83185c9

Browse files
committed
switch to SecurityConfig in gpu_prover
1 parent 8747a49 commit 83185c9

File tree

9 files changed

+120
-65
lines changed

9 files changed

+120
-65
lines changed

gpu_prover/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ prover = { workspace = true }
2323
riscv_transpiler = { workspace = true, features = ["jit"] }
2424
setups = { workspace = true }
2525
trace_and_split = { workspace = true }
26-
verifier_common = { workspace = true }
26+
verifier_common = { workspace = true, features = ["proof_utils"] }
2727
worker = { workspace = true }
2828
itertools = { workspace = true }
2929
log = { workspace = true }

gpu_prover/src/circuit_type.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::machine_type::MachineType;
22
use prover::definitions::OPTIMAL_FOLDING_PROPERTIES;
3+
use prover::prover_stages::ProofSecurityConfig;
34
use setups::{
45
add_sub_lui_auipc_mop, bigint_with_control, blake2_with_compression, inits_and_teardowns,
56
jump_branch_slt, keccak_special5, load_store_subword_only, load_store_word_only, mul_div,
@@ -75,6 +76,13 @@ impl CircuitType {
7576
// Self::Unrolled(unrolled_type) => unrolled_type.get_tree_cap_size(),
7677
// }
7778
}
79+
80+
pub fn get_security_config(&self) -> ProofSecurityConfig {
81+
match self {
82+
Self::Delegation(delegation_type) => delegation_type.get_security_config(),
83+
Self::Unrolled(unrolled_type) => unrolled_type.get_security_config(),
84+
}
85+
}
7886
}
7987

8088
#[repr(u32)]
@@ -155,6 +163,18 @@ impl DelegationCircuitType {
155163
MachineType::Reduced => &[DelegationCircuitType::Blake2WithCompression],
156164
}
157165
}
166+
167+
pub fn get_security_config(&self) -> ProofSecurityConfig {
168+
match self {
169+
Self::BigIntWithControl => {
170+
get_security_config::<{ bigint_with_control::DOMAIN_SIZE }>()
171+
}
172+
Self::Blake2WithCompression => {
173+
get_security_config::<{ blake2_with_compression::DOMAIN_SIZE }>()
174+
}
175+
Self::KeccakSpecial5 => get_security_config::<{ keccak_special5::DOMAIN_SIZE }>(),
176+
}
177+
}
158178
}
159179

160180
impl From<u16> for DelegationCircuitType {
@@ -272,6 +292,17 @@ impl UnrolledCircuitType {
272292
.iter()
273293
.map(|id| DelegationCircuitType::from(*id as u16))
274294
}
295+
296+
pub fn get_security_config(&self) -> ProofSecurityConfig {
297+
match self {
298+
Self::InitsAndTeardowns => {
299+
get_security_config::<{ inits_and_teardowns::DOMAIN_SIZE }>()
300+
}
301+
Self::Memory(circuit_type) => circuit_type.get_security_config(),
302+
Self::NonMemory(circuit_type) => circuit_type.get_security_config(),
303+
Self::Unified => get_security_config::<{ unified_reduced_machine::DOMAIN_SIZE }>(),
304+
}
305+
}
275306
}
276307

277308
#[repr(u32)]
@@ -361,6 +392,17 @@ impl UnrolledMemoryCircuitType {
361392
},
362393
}
363394
}
395+
396+
pub fn get_security_config(&self) -> ProofSecurityConfig {
397+
match self {
398+
Self::LoadStoreSubwordOnly => {
399+
get_security_config::<{ load_store_subword_only::DOMAIN_SIZE }>()
400+
}
401+
Self::LoadStoreWordOnly => {
402+
get_security_config::<{ load_store_word_only::DOMAIN_SIZE }>()
403+
}
404+
}
405+
}
364406
}
365407

366408
#[repr(u32)]
@@ -516,6 +558,18 @@ impl UnrolledNonMemoryCircuitType {
516558
Self::ShiftBinaryCsr => 4,
517559
}
518560
}
561+
562+
pub fn get_security_config(&self) -> ProofSecurityConfig {
563+
match self {
564+
Self::AddSubLuiAuipcMop => {
565+
get_security_config::<{ add_sub_lui_auipc_mop::DOMAIN_SIZE }>()
566+
}
567+
Self::JumpBranchSlt => get_security_config::<{ jump_branch_slt::DOMAIN_SIZE }>(),
568+
Self::MulDiv => get_security_config::<{ mul_div::DOMAIN_SIZE }>(),
569+
Self::MulDivUnsigned => get_security_config::<{ mul_div_unsigned::DOMAIN_SIZE }>(),
570+
Self::ShiftBinaryCsr => get_security_config::<{ shift_binary_csr::DOMAIN_SIZE }>(),
571+
}
572+
}
519573
}
520574

521575
#[inline(always)]
@@ -534,3 +588,21 @@ pub const fn get_tree_cap_size_for_log_domain_size(log_domain_size: u32) -> usiz
534588
pub const fn get_log_tree_cap_size_for_log_domain_size(log_domain_size: u32) -> usize {
535589
OPTIMAL_FOLDING_PROPERTIES[log_domain_size as usize].total_caps_size_log2
536590
}
591+
592+
const fn get_num_foldings<const DOMAIN_SIZE: usize>() -> usize {
593+
assert!(DOMAIN_SIZE.is_power_of_two());
594+
OPTIMAL_FOLDING_PROPERTIES[DOMAIN_SIZE.trailing_zeros() as usize]
595+
.folding_sequence
596+
.len()
597+
}
598+
599+
fn get_security_config<const DOMAIN_SIZE: usize>() -> ProofSecurityConfig
600+
where
601+
[(); get_num_foldings::<DOMAIN_SIZE>()]:,
602+
{
603+
assert!(DOMAIN_SIZE.is_power_of_two());
604+
let config = verifier_common::SizedProofSecurityConfig::<{
605+
get_num_foldings::<DOMAIN_SIZE>()
606+
}>::worst_case_config();
607+
config.for_prover()
608+
}

gpu_prover/src/execution/gpu_worker.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ use std::ffi::CStr;
2323
use std::mem;
2424
use std::ops::Deref;
2525
use std::process::exit;
26-
use verifier_common::num_queries_for_security_params;
2726

2827
pub fn get_gpu_worker_func(
2928
device_id: i32,
@@ -245,11 +244,7 @@ fn gpu_worker(
245244
CircuitType::Delegation(delegation) => Some(delegation as u16),
246245
CircuitType::Unrolled(_) => None,
247246
};
248-
let num_queries = num_queries_for_security_params(
249-
verifier_common::SECURITY_BITS,
250-
verifier_common::POW_BITS,
251-
log_lde_factor as usize,
252-
);
247+
let security_config = circuit_type.get_security_config();
253248
let trees_cache_mode = get_trees_cache_mode(circuit_type, &context);
254249
trace!("BATCH[{batch_id}] GPU_WORKER[{device_id}] producing proof for circuit {circuit_type:?}[{sequence_id}]");
255250
let job = prove(
@@ -264,8 +259,7 @@ fn gpu_worker(
264259
&precomputations.lde_precomputations,
265260
delegation_processing_type,
266261
precomputations.lde_precomputations.lde_factor,
267-
num_queries,
268-
verifier_common::SECURITY_BITS,
262+
&security_config,
269263
None,
270264
false,
271265
trees_cache_mode,

gpu_prover/src/prover/proof.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use prover::definitions::{
2626
};
2727
use prover::prover_stages::cached_data::ProverCachedData;
2828
use prover::prover_stages::unrolled_prover::UnrolledModeProof;
29-
use prover::prover_stages::{ProofPowChallenges, ProofPowConfig};
29+
use prover::prover_stages::{ProofPowChallenges, ProofSecurityConfig};
3030
use prover::transcript::Seed;
3131
use std::sync::Arc;
3232

@@ -81,8 +81,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
8181
lde_precomputations: &LdePrecomputations<impl GoodAllocator>,
8282
delegation_processing_type: Option<u16>,
8383
lde_factor: usize,
84-
num_queries: usize,
85-
security_bits: usize,
84+
security_config: &ProofSecurityConfig,
8685
external_pow_challenges: Option<ProofPowChallenges>,
8786
recompute_cosets: bool,
8887
trees_cache_mode: TreesCacheMode,
@@ -103,8 +102,6 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
103102
assert!(trace_len.is_power_of_two());
104103
let log_domain_size = trace_len.trailing_zeros();
105104
let optimal_folding = OPTIMAL_FOLDING_PROPERTIES[log_domain_size as usize];
106-
let optimal_folding_sequence_len = optimal_folding.folding_sequence.len();
107-
let pow_config = ProofPowConfig::worst_case_config(security_bits, optimal_folding_sequence_len);
108105
let delegation_processing_type = delegation_processing_type.unwrap_or_default();
109106
let cached_data_values = ProverCachedData::new(
110107
&circuit,
@@ -194,7 +191,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
194191
stage_2_range.start(stream)?;
195192
stage_2_output.generate(
196193
&mut seed,
197-
&pow_config,
194+
security_config,
198195
&external_pow_challenges,
199196
&circuit,
200197
is_unrolled,
@@ -213,7 +210,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
213210
stage_3_range.start(stream)?;
214211
let mut stage_3_output = StageThreeOutput::new(
215212
&mut seed,
216-
&pow_config,
213+
security_config,
217214
&external_pow_challenges,
218215
&circuit,
219216
is_unrolled,
@@ -238,7 +235,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
238235
stage_4_range.start(stream)?;
239236
let mut stage_4_output = StageFourOutput::new(
240237
&mut seed,
241-
&pow_config,
238+
security_config,
242239
&external_pow_challenges,
243240
&circuit,
244241
is_unrolled,
@@ -262,13 +259,12 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
262259
stage_5_range.start(stream)?;
263260
let stage_5_output = StageFiveOutput::new(
264261
&mut seed,
265-
&pow_config,
262+
security_config,
266263
&external_pow_challenges,
267264
&mut stage_4_output,
268265
log_domain_size,
269266
log_lde_factor,
270267
&optimal_folding,
271-
num_queries,
272268
&lde_precomputations,
273269
&mut callbacks,
274270
context,
@@ -281,7 +277,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
281277
let fri_queries_pow_range = device_tracing::Range::new("fri_queries_pow")?;
282278
fri_queries_pow_range.start(stream)?;
283279
let mut fri_queries_pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
284-
let fri_queries_pow_bits = pow_config.fri_queries_pow_bits;
280+
let fri_queries_pow_bits = security_config.fri_queries_pow_bits;
285281
assert_ne!(fri_queries_pow_bits, 0);
286282
search_pow_challenge(
287283
&mut seed,
@@ -310,7 +306,7 @@ pub(crate) fn prove<'a, A: GoodAllocator>(
310306
&stage_5_output,
311307
log_domain_size,
312308
log_lde_factor,
313-
num_queries,
309+
security_config.num_queries,
314310
&optimal_folding,
315311
&mut callbacks,
316312
context,

gpu_prover/src/prover/stage_2.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use era_cudart::slice::DeviceSlice;
2525
use field::{Field, FieldExtension};
2626
use prover::definitions::Transcript;
2727
use prover::prover_stages::cached_data::ProverCachedData;
28-
use prover::prover_stages::{ProofPowChallenges, ProofPowConfig};
28+
use prover::prover_stages::{ProofPowChallenges, ProofSecurityConfig};
2929
use prover::transcript::Seed;
3030
use std::slice;
3131

@@ -79,7 +79,7 @@ impl StageTwoOutput {
7979
pub fn generate(
8080
&mut self,
8181
seed: &mut HostAllocation<Seed>,
82-
pow_config: &ProofPowConfig,
82+
security_config: &ProofSecurityConfig,
8383
external_challenges: &Option<ProofPowChallenges>,
8484
circuit: &CompiledCircuitArtifact<BF>,
8585
is_unrolled: bool,
@@ -100,7 +100,7 @@ impl StageTwoOutput {
100100
let stream = context.get_exec_stream();
101101
let seed_accessor = seed.get_mut_accessor();
102102
let mut pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
103-
let pow_bits = pow_config.lookup_pow_bits;
103+
let pow_bits = security_config.lookup_pow_bits;
104104
search_pow_challenge(
105105
seed,
106106
&mut pow_challenge,

gpu_prover/src/prover/stage_3.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use field::FieldExtension;
2020
use prover::definitions::AuxArgumentsBoundaryValues;
2121
use prover::prover_stages::cached_data::ProverCachedData;
2222
use prover::prover_stages::stage3::AlphaPowersLayout;
23-
use prover::prover_stages::{ProofPowChallenges, ProofPowConfig};
23+
use prover::prover_stages::{ProofPowChallenges, ProofSecurityConfig};
2424
use prover::transcript::Seed;
2525
use std::alloc::Global;
2626
use std::slice;
@@ -34,7 +34,7 @@ pub(crate) struct StageThreeOutput {
3434
impl StageThreeOutput {
3535
pub fn new(
3636
seed: &mut HostAllocation<Seed>,
37-
pow_config: &ProofPowConfig,
37+
security_config: &ProofSecurityConfig,
3838
external_challenges: &Option<ProofPowChallenges>,
3939
circuit: &Arc<CompiledCircuitArtifact<BF>>,
4040
is_unrolled: bool,
@@ -69,7 +69,7 @@ impl StageThreeOutput {
6969
let stream = context.get_exec_stream();
7070
let seed_accessor = seed.get_mut_accessor();
7171
let mut pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
72-
let pow_bits = pow_config.quotient_alpha_pow_bits;
72+
let pow_bits = security_config.quotient_alpha_pow_bits;
7373
search_pow_challenge(
7474
seed,
7575
&mut pow_challenge,

gpu_prover/src/prover/stage_4.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use field::{Field, FieldExtension};
3030
use itertools::Itertools;
3131
use prover::definitions::FoldingDescription;
3232
use prover::prover_stages::cached_data::ProverCachedData;
33-
use prover::prover_stages::{ProofPowChallenges, ProofPowConfig, Transcript};
33+
use prover::prover_stages::{ProofPowChallenges, ProofSecurityConfig, Transcript};
3434
use prover::transcript::Seed;
3535
use std::ops::DerefMut;
3636
use std::slice;
@@ -46,7 +46,7 @@ pub(crate) struct StageFourOutput {
4646
impl StageFourOutput {
4747
pub fn new(
4848
seed: &mut HostAllocation<Seed>,
49-
pow_config: &ProofPowConfig,
49+
security_config: &ProofSecurityConfig,
5050
external_challenges: &Option<ProofPowChallenges>,
5151
circuit: &Arc<CompiledCircuitArtifact<BF>>,
5252
is_unrolled: bool,
@@ -89,7 +89,7 @@ impl StageFourOutput {
8989
}
9090
let stream = context.get_exec_stream();
9191
let mut quotient_z_pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
92-
let quotient_z_pow_bits = pow_config.quotient_z_pow_bits;
92+
let quotient_z_pow_bits = security_config.quotient_z_pow_bits;
9393
search_pow_challenge(
9494
seed,
9595
&mut quotient_z_pow_challenge,
@@ -212,7 +212,7 @@ impl StageFourOutput {
212212
};
213213
callbacks.schedule(commit_values_at_z, stream)?;
214214
let mut deep_poly_alpha_pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
215-
let deep_poly_alpha_pow_bits = pow_config.deep_poly_alpha_pow_bits;
215+
let deep_poly_alpha_pow_bits = security_config.deep_poly_alpha_pow_bits;
216216
search_pow_challenge(
217217
seed,
218218
&mut deep_poly_alpha_pow_challenge,

gpu_prover/src/prover/stage_5.rs

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use fft::{
1919
use field::{Field, FieldExtension, Mersenne31Field};
2020
use itertools::Itertools;
2121
use prover::definitions::{FoldingDescription, Transcript};
22-
use prover::prover_stages::{ProofPowChallenges, ProofPowConfig};
22+
use prover::prover_stages::{ProofPowChallenges, ProofSecurityConfig};
2323
use prover::transcript::Seed;
2424
use std::iter;
2525

@@ -48,13 +48,12 @@ pub(crate) struct StageFiveOutput {
4848
impl StageFiveOutput {
4949
pub fn new<'a>(
5050
seed: &mut HostAllocation<Seed>,
51-
pow_config: &ProofPowConfig,
51+
security_config: &ProofSecurityConfig,
5252
external_challenges: &Option<ProofPowChallenges>,
5353
stage_4_output: &mut StageFourOutput,
5454
log_domain_size: u32,
5555
log_lde_factor: u32,
5656
folding_description: &FoldingDescription,
57-
num_queries: usize,
5857
lde_precomputations: &LdePrecomputations<impl GoodAllocator>,
5958
callbacks: &mut Callbacks<'a>,
6059
context: &ProverContext,
@@ -64,7 +63,7 @@ impl StageFiveOutput {
6463
let lde_factor = 1usize << log_lde_factor;
6564
let mut log_current_domain_size = log_domain_size;
6665
assert_eq!(
67-
pow_config.foldings_pow_bits.len(),
66+
security_config.foldings_pow_bits.len(),
6867
folding_description.folding_sequence.len()
6968
);
7069
let oracles_count = folding_description.folding_sequence.len() - 1;
@@ -107,7 +106,7 @@ impl StageFiveOutput {
107106
&fri_oracles[i - 1].ldes
108107
};
109108
let mut pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
110-
let pow_bits = pow_config.foldings_pow_bits[i];
109+
let pow_bits = security_config.foldings_pow_bits[i];
111110
search_pow_challenge(
112111
seed,
113112
&mut pow_challenge,
@@ -154,7 +153,10 @@ impl StageFiveOutput {
154153
}
155154
d_challenges.free();
156155
let expose_all_leafs = if i == oracles_count - 1 {
157-
let log_bound = num_queries.next_power_of_two().trailing_zeros();
156+
let log_bound = security_config
157+
.num_queries
158+
.next_power_of_two()
159+
.trailing_zeros();
158160
log_num_leafs + 1 - log_lde_factor <= log_bound
159161
} else {
160162
false
@@ -246,7 +248,7 @@ impl StageFiveOutput {
246248
);
247249
let final_monomials = {
248250
let mut pow_challenge = unsafe { context.alloc_host_uninit::<u64>() };
249-
let pow_bits = pow_config.foldings_pow_bits[oracles_count];
251+
let pow_bits = security_config.foldings_pow_bits[oracles_count];
250252
search_pow_challenge(
251253
seed,
252254
&mut pow_challenge,

0 commit comments

Comments
 (0)