Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
8a74c09
wip
robik75 Jan 23, 2026
9318b8b
benches
robik75 Jan 30, 2026
979e5de
wip
robik75 Feb 2, 2026
ff72c7b
fix(gpu_prover): fix blake2s tests
robik75 Feb 2, 2026
0b76678
wip
robik75 Feb 3, 2026
8105df3
compiles
Feb 4, 2026
00a4c9f
add device context with twiddles
Feb 5, 2026
0b2437b
stage for L4 test
Feb 5, 2026
3bee484
stage for ncu source
Feb 5, 2026
6bf2fa5
stage for full source profile
Feb 5, 2026
1290006
try kernel with multiple blocks per SM
Feb 6, 2026
5d7c9a0
coalesced experiment
Feb 10, 2026
7262513
register pipeline experiment
Feb 10, 2026
7accb9b
another experiment
Feb 10, 2026
3ceef15
pipeline compute and LDGSTS
Feb 10, 2026
05e8eda
producer-consumer kernel works
Feb 11, 2026
7de3426
fix ldgsts alignment
Feb 11, 2026
990a662
use cmem for all warp-uniform twiddles
Feb 13, 2026
c52931b
logn=24 with split cmem and smem twiddles works
Feb 16, 2026
bcaffed
logn=23 works
Feb 17, 2026
4cfa512
first 16 stages of 3-pass work, also faster than sppark on L4 :)
Feb 20, 2026
4dfd8b9
stage for 5090
Feb 20, 2026
2152dab
wip, debugging dataflow
Feb 20, 2026
2581e93
final stages experiments
Feb 21, 2026
1453be2
3 stage works, 7-8% faster than sppark on 5090
Feb 22, 2026
3c1a63a
first cut at logn=21,22,23 with same pattern
Feb 22, 2026
c1644ce
three-pass works for logn=21,22,23,24, also three and two-pass suppor…
Feb 23, 2026
afb2914
wip: monomials to evals
Feb 24, 2026
6b9b3bd
add option to write chunk-transposed monomials
Feb 26, 2026
a0ef88e
wip: monomials to evals
Feb 27, 2026
83591da
remove experiment graveyard. first b2n kernel works.
Feb 28, 2026
0826f08
fix silly bugs. 3 pass monomials to evals works end to end
Mar 1, 2026
b021738
fix output chaining in 3-pass, rustfmt
Mar 2, 2026
65f54ae
wip: 2-pass monomials to evals
Mar 2, 2026
3b88d42
2-pass monomials to evals dataflow works
Mar 3, 2026
4d8ce60
2-pass monomials to evals works for logn=23,24
Mar 3, 2026
7fb7f49
hacks
Mar 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ members = [
"execution_utils",

"gpu_prover",
"gpu_prover_new",
"gpu_prover_test",

"poseidon2",
Expand Down Expand Up @@ -171,6 +172,7 @@ execution_utils = { path = "./execution_utils", default-features = false }
witness_eval_generator = { path = "./witness_eval_generator", default-features = false }
gpu_witness_eval_generator = { path = "./gpu_witness_eval_generator" }
gpu_prover = { path = "./gpu_prover" }
gpu_prover_new = { path = "./gpu_prover_new" }

serde = { version = "1", default-features = false, features = ["derive", "alloc"] }
clap = { version = "4.5.21", features = ["derive"] }
Expand Down
4 changes: 2 additions & 2 deletions circuit_defs/circuit_common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ pub fn run_for_witness() -> () {
pub fn u32_from_field_elems(src: &[Mersenne31Field; 2]) -> u32 {
use field::PrimeField;

let low = u16::try_from(src[0].as_u64_reduced()).expect("read value is not 16 bit long") as u32;
let low = u16::try_from(src[0].as_u32_reduced()).expect("read value is not 16 bit long") as u32;
let high =
u16::try_from(src[1].as_u64_reduced()).expect("read value is not 16 bit long") as u32;
u16::try_from(src[1].as_u32_reduced()).expect("read value is not 16 bit long") as u32;
low + (high << 16)
}

Expand Down
147 changes: 82 additions & 65 deletions fft/src/column_major/naive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ pub fn serial_ct_ntt_natural_to_bitreversed<F: Field, E: Field + FieldExtension<
distance /= 2;
}

let mut stage = 1;
while num_groups < n {
debug_assert!(num_groups > 1);
let mut k = 0;
Expand Down Expand Up @@ -162,6 +163,8 @@ pub fn serial_ct_ntt_natural_to_bitreversed<F: Field, E: Field + FieldExtension<
pairs_per_group /= 2;
num_groups *= 2;
distance /= 2;
if stage == 23 { break; }
stage += 1;
}
}

Expand All @@ -184,6 +187,8 @@ pub fn serial_ct_ntt_bitreversed_to_natural<F: Field, E: Field + FieldExtension<
let mut num_groups = n / 2;
let mut distance = 1;

let mut stage = 0;

while num_groups > 1 {
// println!("num_groups: {:?}", num_groups);
debug_assert!(num_groups > 1);
Expand All @@ -205,6 +210,10 @@ pub fn serial_ct_ntt_bitreversed_to_natural<F: Field, E: Field + FieldExtension<
a[j + distance] = tmp;
a[j].add_assign(&v);

// if k == 0 && j == 0 {
// println!("from cpu: {} {} {}", u, v, s);
// }

j += 1;
}

Expand All @@ -214,6 +223,11 @@ pub fn serial_ct_ntt_bitreversed_to_natural<F: Field, E: Field + FieldExtension<
pairs_per_group *= 2;
num_groups /= 2;
distance *= 2;

stage += 1;
// if stage == 18 {
// break;
// }
}

{
Expand Down Expand Up @@ -287,69 +301,72 @@ pub fn cache_friendly_ntt_natural_to_bitreversed<F: Field, E: Field + FieldExten
round += 1;
}

while round < cache_friendly_round {
debug_assert!(num_groups > 1);
let mut k = 0;
while k < num_groups {
let idx_1 = k * pairs_per_group * 2;
let idx_2 = idx_1 + pairs_per_group;
let s = omegas_bit_reversed[k];

let mut j = idx_1;
while j < idx_2 {
let mut u = a[j];
let mut v = a[j + distance];
v.mul_assign_by_base(&s);
u.sub_assign(&v);
a[j + distance] = u;
a[j].add_assign(&v);
j += 1;
}
k += 1;
}

pairs_per_group /= 2;
num_groups *= 2;
distance /= 2;
round += 1;
}
let mut cache_bunch = 0;
while cache_bunch < num_groups {
// num_groups=128 // round loop
let mut pairs_per_group_in_cache = pairs_per_group;
let mut distance_in_cache = distance;
let mut num_groups_in_cache = 1;
let num_rounds_in_cache = log_n - round; // 17

let mut round = 0;
while round < num_rounds_in_cache {
// experiment

let mut k = 0;
while k < num_groups_in_cache {
// group loop
let idx_1 = cache_bunch * pairs_per_group * 2 + k * pairs_per_group_in_cache * 2;
let idx_2 = idx_1 + pairs_per_group_in_cache;
let s = omegas_bit_reversed[cache_bunch * num_groups_in_cache + k];

let mut j = idx_1;
while j < idx_2 {
let mut u = a[j];
let mut v = a[j + distance_in_cache];
v.mul_assign_by_base(&s);
u.sub_assign(&v);
a[j + distance_in_cache] = u;
a[j].add_assign(&v);

j += 1;
}
k += 1;
}
pairs_per_group_in_cache /= 2;
num_groups_in_cache *= 2;
distance_in_cache /= 2;
round += 1;
}
cache_bunch += 1;
}
// while round < cache_friendly_round {
// debug_assert!(num_groups > 1);
// let mut k = 0;
// while k < num_groups {
// let idx_1 = k * pairs_per_group * 2;
// let idx_2 = idx_1 + pairs_per_group;
// let s = omegas_bit_reversed[k];

// let mut j = idx_1;
// while j < idx_2 {
// let mut u = a[j];
// let mut v = a[j + distance];
// v.mul_assign_by_base(&s);
// u.sub_assign(&v);
// a[j + distance] = u;
// a[j].add_assign(&v);
// j += 1;
// }
// k += 1;
// }

// pairs_per_group /= 2;
// num_groups *= 2;
// distance /= 2;
// round += 1;
// if round >= 2 {
// break;
// }
// }
// let mut cache_bunch = 0;
// while cache_bunch < num_groups {
// // num_groups=128 // round loop
// let mut pairs_per_group_in_cache = pairs_per_group;
// let mut distance_in_cache = distance;
// let mut num_groups_in_cache = 1;
// let num_rounds_in_cache = log_n - round; // 17

// let mut round = 0;
// while round < num_rounds_in_cache {
// // experiment

// let mut k = 0;
// while k < num_groups_in_cache {
// // group loop
// let idx_1 = cache_bunch * pairs_per_group * 2 + k * pairs_per_group_in_cache * 2;
// let idx_2 = idx_1 + pairs_per_group_in_cache;
// let s = omegas_bit_reversed[cache_bunch * num_groups_in_cache + k];

// let mut j = idx_1;
// while j < idx_2 {
// let mut u = a[j];
// let mut v = a[j + distance_in_cache];
// v.mul_assign_by_base(&s);
// u.sub_assign(&v);
// a[j + distance_in_cache] = u;
// a[j].add_assign(&v);

// j += 1;
// }
// k += 1;
// }
// pairs_per_group_in_cache /= 2;
// num_groups_in_cache *= 2;
// distance_in_cache /= 2;
// round += 1;
// }
// cache_bunch += 1;
// }
}
17 changes: 13 additions & 4 deletions gpu_prover/src/blake2s.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ mod tests {
use crate::ops_simple::set_to_zero;
use crate::utils::GetChunksCount;

const USE_REDUCED_BLAKE2_ROUNDS: bool = true;

fn verify_leaves(values: &[BF], results: &[Digest], log_rows_per_hash: u32) {
let count = results.len();
let values_len = values.len();
Expand Down Expand Up @@ -351,9 +353,13 @@ mod tests {
.collect_vec();
block.copy_from_slice(&chunk);
if i == blocks_count - 1 {
state.absorb_final_block::<false>(&block, block_len, &mut expected);
state.absorb_final_block::<USE_REDUCED_BLAKE2_ROUNDS>(
&block,
block_len,
&mut expected,
);
} else {
state.absorb::<false>(&block);
state.absorb::<USE_REDUCED_BLAKE2_ROUNDS>(&block);
}
}
let actual = results[i];
Expand All @@ -376,7 +382,10 @@ mod tests {
.try_into()
.unwrap();
let mut expected = Digest::default();
Blake2sState::compress_two_to_one::<false>(&state, &mut expected);
Blake2sState::compress_two_to_one::<USE_REDUCED_BLAKE2_ROUNDS>(
&state,
&mut expected,
);
assert_eq!(expected, actual);
});
}
Expand Down Expand Up @@ -613,7 +622,7 @@ mod tests {
block[STATE_SIZE] = h_result[0] as u32;
block[STATE_SIZE + 1] = (h_result[0] >> 32) as u32;
let mut digest = Digest::default();
state.absorb_final_block::<true>(&block, STATE_SIZE + 2, &mut digest);
state.absorb_final_block::<USE_REDUCED_BLAKE2_ROUNDS>(&block, STATE_SIZE + 2, &mut digest);
assert!(digest[0].leading_zeros() >= BITS_COUNT);
}
}
6 changes: 3 additions & 3 deletions gpu_prover/src/prover/arg_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ impl DelegationRequestMetadata {
timestamp_col: timestamp_columns.start() as u32,
memory_timestamp_high_from_circuit_idx,
delegation_type_col: layout.delegation_type.start() as u32,
in_cycle_write_idx: BF::from_u64_unchecked(layout.in_cycle_write_index as u64),
in_cycle_write_idx: BF::from_u32_unchecked(layout.in_cycle_write_index as u32),
abi_mem_offset_high_col,
has_abi_mem_offset_high,
}
Expand Down Expand Up @@ -1107,7 +1107,7 @@ impl RegisterAndIndirectAccesses {
read_value,
register_index,
} => {
let address_low = BF::from_u64_unchecked(register_index as u64);
let address_low = BF::from_u32_unchecked(register_index);
let mut gamma_plus_one_plus_address_low_contribution =
challenges.address_low_challenge.clone();
gamma_plus_one_plus_address_low_contribution.mul_assign_by_base(&address_low);
Expand All @@ -1127,7 +1127,7 @@ impl RegisterAndIndirectAccesses {
write_value,
register_index,
} => {
let address_low = BF::from_u64_unchecked(register_index as u64);
let address_low = BF::from_u32_unchecked(register_index);
let mut gamma_plus_one_plus_address_low_contribution =
challenges.address_low_challenge.clone();
gamma_plus_one_plus_address_low_contribution.mul_assign_by_base(&address_low);
Expand Down
6 changes: 3 additions & 3 deletions gpu_prover/src/prover/stage_3_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,7 @@ pub(super) fn prepare_async_challenge_data(
assert_eq!(helpers.capacity(), MAX_HELPER_VALUES);
let decompression_factor = flat_generic_constraints_metadata.decompression_factor;
let decompression_factor_inv = decompression_factor.inverse().expect("must exist");
let two = BF::from_u64_unchecked(2);
let two = BF::from_u32_unchecked(2);
let lookup_linearization_challenges = &lookup_challenges.linearization_challenges;
let lookup_gamma = lookup_challenges.gamma;
let lookup_gamma_squared = *lookup_gamma.clone().square();
Expand Down Expand Up @@ -1197,7 +1197,7 @@ pub(super) fn prepare_async_challenge_data(
let write_timestamp_low_constant = *mc
.timestamp_low_challenge
.clone()
.mul_assign_by_base(&BF::from_u64_unchecked(i as u64));
.mul_assign_by_base(&BF::from_u32_unchecked(i as u32));
numerator_constant.add_assign(&write_timestamp_low_constant);
if !is_unrolled {
let write_timestamp_high_constant = *mc
Expand Down Expand Up @@ -1324,7 +1324,7 @@ pub(super) fn prepare_async_challenge_data(
} else {
assert_eq!(j == 0, indirect_access.offset_constant == 0);
}
let offset = BF::from_u64_unchecked(indirect_access.offset_constant as u64);
let offset = BF::from_u32_unchecked(indirect_access.offset_constant);
let mut constant = *mc
.address_low_challenge
.clone()
Expand Down
2 changes: 1 addition & 1 deletion gpu_prover/src/prover/stage_3_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ impl<
let internal_constants_helper_idx = helpers.len();
match lookup_set.table_index {
TableIndex::Constant(table_type) => {
let id = BF::from_u64_unchecked(table_type.to_table_id() as u64);
let id = BF::from_u32_unchecked(table_type.to_table_id());
helpers.push(
*table_id_challenge
.clone()
Expand Down
3 changes: 2 additions & 1 deletion gpu_prover/src/witness/placeholder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,8 @@ impl From<CSPlaceholder> for Placeholder {
Placeholder::DelegationIndirectAccessVariableOffset {
variable_index: variable_index as u32,
}
}
},
cs::cs::placeholder::Placeholder::ExecutorFamilyMaskBit { .. } => todo!()
}
}
}
39 changes: 39 additions & 0 deletions gpu_prover_new/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[package]
name = "gpu_prover_new"
version.workspace = true
edition.workspace = true
authors.workspace = true
homepage.workspace = true
repository.workspace = true
license.workspace = true
keywords.workspace = true
categories.workspace = true
build = "build/main.rs"

[dependencies]
fft = { workspace = true }
field = { workspace = true }
worker = { workspace = true }

era_cudart = "0.154"
era_cudart_sys = "0.154"
itertools = "*"
log = "0.4.29"
rayon = "*"

[build-dependencies]
cmake = "0.1"
era_cudart_sys = "0.154"

[dev-dependencies]
env_logger = "0.11"
era_criterion_cuda = "0.2"
criterion = "0.5"
itertools = "0.14"
rand = "0.9"
serial_test = "3"
blake2s_u32 = { workspace = true }

[[bench]]
name = "field"
harness = false
Loading