Skip to content

Commit 8bc157a

Browse files
committed
Build broadcast qm31.
1 parent 6a11158 commit 8bc157a

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

crates/circuits/src/finalize_constants.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ fn finalize_constants_with_min_base(context: &mut Context<impl IValue>, min_base
8888
// Decompose M31 constants not in the chain by expressing them in base `m31_base`.
8989
decompose_m31_constants(context, &mut m31_constants, &mut m31_cache, m31_base);
9090
assert!(m31_constants.is_empty());
91+
92+
// Deal with the QM31 constants.
93+
// Build `i` and `i * u` to get the qm31_basis [i, u, iu].
94+
let two = Var { idx: *m31_cache.get(&2.into()).unwrap() };
95+
let i_var = eval!(context, ((u_var) * (u_var)) - (two));
96+
let iu_var = eval!(context, (i_var) * (u_var));
97+
let qm31_basis: [Var; 3] = [i_var, u_var, iu_var];
98+
// Build the broadcast QM31 constants, i.e. constants of the form (x, x, x, x), x != 0.
99+
decompose_broadcast_constants(
100+
context,
101+
&mut qm31_constants,
102+
&mut m31_cache,
103+
m31_base,
104+
qm31_basis,
105+
);
91106
}
92107

93108
/// Finds the largest integer N such that all values in [0, N] are present as constants.
@@ -208,3 +223,32 @@ fn build_m31_from_base(
208223
assert!(m31_cache.contains_key(&val));
209224
assert!(!m31_constants.contains_key(&val));
210225
}
226+
227+
fn decompose_broadcast_constants(
228+
context: &mut Context<impl IValue>,
229+
qm31_constants: &mut IndexMap<QM31, Var>,
230+
m31_cache: &mut HashMap<M31, usize>,
231+
base: M31,
232+
qm31_basis: [Var; 3],
233+
) {
234+
let [i_var, u_var, iu_var] = qm31_basis;
235+
let one = context.one();
236+
let ones = eval!(context, ((one) + (i_var)) + ((u_var) + (iu_var)));
237+
238+
qm31_constants.retain(|qm31_value, qm31_var| {
239+
let is_broadcast = qm31_value.to_m31_array().iter().tuple_windows().all(|(x, y)| x == y);
240+
if !is_broadcast {
241+
return true;
242+
}
243+
let m31_value = qm31_value.0.0;
244+
// If m31_value is not in the cache, add it.
245+
if !m31_cache.contains_key(&m31_value) {
246+
build_m31_from_base(context, m31_cache, &mut IndexMap::new(), base, m31_value);
247+
}
248+
let m31_idx = *m31_cache.get(&m31_value).unwrap();
249+
// Add a gate m31_val * (1, 1, 1, 1) = qm31_var.
250+
context.circuit.mul.push(Mul { in0: m31_idx, in1: ones.idx, out: qm31_var.idx });
251+
// Remove the element from qm31_constants.
252+
false
253+
});
254+
}

crates/circuits/src/finalize_constants_test.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,13 @@ fn test_plus_one_chain_topology() {
4242
[4] = [6] + [1]
4343
[7] = [4] + [1]
4444
[8] = [7] + [1]
45+
[12] = [1] + [10]
46+
[13] = [2] + [11]
47+
[14] = [12] + [13]
48+
[10] = [9] - [3]
4549
[5] = [2] * [1]
50+
[9] = [2] * [2]
51+
[11] = [10] * [2]
4652
[5] = [2]
4753
output [2]
4854
"#]]
@@ -78,8 +84,14 @@ fn test_large_m31_decomposition() {
7884
[8] = [7] + [1]
7985
[9] = [8] + [5]
8086
[3] = [10] + [5]
87+
[14] = [1] + [12]
88+
[15] = [2] + [13]
89+
[16] = [14] + [15]
90+
[12] = [11] - [5]
8191
[4] = [2] * [1]
8292
[10] = [9] * [8]
93+
[11] = [2] * [2]
94+
[13] = [12] * [2]
8395
[4] = [2]
8496
output [2]
8597
"#]]
@@ -92,3 +104,50 @@ fn test_large_m31_decomposition() {
92104
context.circuit.check_yields();
93105
context.validate_circuit();
94106
}
107+
108+
#[test]
109+
fn test_broadcast_decomposition() {
110+
let mut context = TraceContext::default();
111+
// Add `u`.
112+
// TODO(Leo): remove this once `u` is added to the default constants.
113+
context.constant(qm31_from_u32s(0, 0, 1, 0));
114+
// Broadcast constant (11, 11, 11, 11) — should be yielded as 11 * (1, 1, 1, 1). Since 11 is
115+
// outside the chain (`min_base = 5`), the M31 factor 11 is itself built via base
116+
// decomposition: 11 = 2 * 5 + 1.
117+
context.constant(qm31_from_u32s(11, 11, 11, 11));
118+
finalize_constants_with_min_base(&mut context, 5);
119+
120+
// The plus-one chain populates [5]..=[8] for values 2..=5. The QM31 basis allocates [9] = u*u,
121+
// [10] = u² - 2 = i, [11] = i*u = iu. The ones vector is built as ([1] + [10]) + ([2] + [11])
122+
// yielding wires [12], [13], [14]. Then the M31 factor 11 is decomposed in base 5:
123+
// [15] = [5] * [8] (= 2 * 5 = 10) and [16] = [15] + [1] (= 11). Finally the broadcast is
124+
// yielded by [3] = [16] * [14] (11 * ones).
125+
expect![[r#"
126+
[0] = [0] + [0]
127+
[2] = [2] + [0]
128+
[1] = [1] + [0]
129+
[5] = [1] + [1]
130+
[6] = [5] + [1]
131+
[7] = [6] + [1]
132+
[8] = [7] + [1]
133+
[12] = [1] + [10]
134+
[13] = [2] + [11]
135+
[14] = [12] + [13]
136+
[16] = [15] + [1]
137+
[10] = [9] - [5]
138+
[4] = [2] * [1]
139+
[9] = [2] * [2]
140+
[11] = [10] * [2]
141+
[15] = [5] * [8]
142+
[3] = [16] * [14]
143+
[4] = [2]
144+
output [2]
145+
"#]]
146+
.assert_eq(&format!("{:?}", context.circuit));
147+
148+
assert_eq!(context.get(Var { idx: 14 }), qm31_from_u32s(1, 1, 1, 1));
149+
assert_eq!(context.get(Var { idx: 16 }), M31::from(11u32).into());
150+
assert_eq!(context.get(Var { idx: 3 }), qm31_from_u32s(11, 11, 11, 11));
151+
context.circuit.check_yields();
152+
context.validate_circuit();
153+
}

0 commit comments

Comments
 (0)