Skip to content

Commit a6fa349

Browse files
AztecBotAztecBotMirandaWoodfcarreiro
authored
feat: merge-train/avm (#15940)
See [merge-train-readme.md](https://github.com/AztecProtocol/aztec-packages/blob/next/.github/workflows/merge-train-readme.md). BEGIN_COMMIT_OVERRIDE feat(avm)!: ALU MUL (#15880) fix(avm): catch out of gas in execution loop (#15994) END_COMMIT_OVERRIDE --------- Co-authored-by: AztecBot <[email protected]> Co-authored-by: Miranda Wood <[email protected]> Co-authored-by: Facundo <[email protected]>
1 parent 052658f commit a6fa349

26 files changed

+1054
-152
lines changed

barretenberg/cpp/pil/vm2/alu.pil

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace alu;
99
pol commit sel;
1010

1111
pol commit sel_op_add;
12+
pol commit sel_op_mul;
1213
pol commit sel_op_eq;
1314
pol commit sel_op_lt;
1415
pol commit sel_op_lte;
@@ -37,24 +38,28 @@ pol commit cf;
3738
pol commit helper1;
3839

3940
// maximum bits the number can hold (i.e. 8 for a u8):
40-
// TODO(MW): Now unused since we redirect the LT/LTE range checks to GT gadget - remove?
4141
pol commit max_bits;
4242
// maximum value the number can hold (i.e. 255 for a u8), we 'mod' by max_value + 1
4343
pol commit max_value;
4444
// we need a selector to conditionally lookup ff_gt when inputs a, b are fields:
4545
pol commit sel_is_ff;
46+
// we need a selector to conditionally perform u128 multiplication:
47+
pol commit sel_is_u128;
4648

4749
pol IS_NOT_FF = 1 - sel_is_ff;
50+
pol IS_NOT_U128 = 1 - sel_is_u128;
4851

4952
sel * (1 - sel) = 0;
5053
cf * (1 - cf) = 0;
5154
sel_is_ff * (1 - sel_is_ff) = 0;
55+
sel_is_u128 * (1 - sel_is_u128) = 0;
5256

5357
// TODO: Consider to gate with (1 - sel_tag_err) for op_id. This might help us remove the (1 - sel_tag_err)
5458
// in various operation relations below.
5559
// Note that the op_ids below represent a binary decomposition (see constants_gen.pil):
5660
#[OP_ID_CHECK]
5761
op_id = sel_op_add * constants.AVM_EXEC_OP_ID_ALU_ADD
62+
+ sel_op_mul * constants.AVM_EXEC_OP_ID_ALU_MUL
5863
+ sel_op_eq * constants.AVM_EXEC_OP_ID_ALU_EQ
5964
+ sel_op_lt * constants.AVM_EXEC_OP_ID_ALU_LT
6065
+ sel_op_lte * constants.AVM_EXEC_OP_ID_ALU_LTE
@@ -78,18 +83,28 @@ execution.sel_execute_alu {
7883

7984
// IS_FF CHECKING
8085

81-
// TODO(MW): remove this and check for all (i.e. replace with sel), just being lazy for now. For add, we don't care, for lt we need to differentiate.
8286
pol CHECK_TAG_FF = sel_op_lt + sel_op_lte + sel_op_not;
8387
// We prove that sel_is_ff == 1 <==> ia_tag == MEM_TAG_FF
8488
pol TAG_FF_DIFF = ia_tag - constants.MEM_TAG_FF;
8589
pol commit tag_ff_diff_inv;
8690
#[TAG_IS_FF]
8791
CHECK_TAG_FF * (TAG_FF_DIFF * (sel_is_ff * (1 - tag_ff_diff_inv) + tag_ff_diff_inv) + sel_is_ff - 1) = 0;
8892

93+
// IS_U128 CHECKING
94+
95+
pol CHECK_TAG_U128 = sel_op_mul;
96+
// We prove that sel_is_u128 == 1 <==> ia_tag == MEM_TAG_U128
97+
pol TAG_U128_DIFF = ia_tag - constants.MEM_TAG_U128;
98+
pol commit tag_u128_diff_inv;
99+
#[TAG_IS_U128]
100+
CHECK_TAG_U128 * (TAG_U128_DIFF * (sel_is_u128 * (1 - tag_u128_diff_inv) + tag_u128_diff_inv) + sel_is_u128 - 1) = 0;
101+
102+
// Note: if we never need sel_is_ff and sel_is_u128 in the same op, can combine the above checks into one
103+
89104
// TAG CHECKING
90105

91106
// Will become e.g. sel_op_add * ia_tag + (comparison ops) * MEM_TAG_U1 + ....
92-
pol EXPECTED_C_TAG = (sel_op_add + sel_op_truncate) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
107+
pol EXPECTED_C_TAG = (sel_op_add + sel_op_truncate + sel_op_mul) * ia_tag + (sel_op_eq + sel_op_lt + sel_op_lte) * constants.MEM_TAG_U1;
93108

94109
// The tag of c is generated by the opcode and is never wrong.
95110
// Gating with (1 - sel_tag_err) is necessary because when an error occurs, we have to set the tag to 0,
@@ -123,6 +138,81 @@ sel_op_add * (1 - sel_op_add) = 0;
123138
#[ALU_ADD]
124139
sel_op_add * (1 - sel_tag_err) * (ia + ib - ic - cf * (max_value + 1)) = 0;
125140

141+
// MUL
142+
143+
sel_op_mul * (1 - sel_op_mul) = 0;
144+
145+
pol commit c_hi;
146+
147+
// MUL - non u128
148+
149+
#[ALU_MUL_NON_U128]
150+
sel_op_mul * IS_NOT_U128 * (1 - sel_tag_err)
151+
* (
152+
ia * ib
153+
- ic
154+
- (max_value + 1) * c_hi
155+
) = 0;
156+
157+
// MUL - u128
158+
159+
pol commit sel_mul_u128;
160+
// sel_op_mul & sel_is_u128:
161+
sel_mul_u128 - sel_is_u128 * sel_op_mul = 0;
162+
163+
// Taken from vm1:
164+
// We express a, b in 64-bit slices: a = a_l + a_h * 2^64
165+
// b = b_l + b_h * 2^64
166+
// => a * b = a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 + (a_h * b_h) * 2^128 = c_hi_full * 2^128 + c
167+
// => the 'top bits' are given by (c_hi_full - (a_h * b_h)) * 2^128
168+
// We can show for a 64 bit c_hi = c_hi_full - (a_h * b_h) % 2^64 that:
169+
// a_l * b_l + (a_h * b_l + a_l * b_h) * 2^64 = c_hi * 2^128 + c
170+
// Equivalently (cf = 0 if a_h & b_h = 0):
171+
// a * b_l + a_l * b_h * 2^64 = (cf * 2^64 + c_hi) * 2^128 + c
172+
// => no need for a_h in final relation
173+
174+
pol commit a_lo;
175+
pol commit a_hi;
176+
pol commit b_lo;
177+
pol commit b_hi;
178+
pol TWO_POW_64 = 2 ** 64;
179+
180+
#[A_MUL_DECOMPOSITION]
181+
sel_mul_u128 * (ia - (a_lo + TWO_POW_64 * a_hi)) = 0;
182+
#[B_MUL_DECOMPOSITION]
183+
sel_mul_u128 * (ib - (b_lo + TWO_POW_64 * b_hi)) = 0;
184+
185+
#[ALU_MUL_U128]
186+
sel_mul_u128 * (1 - sel_tag_err)
187+
* (
188+
ia * b_lo + a_lo * b_hi * TWO_POW_64 // a * b without the hi bits
189+
- ic // c_lo
190+
- (max_value + 1) * (cf * TWO_POW_64 + c_hi) // c_hi * 2^128 + (cf ? 2^192 : 0)
191+
) = 0;
192+
193+
// TODO: Once lookups support expression in tuple, we can inline constant_64 into the lookup.
194+
// Note: only used for MUL, so gated by sel_op_mul
195+
pol commit constant_64;
196+
sel_op_mul * (64 - constant_64) = 0;
197+
198+
#[RANGE_CHECK_MUL_U128_A_LO]
199+
sel_mul_u128 { a_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
200+
201+
#[RANGE_CHECK_MUL_U128_A_HI]
202+
sel_mul_u128 { a_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
203+
204+
#[RANGE_CHECK_MUL_U128_B_LO]
205+
sel_mul_u128 { b_lo, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
206+
207+
#[RANGE_CHECK_MUL_U128_B_HI]
208+
sel_mul_u128 { b_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
209+
210+
// No need to range_check c_hi for cases other than u128 because we know a and b's size from the tags and have looked
211+
// up max_value. i.e. we cannot provide a malicious c, c_hi such that a + b - c_hi * 2^n = c passes for n < 128.
212+
// No need to range_check c_lo = ic because the memory write will ensure ic <= max_value.
213+
#[RANGE_CHECK_MUL_U128_C_HI]
214+
sel_mul_u128 { c_hi, constant_64 } in range_check.sel { range_check.value, range_check.rng_chk_bits };
215+
126216
// EQ
127217

128218
sel_op_eq * (1 - sel_op_eq) = 0;
@@ -280,21 +370,21 @@ sel_op_truncate = sel_trunc_non_trivial + sel_trunc_trivial;
280370
#[TRUNC_TRIVIAL_CASE]
281371
sel_trunc_trivial * (ia - ic) = 0;
282372

283-
pol commit lo_128; // 128-bit low limb of ia.
284-
pol commit hi_128; // 128-bit high limb of ia.
373+
// NOTE: reusing a_lo and a_hi columns from MUL in TRUNC:
374+
// For truncate, a_lo = 128-bit low limb of ia and a_hi = 128-bit high limb of ia.
285375
pol commit mid;
286376

287377
#[LARGE_TRUNC_CANONICAL_DEC]
288-
sel_trunc_gte_128 { ia, lo_128, hi_128 }
378+
sel_trunc_gte_128 { ia, a_lo, a_hi }
289379
in
290380
ff_gt.sel_dec { ff_gt.a, ff_gt.a_lo, ff_gt.a_hi };
291381

292382
#[SMALL_TRUNC_VAL_IS_LO]
293-
sel_trunc_lt_128 * (lo_128 - ia) = 0;
383+
sel_trunc_lt_128 * (a_lo - ia) = 0;
294384

295-
// lo_128 = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
385+
// a_lo = ic + mid * 2^ia_tag_bits, where 2^ia_tag_bits is max_value + 1.
296386
#[TRUNC_LO_128_DECOMPOSITION]
297-
sel_trunc_non_trivial * (ic + mid * (max_value + 1) - lo_128) = 0;
387+
sel_trunc_non_trivial * (ic + mid * (max_value + 1) - a_lo) = 0;
298388

299389
// TODO: Once lookups support expression in tuple, we can inline mid_bits into the lookup.
300390
pol commit mid_bits;
@@ -305,4 +395,4 @@ mid_bits = sel_trunc_non_trivial * (128 - max_bits);
305395
// is supported by our range_check gadget.
306396
// No need to range_check ic because the memory write will ensure ic <= max_value.
307397
#[RANGE_CHECK_TRUNC_MID]
308-
sel_trunc_non_trivial {mid, mid_bits} in range_check.sel { range_check.value, range_check.rng_chk_bits };
398+
sel_trunc_non_trivial { mid, mid_bits } in range_check.sel { range_check.value, range_check.rng_chk_bits };

barretenberg/cpp/src/barretenberg/vm2/common/instruction_spec.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,13 @@ const std::unordered_map<ExecutionOpCode, ExecInstructionSpec> EXEC_INSTRUCTION_
431431
.add_inputs({ /*a*/ RegisterInfo::ANY_TAG,
432432
/*b*/ RegisterInfo::ANY_TAG })
433433
.add_output(/*c*/) } },
434+
{ ExecutionOpCode::MUL,
435+
{ .num_addresses = 3,
436+
.gas_cost = { .opcode_gas = AVM_MUL_BASE_L2_GAS, .base_da = 0, .dyn_l2 = 0, .dyn_da = 0 },
437+
.register_info = RegisterInfo()
438+
.add_inputs({ /*a*/ RegisterInfo::ANY_TAG,
439+
/*b*/ RegisterInfo::ANY_TAG })
440+
.add_output(/*c*/) } },
434441
{ ExecutionOpCode::EQ,
435442
{ .num_addresses = 3,
436443
.gas_cost = { .opcode_gas = AVM_EQ_BASE_L2_GAS, .base_da = 0, .dyn_l2 = 0, .dyn_da = 0 },

0 commit comments

Comments
 (0)