Skip to content

Commit b3cc405

Browse files
committed
Add finalize_constants file.
1 parent ad259de commit b3cc405

3 files changed

Lines changed: 196 additions & 0 deletions

File tree

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
use std::collections::HashMap;
2+
3+
use indexmap::IndexMap;
4+
use itertools::Itertools;
5+
use stwo::core::fields::m31::M31;
6+
use stwo::core::fields::qm31::QM31;
7+
8+
use crate::circuit::Add;
9+
use crate::context::{Context, Var};
10+
use crate::eval;
11+
use crate::ivalue::{IValue, qm31_from_u32s};
12+
use crate::ops::{add, eq, output};
13+
14+
#[cfg(test)]
15+
#[path = "finalize_constants_test.rs"]
16+
mod test;
17+
18+
/// Wraps `finalize_constants_with_min_base` and calls it with a default value. The main reason
19+
/// is to make testing easier by choosing a smaller minimum base.
20+
// TODO(Leo): remove allow once integrated in the main flow.
21+
#[allow(unused)]
22+
fn finalize_constants(context: &mut Context<impl IValue>) {
23+
const DEFAULT_MIN_BASE: usize = 256;
24+
finalize_constants_with_min_base(context, DEFAULT_MIN_BASE);
25+
}
26+
27+
/// Yields and constrains every constant in `context.constants()` via arithmetic gates,
28+
/// All constants are derived from the QM31 extension element `u = (0, 0, 1, 0)` by using:
29+
/// - A `+1` chain for consecutive M31 integer constants.
30+
/// - Base decomposition (with a dynamic base B) for larger M31 values.
31+
/// - Broadcast optimization for constants of the form `(x, x, x, x)`.
32+
/// - QM31 basis combination (`i`, `u`, `iu`) for general extension-field constants.
33+
///
34+
/// # Notes
35+
///
36+
/// The `context.constants()` are tracked in two `IndexMap`s — `m31_constants` for values of the
37+
/// form `(x, 0, 0, 0)` and `qm31_constants` for everything else. As each constant is yielded by a
38+
/// gate (or, for `u`, by the public-output logup term), it is removed from its map. At the end
39+
/// of `finalize_constants` both maps must be empty,
40+
///
41+
/// `m31_cache` maps each M31 value constructed in the process to its Var idx. Subsequent
42+
/// decomposition steps (M31 limbs of QM31 constants, broadcast factors, etc.) reuse cached entries
43+
/// instead of rebuilding them, so each distinct M31 value gets at most one yield gate.
44+
///
45+
/// `IndexMap` is used (rather than `HashMap`) so that iteration order is deterministic.
46+
fn finalize_constants_with_min_base(context: &mut Context<impl IValue>, min_base: usize) {
47+
assert!(min_base >= 2);
48+
let mut m31_constants = IndexMap::<M31, Var>::new();
49+
let mut qm31_constants = IndexMap::<QM31, Var>::new();
50+
let mut m31_cache = HashMap::<M31, usize>::new();
51+
// Populate the maps.
52+
context.constants().iter().for_each(|(val, var)| {
53+
if let [x, M31(0), M31(0), M31(0)] = val.to_m31_array() {
54+
m31_constants.insert(x, *var);
55+
} else {
56+
qm31_constants.insert(*val, *var);
57+
}
58+
});
59+
let target_consecutive = find_max_consecutive(&m31_constants).max(min_base);
60+
61+
// Yield and constrain the `zero` wire by adding a gate x + x = x.
62+
let zero_idx = context.zero().idx;
63+
context.circuit.add.push(Add { in0: zero_idx, in1: zero_idx, out: zero_idx });
64+
m31_cache.insert(M31(0), m31_constants.swap_remove(&M31(0)).unwrap().idx);
65+
66+
// Yield the `u` wire by adding a trivial gate x + 0, then add it to the outputs. The constraint
67+
// on the wire comes from the next verifier which will need to add a logup_use_term with value
68+
// (0,0,1,0) to its public logup sum.
69+
let u_var = context.u();
70+
context.circuit.add.push(Add { in0: u_var.idx, in1: zero_idx, out: u_var.idx });
71+
output(context, u_var);
72+
qm31_constants.swap_remove(&qm31_from_u32s(0, 0, 1, 0));
73+
74+
// Yield the `one` wire by adding a trivial gate x + 0 = x, then constrain it by adding a gate x
75+
// * u = u.
76+
let one = context.one();
77+
context.circuit.add.push(Add { in0: one.idx, in1: zero_idx, out: one.idx });
78+
let u_times_one = eval!(context, (u_var) * (one));
79+
eq(context, u_times_one, u_var);
80+
m31_cache.insert(M31(1), m31_constants.swap_remove(&M31(1)).unwrap().idx);
81+
82+
// Build the +1 chain for consecutive M31 constants.
83+
build_plus_one_chain(context, &mut m31_constants, &mut m31_cache, target_consecutive);
84+
}
85+
86+
/// Finds the largest integer N such that all values in [0, N] are present as constants.
87+
///
88+
/// # Panics
89+
///
90+
/// Panics if `m31_constants` doesn't contain zero.
91+
fn find_max_consecutive(m31_constants: &IndexMap<M31, Var>) -> usize {
92+
assert!(m31_constants.contains_key(&M31(0)));
93+
let m31_values = m31_constants.keys().map(|k| k.0).sorted();
94+
// After sorting, a consecutive run from 0 satisfies m31_values[i] == i.
95+
let n_consecutive =
96+
m31_values.enumerate().position(|(i, v)| i != v as usize).unwrap_or(m31_constants.len());
97+
// The assert at the beginning ensures that n_consecutive > 0, so this subtraction does not
98+
// overflow.
99+
n_consecutive - 1
100+
}
101+
102+
/// Builds the +1 chain: Add gates for 1+1=2, 2+1=3, ..., up to `target_consecutive`.
103+
///
104+
/// For each value, if a constant with that M31 value was requested, the Add gate outputs
105+
/// directly to the reserved Var idx. Otherwise a fresh Var is allocated.
106+
fn build_plus_one_chain(
107+
context: &mut Context<impl IValue>,
108+
m31_constants: &mut IndexMap<M31, Var>,
109+
m31_cache: &mut HashMap<M31, usize>,
110+
target_consecutive: usize,
111+
) {
112+
let one_idx = context.one().idx;
113+
let mut prev_var = context.one();
114+
115+
for val in 2..=target_consecutive {
116+
let next_var = if let Some(v) = m31_constants.swap_remove(&M31::from(val)) {
117+
context.circuit.add.push(Add { in0: prev_var.idx, in1: one_idx, out: v.idx });
118+
v
119+
} else {
120+
add(context, prev_var, context.one())
121+
};
122+
m31_cache.insert(val.into(), next_var.idx);
123+
prev_var = next_var;
124+
}
125+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
use expect_test::expect;
2+
use stwo::core::fields::m31::M31;
3+
4+
use super::*;
5+
use crate::context::TraceContext;
6+
7+
#[test]
8+
fn test_no_constants_beyond_defaults() {
9+
let mut context = TraceContext::default();
10+
// Add `u`.
11+
// TODO(Leo): remove this once `u` is added to the default constants.
12+
context.constant(qm31_from_u32s(0, 0, 1, 0));
13+
finalize_constants(&mut context);
14+
context.circuit.check_yields();
15+
context.validate_circuit();
16+
}
17+
18+
#[test]
19+
fn test_small_consecutive_m31_constants() {
20+
let mut context = TraceContext::default();
21+
// Add `u`.
22+
// TODO(Leo): remove this once `u` is added to the default constants.
23+
context.constant(qm31_from_u32s(0, 0, 1, 0));
24+
for i in 0u32..10 {
25+
context.constant(i.into());
26+
}
27+
finalize_constants(&mut context);
28+
context.circuit.check_yields();
29+
context.validate_circuit();
30+
}
31+
32+
#[test]
33+
fn test_plus_one_chain_topology() {
34+
let mut context = TraceContext::default();
35+
context.constant(qm31_from_u32s(0, 0, 1, 0));
36+
context.constant(M31::from(2u32).into());
37+
context.constant(M31::from(4u32).into());
38+
let m31_constants = IndexMap::from([
39+
(0.into(), Var { idx: 0 }),
40+
(1.into(), Var { idx: 1 }),
41+
(2.into(), Var { idx: 3 }),
42+
(4.into(), Var { idx: 4 }),
43+
]);
44+
assert_eq!(find_max_consecutive(&m31_constants), 2);
45+
// `min_base = 6` and `find_max_consecutive` returns 2 (gap at 3), so the chain runs 2..=6.
46+
finalize_constants_with_min_base(&mut context, 6);
47+
48+
expect![[r#"
49+
[0] = [0] + [0]
50+
[2] = [2] + [0]
51+
[1] = [1] + [0]
52+
[3] = [1] + [1]
53+
[6] = [3] + [1]
54+
[4] = [6] + [1]
55+
[7] = [4] + [1]
56+
[8] = [7] + [1]
57+
[5] = [2] * [1]
58+
[5] = [2]
59+
output [2]
60+
"#]]
61+
.assert_eq(&format!("{:?}", context.circuit));
62+
63+
// The chain populated fresh vars [6], [7], [8] with values 3, 5, 6 respectively.
64+
assert_eq!(context.get(Var { idx: 6 }), M31::from(3u32).into());
65+
assert_eq!(context.get(Var { idx: 7 }), M31::from(5u32).into());
66+
assert_eq!(context.get(Var { idx: 8 }), M31::from(6u32).into());
67+
68+
context.circuit.check_yields();
69+
context.validate_circuit();
70+
}

crates/circuits/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub mod blake;
44
pub mod circuit;
55
pub mod context;
66
pub mod extract_bits;
7+
pub mod finalize_constants;
78
pub mod ivalue;
89
pub mod ops;
910
pub mod simd;

0 commit comments

Comments
 (0)