Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 33 additions & 29 deletions prover/src/extensions/ram_init_final.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,20 @@ impl FrameworkEvalExt for RamInitFinalEval {
}

impl RamInitFinalEval {
/// Converts WORD_SIZE bytes into a pair of 16-bit values (low, high)
/// where low = bytes[0] + bytes[1] * 256
/// and high = bytes[WORD_SIZE_HALVED] + bytes[WORD_SIZE_HALVED + 1] * 256
fn bytes_to_word_halves<E: EvalAtRow>(bytes: &[E::F]) -> (E::F, E::F) {
assert_eq!(bytes.len(), WORD_SIZE, "Expected {} bytes", WORD_SIZE);

let low = bytes[0].clone()
+ bytes[1].clone() * E::F::from((1 << 8).into());
let high = bytes[WORD_SIZE_HALVED].clone()
+ bytes[WORD_SIZE_HALVED + 1].clone() * E::F::from((1 << 8).into());

(low, high)
}

fn constrain_add_initial_values<E: EvalAtRow>(
&self,
eval: &mut E,
Expand All @@ -161,12 +175,9 @@ impl RamInitFinalEval {
preprocessed_init_value: E::F,
ram_init_final_flag: E::F,
) {
// Build tuple: [addr_low, addr_high, init_value, counter_zeros...]
let mut tuple = vec![];
// Build the tuple from the RAM address bytes.
let addr_low = ram_init_final_addr[0].clone()
+ ram_init_final_addr[1].clone() * E::F::from((1 << 8).into());
let addr_high = ram_init_final_addr[2].clone()
+ ram_init_final_addr[3].clone() * E::F::from((1 << 8).into());
let (addr_low, addr_high) = Self::bytes_to_word_halves(ram_init_final_addr);
tuple.push(addr_low);
tuple.push(addr_high);
// Add the product of preprocessed init flag and value.
Expand All @@ -192,19 +203,13 @@ impl RamInitFinalEval {
ram_init_final_flag: E::F,
) {
let mut tuple = vec![];
let addr_low = ram_init_final_addr[0].clone()
+ ram_init_final_addr[1].clone() * E::F::from((1 << 8).into());
let addr_high = ram_init_final_addr[2].clone()
+ ram_init_final_addr[3].clone() * E::F::from((1 << 8).into());
let (addr_low, addr_high) = Self::bytes_to_word_halves(ram_init_final_addr);
tuple.push(addr_low);
tuple.push(addr_high);

tuple.push(ram_final_value);

let counter_low = ram_final_counter[0].clone()
+ ram_final_counter[1].clone() * E::F::from((1 << 8).into());
let counter_high = ram_final_counter[2].clone()
+ ram_final_counter[3].clone() * E::F::from((1 << 8).into());
let (counter_low, counter_high) = Self::bytes_to_word_halves(ram_final_counter);
tuple.push(counter_low);
tuple.push(counter_high);

Expand Down Expand Up @@ -344,6 +349,18 @@ impl BuiltInExtension for RamInitFinal {
}

impl RamInitFinal {
/// Packed version for SimdBackend: converts WORD_SIZE bytes from columns at given row
/// into a pair of 16-bit values (low, high)
fn bytes_to_word_halves_packed(byte_cols: &[BaseColumn], vec_row: usize) -> (PackedBaseField, PackedBaseField) {
assert_eq!(byte_cols.len(), WORD_SIZE, "Expected {} byte columns", WORD_SIZE);

let shift = PackedBaseField::broadcast((1 << 8).into());
let low = byte_cols[0].data[vec_row] + byte_cols[1].data[vec_row] * shift;
let high = byte_cols[WORD_SIZE_HALVED].data[vec_row]
+ byte_cols[WORD_SIZE_HALVED + 1].data[vec_row] * shift;

(low, high)
}
fn preprocessed_columns(log_size: u32, program_trace_ref: ProgramTraceRef) -> Vec<BaseColumn> {
let total_len = program_trace_ref.init_memory.len()
+ program_trace_ref.exit_code.len()
Expand Down Expand Up @@ -542,12 +559,7 @@ impl RamInitFinal {
// Add (address, value, 0)
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut tuple = vec![];
let addr_low = ram_init_final_addr[0].data[vec_row]
+ ram_init_final_addr[1].data[vec_row]
* PackedBaseField::broadcast((1 << 8).into());
let addr_high = ram_init_final_addr[2].data[vec_row]
+ ram_init_final_addr[3].data[vec_row]
* PackedBaseField::broadcast((1 << 8).into());
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
tuple.push(addr_low);
tuple.push(addr_high);

Expand Down Expand Up @@ -586,21 +598,13 @@ impl RamInitFinal {
for vec_row in 0..(1 << (log_size - LOG_N_LANES)) {
let mut tuple = vec![];

let addr_low = ram_init_final_addr[0].data[vec_row]
+ ram_init_final_addr[1].data[vec_row]
* PackedBaseField::broadcast((1 << 8).into());
let addr_high = ram_init_final_addr[2].data[vec_row]
+ ram_init_final_addr[3].data[vec_row]
* PackedBaseField::broadcast((1 << 8).into());
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
tuple.push(addr_low);
tuple.push(addr_high);

tuple.push(ram_final_value.data[vec_row]);

let counter_low = ram_final_counter[0].data[vec_row]
+ ram_final_counter[1].data[vec_row] * PackedBaseField::broadcast((1 << 8).into());
let counter_high = ram_final_counter[2].data[vec_row]
+ ram_final_counter[3].data[vec_row] * PackedBaseField::broadcast((1 << 8).into());
let (addr_low, addr_high) = Self::bytes_to_word_halves_packed(ram_init_final_addr, vec_row);
tuple.push(counter_low);
tuple.push(counter_high);

Expand Down