|
| 1 | +use num_traits::One; |
| 2 | +use stwo_prover::constraint_framework::{logup::LookupElements, EvalAtRow}; |
| 3 | + |
| 4 | +use nexus_vm::{riscv::BuiltinOpcode, WORD_SIZE}; |
| 5 | + |
| 6 | +use crate::{ |
| 7 | + column::Column::{self, *}, |
| 8 | + components::MAX_LOOKUP_TUPLE_SIZE, |
| 9 | + trace::{ |
| 10 | + eval::{trace_eval, TraceEval}, |
| 11 | + sidenote::SideNote, |
| 12 | + BoolWord, ProgramStep, Traces, Word, |
| 13 | + }, |
| 14 | + traits::{ExecuteChip, MachineChip}, |
| 15 | +}; |
| 16 | + |
| 17 | +use super::add::{self}; |
| 18 | + |
| 19 | +pub struct ExecutionResult { |
| 20 | + pub diff_bytes: Word, |
| 21 | + pub borrow_bits: BoolWord, |
| 22 | + pub pc_next: Word, |
| 23 | + pub carry_bits: BoolWord, |
| 24 | +} |
| 25 | + |
| 26 | +pub struct BgeuChip; |
| 27 | + |
| 28 | +impl ExecuteChip for BgeuChip { |
| 29 | + type ExecutionResult = ExecutionResult; |
| 30 | + |
| 31 | + fn execute(program_step: &ProgramStep) -> Self::ExecutionResult { |
| 32 | + let value_a = program_step.get_value_a(); |
| 33 | + let value_b = program_step.get_value_b(); |
| 34 | + let imm = program_step.get_value_c().0; |
| 35 | + let pc = program_step.step.pc.to_le_bytes(); |
| 36 | + |
| 37 | + let (diff_bytes, borrow_bits) = super::sub::subtract_with_borrow(value_a, value_b); |
| 38 | + |
| 39 | + // ltu_flag is equal to borrow_bit[3] |
| 40 | + let (pc_next, carry_bits) = if borrow_bits[3] { |
| 41 | + // a < b is true: pc_next = pc + 4 |
| 42 | + add::add_with_carries(pc, 4u32.to_le_bytes()) |
| 43 | + } else { |
| 44 | + // a >= b is true: pc_next = pc + imm |
| 45 | + add::add_with_carries(pc, imm) |
| 46 | + }; |
| 47 | + |
| 48 | + ExecutionResult { |
| 49 | + diff_bytes, |
| 50 | + borrow_bits, |
| 51 | + pc_next, |
| 52 | + carry_bits, |
| 53 | + } |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +impl MachineChip for BgeuChip { |
| 58 | + fn fill_main_trace( |
| 59 | + traces: &mut Traces, |
| 60 | + row_idx: usize, |
| 61 | + vm_step: &ProgramStep, |
| 62 | + _side_note: &mut SideNote, |
| 63 | + ) { |
| 64 | + if !matches!( |
| 65 | + vm_step.step.instruction.opcode.builtin(), |
| 66 | + Some(BuiltinOpcode::BGEU) |
| 67 | + ) { |
| 68 | + return; |
| 69 | + } |
| 70 | + |
| 71 | + let ExecutionResult { |
| 72 | + diff_bytes, |
| 73 | + borrow_bits, |
| 74 | + pc_next, |
| 75 | + carry_bits, |
| 76 | + } = Self::execute(vm_step); |
| 77 | + |
| 78 | + traces.fill_columns(row_idx, diff_bytes, Column::Helper1); |
| 79 | + traces.fill_columns(row_idx, borrow_bits, Column::BorrowFlag); |
| 80 | + |
| 81 | + // Fill valueA |
| 82 | + traces.fill_columns(row_idx, vm_step.get_value_a(), Column::ValueA); |
| 83 | + |
| 84 | + // Fill PcNext and CarryFlag, since Pc and Immediate are filled to the main trace in CPU. |
| 85 | + traces.fill_columns(row_idx, pc_next, Column::PcNext); |
| 86 | + traces.fill_columns(row_idx, carry_bits, Column::CarryFlag); |
| 87 | + } |
| 88 | + |
| 89 | + fn add_constraints<E: EvalAtRow>( |
| 90 | + eval: &mut E, |
| 91 | + trace_eval: &TraceEval<E>, |
| 92 | + _lookup_elements: &LookupElements<MAX_LOOKUP_TUPLE_SIZE>, |
| 93 | + ) { |
| 94 | + let modulus = E::F::from(256u32.into()); |
| 95 | + let (value_a, _) = trace_eval!(trace_eval, ValueA); |
| 96 | + let (value_b, _) = trace_eval!(trace_eval, ValueB); |
| 97 | + let (value_c, _) = trace_eval!(trace_eval, ValueC); |
| 98 | + let (pc, _) = trace_eval!(trace_eval, Column::Pc); |
| 99 | + let (carry_bits, _) = trace_eval!(trace_eval, Column::CarryFlag); |
| 100 | + let (borrow_bits, _) = trace_eval!(trace_eval, Column::BorrowFlag); |
| 101 | + let (diff_bytes, _) = trace_eval!(trace_eval, Column::Helper1); |
| 102 | + let (pc_next, _) = trace_eval!(trace_eval, Column::PcNext); |
| 103 | + let ([is_bgeu], _) = trace_eval!(trace_eval, Column::IsBgeu); |
| 104 | + let ltu_flag = borrow_bits[3].clone(); |
| 105 | + |
| 106 | + // is_bgeu・(a_val_1 - b_val_1 - h1_1 + borrow_1・2^8) = 0 |
| 107 | + // is_bgeu・(a_val_2 - b_val_2 - h1_2 + borrow_2・2^8 - borrow_1) = 0 |
| 108 | + // is_bgeu・(a_val_3 - b_val_3 - h1_3 + borrow_3・2^8 - borrow_2) = 0 |
| 109 | + // is_bgeu・(a_val_4 - b_val_4 - h1_4 + ltu_flag・2^8 - borrow_3) = 0 |
| 110 | + eval.add_constraint( |
| 111 | + is_bgeu.clone() |
| 112 | + * (value_a[0].clone() - value_b[0].clone() - diff_bytes[0].clone() |
| 113 | + + borrow_bits[0].clone() * modulus.clone()), |
| 114 | + ); |
| 115 | + for i in 1..WORD_SIZE { |
| 116 | + eval.add_constraint( |
| 117 | + is_bgeu.clone() |
| 118 | + * (value_a[i].clone() - value_b[i].clone() - diff_bytes[i].clone() |
| 119 | + + borrow_bits[i].clone() * modulus.clone() |
| 120 | + - borrow_bits[i - 1].clone()), |
| 121 | + ); |
| 122 | + } |
| 123 | + |
| 124 | + // is_bgeu・( (1 - ltu_flag)・c_val_1 + ltu_flag・4 + pc_1 - carry_1·2^8 - pc_next_1) =0 |
| 125 | + // is_bgeu・( (1 - ltu_flag)・c_val_2 + pc_2 + carry_1 - carry_2·2^8 - pc_next_2) = 0 |
| 126 | + // is_bgeu・( (1 - ltu_flag)・c_val_3 + pc_3 + carry_2 - carry_3·2^8 - pc_next_3) = 0 |
| 127 | + // is_bgeu・( (1 - ltu_flag)・c_val_4 + pc_4 + carry_3 - carry_4·2^8 - pc_next_4) = 0 |
| 128 | + eval.add_constraint( |
| 129 | + is_bgeu.clone() |
| 130 | + * ((E::F::one() - ltu_flag.clone()) * value_c[0].clone() |
| 131 | + + ltu_flag.clone() * E::F::from(4u32.into()) |
| 132 | + + pc[0].clone() |
| 133 | + - carry_bits[0].clone() * modulus.clone() |
| 134 | + - pc_next[0].clone()), |
| 135 | + ); |
| 136 | + for i in 1..WORD_SIZE { |
| 137 | + eval.add_constraint( |
| 138 | + is_bgeu.clone() |
| 139 | + * ((E::F::one() - ltu_flag.clone()) * value_c[i].clone() |
| 140 | + + pc[i].clone() |
| 141 | + + carry_bits[i - 1].clone() |
| 142 | + - carry_bits[i].clone() * modulus.clone() |
| 143 | + - pc_next[i].clone()), |
| 144 | + ); |
| 145 | + } |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +#[cfg(test)] |
| 150 | +mod test { |
| 151 | + use crate::{ |
| 152 | + chips::{AddChip, CpuChip, RegisterMemCheckChip, SubChip}, |
| 153 | + test_utils::assert_chip, |
| 154 | + trace::{program::iter_program_steps, PreprocessedTraces}, |
| 155 | + }; |
| 156 | + |
| 157 | + use super::*; |
| 158 | + use nexus_vm::{ |
| 159 | + riscv::{BasicBlock, BuiltinOpcode, Instruction, InstructionType, Opcode}, |
| 160 | + trace::k_trace_direct, |
| 161 | + }; |
| 162 | + |
| 163 | + const LOG_SIZE: u32 = PreprocessedTraces::MIN_LOG_SIZE; |
| 164 | + |
| 165 | + #[rustfmt::skip] |
| 166 | + fn setup_basic_block_ir() -> Vec<BasicBlock> { |
| 167 | + let basic_block = BasicBlock::new(vec![ |
| 168 | + // Set x10 = 1 |
| 169 | + Instruction::new(Opcode::from(BuiltinOpcode::ADDI), 10, 0, 1, InstructionType::IType), |
| 170 | + // Set x1 = 10 |
| 171 | + Instruction::new(Opcode::from(BuiltinOpcode::ADDI), 1, 0, 10, InstructionType::IType), |
| 172 | + // Set x2 = 20 |
| 173 | + Instruction::new(Opcode::from(BuiltinOpcode::ADDI), 2, 0, 20, InstructionType::IType), |
| 174 | + // Set x3 = 10 (same as x1) |
| 175 | + Instruction::new(Opcode::from(BuiltinOpcode::ADDI), 3, 0, 10, InstructionType::IType), |
| 176 | + // Set x4 = -10 |
| 177 | + Instruction::new(Opcode::from(BuiltinOpcode::SUB), 4, 0, 1, InstructionType::RType), |
| 178 | + // Set x5 = 0xFFFFFFFF (max unsigned value) |
| 179 | + Instruction::new(Opcode::from(BuiltinOpcode::SUB), 5, 0, 10, InstructionType::RType), |
| 180 | + |
| 181 | + // Case 1: BGEU with equal values (should branch) |
| 182 | + // BGEU x1, x3, 0xff (should branch as x1 >= x3 is true) |
| 183 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 1, 3, 12, InstructionType::BType), |
| 184 | + |
| 185 | + // Unimpl instructions to fill the gap (trigger error when executed) |
| 186 | + Instruction::unimpl(), |
| 187 | + Instruction::unimpl(), |
| 188 | + |
| 189 | + // Case 2: BGEU with different values (should not branch) |
| 190 | + // BGEU x1, x2, 12 (should not branch as x1 >= x2 is false) |
| 191 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 1, 2, 0xff, InstructionType::BType), |
| 192 | + |
| 193 | + |
| 194 | + // Case 3: BGEU with zero and non-zero (should not branch) |
| 195 | + // BGEU x0, x1, 8 (should not branch as x0 >= x1 is false) |
| 196 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 0, 1, 0xff, InstructionType::BType), |
| 197 | + |
| 198 | + // Case 4: BGEU with zero and zero (should branch) |
| 199 | + // BGEU x0, x0, 0xff (should branch as x0 >= x0 is true) |
| 200 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 0, 0, 12, InstructionType::BType), |
| 201 | + |
| 202 | + // Unimpl instructions to fill the gap (trigger error when executed) |
| 203 | + Instruction::unimpl(), |
| 204 | + Instruction::unimpl(), |
| 205 | + |
| 206 | + // Case 5: BGEU with negative and positive values (should branch) |
| 207 | + // BGEU x4, x1, 0xff (should branch as 0xfffffff6 >= 10 unsigned) |
| 208 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 4, 1, 12, InstructionType::BType), |
| 209 | + |
| 210 | + // Unimpl instructions to fill the gap (trigger error when executed) |
| 211 | + Instruction::unimpl(), |
| 212 | + Instruction::unimpl(), |
| 213 | + |
| 214 | + // Case 6: BGEU with max unsigned value and zero (should branch) |
| 215 | + // BGEU x5, x0, 0xff (should branch as 0xFFFFFFFF >= 0) |
| 216 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 5, 0, 12, InstructionType::BType), |
| 217 | + |
| 218 | + // Unimpl instructions to fill the gap (trigger error when executed) |
| 219 | + Instruction::unimpl(), |
| 220 | + Instruction::unimpl(), |
| 221 | + |
| 222 | + // Case 7: BGEU with zero and max unsigned value (should not branch) |
| 223 | + // BGEU x0, x5, 12 (should not branch as 0 >= 0xFFFFFFFF is false) |
| 224 | + Instruction::new(Opcode::from(BuiltinOpcode::BGEU), 0, 5, 0xff, InstructionType::BType), |
| 225 | + ]); |
| 226 | + vec![basic_block] |
| 227 | + } |
| 228 | + |
| 229 | + #[test] |
| 230 | + fn test_k_trace_constrained_bgeu_instructions() { |
| 231 | + type Chips = (CpuChip, AddChip, SubChip, BgeuChip, RegisterMemCheckChip); |
| 232 | + let basic_block = setup_basic_block_ir(); |
| 233 | + let k = 1; |
| 234 | + |
| 235 | + // Get traces from VM K-Trace interface |
| 236 | + let vm_traces = k_trace_direct(&basic_block, k).expect("Failed to create trace"); |
| 237 | + |
| 238 | + // Trace circuit |
| 239 | + let mut traces = Traces::new(LOG_SIZE); |
| 240 | + let mut side_note = SideNote::default(); |
| 241 | + let program_steps = iter_program_steps(&vm_traces, traces.num_rows()); |
| 242 | + |
| 243 | + // We iterate each block in the trace for each instruction |
| 244 | + for (row_idx, program_step) in program_steps.enumerate() { |
| 245 | + Chips::fill_main_trace(&mut traces, row_idx, &program_step, &mut side_note); |
| 246 | + } |
| 247 | + |
| 248 | + let mut preprocessed_column = PreprocessedTraces::empty(LOG_SIZE); |
| 249 | + preprocessed_column.fill_is_first(); |
| 250 | + preprocessed_column.fill_is_first32(); |
| 251 | + preprocessed_column.fill_row_idx(); |
| 252 | + preprocessed_column.fill_timestamps(); |
| 253 | + assert_chip::<Chips>(traces, Some(preprocessed_column)); |
| 254 | + } |
| 255 | +} |
0 commit comments