Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poseidon2 half output #2514

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Split fixed and executor working.
lvella committed Mar 4, 2025
commit 69fe10256b77f701b2cb837b7a385249737dc8e7
10 changes: 5 additions & 5 deletions plonky3/src/params/poseidon2/goldilocks/powdr_accel_impl.rs
Original file line number Diff line number Diff line change
@@ -3,8 +3,8 @@ use p3_field::AbstractField;
use p3_goldilocks::Goldilocks;
use p3_symmetric::CryptographicPermutation;
use powdr_riscv_runtime::{
goldilocks::{extract_opaque_vec8, Goldilocks as PowdrGoldilocks, OpaqueGoldilocks},
hash::{poseidon2_gl, poseidon2_gl_inplace},
goldilocks::{extract_opaque_vec, Goldilocks as PowdrGoldilocks, OpaqueGoldilocks},
hash::poseidon2_gl_inplace,
};

#[derive(Clone, Copy, Debug)]
@@ -21,10 +21,10 @@ impl p3_symmetric::Permutation<[Goldilocks; 8]> for Permutation {
// canonical representation internally, so it is safe to cast between their
// array's pointers.
let input = unsafe { &*(&input as *const _ as *const [PowdrGoldilocks; 8]) };
let input = input.map(|x| OpaqueGoldilocks::from(x));
let output = poseidon2_gl(&input);
let mut state = input.map(|x| OpaqueGoldilocks::from(x));
poseidon2_gl_inplace(&mut state);

extract_opaque_vec8(&output).map(|x| Goldilocks::from_canonical_u64(x))
extract_opaque_vec::<8>(&state).map(|x| Goldilocks::from_canonical_u64(x))
}

fn permute_mut(&self, data: &mut [Goldilocks; 8]) {
11 changes: 10 additions & 1 deletion riscv-executor/src/lib.rs
Original file line number Diff line number Diff line change
@@ -2438,7 +2438,16 @@ impl<F: FieldElement> Executor<'_, '_, F> {
.try_into()
.unwrap();

let output_half = self.proc.get_reg_mem(args[2].u()).u();

let result = poseidon2_gl::poseidon2_gl(&inputs);
let result = match output_half {
0 => &result[0..0],
1 => &result[0..4],
2 => &result[4..8],
3 => &result[0..8],
_ => unreachable!(),
};

let output_ptr = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(output_ptr));
@@ -2668,7 +2677,7 @@ impl<F: FieldElement> Executor<'_, '_, F> {
let output_ptr = self.proc.get_reg_mem(args[1].u()).u();
assert!(is_multiple_of_4(output_ptr));

let result = (0..8)
let result = (0..4)
.flat_map(|i| {
let v = self.proc.get_mem(input_ptr + i * 4, 0, 0).into_fe();
let v = v.to_integer().try_into_u64().unwrap();
10 changes: 8 additions & 2 deletions riscv-runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -8,8 +8,14 @@ homepage = "https://powdr.org"
repository = "https://github.com/powdr-labs/powdr"

[dependencies]
serde = { version = "1.0", default-features = false, features = ["alloc", "derive", "rc"] }
serde_cbor = { version = "0.11.2", default-features = false, features = ["alloc"] }
serde = { version = "1.0", default-features = false, features = [
"alloc",
"derive",
"rc",
] }
serde_cbor = { version = "0.11.2", default-features = false, features = [
"alloc",
] }
powdr-riscv-syscalls = { path = "../riscv-syscalls", version = "0.1.4" }
getrandom = { version = "0.2", features = ["custom"], optional = true }
spin = "0.9"
20 changes: 17 additions & 3 deletions riscv-runtime/src/goldilocks.rs
Original file line number Diff line number Diff line change
@@ -44,10 +44,24 @@ impl From<Goldilocks> for OpaqueGoldilocks {
}

/// Extract the Goldilocks values from the OpaqueGoldilocks values.
pub fn extract_opaque_vec8(vec: &[OpaqueGoldilocks; 8]) -> [u64; 8] {
///
/// The array size must be a multiple of 4.
pub fn extract_opaque_vec<const N: usize>(vec: &[OpaqueGoldilocks; N]) -> [u64; N] {
assert_eq!(N % 4, 0);
unsafe {
let mut output: MaybeUninit<[u64; 8]> = MaybeUninit::uninit();
ecall!(Syscall::SplitGLVec, in("a0") vec, in("a1") output.as_mut_ptr());
let mut output: MaybeUninit<[u64; N]> = MaybeUninit::uninit();

let input_ptr = vec.as_ptr();
let output_ptr = (*output.as_mut_ptr()).as_mut_ptr();

for i in 0..(N / 4) {
ecall!(
Syscall::SplitGLVec,
in("a0") input_ptr.add(i * 4),
in("a1") output_ptr.add(i * 4)
);
}

output.assume_init()
}
}
21 changes: 16 additions & 5 deletions riscv-runtime/src/hash.rs
Original file line number Diff line number Diff line change
@@ -27,15 +27,26 @@ pub fn poseidon_gl(data: &mut [Goldilocks; 12]) -> &[Goldilocks; 4] {
/// Perform one Poseidon2 permutation with 8 Goldilocks field elements in-place.
pub fn poseidon2_gl_inplace(data: &mut [OpaqueGoldilocks; 8]) {
unsafe {
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") data);
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") data, in("a2") 3);
}
}

/// Perform one Poseidon2 permutation with 8 Goldilocks field elements.
pub fn poseidon2_gl(data: &[OpaqueGoldilocks; 8]) -> [OpaqueGoldilocks; 8] {
#[repr(u32)]
pub enum Poseidon2OutputHalf {
//None = 0,
FirstHalf = 1,
SecondHalf = 2,
//Full = 3,
}
Comment on lines +35 to +40
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pub enum Poseidon2OutputHalf {
//None = 0,
FirstHalf = 1,
SecondHalf = 2,
//Full = 3,
}
pub enum Poseidon2OutputHalf {
FirstHalf = 1,
SecondHalf = 2,
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments act as a sort of documentation... do you really want them gone?


/// Perform one Poseidon2 compression with 8 Goldilocks field elements.
pub fn poseidon2_gl_compression(
data: &[OpaqueGoldilocks; 8],
output_half: Poseidon2OutputHalf,
) -> [OpaqueGoldilocks; 4] {
unsafe {
let mut output: MaybeUninit<[OpaqueGoldilocks; 8]> = MaybeUninit::uninit();
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") output.as_mut_ptr());
let mut output: MaybeUninit<[OpaqueGoldilocks; 4]> = MaybeUninit::uninit();
ecall!(Syscall::Poseidon2GL, in("a0") data, in("a1") output.as_mut_ptr(), in("a2") output_half as u32);
output.assume_init()
}
}
2 changes: 1 addition & 1 deletion riscv/src/large_field/runtime.rs
Original file line number Diff line number Diff line change
@@ -291,7 +291,7 @@ impl Runtime {
);

self.add_submachine(
"std::machines::split::split_gl_vec::SplitGLVec8",
"std::machines::split::split_gl_vec::SplitGLVec4",
None,
"split_gl_vec",
vec!["memory", "split_gl", "MIN_DEGREE", "MAIN_MAX_DEGREE"],
49 changes: 22 additions & 27 deletions riscv/tests/riscv_data/poseidon2_gl_via_coprocessor/src/main.rs
Original file line number Diff line number Diff line change
@@ -2,46 +2,44 @@
#![no_std]

use powdr_riscv_runtime::{
goldilocks::{extract_opaque_vec8, Goldilocks, OpaqueGoldilocks, PRIME},
hash::{poseidon2_gl, poseidon2_gl_inplace},
goldilocks::{extract_opaque_vec, Goldilocks, OpaqueGoldilocks, PRIME},
hash::{poseidon2_gl_compression, poseidon2_gl_inplace, Poseidon2OutputHalf},
};

#[no_mangle]
fn main() {
let i = [OpaqueGoldilocks::from(0); 8];
let h = extract_opaque_vec8(&poseidon2_gl(&i));
let h = extract_opaque_vec::<4>(&poseidon2_gl_compression(
&i,
Poseidon2OutputHalf::FirstHalf,
));
assert_eq!(h[0], 14905565590733827480);
assert_eq!(h[1], 640905753703258831);
assert_eq!(h[2], 4579128623722792381);
assert_eq!(h[3], 158153743058056413);
assert_eq!(h[4], 5905145432652609062);
assert_eq!(h[5], 9814446752588696081);
assert_eq!(h[6], 13759450385053274731);
assert_eq!(h[7], 2402148582355896469);

let i = [OpaqueGoldilocks::from(1); 8];
let h = extract_opaque_vec8(&poseidon2_gl(&i));
assert_eq!(h[0], 18201552556563266798);
assert_eq!(h[1], 6814935789744812745);
assert_eq!(h[2], 5947349602629011250);
assert_eq!(h[3], 15482468195247053191);
assert_eq!(h[4], 2971437633000883992);
assert_eq!(h[5], 9752341516515962403);
assert_eq!(h[6], 15477293561177957600);
assert_eq!(h[7], 13574628582471329853);
let h = extract_opaque_vec::<4>(&poseidon2_gl_compression(
&i,
Poseidon2OutputHalf::SecondHalf,
));
assert_eq!(h[0], 2971437633000883992);
assert_eq!(h[1], 9752341516515962403);
assert_eq!(h[2], 15477293561177957600);
assert_eq!(h[3], 13574628582471329853);

let minus_one = PRIME - 1;
let i = [OpaqueGoldilocks::from(Goldilocks::new(minus_one)); 8];
let h = extract_opaque_vec8(&poseidon2_gl(&i));
let h = extract_opaque_vec::<4>(&poseidon2_gl_compression(
&i,
Poseidon2OutputHalf::FirstHalf,
));
assert_eq!(h[0], 13601391594672984423);
assert_eq!(h[1], 7799837486760213030);
assert_eq!(h[2], 4721195013230721931);
assert_eq!(h[3], 6190752424007146655);
assert_eq!(h[4], 5006958669091947377);
assert_eq!(h[5], 716937639216173272);
assert_eq!(h[6], 10656923966581845557);
assert_eq!(h[7], 6633446230068695780);

// Also test the inplace version
let mut i = [
923978,
235763497586,
@@ -53,7 +51,9 @@ fn main() {
2087,
]
.map(|x| OpaqueGoldilocks::from(Goldilocks::new(x)));
let h = extract_opaque_vec8(&poseidon2_gl(&i));
poseidon2_gl_inplace(&mut i);

let h = extract_opaque_vec::<8>(&i);
assert_eq!(h[0], 14498150941209346562);
assert_eq!(h[1], 8038616707062714447);
assert_eq!(h[2], 17242548914990530484);
@@ -62,9 +62,4 @@ fn main() {
assert_eq!(h[5], 12505236434419724338);
assert_eq!(h[6], 3134668969942435695);
assert_eq!(h[7], 1912726109528180442);

// Also test the inplace version
poseidon2_gl_inplace(&mut i);
let h_inplace = extract_opaque_vec8(&i);
assert_eq!(h, h_inplace);
}
33 changes: 6 additions & 27 deletions std/machines/split/split_gl_vec.asm
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@ use std::array;
use std::machines::large_field::memory::Memory;
use super::split_gl::SplitGL;

machine SplitGLVec8(mem: Memory, split_gl: SplitGL) with
machine SplitGLVec4(mem: Memory, split_gl: SplitGL) with
latch: latch,
call_selectors: sel,
{
@@ -12,8 +12,8 @@ machine SplitGLVec8(mem: Memory, split_gl: SplitGL) with
// Is this a used row?
let is_used = array::sum(sel);

// Reads 8 memory words from input_addr as field elements at time_step
// and writes 16 memory words to output_addr with u32s representing
// Reads 4 memory words from input_addr as field elements at time_step
// and writes 8 memory words to output_addr with u32s representing
// the decomposed field elements, in little-endian (i.e., the
// lower word address is the least significant limb), in time_step + 1.
//
@@ -24,31 +24,23 @@ machine SplitGLVec8(mem: Memory, split_gl: SplitGL) with
let output_addr;
let time_step;

let input: col[8];
let input: col[4];

// TODO: when link is available inside functions, we can turn this into array operations.
link if is_used ~> input[0] = mem.mload(input_addr + 0, time_step);
link if is_used ~> input[1] = mem.mload(input_addr + 4, time_step);
link if is_used ~> input[2] = mem.mload(input_addr + 8, time_step);
link if is_used ~> input[3] = mem.mload(input_addr + 12, time_step);
link if is_used ~> input[4] = mem.mload(input_addr + 16, time_step);
link if is_used ~> input[5] = mem.mload(input_addr + 20, time_step);
link if is_used ~> input[6] = mem.mload(input_addr + 24, time_step);
link if is_used ~> input[7] = mem.mload(input_addr + 28, time_step);

// Split the output into high and low limbs
let output_low: col[8];
let output_high: col[8];
let output_low: col[4];
let output_high: col[4];

// TODO: turn this into array operations
link if is_used ~> (output_low[0], output_high[0]) = split_gl.split(input[0]);
link if is_used ~> (output_low[1], output_high[1]) = split_gl.split(input[1]);
link if is_used ~> (output_low[2], output_high[2]) = split_gl.split(input[2]);
link if is_used ~> (output_low[3], output_high[3]) = split_gl.split(input[3]);
link if is_used ~> (output_low[4], output_high[4]) = split_gl.split(input[4]);
link if is_used ~> (output_low[5], output_high[5]) = split_gl.split(input[5]);
link if is_used ~> (output_low[6], output_high[6]) = split_gl.split(input[6]);
link if is_used ~> (output_low[7], output_high[7]) = split_gl.split(input[7]);

// TODO: turn this into array operations
link if is_used ~> mem.mstore(output_addr + 0, time_step + 1, output_low[0]);
@@ -62,17 +54,4 @@ machine SplitGLVec8(mem: Memory, split_gl: SplitGL) with

link if is_used ~> mem.mstore(output_addr + 24, time_step + 1, output_low[3]);
link if is_used ~> mem.mstore(output_addr + 28, time_step + 1, output_high[3]);

link if is_used ~> mem.mstore(output_addr + 32, time_step + 1, output_low[4]);
link if is_used ~> mem.mstore(output_addr + 36, time_step + 1, output_high[4]);

link if is_used ~> mem.mstore(output_addr + 40, time_step + 1, output_low[5]);
link if is_used ~> mem.mstore(output_addr + 44, time_step + 1, output_high[5]);

link if is_used ~> mem.mstore(output_addr + 48, time_step + 1, output_low[6]);
link if is_used ~> mem.mstore(output_addr + 52, time_step + 1, output_high[6]);

link if is_used ~> mem.mstore(output_addr + 56, time_step + 1, output_low[7]);
link if is_used ~> mem.mstore(output_addr + 60, time_step + 1, output_high[7]);

}
6 changes: 3 additions & 3 deletions test_data/std/poseidon2_gl_test.asm
Original file line number Diff line number Diff line change
@@ -16,9 +16,9 @@ machine Main with degree: main_degree {
reg ADDR1[<=];
reg ADDR2[<=];

// Increase the time step by 4 in each row, so that the poseidon machine
// Increase the time step by 2 in each row, so that the poseidon machine
// can read in the given time step and write in the next time step.
col fixed STEP(i) { 4 * i };
col fixed STEP(i) { 2 * i };
Byte2 byte2;
Memory memory(byte2, memory_degree, memory_degree);

@@ -27,7 +27,7 @@ machine Main with degree: main_degree {

Poseidon2GL poseidon2(memory, poseidon2_degree, poseidon2_degree);
instr poseidon2 ADDR1, ADDR2, X1 ->
link ~> poseidon2.permute(ADDR1, STEP, ADDR2, STEP + 2, X1);
link ~> poseidon2.permute(ADDR1, STEP, ADDR2, STEP + 1, X1);

col witness val;
instr assert_eq ADDR1, X1 ->
38 changes: 5 additions & 33 deletions test_data/std/split_gl_vec_test.asm
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::machines::split::ByteCompare;
use std::machines::split::split_gl::SplitGL;
use std::machines::split::split_gl_vec::SplitGLVec8;
use std::machines::split::split_gl_vec::SplitGLVec4;
use std::machines::large_field::memory::Memory;
use std::machines::range::Byte2;

@@ -20,7 +20,7 @@ machine Main with degree: main_degree {
ByteCompare byte_compare;
SplitGL split_machine(byte_compare, split_degree, split_degree);

SplitGLVec8 split_vec_machine(memory, split_machine, split_vec_degree, split_vec_degree);
SplitGLVec4 split_vec_machine(memory, split_machine, split_vec_degree, split_vec_degree);

col fixed STEP(i) { 2 * i };

@@ -39,20 +39,16 @@ machine Main with degree: main_degree {

function main {
// Store 8 field elements sequentially in memory
mstore 100, 0;
mstore 100, 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason for this change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there was, I forgot. It feels like it had something to do with the removal of the 0x0000000100000002 case below.

mstore 104, 0xffffffff00000000;
mstore 108, 0xfffffffeffffffff;
mstore 112, 0xabcdef0123456789;
mstore 116, 0x0000000100000002;
mstore 120, 0x0000000300000004;
mstore 124, 0x0000000500000006;
mstore 128, 0x0000000700000008;

// Split the previously stored field elements
split 100, 200;

// Assert the field elements are what was written
assert_eq 200, 0;
assert_eq 200, 1;
assert_eq 204, 0;

assert_eq 208, 0;
@@ -64,23 +60,11 @@ machine Main with degree: main_degree {
assert_eq 224, 0x23456789;
assert_eq 228, 0xabcdef01;

assert_eq 232, 0x00000002;
assert_eq 236, 0x00000001;

assert_eq 240, 0x00000004;
assert_eq 244, 0x00000003;

assert_eq 248, 0x00000006;
assert_eq 252, 0x00000005;

assert_eq 256, 0x00000008;
assert_eq 260, 0x00000007;

// Same split, but now overlaping the input and output
split 100, 104;

// Assert the field elements are what was written
assert_eq 104, 0;
assert_eq 104, 1;
assert_eq 108, 0;

assert_eq 112, 0;
@@ -92,18 +76,6 @@ machine Main with degree: main_degree {
assert_eq 128, 0x23456789;
assert_eq 132, 0xabcdef01;

assert_eq 136, 0x00000002;
assert_eq 140, 0x00000001;

assert_eq 144, 0x00000004;
assert_eq 148, 0x00000003;

assert_eq 152, 0x00000006;
assert_eq 156, 0x00000005;

assert_eq 160, 0x00000008;
assert_eq 164, 0x00000007;

return;
}
}