Skip to content

Commit d61f392

Browse files
authored
chore: refactor keccak into subfunctions (#18)
1 parent ea2bf0e commit d61f392

File tree

1 file changed

+56
-35
lines changed

1 file changed

+56
-35
lines changed

src/keccak256.nr

Lines changed: 56 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ mod benchmarks;
44

55
use std::hash::keccakf1600;
66
use std::runtime::is_unconstrained;
7+
use std::static_assert;
78

89
global BLOCK_SIZE_IN_BYTES: u32 = 136; //(1600 - BITS * 2) / WORD_SIZE;
910
global WORD_SIZE: u32 = 8; // Limbs are made up of u64s so 8 bytes each.
@@ -30,15 +31,26 @@ pub fn keccak256<let N: u32>(input: [u8; N], message_size: u32) -> [u8; 32] {
3031
}
3132

3233
//1. format_input_lanes and apply padding
33-
let max_blocks = (N + BLOCK_SIZE_IN_BYTES) / BLOCK_SIZE_IN_BYTES;
3434
let real_max_blocks = (message_size + BLOCK_SIZE_IN_BYTES) / BLOCK_SIZE_IN_BYTES;
3535

3636
// Apply Keccak padding (0x01 after message, 0x80 at block end)
3737
apply_keccak_padding(&mut block_bytes, message_size, real_max_blocks);
3838

39+
let block_array = convert_to_u64_array(block_bytes);
40+
41+
let state = apply_keccak_permutations(block_array, real_max_blocks);
42+
43+
//3. sponge_squeeze
44+
read_hash_from_state(state)
45+
}
46+
47+
fn convert_to_u64_array<let N: u32>(input: [u8; N]) -> [u64; N / WORD_SIZE] {
48+
static_assert(
49+
N % WORD_SIZE == 0,
50+
"Byte array is expected to cleanly divide into chunks",
51+
);
3952
// populate a vector of 64-bit limbs from our byte array
40-
let mut sliced_buffer =
41-
[0; (((N / BLOCK_SIZE_IN_BYTES) + 1) * BLOCK_SIZE_IN_BYTES) / WORD_SIZE];
53+
let mut sliced_buffer = [0; N / WORD_SIZE];
4254
for i in 0..sliced_buffer.len() {
4355
let limb_start = WORD_SIZE * i;
4456

@@ -49,7 +61,7 @@ pub fn keccak256<let N: u32>(input: [u8; N], message_size: u32) -> [u8; 32] {
4961
WORD_SIZE,
5062
|i: u32| {
5163
quote {
52-
sliced += v * (block_bytes[limb_start + $i] as Field);
64+
sliced += v * (input[limb_start + $i] as Field);
5365
v *= 256;
5466
}
5567
},
@@ -58,29 +70,36 @@ pub fn keccak256<let N: u32>(input: [u8; N], message_size: u32) -> [u8; 32] {
5870
sliced.assert_max_bit_size::<64>();
5971
sliced_buffer[i] = sliced as u64;
6072
}
73+
sliced_buffer
74+
}
6175

76+
fn apply_keccak_permutations<let N: u32>(
77+
flattened_blocks_array: [u64; N],
78+
num_blocks: u32,
79+
) -> [u64; NUM_KECCAK_LANES] {
6280
//2. sponge_absorb
6381
let mut state: [u64; NUM_KECCAK_LANES] = [0; NUM_KECCAK_LANES];
64-
// `real_max_blocks` is guaranteed to at least be `1`
82+
// `num_blocks` is guaranteed to at least be `1`
6583
// We peel out the first block as to avoid a conditional inside of the loop.
6684
// Otherwise, a dynamic predicate can cause a blowup in a constrained runtime.
6785
unroll_loop!(
6886
0u32,
6987
LIMBS_PER_BLOCK,
7088
|i: u32| {
7189
quote {
72-
state[$i] = sliced_buffer[$i];
90+
state[$i] = flattened_blocks_array[$i];
7391
}
7492
},
7593
);
7694
state = keccakf1600(state);
7795

78-
let state = if is_unconstrained() {
96+
let max_blocks = N / LIMBS_PER_BLOCK;
97+
if is_unconstrained() {
7998
// When in an unconstrained runtime we can take advantage of runtime loop bounds,
8099
// thus allowing us to simplify the loop body.
81-
for i in 1..real_max_blocks {
100+
for i in 1..num_blocks {
82101
for j in 0..LIMBS_PER_BLOCK {
83-
state[j] = state[j] ^ sliced_buffer[i * LIMBS_PER_BLOCK + j];
102+
state[j] = state[j] ^ flattened_blocks_array[i * LIMBS_PER_BLOCK + j];
84103
}
85104
state = keccakf1600(state);
86105
}
@@ -89,20 +108,42 @@ pub fn keccak256<let N: u32>(input: [u8; N], message_size: u32) -> [u8; 32] {
89108
} else {
90109
// We store the intermediate states in an array to avoid having a dynamic predicate
91110
// inside the loop, which can cause a blowup in a constrained runtime.
92-
let mut intermediate_states = [state; (N + BLOCK_SIZE_IN_BYTES) / BLOCK_SIZE_IN_BYTES + 1];
111+
let mut intermediate_states = [state; N / LIMBS_PER_BLOCK + 1];
93112
for i in 1..max_blocks {
94113
let mut previous_state = intermediate_states[i - 1];
95114
for j in 0..LIMBS_PER_BLOCK {
96-
previous_state[j] = previous_state[j] ^ sliced_buffer[i * LIMBS_PER_BLOCK + j];
115+
previous_state[j] =
116+
previous_state[j] ^ flattened_blocks_array[i * LIMBS_PER_BLOCK + j];
97117
}
98118
intermediate_states[i] = keccakf1600(previous_state);
99119
}
100120

101-
// We can then take the state as of `real_max_blocks`, ignoring later permutations.
102-
intermediate_states[real_max_blocks - 1]
103-
};
121+
// We can then take the state as of `num_blocks`, ignoring later permutations.
122+
intermediate_states[num_blocks - 1]
123+
}
124+
}
104125

105-
//3. sponge_squeeze
126+
// Apply Keccak padding to the block_bytes array
127+
// Append 0x01 after message, then 0x80 at end of block
128+
// If both padding bytes collide at the same byte, combine them as 0x81
129+
#[inline_always]
130+
pub(crate) fn apply_keccak_padding<let BLOCK_BYTES: u32>(
131+
block_bytes: &mut [u8; BLOCK_BYTES],
132+
message_size: u32,
133+
real_max_blocks: u32,
134+
) {
135+
let real_blocks_bytes = real_max_blocks * BLOCK_SIZE_IN_BYTES;
136+
137+
if message_size == real_blocks_bytes - 1 {
138+
// Combine both padding bits: 0x01 | 0x80 = 0x81
139+
block_bytes[message_size] = 0x81;
140+
} else {
141+
block_bytes[message_size] = 0x01;
142+
block_bytes[real_blocks_bytes - 1] = 0x80;
143+
}
144+
}
145+
146+
fn read_hash_from_state(state: [u64; NUM_KECCAK_LANES]) -> [u8; 32] {
106147
let mut result = [0; 32];
107148
unroll_loop!(
108149
0u32,
@@ -126,26 +167,6 @@ pub fn keccak256<let N: u32>(input: [u8; N], message_size: u32) -> [u8; 32] {
126167
result
127168
}
128169

129-
// Apply Keccak padding to the block_bytes array
130-
// Append 0x01 after message, then 0x80 at end of block
131-
// If both padding bytes collide at the same byte, combine them as 0x81
132-
#[inline_always]
133-
pub(crate) fn apply_keccak_padding<let BLOCK_BYTES: u32>(
134-
block_bytes: &mut [u8; BLOCK_BYTES],
135-
message_size: u32,
136-
real_max_blocks: u32,
137-
) {
138-
let real_blocks_bytes = real_max_blocks * BLOCK_SIZE_IN_BYTES;
139-
140-
if message_size == real_blocks_bytes - 1 {
141-
// Combine both padding bits: 0x01 | 0x80 = 0x81
142-
block_bytes[message_size] = 0x81;
143-
} else {
144-
block_bytes[message_size] = 0x01;
145-
block_bytes[real_blocks_bytes - 1] = 0x80;
146-
}
147-
}
148-
149170
comptime fn unroll_loop(start: u32, end: u32, body: fn(u32) -> Quoted) -> Quoted {
150171
let mut iterations: [Quoted] = &[];
151172
for i in start..end {

0 commit comments

Comments
 (0)