Skip to content

Commit bf3af12

Browse files
committed
Build broadcast qm31.
1 parent 12270af commit bf3af12

2 files changed

Lines changed: 123 additions & 0 deletions

File tree

crates/circuits/src/finalize_constants.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use stwo::core::fields::qm31::QM31;
77

88
use crate::circuit::{Add, Mul};
99
use crate::context::{Context, Var};
10+
use crate::eval;
1011
use crate::ivalue::{IValue, qm31_from_u32s};
1112
use crate::ops::{add, output};
1213

@@ -83,6 +84,23 @@ fn finalize_constants_with_min_base(context: &mut Context<impl IValue>, min_base
8384
// Decompose M31 constants not in the chain by expressing them in base `m31_base`.
8485
decompose_m31_constants(context, &mut m31_constants, &mut m31_cache, m31_base);
8586
assert!(m31_constants.is_empty());
87+
88+
// Deal with the QM31 constants.
89+
// Build `i` and `i * u` to get the qm31_basis [i, u, iu].
90+
// We know `2` is already produced because `min_base >= 2`.
91+
let two = Var { idx: *m31_cache.get(&2.into()).unwrap() };
92+
// `u * u = 2 + i`.
93+
let i_var = eval!(context, ((u_var) * (u_var)) - (two));
94+
let iu_var = eval!(context, (i_var) * (u_var));
95+
let qm31_basis: [Var; 3] = [i_var, u_var, iu_var];
96+
// Build the broadcast QM31 constants, i.e. constants of the form (x, x, x, x), x != 0.
97+
decompose_broadcast_constants(
98+
context,
99+
&mut qm31_constants,
100+
&mut m31_cache,
101+
m31_base,
102+
qm31_basis,
103+
);
86104
}
87105

88106
/// Finds the largest integer N such that all values in [0, N] are present as constants.
@@ -210,3 +228,50 @@ fn build_m31_from_base(
210228
assert!(m31_cache.contains_key(&val));
211229
assert!(!m31_constants.contains_key(&val));
212230
}
231+
232+
/// Yields and constrains every "broadcast" QM31 constant in `qm31_constants` (i.e. constants of
233+
/// the form `(x, x, x, x)` with `x != 0`) by expressing them as `x * (1, 1, 1, 1)`.
234+
///
235+
/// For each broadcast constant, the M31 scalar `x` is retrieved from `m31_cache` (and built via
236+
/// `build_m31_from_base` if missing), and a single Mul gate `x * (1,1,1,1) = (x,x,x,x)` yields and
237+
/// constrains the constant's wire.
238+
///
239+
/// Non-broadcast entries of `qm31_constants` are left untouched; broadcast entries are removed
240+
/// once yielded.
241+
fn decompose_broadcast_constants(
242+
context: &mut Context<impl IValue>,
243+
qm31_constants: &mut IndexMap<QM31, Var>,
244+
m31_cache: &mut HashMap<M31, usize>,
245+
base: M31,
246+
qm31_basis: [Var; 3],
247+
) {
248+
let one = context.one();
249+
let [i_var, u_var, iu_var] = qm31_basis;
250+
// Build and constrain the wire corresponding to (1, 1, 1, 1).
251+
let ones_value = qm31_from_u32s(1, 1, 1, 1);
252+
let ones_var = if let Some(var) = qm31_constants.swap_remove(&ones_value) {
253+
var
254+
} else {
255+
context.new_var(IValue::from_qm31(ones_value))
256+
};
257+
let one_plus_i = add(context, one, i_var);
258+
let u_plus_iu = add(context, u_var, iu_var);
259+
context.circuit.add.push(Add { in0: one_plus_i.idx, in1: u_plus_iu.idx, out: ones_var.idx });
260+
261+
qm31_constants.retain(|qm31_value, qm31_var| {
262+
let is_broadcast = qm31_value.to_m31_array().iter().tuple_windows().all(|(x, y)| x == y);
263+
if !is_broadcast {
264+
return true;
265+
}
266+
let m31_value = qm31_value.0.0;
267+
// If m31_value is not in the cache, add it.
268+
if !m31_cache.contains_key(&m31_value) {
269+
build_m31_from_base(context, m31_cache, &mut IndexMap::new(), base, m31_value);
270+
}
271+
let m31_idx = *m31_cache.get(&m31_value).unwrap();
272+
// Add a gate m31_val * (1, 1, 1, 1) = qm31_var.
273+
context.circuit.mul.push(Mul { in0: m31_idx, in1: ones_var.idx, out: qm31_var.idx });
274+
// Remove the element from qm31_constants.
275+
false
276+
});
277+
}

crates/circuits/src/finalize_constants_test.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ fn test_plus_one_chain_topology() {
4141
[4] = [5] + [1]
4242
[6] = [4] + [1]
4343
[7] = [6] + [1]
44+
[12] = [1] + [9]
45+
[13] = [2] + [10]
46+
[11] = [12] + [13]
47+
[9] = [8] - [3]
4448
[2] = [2] * [1]
49+
[8] = [2] * [2]
50+
[10] = [9] * [2]
4551
output [2]
4652
"#]]
4753
.assert_eq(&format!("{:?}", context.circuit));
@@ -75,8 +81,14 @@ fn test_large_m31_decomposition() {
7581
[7] = [6] + [1]
7682
[8] = [7] + [4]
7783
[3] = [9] + [4]
84+
[14] = [1] + [11]
85+
[15] = [2] + [12]
86+
[13] = [14] + [15]
87+
[11] = [10] - [4]
7888
[2] = [2] * [1]
7989
[9] = [8] * [7]
90+
[10] = [2] * [2]
91+
[12] = [11] * [2]
8092
output [2]
8193
"#]]
8294
.assert_eq(&format!("{:?}", context.circuit));
@@ -88,3 +100,49 @@ fn test_large_m31_decomposition() {
88100
context.circuit.check_yields();
89101
context.validate_circuit();
90102
}
103+
104+
#[test]
105+
fn test_broadcast_decomposition() {
106+
let mut context = TraceContext::default();
107+
// Add `u`.
108+
// TODO(Leo): remove this once `u` is added to the default constants.
109+
context.constant(qm31_from_u32s(0, 0, 1, 0));
110+
// Broadcast constant (11, 11, 11, 11) — should be yielded as 11 * (1, 1, 1, 1). Since 11 is
111+
// outside the chain (`min_base = 5`), the M31 factor 11 is itself built via base
112+
// decomposition: 11 = 2 * 5 + 1.
113+
context.constant(qm31_from_u32s(11, 11, 11, 11));
114+
finalize_constants_with_min_base(&mut context, 5);
115+
116+
// The plus-one chain populates [4]..=[7] for values 2..=5. The QM31 basis allocates
117+
// [8] = u*u, [9] = u² - 2 = i, [10] = i*u = iu. The ones vector (1, 1, 1, 1) lands in [11],
118+
// built as ([1] + [9]) + ([2] + [10]), with the partial sums in [12] and [13]. Then the M31
119+
// factor 11 is decomposed in base 5 (11 = 2*5 + 1): [14] = [4] * [7] (= 2 * 5 = 10) and
120+
// [15] = [14] + [1] (= 11). Finally the broadcast is yielded by [3] = [15] * [11]
121+
// (11 * ones).
122+
expect![[r#"
123+
[0] = [0] + [0]
124+
[1] = [1] + [0]
125+
[4] = [1] + [1]
126+
[5] = [4] + [1]
127+
[6] = [5] + [1]
128+
[7] = [6] + [1]
129+
[12] = [1] + [9]
130+
[13] = [2] + [10]
131+
[11] = [12] + [13]
132+
[15] = [14] + [1]
133+
[9] = [8] - [4]
134+
[2] = [2] * [1]
135+
[8] = [2] * [2]
136+
[10] = [9] * [2]
137+
[14] = [4] * [7]
138+
[3] = [15] * [11]
139+
output [2]
140+
"#]]
141+
.assert_eq(&format!("{:?}", context.circuit));
142+
143+
assert_eq!(context.get(Var { idx: 11 }), qm31_from_u32s(1, 1, 1, 1));
144+
assert_eq!(context.get(Var { idx: 15 }), M31::from(11u32).into());
145+
assert_eq!(context.get(Var { idx: 3 }), qm31_from_u32s(11, 11, 11, 11));
146+
context.circuit.check_yields();
147+
context.validate_circuit();
148+
}

0 commit comments

Comments
 (0)