@@ -9,6 +9,7 @@ namespace alu;
99pol commit sel;
1010
1111pol commit sel_op_add;
12+ pol commit sel_op_mul;
1213pol commit sel_op_eq;
1314pol commit sel_op_lt;
1415pol commit sel_op_lte;
@@ -37,24 +38,28 @@ pol commit cf;
3738pol commit helper1;
3839
3940// maximum bits the number can hold (i.e. 8 for a u8):
40- // TODO(MW): Now unused since we redirect the LT/LTE range checks to GT gadget - remove?
4141pol commit max_bits;
4242// maximum value the number can hold (i.e. 255 for a u8), we 'mod' by max_value + 1
4343pol commit max_value;
4444// we need a selector to conditionally lookup ff_gt when inputs a, b are fields:
4545pol commit sel_is_ff;
46+ // we need a selector to conditionally perform u128 multiplication:
47+ pol commit sel_is_u128;
4648
4749pol IS_NOT_FF = 1 - sel_is_ff;
50+ pol IS_NOT_U128 = 1 - sel_is_u128;
4851
4952sel * (1 - sel) = 0;
5053cf * (1 - cf) = 0;
5154sel_is_ff * (1 - sel_is_ff) = 0;
55+ sel_is_u128 * (1 - sel_is_u128) = 0;
5256
5357// TODO: Consider to gate with (1 - sel_tag_err) for op_id. This might help us remove the (1 - sel_tag_err)
5458// in various operation relations below.
5559// Note that the op_ids below represent a binary decomposition (see constants_gen.pil):
5660#[OP_ID_CHECK]
5761op_id = sel_op_add * constants.AVM_EXEC_OP_ID_ALU_ADD
62+ + sel_op_mul * constants.AVM_EXEC_OP_ID_ALU_MUL
5863 + sel_op_eq * constants.AVM_EXEC_OP_ID_ALU_EQ
5964 + sel_op_lt * constants.AVM_EXEC_OP_ID_ALU_LT
6065 + sel_op_lte * constants.AVM_EXEC_OP_ID_ALU_LTE
@@ -78,18 +83,28 @@ execution.sel_execute_alu {
7883
7984// IS_FF CHECKING
8085
81- // TODO(MW): remove this and check for all (i.e. replace with sel), just being lazy for now. For add, we don't care, for lt we need to differentiate.
8286pol CHECK_TAG_FF = sel_op_lt + sel_op_lte + sel_op_not;
8387// We prove that sel_is_ff == 1 <==> ia_tag == MEM_TAG_FF
8488pol TAG_FF_DIFF = ia_tag - constants.MEM_TAG_FF;
8589pol commit tag_ff_diff_inv;
8690#[TAG_IS_FF]
8791CHECK_TAG_FF * (TAG_FF_DIFF * (sel_is_ff * (1 - tag_ff_diff_inv) + tag_ff_diff_inv) + sel_is_ff - 1) = 0;
8892
93+ // IS_U128 CHECKING
94+
95+ pol CHECK_TAG_U128 = sel_op_mul;
96+ // We prove that sel_is_u128 == 1 <==> ia_tag == MEM_TAG_U128
97+ pol TAG_U128_DIFF = ia_tag - constants.MEM_TAG_U128;
98+ pol commit tag_u128_diff_inv;
99+ #[TAG_IS_U128]
100+ CHECK_TAG_U128 * (TAG_U128_DIFF * (sel_is_u128 * (1 - tag_u128_diff_inv) + tag_u128_diff_inv) + sel_is_u128 - 1) = 0;
101+
102+ // Note: if we never need sel_is_ff and sel_is_u128 in the same op, can combine the above checks into one
103+
89104// TAG CHECKING
90105
91106// Will become e.g. sel_op_add * ia_tag + (comparison ops) * MEM_TAG_U1 + ....
92- pol EXPECTED_C_TAG = (sel_op_add + sel_op_truncate) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
107+ pol EXPECTED_C_TAG = (sel_op_add + sel_op_truncate + sel_op_mul ) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
93108
94109// The tag of c is generated by the opcode and is never wrong.
95110// Gating with (1 - sel_tag_err) is necessary because when an error occurs, we have to set the tag to 0,
@@ -123,6 +138,81 @@ sel_op_add * (1 - sel_op_add) = 0;
123138#[ALU_ADD]
124139sel_op_add * (1 - sel_tag_err) * (ia + ib - ic - cf * (max_value + 1)) = 0;
125140
141+ // MUL
142+
143+ sel_op_mul * (1 - sel_op_mul) = 0;
144+
145+ pol commit c_hi;
146+
147+ // MUL - non u128
148+
149+ #[ALU_MUL_NON_U128]
150+ sel_op_mul * IS_NOT_U128 * (1 - sel_tag_err)
151+ * (
152+ ia * ib
153+ - ic
154+ - (max_value + 1) * c_hi
155+ ) = 0;
156+
157+ // MUL - u128
158+
159+ pol commit sel_mul_u128;
160+ // sel_op_mul & sel_is_u128:
161+ sel_mul_u128 - sel_is_u128 * sel_op_mul = 0;
162+
163+ // Taken from vm1:
164+ // We express a, b in 64-bit slices: a = a_l + a_h * 2^64
165+ // b = b_l + b_h * 2^64
166+ // => a * b = a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 + (a_h * b_h) * 2^128 = c_hi_full * 2^128 + c
167+ // => the 'top bits' are given by (c_hi_full - (a_h * b_h)) * 2^128
168+ // We can show for a 64 bit c_hi = c_hi_full - (a_h * b_h) % 2^64 that:
169+ // a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 = c_hi * 2^128 + c
170+ // Equivalently (cf = 0 if a_h & b_h = 0):
171+ // a * b_l + a_l * b_h * 2^64 = (cf * 2^64 + c_hi) * 2^128 + c
172+ // => no need for a_h in final relation
173+
174+ pol commit a_lo;
175+ pol commit a_hi;
176+ pol commit b_lo;
177+ pol commit b_hi;
178+ pol TWO_POW_64 = 2 ** 64;
179+
180+ #[A_MUL_DECOMPOSITION]
181+ sel_mul_u128 * (ia - (a_lo + TWO_POW_64 * a_hi)) = 0;
182+ #[B_MUL_DECOMPOSITION]
183+ sel_mul_u128 * (ib - (b_lo + TWO_POW_64 * b_hi)) = 0;
184+
185+ #[ALU_MUL_U128]
186+ sel_mul_u128 * (1 - sel_tag_err)
187+ * (
188+ ia * b_lo + a_lo * b_hi * TWO_POW_64 // a * b without the hi bits
189+ - ic // c_lo
190+ - (max_value + 1) * (cf * TWO_POW_64 + c_hi) // c_hi * 2^128 + (cf ? 2^192 : 0)
191+ ) = 0;
192+
193+ // TODO: Once lookups support expression in tuple, we can inline constant_64 into the lookup.
194+ // Note: only used for MUL, so gated by sel_op_mul
195+ pol commit constant_64;
196+ sel_op_mul * (64 - constant_64) = 0;
197+
198+ #[RANGE_CHECK_MUL_U128_A_LO]
199+ sel_mul_u128 { a_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
200+
201+ #[RANGE_CHECK_MUL_U128_A_HI]
202+ sel_mul_u128 { a_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
203+
204+ #[RANGE_CHECK_MUL_U128_B_LO]
205+ sel_mul_u128 { b_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
206+
207+ #[RANGE_CHECK_MUL_U128_B_HI]
208+ sel_mul_u128 { b_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
209+
210+ // No need to range_check c_hi for cases other than u128 because we know a and b's size from the tags and have looked
211+ // up max_value. i.e. we cannot provide a malicious c, c_hi such that a + b - c_hi * 2^n = c passes for n < 128.
212+ // No need to range_check c_lo = ic because the memory write will ensure ic <= max_value.
213+ #[RANGE_CHECK_MUL_U128_C_HI]
214+ sel_mul_u128 { c_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
215+
126216// EQ
127217
128218sel_op_eq * (1 - sel_op_eq) = 0;
@@ -280,21 +370,21 @@ sel_op_truncate = sel_trunc_non_trivial + sel_trunc_trivial;
280370#[TRUNC_TRIVIAL_CASE]
281371sel_trunc_trivial * (ia - ic) = 0;
282372
283- pol commit lo_128; // 128-bit low limb of ia.
284- pol commit hi_128; // 128-bit high limb of ia.
373+ // NOTE: reusing a_lo and a_hi columns from MUL in TRUNC:
374+ // For truncate, a_lo = 128-bit low limb of ia and a_hi = 128-bit high limb of ia.
285375pol commit mid;
286376
287377#[LARGE_TRUNC_CANONICAL_DEC]
288- sel_trunc_gte_128 { ia, lo_128, hi_128 }
378+ sel_trunc_gte_128 { ia, a_lo, a_hi }
289379in
290380ff_gt.sel_dec { ff_gt.a, ff_gt.a_lo, ff_gt.a_hi };
291381
292382#[SMALL_TRUNC_VAL_IS_LO]
293- sel_trunc_lt_128 * (lo_128 - ia) = 0;
383+ sel_trunc_lt_128 * (a_lo - ia) = 0;
294384
295- // lo_128 = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
385+ // a_lo = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
296386#[TRUNC_LO_128_DECOMPOSITION]
297- sel_trunc_non_trivial * (ic + mid * (max_value + 1) - lo_128 ) = 0;
387+ sel_trunc_non_trivial * (ic + mid * (max_value + 1) - a_lo ) = 0;
298388
299389// TODO: Once lookups support expression in tuple, we can inline mid_bits into the lookup.
300390pol commit mid_bits;
@@ -305,4 +395,4 @@ mid_bits = sel_trunc_non_trivial * (128 - max_bits);
305395// is supported by our range_check gadget.
306396// No need to range_check ic because the memory write will ensure ic <= max_value.
307397#[RANGE_CHECK_TRUNC_MID]
308- sel_trunc_non_trivial {mid, mid_bits} in range_check.sel { range_check.value, range_check.rng_chk_bits };
398+ sel_trunc_non_trivial { mid, mid_bits } in range_check.sel { range_check.value, range_check.rng_chk_bits };
0 commit comments