Skip to content

Commit 2fa0396

Browse files
committed
add blake in circuit + tests
1 parent d2c9dcd commit 2fa0396

2 files changed

Lines changed: 207 additions & 2 deletions

File tree

crates/circuits/src/blake.rs

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
use blake2::{Blake2s256, Digest};
22
use itertools::Itertools;
3+
use stwo::core::fields::m31::M31;
34
use stwo::core::vcs::blake2_hash::Blake2sHash;
45
use stwo::core::{fields::qm31::QM31, vcs::blake2_hash::reduce_to_m31};
56

67
use crate::circuit::{Blake, BlakeGGate, M31ToU32, TripleXor};
78
use crate::context::{Context, Var};
9+
use crate::eval;
810
use crate::ivalue::{IValue, qm31_from_u32s};
9-
use crate::ops::Guess;
11+
use crate::ops::{Guess, from_partial_evals};
12+
use crate::simd::Simd;
1013

1114
#[cfg(test)]
1215
#[path = "blake_test.rs"]
@@ -127,6 +130,150 @@ pub fn blake<Value: IValue>(
127130
HashValue(out_var0, out_var1)
128131
}
129132

133+
/// Blake2s IV.
134+
const BLAKE2S_IV: [u32; 8] = [
135+
0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19,
136+
];
137+
138+
/// Message permutations for Blake2s (10 rounds × 16 indices).
139+
const BLAKE_SIGMA: [[u8; 16]; 10] = [
140+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
141+
[14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3],
142+
[11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4],
143+
[7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8],
144+
[9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13],
145+
[2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9],
146+
[12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11],
147+
[13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10],
148+
[6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5],
149+
[10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0],
150+
];
151+
152+
/// Column indices of the states send to the `G` in each Blake2s round.
153+
const G_STATE_INDICES: [(usize, usize, usize, usize); 8] = [
154+
(0, 4, 8, 12),
155+
(1, 5, 9, 13),
156+
(2, 6, 10, 14),
157+
(3, 7, 11, 15),
158+
(0, 5, 10, 15),
159+
(1, 6, 11, 12),
160+
(2, 7, 8, 13),
161+
(3, 4, 9, 14),
162+
];
163+
164+
const N_G_CALLS_PER_ROUND: usize = 8;
165+
166+
#[inline]
167+
fn u32_packed_constant<Value: IValue>(ctx: &mut Context<Value>, x: u32) -> Var {
168+
ctx.constant(qm31_from_u32s(x & 0xffff, x >> 16, 0, 0))
169+
}
170+
171+
/// Adds a Blake2s hash using decomposed gates (`m31_to_u32`, `blake_g_gate`, `triple_xor`) to the
172+
/// circuit, and returns the two output variables as [`HashValue`].
173+
///
174+
/// NOTE: If the number of bytes is not a multiple of 16, the caller must make sure that the
175+
/// remaining bytes are zero.
176+
/// For example, if `n_bytes` is 4, only the first coordinate of the [`QM31`] may be non-zero.
177+
/// If `n_bytes` is 1, that coordinate must be < 256.
178+
pub fn blake_from_gates<Value: IValue>(
179+
ctx: &mut Context<Value>,
180+
input: &[Var],
181+
n_bytes: usize,
182+
) -> HashValue<Var> {
183+
// Sanity check: check the number of bytes is consistent with the number of [QM31] values.
184+
assert_eq!(input.len(), n_bytes.div_ceil(16));
185+
186+
const BLOCK_BYTES: usize = 64;
187+
const WORDS_PER_BLOCK: usize = 16;
188+
189+
// Unpack each QM31 message chunk into four u32 limbs.
190+
let mut message_u32s: Vec<Var> = Vec::new();
191+
for &var in input {
192+
let simd = Simd::from_packed(vec![var], 4);
193+
for coord in 0..4 {
194+
let comp = Simd::unpack_idx(ctx, &simd, coord);
195+
message_u32s.push(m31_to_u32(ctx, comp));
196+
}
197+
}
198+
199+
let n_blocks = std::cmp::max(1, n_bytes.div_ceil(BLOCK_BYTES));
200+
let total_words = n_blocks * WORDS_PER_BLOCK;
201+
let zero_u32 = u32_packed_constant(ctx, 0);
202+
while message_u32s.len() < total_words {
203+
message_u32s.push(zero_u32);
204+
}
205+
206+
// `h`: IV XORed with the parameter block (depth 1, fanout 1, digest length 32, key length 0).
207+
let mut h: [Var; 8] = std::array::from_fn(|i| {
208+
let iv_val = if i == 0 { BLAKE2S_IV[0] ^ 0x01010020 } else { BLAKE2S_IV[i] };
209+
u32_packed_constant(ctx, iv_val)
210+
});
211+
212+
for block_idx in 0..n_blocks {
213+
let block: [Var; WORDS_PER_BLOCK] =
214+
std::array::from_fn(|i| message_u32s[block_idx * WORDS_PER_BLOCK + i]);
215+
let t0 = std::cmp::min(n_bytes, (block_idx + 1) * BLOCK_BYTES) as u32;
216+
let t1 = 0u32;
217+
let last = block_idx == n_blocks - 1;
218+
219+
let prev_h = h;
220+
221+
let mut v: [Var; 16] = std::array::from_fn(|i| {
222+
if i < 8 {
223+
h[i]
224+
} else {
225+
let mut iv = BLAKE2S_IV[i - 8];
226+
if i == 12 {
227+
iv ^= t0;
228+
}
229+
if i == 13 {
230+
iv ^= t1;
231+
}
232+
if i == 14 && last {
233+
iv ^= 0xFFFF_FFFF;
234+
}
235+
u32_packed_constant(ctx, iv)
236+
}
237+
});
238+
239+
for permutation in &BLAKE_SIGMA {
240+
for g_idx in 0..N_G_CALLS_PER_ROUND {
241+
let (ai, bi, ci, di) = G_STATE_INDICES[g_idx];
242+
let (new_a, new_b, new_c, new_d) = blake_g_gate(
243+
ctx,
244+
v[ai],
245+
v[bi],
246+
v[ci],
247+
v[di],
248+
block[permutation[g_idx * 2] as usize],
249+
block[permutation[g_idx * 2 + 1] as usize],
250+
);
251+
v[ai] = new_a;
252+
v[bi] = new_b;
253+
v[ci] = new_c;
254+
v[di] = new_d;
255+
}
256+
}
257+
258+
for i in 0..8 {
259+
h[i] = triple_xor(ctx, prev_h[i], v[i], v[i + 8]);
260+
}
261+
}
262+
263+
let c_2_pow_16 = ctx.constant(M31::from(1u32 << 16).into());
264+
let reduced: [Var; 8] = std::array::from_fn(|i| {
265+
let h_simd = Simd::from_packed(vec![h[i]], 2);
266+
let low = Simd::unpack_idx(ctx, &h_simd, 0);
267+
let high = Simd::unpack_idx(ctx, &h_simd, 1);
268+
eval!(ctx, (low) + ((high) * (c_2_pow_16)))
269+
});
270+
271+
let out0 = from_partial_evals(ctx, [reduced[0], reduced[1], reduced[2], reduced[3]]);
272+
let out1 = from_partial_evals(ctx, [reduced[4], reduced[5], reduced[6], reduced[7]]);
273+
274+
HashValue(out0, out1)
275+
}
276+
130277
/// Adds a TripleXor gate to the circuit: XOR three u32 values encoded as QM31 `(u16, u16, 0, 0)`
131278
/// and return the result in the same encoding.
132279
pub fn triple_xor<Value: IValue>(

crates/circuits/src/blake_test.rs

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use blake2::{Blake2s256, Digest};
22
use rstest::rstest;
33
use stwo::core::vcs::blake2_hash::reduce_to_m31;
44

5-
use crate::blake::{blake, qm31_from_bytes};
5+
use crate::blake::{blake, blake_from_gates, blake_qm31, qm31_from_bytes};
66
use crate::context::TraceContext;
77
use crate::ivalue::qm31_from_u32s;
88
use crate::ops::{Guess, eq, guess};
@@ -48,3 +48,61 @@ fn test_blake(#[case] wrong_output: bool) {
4848

4949
assert_eq!(context.is_circuit_valid(), !wrong_output);
5050
}
51+
52+
#[test]
53+
fn test_blake_from_gates_equal_old_blake() {
54+
let mut context = TraceContext::default();
55+
56+
let input_values = [
57+
qm31_from_u32s(1, 2, 3, 4),
58+
qm31_from_u32s(5, 6, 7, 8),
59+
qm31_from_u32s(9, 10, 11, 12),
60+
qm31_from_u32s(13, 14, 15, 16),
61+
qm31_from_u32s(17, 0, 0, 0),
62+
];
63+
64+
let input = input_values.guess(&mut context);
65+
66+
let out_mono = blake(&mut context, &input, 66);
67+
let out_decomposed = blake_from_gates(&mut context, &input, 66);
68+
69+
eq(&mut context, out_mono.0, out_decomposed.0);
70+
eq(&mut context, out_mono.1, out_decomposed.1);
71+
72+
context.finalize_guessed_vars();
73+
context.circuit.check_yields();
74+
assert!(context.is_circuit_valid());
75+
}
76+
77+
#[test]
78+
fn test_blake_from_gates_independent() {
79+
let mut context = TraceContext::default();
80+
81+
let message: [u32; 16] = [
82+
930933030, 1766240503, 3660871006, 388409270, 1948594622, 3119396969, 3924579183,
83+
2089920034, 3857888532, 929304360, 1810891574, 860971754, 1822893775, 2008495810,
84+
2958962335, 2340515744,
85+
];
86+
let n_bytes = 64;
87+
88+
let input_values = [
89+
qm31_from_u32s(message[0], message[1], message[2], message[3]),
90+
qm31_from_u32s(message[4], message[5], message[6], message[7]),
91+
qm31_from_u32s(message[8], message[9], message[10], message[11]),
92+
qm31_from_u32s(message[12], message[13], message[14], message[15]),
93+
];
94+
95+
let expected = blake_qm31(&input_values, n_bytes);
96+
97+
let input = input_values.guess(&mut context);
98+
let output = blake_from_gates(&mut context, &input, n_bytes);
99+
100+
let out0 = guess(&mut context, expected.0);
101+
let out1 = guess(&mut context, expected.1);
102+
eq(&mut context, output.0, out0);
103+
eq(&mut context, output.1, out1);
104+
105+
context.finalize_guessed_vars();
106+
context.circuit.check_yields();
107+
assert!(context.is_circuit_valid());
108+
}

0 commit comments

Comments
 (0)