|
1 | 1 | use blake2::{Blake2s256, Digest}; |
2 | 2 | use itertools::Itertools; |
| 3 | +use stwo::core::fields::m31::M31; |
3 | 4 | use stwo::core::vcs::blake2_hash::Blake2sHash; |
4 | 5 | use stwo::core::{fields::qm31::QM31, vcs::blake2_hash::reduce_to_m31}; |
5 | 6 |
|
6 | 7 | use crate::circuit::{Blake, BlakeGGate, M31ToU32, TripleXor}; |
7 | 8 | use crate::context::{Context, Var}; |
| 9 | +use crate::eval; |
8 | 10 | use crate::ivalue::{IValue, qm31_from_u32s}; |
9 | | -use crate::ops::Guess; |
| 11 | +use crate::ops::{Guess, from_partial_evals}; |
| 12 | +use crate::simd::Simd; |
10 | 13 |
|
11 | 14 | #[cfg(test)] |
12 | 15 | #[path = "blake_test.rs"] |
@@ -127,6 +130,150 @@ pub fn blake<Value: IValue>( |
127 | 130 | HashValue(out_var0, out_var1) |
128 | 131 | } |
129 | 132 |
|
| 133 | +/// Blake2s IV. |
| 134 | +const BLAKE2S_IV: [u32; 8] = [ |
| 135 | + 0x6A09E667, 0xBB67AE85, 0x3C6EF372, 0xA54FF53A, 0x510E527F, 0x9B05688C, 0x1F83D9AB, 0x5BE0CD19, |
| 136 | +]; |
| 137 | + |
| 138 | +/// Message permutations for Blake2s (10 rounds × 16 indices). |
| 139 | +const BLAKE_SIGMA: [[u8; 16]; 10] = [ |
| 140 | + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], |
| 141 | + [14, 10, 4, 8, 9, 15, 13, 6, 1, 12, 0, 2, 11, 7, 5, 3], |
| 142 | + [11, 8, 12, 0, 5, 2, 15, 13, 10, 14, 3, 6, 7, 1, 9, 4], |
| 143 | + [7, 9, 3, 1, 13, 12, 11, 14, 2, 6, 5, 10, 4, 0, 15, 8], |
| 144 | + [9, 0, 5, 7, 2, 4, 10, 15, 14, 1, 11, 12, 6, 8, 3, 13], |
| 145 | + [2, 12, 6, 10, 0, 11, 8, 3, 4, 13, 7, 5, 15, 14, 1, 9], |
| 146 | + [12, 5, 1, 15, 14, 13, 4, 10, 0, 7, 6, 3, 9, 2, 8, 11], |
| 147 | + [13, 11, 7, 14, 12, 1, 3, 9, 5, 0, 15, 4, 8, 6, 2, 10], |
| 148 | + [6, 15, 14, 9, 11, 3, 0, 8, 12, 2, 13, 7, 1, 4, 10, 5], |
| 149 | + [10, 2, 8, 4, 7, 6, 1, 5, 15, 11, 9, 14, 3, 12, 13, 0], |
| 150 | +]; |
| 151 | + |
| 152 | +/// Column indices of the states send to the `G` in each Blake2s round. |
| 153 | +const G_STATE_INDICES: [(usize, usize, usize, usize); 8] = [ |
| 154 | + (0, 4, 8, 12), |
| 155 | + (1, 5, 9, 13), |
| 156 | + (2, 6, 10, 14), |
| 157 | + (3, 7, 11, 15), |
| 158 | + (0, 5, 10, 15), |
| 159 | + (1, 6, 11, 12), |
| 160 | + (2, 7, 8, 13), |
| 161 | + (3, 4, 9, 14), |
| 162 | +]; |
| 163 | + |
| 164 | +const N_G_CALLS_PER_ROUND: usize = 8; |
| 165 | + |
| 166 | +#[inline] |
| 167 | +fn u32_packed_constant<Value: IValue>(ctx: &mut Context<Value>, x: u32) -> Var { |
| 168 | + ctx.constant(qm31_from_u32s(x & 0xffff, x >> 16, 0, 0)) |
| 169 | +} |
| 170 | + |
| 171 | +/// Adds a Blake2s hash using decomposed gates (`m31_to_u32`, `blake_g_gate`, `triple_xor`) to the |
| 172 | +/// circuit, and returns the two output variables as [`HashValue`]. |
| 173 | +/// |
| 174 | +/// NOTE: If the number of bytes is not a multiple of 16, the caller must make sure that the |
| 175 | +/// remaining bytes are zero. |
| 176 | +/// For example, if `n_bytes` is 4, only the first coordinate of the [`QM31`] may be non-zero. |
| 177 | +/// If `n_bytes` is 1, that coordinate must be < 256. |
| 178 | +pub fn blake_from_gates<Value: IValue>( |
| 179 | + ctx: &mut Context<Value>, |
| 180 | + input: &[Var], |
| 181 | + n_bytes: usize, |
| 182 | +) -> HashValue<Var> { |
| 183 | + // Sanity check: check the number of bytes is consistent with the number of [QM31] values. |
| 184 | + assert_eq!(input.len(), n_bytes.div_ceil(16)); |
| 185 | + |
| 186 | + const BLOCK_BYTES: usize = 64; |
| 187 | + const WORDS_PER_BLOCK: usize = 16; |
| 188 | + |
| 189 | + // Unpack each QM31 message chunk into four u32 limbs. |
| 190 | + let mut message_u32s: Vec<Var> = Vec::new(); |
| 191 | + for &var in input { |
| 192 | + let simd = Simd::from_packed(vec![var], 4); |
| 193 | + for coord in 0..4 { |
| 194 | + let comp = Simd::unpack_idx(ctx, &simd, coord); |
| 195 | + message_u32s.push(m31_to_u32(ctx, comp)); |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + let n_blocks = std::cmp::max(1, n_bytes.div_ceil(BLOCK_BYTES)); |
| 200 | + let total_words = n_blocks * WORDS_PER_BLOCK; |
| 201 | + let zero_u32 = u32_packed_constant(ctx, 0); |
| 202 | + while message_u32s.len() < total_words { |
| 203 | + message_u32s.push(zero_u32); |
| 204 | + } |
| 205 | + |
| 206 | + // `h`: IV XORed with the parameter block (depth 1, fanout 1, digest length 32, key length 0). |
| 207 | + let mut h: [Var; 8] = std::array::from_fn(|i| { |
| 208 | + let iv_val = if i == 0 { BLAKE2S_IV[0] ^ 0x01010020 } else { BLAKE2S_IV[i] }; |
| 209 | + u32_packed_constant(ctx, iv_val) |
| 210 | + }); |
| 211 | + |
| 212 | + for block_idx in 0..n_blocks { |
| 213 | + let block: [Var; WORDS_PER_BLOCK] = |
| 214 | + std::array::from_fn(|i| message_u32s[block_idx * WORDS_PER_BLOCK + i]); |
| 215 | + let t0 = std::cmp::min(n_bytes, (block_idx + 1) * BLOCK_BYTES) as u32; |
| 216 | + let t1 = 0u32; |
| 217 | + let last = block_idx == n_blocks - 1; |
| 218 | + |
| 219 | + let prev_h = h; |
| 220 | + |
| 221 | + let mut v: [Var; 16] = std::array::from_fn(|i| { |
| 222 | + if i < 8 { |
| 223 | + h[i] |
| 224 | + } else { |
| 225 | + let mut iv = BLAKE2S_IV[i - 8]; |
| 226 | + if i == 12 { |
| 227 | + iv ^= t0; |
| 228 | + } |
| 229 | + if i == 13 { |
| 230 | + iv ^= t1; |
| 231 | + } |
| 232 | + if i == 14 && last { |
| 233 | + iv ^= 0xFFFF_FFFF; |
| 234 | + } |
| 235 | + u32_packed_constant(ctx, iv) |
| 236 | + } |
| 237 | + }); |
| 238 | + |
| 239 | + for permutation in &BLAKE_SIGMA { |
| 240 | + for g_idx in 0..N_G_CALLS_PER_ROUND { |
| 241 | + let (ai, bi, ci, di) = G_STATE_INDICES[g_idx]; |
| 242 | + let (new_a, new_b, new_c, new_d) = blake_g_gate( |
| 243 | + ctx, |
| 244 | + v[ai], |
| 245 | + v[bi], |
| 246 | + v[ci], |
| 247 | + v[di], |
| 248 | + block[permutation[g_idx * 2] as usize], |
| 249 | + block[permutation[g_idx * 2 + 1] as usize], |
| 250 | + ); |
| 251 | + v[ai] = new_a; |
| 252 | + v[bi] = new_b; |
| 253 | + v[ci] = new_c; |
| 254 | + v[di] = new_d; |
| 255 | + } |
| 256 | + } |
| 257 | + |
| 258 | + for i in 0..8 { |
| 259 | + h[i] = triple_xor(ctx, prev_h[i], v[i], v[i + 8]); |
| 260 | + } |
| 261 | + } |
| 262 | + |
| 263 | + let c_2_pow_16 = ctx.constant(M31::from(1u32 << 16).into()); |
| 264 | + let reduced: [Var; 8] = std::array::from_fn(|i| { |
| 265 | + let h_simd = Simd::from_packed(vec![h[i]], 2); |
| 266 | + let low = Simd::unpack_idx(ctx, &h_simd, 0); |
| 267 | + let high = Simd::unpack_idx(ctx, &h_simd, 1); |
| 268 | + eval!(ctx, (low) + ((high) * (c_2_pow_16))) |
| 269 | + }); |
| 270 | + |
| 271 | + let out0 = from_partial_evals(ctx, [reduced[0], reduced[1], reduced[2], reduced[3]]); |
| 272 | + let out1 = from_partial_evals(ctx, [reduced[4], reduced[5], reduced[6], reduced[7]]); |
| 273 | + |
| 274 | + HashValue(out0, out1) |
| 275 | +} |
| 276 | + |
130 | 277 | /// Adds a TripleXor gate to the circuit: XOR three u32 values encoded as QM31 `(u16, u16, 0, 0)` |
131 | 278 | /// and return the result in the same encoding. |
132 | 279 | pub fn triple_xor<Value: IValue>( |
|
0 commit comments