Skip to content

Commit 4b468fd

Browse files
antoniupopisaacdecoded
authored andcommitted
feat(coprocessor): re-randomise input ciphertexts before first compression (#2073)
* feat(coprocessor): add re-randomisation of input ciphertexts * test(coprocessor): add regression tests for input re-randomisation
1 parent ce117f2 commit 4b468fd

File tree

3 files changed

+337
-25
lines changed

3 files changed

+337
-25
lines changed

coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs

Lines changed: 113 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use fhevm_engine_common::tfhe_ops::current_ciphertext_version;
12
use serial_test::serial;
23
use test_harness::db_utils::ACL_CONTRACT_ADDR;
34

@@ -58,6 +59,15 @@ async fn test_verify_empty_input_list() {
5859
assert!(utils::is_valid(&pool, request_id, max_retries)
5960
.await
6061
.unwrap());
62+
63+
let handles = utils::wait_for_handles(&pool, request_id, max_retries)
64+
.await
65+
.unwrap();
66+
assert!(handles.is_empty());
67+
assert!(utils::fetch_stored_ciphertexts(&pool, &handles)
68+
.await
69+
.unwrap()
70+
.is_empty());
6171
}
6272

6373
#[tokio::test]
@@ -89,18 +99,110 @@ async fn test_max_input_index() {
8999

90100
// Test with highest number of inputs - 255
91101
let inputs = vec![utils::ZkInput::U64(2); MAX_INPUT_INDEX as usize + 1];
92-
assert!(utils::is_valid(
102+
let request_id = utils::insert_proof(
93103
&pool,
94-
utils::insert_proof(
95-
&pool,
96-
102,
97-
&utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await,
98-
&aux.0
99-
)
100-
.await
101-
.expect("valid db insert"),
102-
5000
104+
102,
105+
&utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await,
106+
&aux.0,
103107
)
104108
.await
105-
.expect("non-expired db query"));
109+
.expect("valid db insert");
110+
assert!(utils::is_valid(&pool, request_id, 5000)
111+
.await
112+
.expect("non-expired db query"));
113+
114+
let handles = utils::wait_for_handles(&pool, request_id, 5000)
115+
.await
116+
.expect("wait for handles");
117+
assert_eq!(handles.len(), MAX_INPUT_INDEX as usize + 1);
118+
assert_eq!(handles.first().expect("first handle")[21], 0);
119+
assert_eq!(handles.last().expect("last handle")[21], MAX_INPUT_INDEX);
120+
assert_eq!(
121+
&handles.last().expect("last handle")[22..30],
122+
&aux.0.chain_id.as_u64().to_be_bytes()
123+
);
124+
assert_eq!(
125+
handles.last().expect("last handle")[31],
126+
current_ciphertext_version() as u8
127+
);
128+
}
129+
130+
#[tokio::test]
131+
#[serial(db)]
132+
async fn test_verify_proof_rerandomises_ciphertexts_before_storage() {
133+
let (pool_mngr, _instance) = utils::setup().await.expect("valid setup");
134+
let pool = pool_mngr.pool();
135+
136+
let aux: (crate::auxiliary::ZkData, [u8; 92]) =
137+
utils::aux_fixture(ACL_CONTRACT_ADDR.to_owned());
138+
let inputs = vec![
139+
utils::ZkInput::Bool(true),
140+
utils::ZkInput::U8(42),
141+
utils::ZkInput::U16(12345),
142+
utils::ZkInput::U32(67890),
143+
utils::ZkInput::U64(1234567890),
144+
];
145+
let zk_pok = utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await;
146+
let request_id = utils::insert_proof(&pool, 103, &zk_pok, &aux.0)
147+
.await
148+
.unwrap();
149+
150+
assert!(utils::is_valid(&pool, request_id, 1000).await.unwrap());
151+
152+
let handles = utils::wait_for_handles(&pool, request_id, 1000)
153+
.await
154+
.unwrap();
155+
assert_eq!(handles.len(), inputs.len());
156+
for (idx, handle) in handles.iter().enumerate() {
157+
assert_eq!(handle.len(), 32);
158+
assert_eq!(handle[21], idx as u8);
159+
assert_eq!(&handle[22..30], &aux.0.chain_id.as_u64().to_be_bytes());
160+
assert_eq!(handle[31], current_ciphertext_version() as u8);
161+
}
162+
163+
let stored = utils::fetch_stored_ciphertexts(&pool, &handles)
164+
.await
165+
.unwrap();
166+
assert_eq!(stored.len(), inputs.len());
167+
assert_eq!(
168+
stored
169+
.iter()
170+
.map(|ct| ct.input_blob_index)
171+
.collect::<Vec<_>>(),
172+
(0..inputs.len() as i32).collect::<Vec<_>>()
173+
);
174+
assert_eq!(
175+
stored
176+
.iter()
177+
.map(|ct| ct.handle.as_slice())
178+
.collect::<Vec<_>>(),
179+
handles
180+
.iter()
181+
.map(|handle| handle.as_slice())
182+
.collect::<Vec<_>>()
183+
);
184+
185+
let baseline = utils::compress_inputs_without_rerandomization(&pool, &zk_pok)
186+
.await
187+
.unwrap();
188+
assert_eq!(baseline.len(), stored.len());
189+
assert!(
190+
stored
191+
.iter()
192+
.zip(&baseline)
193+
.all(|(stored_ct, baseline_ct)| stored_ct.ciphertext != *baseline_ct),
194+
"stored ciphertexts should differ from the pre-rerandomization compression"
195+
);
196+
197+
let decrypted = utils::decrypt_ciphertexts(&pool, &handles).await.unwrap();
198+
assert_eq!(
199+
decrypted
200+
.iter()
201+
.map(|result| result.value.clone())
202+
.collect::<Vec<_>>(),
203+
inputs
204+
.iter()
205+
.map(|input| input.cleartext())
206+
.collect::<Vec<_>>()
207+
);
106208
}

coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@ use fhevm_engine_common::chain_id::ChainId;
22
use fhevm_engine_common::crs::CrsCache;
33
use fhevm_engine_common::db_keys::DbKeyCache;
44
use fhevm_engine_common::pg_pool::PostgresPoolManager;
5-
use fhevm_engine_common::utils::safe_serialize;
5+
use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, extract_ct_list};
6+
use fhevm_engine_common::types::SupportedFheCiphertexts;
7+
use fhevm_engine_common::utils::{safe_deserialize_conformant, safe_serialize};
8+
use sqlx::Row;
69
use std::sync::Arc;
710
use std::time::{Duration, SystemTime};
811
use test_harness::instance::{DBInstance, ImportMode};
12+
use tfhe::integer::ciphertext::IntegerProvenCompactCiphertextListConformanceParams;
913
use tokio::sync::RwLock;
1014
use tokio::time::sleep;
1115

@@ -84,6 +88,147 @@ pub(crate) async fn is_valid(
8488
Ok(false)
8589
}
8690

91+
#[derive(Debug)]
92+
pub(crate) struct StoredCiphertext {
93+
pub(crate) handle: Vec<u8>,
94+
pub(crate) ciphertext: Vec<u8>,
95+
pub(crate) ciphertext_type: i16,
96+
pub(crate) input_blob_index: i32,
97+
}
98+
99+
#[derive(Debug, PartialEq, Eq)]
100+
pub(crate) struct DecryptionResult {
101+
pub(crate) output_type: i16,
102+
pub(crate) value: String,
103+
}
104+
105+
pub(crate) async fn wait_for_handles(
106+
pool: &sqlx::PgPool,
107+
zk_proof_id: i64,
108+
max_retries: usize,
109+
) -> Result<Vec<Vec<u8>>, sqlx::Error> {
110+
for _ in 0..max_retries {
111+
sleep(Duration::from_millis(100)).await;
112+
let row = sqlx::query("SELECT verified, handles FROM verify_proofs WHERE zk_proof_id = $1")
113+
.bind(zk_proof_id)
114+
.fetch_one(pool)
115+
.await?;
116+
117+
let verified: Option<bool> = row.try_get("verified")?;
118+
if !matches!(verified, Some(true)) {
119+
continue;
120+
}
121+
122+
let handles: Option<Vec<u8>> = row.try_get("handles")?;
123+
let handles = handles.unwrap_or_default();
124+
assert_eq!(handles.len() % 32, 0);
125+
126+
return Ok(handles.chunks(32).map(|chunk| chunk.to_vec()).collect());
127+
}
128+
129+
Ok(vec![])
130+
}
131+
132+
pub(crate) async fn fetch_stored_ciphertexts(
133+
pool: &sqlx::PgPool,
134+
handles: &[Vec<u8>],
135+
) -> Result<Vec<StoredCiphertext>, sqlx::Error> {
136+
if handles.is_empty() {
137+
return Ok(vec![]);
138+
}
139+
140+
let rows = sqlx::query(
141+
"
142+
SELECT handle, ciphertext, ciphertext_type, input_blob_index
143+
FROM ciphertexts
144+
WHERE handle = ANY($1::BYTEA[])
145+
AND ciphertext_version = $2
146+
ORDER BY input_blob_index ASC
147+
",
148+
)
149+
.bind(handles)
150+
.bind(current_ciphertext_version())
151+
.fetch_all(pool)
152+
.await?;
153+
154+
rows.into_iter()
155+
.map(|row| {
156+
Ok(StoredCiphertext {
157+
handle: row.try_get("handle")?,
158+
ciphertext: row.try_get("ciphertext")?,
159+
ciphertext_type: row.try_get("ciphertext_type")?,
160+
input_blob_index: row.try_get("input_blob_index")?,
161+
})
162+
})
163+
.collect()
164+
}
165+
166+
pub(crate) async fn decrypt_ciphertexts(
167+
pool: &sqlx::PgPool,
168+
handles: &[Vec<u8>],
169+
) -> anyhow::Result<Vec<DecryptionResult>> {
170+
let stored = fetch_stored_ciphertexts(pool, handles).await?;
171+
let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache");
172+
let key = db_key_cache.fetch_latest(pool).await?;
173+
174+
tokio::task::spawn_blocking(move || {
175+
let client_key = key.cks.expect("client key available in tests");
176+
tfhe::set_server_key(key.sks);
177+
178+
stored
179+
.into_iter()
180+
.map(|ct| {
181+
let deserialized = SupportedFheCiphertexts::decompress_no_memcheck(
182+
ct.ciphertext_type,
183+
&ct.ciphertext,
184+
)
185+
.expect("valid compressed ciphertext");
186+
DecryptionResult {
187+
output_type: ct.ciphertext_type,
188+
value: deserialized.decrypt(&client_key),
189+
}
190+
})
191+
.collect::<Vec<_>>()
192+
})
193+
.await
194+
.map_err(anyhow::Error::from)
195+
}
196+
197+
pub(crate) async fn compress_inputs_without_rerandomization(
198+
pool: &sqlx::PgPool,
199+
raw_ct: &[u8],
200+
) -> anyhow::Result<Vec<Vec<u8>>> {
201+
let db_key_cache = DbKeyCache::new(MAX_CACHED_KEYS).expect("create db key cache");
202+
let latest_key = db_key_cache.fetch_latest(pool).await?;
203+
let latest_crs = CrsCache::load(pool)
204+
.await?
205+
.get_latest()
206+
.cloned()
207+
.expect("latest CRS");
208+
209+
let verified_list: tfhe::ProvenCompactCiphertextList = safe_deserialize_conformant(
210+
raw_ct,
211+
&IntegerProvenCompactCiphertextListConformanceParams::from_public_key_encryption_parameters_and_crs_parameters(
212+
latest_key.pks.parameters(),
213+
&latest_crs.crs,
214+
),
215+
)?;
216+
217+
if verified_list.is_empty() {
218+
return Ok(vec![]);
219+
}
220+
221+
tokio::task::spawn_blocking(move || {
222+
tfhe::set_server_key(latest_key.sks);
223+
let expanded = verified_list.expand_without_verification()?;
224+
let cts = extract_ct_list(&expanded)?;
225+
cts.into_iter()
226+
.map(|ct| ct.compress().map_err(anyhow::Error::from))
227+
.collect()
228+
})
229+
.await?
230+
}
231+
87232
#[derive(Debug, Clone)]
88233
pub(crate) enum ZkInput {
89234
Bool(bool),
@@ -93,6 +238,18 @@ pub(crate) enum ZkInput {
93238
U64(u64),
94239
}
95240

241+
impl ZkInput {
242+
pub(crate) fn cleartext(&self) -> String {
243+
match self {
244+
Self::Bool(value) => value.to_string(),
245+
Self::U8(value) => value.to_string(),
246+
Self::U16(value) => value.to_string(),
247+
Self::U32(value) => value.to_string(),
248+
Self::U64(value) => value.to_string(),
249+
}
250+
}
251+
}
252+
96253
pub(crate) async fn generate_zk_pok_with_inputs(
97254
pool: &sqlx::PgPool,
98255
aux_data: &[u8],

0 commit comments

Comments
 (0)