Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion field/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mod array;
mod batch_inverse;
mod exponentiation;
pub mod extension;
mod field;
pub mod field;
mod helpers;
mod packed;

Expand Down
5 changes: 3 additions & 2 deletions keccak-air/examples/prove_baby_bear_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tracing_forest::ForestLayer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Registry};
use p3_field::AbstractField;

const NUM_HASHES: usize = 680;

Expand Down Expand Up @@ -54,7 +55,7 @@ fn main() -> Result<(), VerificationError> {
type Challenger = DuplexChallenger<Val, Perm, 16>;

let fri_config = FriConfig {
log_blowup: 1,
log_blowup: 2,
num_queries: 100,
proof_of_work_bits: 16,
mmcs: challenge_mmcs,
Expand Down Expand Up @@ -84,6 +85,6 @@ fn main() -> Result<(), VerificationError> {
&KeccakAir {},
&mut challenger,
&proof,
&RowMajorMatrix::new(vec![], 0),
&RowMajorMatrix::new(vec![Val::zero()], 1),
)
}
6 changes: 3 additions & 3 deletions keccak-air/examples/prove_baby_bear_poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use p3_challenger::DuplexChallenger;
use p3_commit::ExtensionMmcs;
use p3_dft::Radix2DitParallel;
use p3_field::extension::BinomialExtensionField;
use p3_field::Field;
use p3_field::{AbstractField, Field};
use p3_fri::{FriConfig, TwoAdicFriPcs, TwoAdicFriPcsConfig};
use p3_keccak_air::{generate_trace_rows, KeccakAir};
use p3_matrix::dense::RowMajorMatrix;
Expand Down Expand Up @@ -54,7 +54,7 @@ fn main() -> Result<(), VerificationError> {
type Challenger = DuplexChallenger<Val, Perm, 16>;

let fri_config = FriConfig {
log_blowup: 1,
log_blowup: 2,
num_queries: 100,
proof_of_work_bits: 16,
mmcs: challenge_mmcs,
Expand Down Expand Up @@ -84,6 +84,6 @@ fn main() -> Result<(), VerificationError> {
&KeccakAir {},
&mut challenger,
&proof,
&RowMajorMatrix::new(vec![], 0),
&RowMajorMatrix::new(vec![Val::zero()], 1),
)
}
5 changes: 3 additions & 2 deletions keccak-air/examples/prove_goldilocks_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use tracing_forest::ForestLayer;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, Registry};
use p3_field::AbstractField;

const NUM_HASHES: usize = 680;

Expand Down Expand Up @@ -53,7 +54,7 @@ fn main() -> Result<(), VerificationError> {
type Challenger = DuplexChallenger<Val, Perm, 8>;

let fri_config = FriConfig {
log_blowup: 1,
log_blowup: 2,
num_queries: 100,
proof_of_work_bits: 16,
mmcs: challenge_mmcs,
Expand Down Expand Up @@ -83,6 +84,6 @@ fn main() -> Result<(), VerificationError> {
&KeccakAir {},
&mut challenger,
&proof,
&RowMajorMatrix::new(vec![], 0),
&RowMajorMatrix::new(vec![Val::zero()], 1),
)
}
24 changes: 5 additions & 19 deletions keccak-air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,10 @@ impl<AB: AirBuilder> Air<AB> for KeccakAir {
let local: &KeccakCols<AB::Var> = main.row_slice(0).borrow();
let next: &KeccakCols<AB::Var> = main.row_slice(1).borrow();

// The export flag must be 0 or 1.
builder.assert_bool(local.export);

// If this is not the final step, the export flag must be off.
let final_step = local.step_flags[NUM_ROUNDS - 1];
let not_final_step = AB::Expr::one() - final_step;
builder
.when(not_final_step.clone())
.assert_zero(local.export);

// If this is not the final step, the local and next preimages must match.
for y in 0..5 {
for x in 0..5 {
for limb in 0..U64_LIMBS {
builder
.when_transition()
.when(not_final_step.clone())
.assert_eq(local.preimage[y][x][limb], next.preimage[y][x][limb]);
}
}
}

// C'[x, z] = xor(C[x, z], C[x - 1, z], C[x + 1, z - 1]).
for x in 0..5 {
Expand Down Expand Up @@ -115,7 +98,8 @@ impl<AB: AirBuilder> Air<AB> for KeccakAir {
let computed_limb = (limb * BITS_PER_LIMB..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::zero(), |acc, z| acc.double() + get_bit(z));
builder.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]);
builder
.assert_eq(computed_limb, local.a_prime_prime[y][x][limb]);
}
}
}
Expand Down Expand Up @@ -149,7 +133,8 @@ impl<AB: AirBuilder> Air<AB> for KeccakAir {
..(limb + 1) * BITS_PER_LIMB)
.rev()
.fold(AB::Expr::zero(), |acc, z| acc.double() + get_xored_bit(z));
builder.assert_eq(
builder
.assert_eq(
computed_a_prime_prime_prime_0_0_limb,
a_prime_prime_prime_0_0_limb,
);
Expand All @@ -168,5 +153,6 @@ impl<AB: AirBuilder> Air<AB> for KeccakAir {
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the below can go away when we unify the columns as discussed elsewhere in this review.


}
}
69 changes: 51 additions & 18 deletions keccak-air/src/columns.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use core::borrow::{Borrow, BorrowMut};
use core::array;
use core::fmt::{Debug, Formatter, Result};
use core::mem::{size_of, transmute};

use p3_util::indices_arr;
Expand All @@ -14,21 +16,10 @@ use crate::{NUM_ROUNDS, RATE_LIMBS, U64_LIMBS};
/// convention of `x, y, z` order, but it has the benefit that input lists map to AIR columns in a
/// nicer way.
#[repr(C)]
pub(crate) struct KeccakCols<T> {
pub struct KeccakCols<T> {
/// The `i`th value is set to 1 if we are in the `i`th round, otherwise 0.
pub step_flags: [T; NUM_ROUNDS],

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could unify the postimage columns with the a_prime_prime columns, since these are never both used in the same row. They effectively serve the same purpose, and by unifying these, we can get rid of the .when(not_final_row) conditions. Doing the latter may reduce the degree of the STARK from 4 to 3.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, @CarloModicaPortfolio says we can unify step_flags[NUM_ROUNDS-1] with export.

/// A register which indicates if a row should be exported, i.e. included in a multiset equality
/// argument. Should be 1 only for certain rows which are final steps, i.e. with
/// `step_flags[23] = 1`.
pub export: T,

/// Permutation inputs, stored in y-major order.
pub preimage: [[[T; U64_LIMBS]; 5]; 5],

/// Permutation outputs, stored in y-major order.
pub postimage: [[[T; U64_LIMBS]; 5]; 5],

pub a: [[[T; U64_LIMBS]; 5]; 5],

/// ```ignore
Expand Down Expand Up @@ -62,6 +53,37 @@ pub(crate) struct KeccakCols<T> {
pub a_prime_prime_prime_0_0_limbs: [T; U64_LIMBS],
}

impl<T: Default> Default for KeccakCols<T> {
fn default() -> Self {
Self {
step_flags: array::from_fn(|_| T::default()),
a: array::from_fn(|_| array::from_fn(|_| array::from_fn(|_| T::default()))),
c: array::from_fn(|_| array::from_fn(|_| T::default())),
c_prime: array::from_fn(|_| array::from_fn(|_| T::default())),
a_prime: array::from_fn(|_| array::from_fn(|_| array::from_fn(|_| T::default()))),
a_prime_prime: array::from_fn(|_| array::from_fn(|_| array::from_fn(|_| T::default()))),
a_prime_prime_0_0_bits: array::from_fn(|_| T::default()),
a_prime_prime_prime_0_0_limbs: array::from_fn(|_| T::default()),
}
}
}

impl<T: Debug> Debug for KeccakCols<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
f.debug_struct("KeccakCols")
.field("step_flags", &self.step_flags)
.field("a", &self.a)
.field("c", &self.c)
.field("c_prime", &self.c_prime)
.field("a_prime", &self.a_prime)
.field("a_prime_prime", &self.a_prime_prime)
.field("a_prime_prime_0_0_bits", &self.a_prime_prime_0_0_bits)
.field("a_prime_prime_prime_0_0_limbs", &self.a_prime_prime_prime_0_0_limbs)
.finish()
}
}


impl<T: Copy> KeccakCols<T> {
pub fn b(&self, x: usize, y: usize, z: usize) -> T {
debug_assert!(x < 5);
Expand Down Expand Up @@ -93,30 +115,41 @@ impl<T: Copy> KeccakCols<T> {
}
}

pub fn input_limb(i: usize) -> usize {
debug_assert!(i < RATE_LIMBS);
impl<T: Clone> Clone for KeccakCols<T> {
fn clone(&self) -> Self {
Self {
step_flags: self.step_flags.clone(),
a: self.a.clone(),
c: self.c.clone(),
c_prime: self.c_prime.clone(),
a_prime: self.a_prime.clone(),
a_prime_prime: self.a_prime_prime.clone(),
a_prime_prime_0_0_bits: self.a_prime_prime_0_0_bits.clone(),
a_prime_prime_prime_0_0_limbs: self.a_prime_prime_prime_0_0_limbs.clone(),
}
}
}

pub fn input_limb(i: usize) -> usize {
let i_u64 = i / U64_LIMBS;
let limb_index = i % U64_LIMBS;

// The 5x5 state is treated as y-major, as per the Keccak spec.
let y = i_u64 / 5;
let x = i_u64 % 5;

KECCAK_COL_MAP.preimage[y][x][limb_index]
KECCAK_COL_MAP.a[y][x][limb_index]
}

pub fn output_limb(i: usize) -> usize {
debug_assert!(i < RATE_LIMBS);

let i_u64 = i / U64_LIMBS;
let limb_index = i % U64_LIMBS;

// The 5x5 state is treated as y-major, as per the Keccak spec.
let y = i_u64 / 5;
let x = i_u64 % 5;

KECCAK_COL_MAP.postimage[y][x][limb_index]
KECCAK_COL_MAP.a_prime_prime_prime(y, x, limb_index)
}

pub(crate) const NUM_KECCAK_COLS: usize = size_of::<KeccakCols<u8>>();
Expand Down
6 changes: 3 additions & 3 deletions keccak-air/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,10 @@ const RC_BITS: [[u8; 64]; 24] = [
],
];

pub(crate) const fn rc_value_limb(round: usize, limb: usize) -> u16 {
(RC[round] >> (limb * BITS_PER_LIMB)) as u16
pub const fn rc_value_limb(round: usize, limb: usize) -> u8 {
(RC[round] >> (limb * BITS_PER_LIMB)) as u8
}

pub(crate) const fn rc_value_bit(round: usize, bit_index: usize) -> u8 {
pub const fn rc_value_bit(round: usize, bit_index: usize) -> u8 {
RC_BITS[round][bit_index]
}
33 changes: 10 additions & 23 deletions keccak-air/src/generation.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use alloc::vec;
use alloc::vec::Vec;
use core::iter;

use p3_field::PrimeField64;
use p3_field::PrimeField32;
use p3_matrix::dense::RowMajorMatrix;
use tracing::instrument;

Expand All @@ -12,7 +11,7 @@ use crate::logic::{andn, xor};
use crate::{BITS_PER_LIMB, NUM_ROUNDS, U64_LIMBS};

#[instrument(name = "generate Keccak trace", skip_all)]
pub fn generate_trace_rows<F: PrimeField64>(inputs: Vec<[u64; 25]>) -> RowMajorMatrix<F> {
pub fn generate_trace_rows<F: PrimeField32>(inputs: Vec<[u64; 25]>) -> RowMajorMatrix<F> {
let num_rows = (inputs.len() * NUM_ROUNDS).next_power_of_two();
let mut trace =
RowMajorMatrix::new(vec![F::zero(); num_rows * NUM_KECCAK_COLS], NUM_KECCAK_COLS);
Expand All @@ -30,26 +29,13 @@ pub fn generate_trace_rows<F: PrimeField64>(inputs: Vec<[u64; 25]>) -> RowMajorM
}

/// `rows` will normally consist of 24 rows, with an exception for the final row.
fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
// Populate the preimage for each row.
for row in rows.iter_mut() {
for y in 0..5 {
for x in 0..5 {
let input_xy = input[y * 5 + x];
for limb in 0..U64_LIMBS {
row.preimage[y][x][limb] =
F::from_canonical_u64((input_xy >> (16 * limb)) & 0xFFFF);
}
}
}
}

fn generate_trace_rows_for_perm<F: PrimeField32>(rows: &mut [KeccakCols<F>], input: [u64; 25]) {
// Populate the round input for the first round.
for y in 0..5 {
for x in 0..5 {
let input_xy = input[y * 5 + x];
for limb in 0..U64_LIMBS {
rows[0].a[y][x][limb] = F::from_canonical_u64((input_xy >> (16 * limb)) & 0xFFFF);
rows[0].a[y][x][limb] = F::from_canonical_u64((input_xy >> (8 * limb)) & 0xFF);
}
}
}
Expand All @@ -70,7 +56,7 @@ fn generate_trace_rows_for_perm<F: PrimeField64>(rows: &mut [KeccakCols<F>], inp
}
}

fn generate_trace_row_for_round<F: PrimeField64>(row: &mut KeccakCols<F>, round: usize) {
fn generate_trace_row_for_round<F: PrimeField32>(row: &mut KeccakCols<F>, round: usize) {
row.step_flags[round] = F::one();

// Populate C[x] = xor(A[x, 0], A[x, 1], A[x, 2], A[x, 3], A[x, 4]).
Expand All @@ -79,7 +65,7 @@ fn generate_trace_row_for_round<F: PrimeField64>(row: &mut KeccakCols<F>, round:
let limb = z / BITS_PER_LIMB;
let bit_in_limb = z % BITS_PER_LIMB;
let a = (0..5).map(|y| {
let a_limb = row.a[y][x][limb].as_canonical_u64() as u16;
let a_limb = row.a[y][x][limb].as_canonical_u64();
((a_limb >> bit_in_limb) & 1) != 0
});
row.c[x][z] = F::from_bool(a.fold(false, |acc, x| acc ^ x));
Expand All @@ -106,7 +92,7 @@ fn generate_trace_row_for_round<F: PrimeField64>(row: &mut KeccakCols<F>, round:
for z in 0..64 {
let limb = z / BITS_PER_LIMB;
let bit_in_limb = z % BITS_PER_LIMB;
let a_limb = row.a[y][x][limb].as_canonical_u64() as u16;
let a_limb = row.a[y][x][limb].as_canonical_u64() as u8;
let a_bit = F::from_bool(((a_limb >> bit_in_limb) & 1) != 0);
row.a_prime[y][x][z] = xor([a_bit, row.c[x][z], row.c_prime[x][z]]);
}
Expand Down Expand Up @@ -152,6 +138,7 @@ fn generate_trace_row_for_round<F: PrimeField64>(row: &mut KeccakCols<F>, round:
for limb in 0..U64_LIMBS {
let rc_lo = rc_value_limb(round, limb);
row.a_prime_prime_prime_0_0_limbs[limb] =
F::from_canonical_u16(row.a_prime_prime[0][0][limb].as_canonical_u64() as u16 ^ rc_lo);
F::from_canonical_u8(row.a_prime_prime[0][0][limb].as_canonical_u64() as u8 ^ rc_lo);
}
}

}
12 changes: 6 additions & 6 deletions keccak-air/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ extern crate alloc;

mod air;
mod columns;
mod constants;
mod generation;
mod logic;
pub mod constants;
pub mod generation;
pub mod logic;
mod round_flags;

pub use air::*;
pub use columns::*;
pub use generation::*;

const NUM_ROUNDS: usize = 24;
const BITS_PER_LIMB: usize = 16;
const U64_LIMBS: usize = 64 / BITS_PER_LIMB;
pub const NUM_ROUNDS: usize = 24;
pub const BITS_PER_LIMB: usize = 8;
pub const U64_LIMBS: usize = 64 / BITS_PER_LIMB;
const RATE_BITS: usize = 1088;
const RATE_LIMBS: usize = RATE_BITS / BITS_PER_LIMB;
Loading
Loading