diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs b/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs index 5fdbd6d50f..43161b5944 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs @@ -1,3 +1,4 @@ +use fhevm_engine_common::tfhe_ops::current_ciphertext_version; use serial_test::serial; use test_harness::db_utils::ACL_CONTRACT_ADDR; @@ -58,6 +59,15 @@ async fn test_verify_empty_input_list() { assert!(utils::is_valid(&pool, request_id, max_retries) .await .unwrap()); + + let handles = utils::wait_for_handles(&pool, request_id, max_retries) + .await + .unwrap(); + assert!(handles.is_empty()); + assert!(utils::fetch_stored_ciphertexts(&pool, &handles) + .await + .unwrap() + .is_empty()); } #[tokio::test] @@ -89,18 +99,110 @@ async fn test_max_input_index() { // Test with highest number of inputs - 255 let inputs = vec![utils::ZkInput::U64(2); MAX_INPUT_INDEX as usize + 1]; - assert!(utils::is_valid( + let request_id = utils::insert_proof( &pool, - utils::insert_proof( - &pool, - 102, - &utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await, - &aux.0 - ) - .await - .expect("valid db insert"), - 5000 + 102, + &utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await, + &aux.0, ) .await - .expect("non-expired db query")); + .expect("valid db insert"); + assert!(utils::is_valid(&pool, request_id, 5000) + .await + .expect("non-expired db query")); + + let handles = utils::wait_for_handles(&pool, request_id, 5000) + .await + .expect("wait for handles"); + assert_eq!(handles.len(), MAX_INPUT_INDEX as usize + 1); + assert_eq!(handles.first().expect("first handle")[21], 0); + assert_eq!(handles.last().expect("last handle")[21], MAX_INPUT_INDEX); + assert_eq!( + &handles.last().expect("last handle")[22..30], + &aux.0.chain_id.as_u64().to_be_bytes() + ); + assert_eq!( + handles.last().expect("last handle")[31], + current_ciphertext_version() as u8 + ); +} + +#[tokio::test] +#[serial(db)] +async fn test_verify_proof_rerandomises_ciphertexts_before_storage() { + let (pool_mngr, _instance) = utils::setup().await.expect("valid setup"); + let pool = pool_mngr.pool(); + + let aux: (crate::auxiliary::ZkData, [u8; 92]) = + utils::aux_fixture(ACL_CONTRACT_ADDR.to_owned()); + let inputs = vec![ + utils::ZkInput::Bool(true), + utils::ZkInput::U8(42), + utils::ZkInput::U16(12345), + utils::ZkInput::U32(67890), + utils::ZkInput::U64(1234567890), + ]; + let zk_pok = utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await; + let request_id = utils::insert_proof(&pool, 103, &zk_pok, &aux.0) + .await + .unwrap(); + + assert!(utils::is_valid(&pool, request_id, 1000).await.unwrap()); + + let handles = utils::wait_for_handles(&pool, request_id, 1000) + .await + .unwrap(); + assert_eq!(handles.len(), inputs.len()); + for (idx, handle) in handles.iter().enumerate() { + assert_eq!(handle.len(), 32); + assert_eq!(handle[21], idx as u8); + assert_eq!(&handle[22..30], &aux.0.chain_id.as_u64().to_be_bytes()); + assert_eq!(handle[31], current_ciphertext_version() as u8); + } + + let stored = utils::fetch_stored_ciphertexts(&pool, &handles) + .await + .unwrap(); + assert_eq!(stored.len(), inputs.len()); + assert_eq!( + stored + .iter() + .map(|ct| ct.input_blob_index) + .collect::>(), + (0..inputs.len() as i32).collect::>() + ); + assert_eq!( + stored + .iter() + .map(|ct| ct.handle.as_slice()) + .collect::>(), + handles + .iter() + .map(|handle| handle.as_slice()) + .collect::>() + ); + + let baseline = utils::compress_inputs_without_rerandomization(&pool, &zk_pok) + .await + .unwrap(); + assert_eq!(baseline.len(), stored.len()); + assert!( + stored + .iter() + .zip(&baseline) + .all(|(stored_ct, baseline_ct)| stored_ct.ciphertext != *baseline_ct), + "stored ciphertexts should differ from the pre-rerandomization compression" + ); + + let decrypted = utils::decrypt_ciphertexts(&pool, &handles).await.unwrap(); + assert_eq!( + decrypted + .iter() + .map(|result| result.value.clone()) + .collect::>(), + inputs + .iter() + .map(|input| input.cleartext()) + .collect::>() + ); } diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs b/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs index 14522f9949..6b38cafaab 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs @@ -2,10 +2,14 @@ use fhevm_engine_common::chain_id::ChainId; use fhevm_engine_common::crs::CrsCache; use fhevm_engine_common::db_keys::DbKeyCache; use fhevm_engine_common::pg_pool::PostgresPoolManager; -use fhevm_engine_common::utils::safe_serialize; +use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, extract_ct_list}; +use fhevm_engine_common::types::SupportedFheCiphertexts; +use fhevm_engine_common::utils::{safe_deserialize_conformant, safe_serialize}; +use sqlx::Row; use std::sync::Arc; use std::time::{Duration, SystemTime}; use test_harness::instance::{DBInstance, ImportMode}; +use tfhe::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams; use tokio::sync::RwLock; use tokio::time::sleep; @@ -84,6 +88,147 @@ pub(crate) async fn is_valid( Ok(false) } +#[derive(Debug)] +pub(crate) struct StoredCiphertext { + pub(crate) handle: Vec, + pub(crate) ciphertext: Vec, + pub(crate) ciphertext_type: i16, + pub(crate) input_blob_index: i32, +} + +#[derive(Debug, PartialEq, Eq)] +pub(crate) struct DecryptionResult { + pub(crate) output_type: i16, + pub(crate) value: String, +} + +pub(crate) async fn wait_for_handles( + pool: &sqlx::PgPool, + zk_proof_id: i64, + max_retries: usize, +) -> Result>, sqlx::Error> { + for _ in 0..max_retries { + sleep(Duration::from_millis(100)).await; + let row = sqlx::query("SELECT verified, handles FROM verify_proofs WHERE zk_proof_id = $1") + .bind(zk_proof_id) + .fetch_one(pool) + .await?; + + let verified: Option = row.try_get("verified")?; + if !matches!(verified, Some(true)) { + continue; + } + + let handles: Option> = row.try_get("handles")?; + let handles = handles.unwrap_or_default(); + assert_eq!(handles.len() % 32, 0); + + return Ok(handles.chunks(32).map(|chunk| chunk.to_vec()).collect()); + } + + Ok(vec![]) +} + +pub(crate) async fn fetch_stored_ciphertexts( + pool: &sqlx::PgPool, + handles: &[Vec], +) -> Result, sqlx::Error> { + if handles.is_empty() { + return Ok(vec![]); + } + + let rows = sqlx::query( + " + SELECT handle, ciphertext, ciphertext_type, input_blob_index + FROM ciphertexts + WHERE handle = ANY($1::BYTEA[]) + AND ciphertext_version = $2 + ORDER BY input_blob_index ASC + ", + ) + .bind(handles) + .bind(current_ciphertext_version()) + .fetch_all(pool) + .await?; + + rows.into_iter() + .map(|row| { + Ok(StoredCiphertext { + handle: row.try_get("handle")?, + ciphertext: row.try_get("ciphertext")?, + ciphertext_type: row.try_get("ciphertext_type")?, + input_blob_index: row.try_get("input_blob_index")?, + }) + }) + .collect() +} + +pub(crate) async fn decrypt_ciphertexts( + pool: &sqlx::PgPool, + handles: &[Vec], +) -> anyhow::Result> { + let stored = fetch_stored_ciphertexts(pool, handles).await?; + let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache"); + let key = db_key_cache.fetch_latest(pool).await?; + + tokio::task::spawn_blocking(move || { + let client_key = key.cks.expect("client key available in tests"); + tfhe::set_server_key(key.sks); + + stored + .into_iter() + .map(|ct| { + let deserialized = SupportedFheCiphertexts::decompress_no_memcheck( + ct.ciphertext_type, + &ct.ciphertext, + ) + .expect("valid compressed ciphertext"); + DecryptionResult { + output_type: ct.ciphertext_type, + value: deserialized.decrypt(&client_key), + } + }) + .collect::>() + }) + .await + .map_err(anyhow::Error::from) +} + +pub(crate) async fn compress_inputs_without_rerandomization( + pool: &sqlx::PgPool, + raw_ct: &[u8], +) -> anyhow::Result>> { + let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache"); + let latest_key = db_key_cache.fetch_latest(pool).await?; + let latest_crs = CrsCache::load(pool) + .await? + .get_latest() + .cloned() + .expect("latest CRS"); + + let verified_list: tfhe::ProvenCompactCiphertextList = safe_deserialize_conformant( + raw_ct, + &IntegerProvenCompactCiphertextListConformanceParams::from_public_key_encryption_parameters_and_crs_parameters( + latest_key.pks.parameters(), + &latest_crs.crs, + ), + )?; + + if verified_list.is_empty() { + return Ok(vec![]); + } + + tokio::task::spawn_blocking(move || { + tfhe::set_server_key(latest_key.sks); + let expanded = verified_list.expand_without_verification()?; + let cts = extract_ct_list(&expanded)?; + cts.into_iter() + .map(|ct| ct.compress().map_err(anyhow::Error::from)) + .collect() + }) + .await? +} + #[derive(Debug, Clone)] pub(crate) enum ZkInput { Bool(bool), @@ -93,6 +238,18 @@ pub(crate) enum ZkInput { U64(u64), } +impl ZkInput { + pub(crate) fn cleartext(&self) -> String { + match self { + Self::Bool(value) => value.to_string(), + Self::U8(value) => value.to_string(), + Self::U16(value) => value.to_string(), + Self::U32(value) => value.to_string(), + Self::U64(value) => value.to_string(), + } + } +} + pub(crate) async fn generate_zk_pok_with_inputs( pool: &sqlx::PgPool, aux_data: &[u8], diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs b/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs index 9b3aa37f4a..077b1b1bea 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs @@ -7,7 +7,7 @@ use fhevm_engine_common::host_chains::HostChainsCache; use fhevm_engine_common::pg_pool::{PostgresPoolManager, ServiceError}; use fhevm_engine_common::telemetry; use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, extract_ct_list}; -use fhevm_engine_common::types::SupportedFheCiphertexts; +use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts}; use fhevm_engine_common::utils::safe_deserialize_conformant; use sha3::Digest; @@ -16,6 +16,7 @@ use sqlx::{postgres::PgListener, PgPool, Row}; use sqlx::{Postgres, Transaction}; use std::str::FromStr; use tfhe::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams; +use tfhe::ReRandomizationContext; use tokio::sync::RwLock; use tokio::task::JoinSet; @@ -38,6 +39,8 @@ const EVENT_CIPHERTEXT_COMPUTED: &str = "event_ciphertext_computed"; const RAW_CT_HASH_DOMAIN_SEPARATOR: [u8; 8] = *b"ZK-w_rct"; const HANDLE_HASH_DOMAIN_SEPARATOR: [u8; 8] = *b"ZK-w_hdl"; +const RERANDOMISATION_DOMAIN_SEPARATOR: [u8; 8] = *b"ZKw_Rrnd"; +const COMPACT_PUBLIC_ENCRYPTION_DOMAIN_SEPARATOR: [u8; 8] = *b"TFHE_Enc"; pub(crate) struct Ciphertext { handle: Vec, @@ -415,16 +418,29 @@ pub(crate) fn verify_proof( let mut cts = expand_verified_list(request_id, &verified_list) .inspect_err(telemetry::set_current_span_error)?; - // Step 3: Create ciphertext handles + // Step 3: Compute blob hash and set re-randomization metadata on all ciphertexts let mut h = Keccak256::new(); h.update(RAW_CT_HASH_DOMAIN_SEPARATOR); h.update(raw_ct); let blob_hash = h.finalize().to_vec(); + let handles: Vec> = cts + .iter_mut() + .enumerate() + .map(|(idx, ct)| set_ciphertext_metadata(&blob_hash, idx, ct, aux_data)) + .collect::, ExecutionError>>() + .inspect_err(telemetry::set_current_span_error)?; + + // Step 4: Re-randomize all ciphertexts before compression + re_randomise_ciphertexts(&mut cts, &blob_hash, &key.pks) + .inspect_err(telemetry::set_current_span_error)?; + + // Step 5: Compress and build final ciphertext records let cts = cts .iter_mut() + .zip(handles) .enumerate() - .map(|(idx, ct)| create_ciphertext(request_id, &blob_hash, idx, ct, aux_data)) + .map(|(idx, (ct, handle))| finalize_ciphertext(request_id, handle, idx, ct, aux_data)) .collect::, ExecutionError>>() .inspect_err(telemetry::set_current_span_error)?; @@ -505,19 +521,14 @@ fn expand_verified_list( Ok(cts) } -/// Creates a ciphertext -#[tracing::instrument(skip_all, fields( - ct_type = tracing::field::Empty, - ct_idx = ct_idx, - chain_id = %aux_data.chain_id, -))] -fn create_ciphertext( - request_id: i64, +/// Computes the handle hash and sets re-randomization metadata on a ciphertext. +/// Returns the full 256-bit handle hash (before index/chain/type/version are patched in). +fn set_ciphertext_metadata( blob_hash: &[u8], ct_idx: usize, the_ct: &mut SupportedFheCiphertexts, aux_data: &auxiliary::ZkData, -) -> Result { +) -> Result, ExecutionError> { if ct_idx > MAX_INPUT_INDEX as usize { return Err(ExecutionError::TooManyInputs(ct_idx)); } @@ -534,12 +545,54 @@ fn create_ciphertext( .into_array(), ); handle_hash.update(chain_id_bytes); - let mut handle = handle_hash.finalize().to_vec(); + let handle = handle_hash.finalize().to_vec(); assert_eq!(handle.len(), 32); // Add the full 256bit hash as re-randomization metadata, NOT the // truncated hash of the handle the_ct.add_re_randomization_metadata(&handle); + + Ok(handle) +} + +/// Re-randomizes all ciphertexts using the compact public key. +#[tracing::instrument(name = "rerandomise_cts", skip_all)] +fn re_randomise_ciphertexts( + cts: &mut [SupportedFheCiphertexts], + blob_hash: &[u8], + cpk: &tfhe::CompactPublicKey, +) -> Result<(), ExecutionError> { + let mut re_rand_context = ReRandomizationContext::new( + RERANDOMISATION_DOMAIN_SEPARATOR, + [blob_hash], + COMPACT_PUBLIC_ENCRYPTION_DOMAIN_SEPARATOR, + ); + for ct in cts.iter() { + ct.add_to_re_randomization_context(&mut re_rand_context); + } + let mut seed_gen = re_rand_context.finalize(); + for ct in cts.iter_mut() { + let seed = seed_gen + .next_seed() + .map_err(FhevmError::ReRandomisationError)?; + ct.re_randomise(cpk, seed)?; + } + Ok(()) +} + +/// Compresses the ciphertext and builds the final Ciphertext record with patched handle. +#[tracing::instrument(skip_all, fields( + ct_type = tracing::field::Empty, + ct_idx = ct_idx, + chain_id = %aux_data.chain_id, +))] +fn finalize_ciphertext( + request_id: i64, + mut handle: Vec, + ct_idx: usize, + the_ct: &mut SupportedFheCiphertexts, + aux_data: &auxiliary::ZkData, +) -> Result { let serialized_type = the_ct.type_num(); let compressed = the_ct.compress()?;