Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 113 additions & 11 deletions coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use fhevm_engine_common::tfhe_ops::current_ciphertext_version;
use serial_test::serial;
use test_harness::db_utils::ACL_CONTRACT_ADDR;

Expand Down Expand Up @@ -58,6 +59,15 @@ async fn test_verify_empty_input_list() {
assert!(utils::is_valid(&pool, request_id, max_retries)
.await
.unwrap());

let handles = utils::wait_for_handles(&pool, request_id, max_retries)
.await
.unwrap();
assert!(handles.is_empty());
assert!(utils::fetch_stored_ciphertexts(&pool, &handles)
.await
.unwrap()
.is_empty());
}

#[tokio::test]
Expand Down Expand Up @@ -89,18 +99,110 @@ async fn test_max_input_index() {

// Test with highest number of inputs - 255
let inputs = vec![utils::ZkInput::U64(2); MAX_INPUT_INDEX as usize + 1];
assert!(utils::is_valid(
let request_id = utils::insert_proof(
&pool,
utils::insert_proof(
&pool,
102,
&utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await,
&aux.0
)
.await
.expect("valid db insert"),
5000
102,
&utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await,
&aux.0,
)
.await
.expect("non-expired db query"));
.expect("valid db insert");
assert!(utils::is_valid(&pool, request_id, 5000)
.await
.expect("non-expired db query"));

let handles = utils::wait_for_handles(&pool, request_id, 5000)
.await
.expect("wait for handles");
assert_eq!(handles.len(), MAX_INPUT_INDEX as usize + 1);
assert_eq!(handles.first().expect("first handle")[21], 0);
assert_eq!(handles.last().expect("last handle")[21], MAX_INPUT_INDEX);
assert_eq!(
&handles.last().expect("last handle")[22..30],
&aux.0.chain_id.as_u64().to_be_bytes()
);
assert_eq!(
handles.last().expect("last handle")[31],
current_ciphertext_version() as u8
);
}

#[tokio::test]
#[serial(db)]
async fn test_verify_proof_rerandomises_ciphertexts_before_storage() {
let (pool_mngr, _instance) = utils::setup().await.expect("valid setup");
let pool = pool_mngr.pool();

let aux: (crate::auxiliary::ZkData, [u8; 92]) =
utils::aux_fixture(ACL_CONTRACT_ADDR.to_owned());
let inputs = vec![
utils::ZkInput::Bool(true),
utils::ZkInput::U8(42),
utils::ZkInput::U16(12345),
utils::ZkInput::U32(67890),
utils::ZkInput::U64(1234567890),
];
let zk_pok = utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await;
let request_id = utils::insert_proof(&pool, 103, &zk_pok, &aux.0)
.await
.unwrap();

assert!(utils::is_valid(&pool, request_id, 1000).await.unwrap());

let handles = utils::wait_for_handles(&pool, request_id, 1000)
.await
.unwrap();
assert_eq!(handles.len(), inputs.len());
for (idx, handle) in handles.iter().enumerate() {
assert_eq!(handle.len(), 32);
assert_eq!(handle[21], idx as u8);
assert_eq!(&handle[22..30], &aux.0.chain_id.as_u64().to_be_bytes());
assert_eq!(handle[31], current_ciphertext_version() as u8);
}

let stored = utils::fetch_stored_ciphertexts(&pool, &handles)
.await
.unwrap();
assert_eq!(stored.len(), inputs.len());
assert_eq!(
stored
.iter()
.map(|ct| ct.input_blob_index)
.collect::<Vec<_>>(),
(0..inputs.len() as i32).collect::<Vec<_>>()
);
assert_eq!(
stored
.iter()
.map(|ct| ct.handle.as_slice())
.collect::<Vec<_>>(),
handles
.iter()
.map(|handle| handle.as_slice())
.collect::<Vec<_>>()
);

let baseline = utils::compress_inputs_without_rerandomization(&pool, &zk_pok)
.await
.unwrap();
assert_eq!(baseline.len(), stored.len());
assert!(
stored
.iter()
.zip(&baseline)
.all(|(stored_ct, baseline_ct)| stored_ct.ciphertext != *baseline_ct),
"stored ciphertexts should differ from the pre-rerandomization compression"
);

let decrypted = utils::decrypt_ciphertexts(&pool, &handles).await.unwrap();
assert_eq!(
decrypted
.iter()
.map(|result| result.value.clone())
.collect::<Vec<_>>(),
inputs
.iter()
.map(|input| input.cleartext())
.collect::<Vec<_>>()
);
}
159 changes: 158 additions & 1 deletion coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ use fhevm_engine_common::chain_id::ChainId;
use fhevm_engine_common::crs::CrsCache;
use fhevm_engine_common::db_keys::DbKeyCache;
use fhevm_engine_common::pg_pool::PostgresPoolManager;
use fhevm_engine_common::utils::safe_serialize;
use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, extract_ct_list};
use fhevm_engine_common::types::SupportedFheCiphertexts;
use fhevm_engine_common::utils::{safe_deserialize_conformant, safe_serialize};
use sqlx::Row;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use test_harness::instance::{DBInstance, ImportMode};
use tfhe::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
use tokio::sync::RwLock;
use tokio::time::sleep;

Expand Down Expand Up @@ -84,6 +88,147 @@ pub(crate) async fn is_valid(
Ok(false)
}

#[derive(Debug)]
pub(crate) struct StoredCiphertext {
pub(crate) handle: Vec<u8>,
pub(crate) ciphertext: Vec<u8>,
pub(crate) ciphertext_type: i16,
pub(crate) input_blob_index: i32,
}

#[derive(Debug, PartialEq, Eq)]
pub(crate) struct DecryptionResult {
pub(crate) output_type: i16,
pub(crate) value: String,
}

pub(crate) async fn wait_for_handles(
pool: &sqlx::PgPool,
zk_proof_id: i64,
max_retries: usize,
) -> Result<Vec<Vec<u8>>, sqlx::Error> {
for _ in 0..max_retries {
sleep(Duration::from_millis(100)).await;
let row = sqlx::query("SELECT verified, handles FROM verify_proofs WHERE zk_proof_id = $1")
.bind(zk_proof_id)
.fetch_one(pool)
.await?;

let verified: Option<bool> = row.try_get("verified")?;
if !matches!(verified, Some(true)) {
continue;
}

let handles: Option<Vec<u8>> = row.try_get("handles")?;
let handles = handles.unwrap_or_default();
assert_eq!(handles.len() % 32, 0);

return Ok(handles.chunks(32).map(|chunk| chunk.to_vec()).collect());
}

Ok(vec![])
}

pub(crate) async fn fetch_stored_ciphertexts(
pool: &sqlx::PgPool,
handles: &[Vec<u8>],
) -> Result<Vec<StoredCiphertext>, sqlx::Error> {
if handles.is_empty() {
return Ok(vec![]);
}

let rows = sqlx::query(
"
SELECT handle, ciphertext, ciphertext_type, input_blob_index
FROM ciphertexts
WHERE handle = ANY($1::BYTEA[])
AND ciphertext_version = $2
ORDER BY input_blob_index ASC
",
)
.bind(handles)
.bind(current_ciphertext_version())
.fetch_all(pool)
.await?;

rows.into_iter()
.map(|row| {
Ok(StoredCiphertext {
handle: row.try_get("handle")?,
ciphertext: row.try_get("ciphertext")?,
ciphertext_type: row.try_get("ciphertext_type")?,
input_blob_index: row.try_get("input_blob_index")?,
})
})
.collect()
}

pub(crate) async fn decrypt_ciphertexts(
pool: &sqlx::PgPool,
handles: &[Vec<u8>],
) -> anyhow::Result<Vec<DecryptionResult>> {
let stored = fetch_stored_ciphertexts(pool, handles).await?;
let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache");
let key = db_key_cache.fetch_latest(pool).await?;

tokio::task::spawn_blocking(move || {
let client_key = key.cks.expect("client key available in tests");
tfhe::set_server_key(key.sks);

stored
.into_iter()
.map(|ct| {
let deserialized = SupportedFheCiphertexts::decompress_no_memcheck(
ct.ciphertext_type,
&ct.ciphertext,
)
.expect("valid compressed ciphertext");
DecryptionResult {
output_type: ct.ciphertext_type,
value: deserialized.decrypt(&client_key),
}
})
.collect::<Vec<_>>()
})
.await
.map_err(anyhow::Error::from)
}

pub(crate) async fn compress_inputs_without_rerandomization(
pool: &sqlx::PgPool,
raw_ct: &[u8],
) -> anyhow::Result<Vec<Vec<u8>>> {
let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache");
let latest_key = db_key_cache.fetch_latest(pool).await?;
let latest_crs = CrsCache::load(pool)
.await?
.get_latest()
.cloned()
.expect("latest CRS");

let verified_list: tfhe::ProvenCompactCiphertextList = safe_deserialize_conformant(
raw_ct,
&IntegerProvenCompactCiphertextListConformanceParams::from_public_key_encryption_parameters_and_crs_parameters(
latest_key.pks.parameters(),
&latest_crs.crs,
),
)?;

if verified_list.is_empty() {
return Ok(vec![]);
}

tokio::task::spawn_blocking(move || {
tfhe::set_server_key(latest_key.sks);
let expanded = verified_list.expand_without_verification()?;
let cts = extract_ct_list(&expanded)?;
cts.into_iter()
.map(|ct| ct.compress().map_err(anyhow::Error::from))
.collect()
})
.await?
}

#[derive(Debug, Clone)]
pub(crate) enum ZkInput {
Bool(bool),
Expand All @@ -93,6 +238,18 @@ pub(crate) enum ZkInput {
U64(u64),
}

impl ZkInput {
pub(crate) fn cleartext(&self) -> String {
match self {
Self::Bool(value) => value.to_string(),
Self::U8(value) => value.to_string(),
Self::U16(value) => value.to_string(),
Self::U32(value) => value.to_string(),
Self::U64(value) => value.to_string(),
}
}
}

pub(crate) async fn generate_zk_pok_with_inputs(
pool: &sqlx::PgPool,
aux_data: &[u8],
Expand Down
Loading
Loading