diff --git a/coprocessor/fhevm-engine/db-migration/migrations/20260128095635_remove_tenants.sql b/coprocessor/fhevm-engine/db-migration/migrations/20260128095635_remove_tenants.sql index c6f1a035c8..8c95a6000d 100644 --- a/coprocessor/fhevm-engine/db-migration/migrations/20260128095635_remove_tenants.sql +++ b/coprocessor/fhevm-engine/db-migration/migrations/20260128095635_remove_tenants.sql @@ -80,6 +80,12 @@ ALTER TABLE ciphertexts DROP CONSTRAINT ciphertexts_pkey; ALTER TABLE ciphertexts DROP COLUMN tenant_id; ALTER TABLE ciphertexts ADD PRIMARY KEY (handle, ciphertext_version); +-- ciphertexts128.tenant_id no longer needed. +ALTER TABLE ciphertexts128 DROP CONSTRAINT ciphertexts128_pkey; +ALTER TABLE ciphertexts128 DROP COLUMN tenant_id; +ALTER TABLE ciphertexts128 ADD PRIMARY KEY (handle); +DROP INDEX IF EXISTS idx_ciphertexts128_handle; + -- computations.tenant_id no longer needed. ALTER TABLE computations DROP CONSTRAINT computations_pkey; DROP INDEX IF EXISTS idx_computations_pk; @@ -87,6 +93,10 @@ ALTER TABLE computations DROP COLUMN tenant_id; ALTER TABLE computations ADD PRIMARY KEY (output_handle, transaction_id); -- pbs_computations.tenant_id no longer needed. +ALTER TABLE pbs_computations ADD COLUMN host_chain_id BIGINT DEFAULT NULL; +UPDATE pbs_computations SET host_chain_id = (SELECT chain_id FROM keys WHERE tenant_id = pbs_computations.tenant_id); +ALTER TABLE pbs_computations ALTER COLUMN host_chain_id SET NOT NULL; +ALTER TABLE pbs_computations ADD CONSTRAINT pbs_computations_host_chain_id_positive CHECK (host_chain_id > 0); ALTER TABLE pbs_computations DROP CONSTRAINT pbs_computations_pkey; ALTER TABLE pbs_computations DROP COLUMN tenant_id; ALTER TABLE pbs_computations ADD PRIMARY KEY (handle); diff --git a/coprocessor/fhevm-engine/sns-worker/README.md b/coprocessor/fhevm-engine/sns-worker/README.md index 25750dba77..edf9906680 100644 --- a/coprocessor/fhevm-engine/sns-worker/README.md +++ b/coprocessor/fhevm-engine/sns-worker/README.md @@ -17,19 +17,19 @@ Upon receiving a notification, it mainly does the following steps: Runs sns-executor. See also `src/bin/utils/daemon_cli.rs` - + ## Running a SnS Worker -### The SnS key can be retrieved from the Large Objects table (pg_largeobject). Before running a worker, the sns_pk should be imported into tenants tables as shown below. If tenants table is not in use, then keys can be passed with CLI param --keys_file_path +### The SnS key can be retrieved from the Large Objects table (pg_largeobject). Before running a worker, the sns_pk should be imported into the keys table as shown below. If the keys table is not in use, then keys can be passed with CLI param --keys_file_path ```sql -- Example query to import sns_pk from fhevm-keys/sns_pk -- Import the sns_pk into the Large Object storage sns_pk_loid := lo_import('../fhevm-keys/sns_pk'); --- Update the tenants table with the new Large Object OID -UPDATE tenants +-- Update the keys table with the new Large Object OID +UPDATE keys SET sns_pk = sns_pk_loid -WHERE tenant_id = 1; +WHERE sequence_number = (SELECT sequence_number FROM keys ORDER BY sequence_number DESC LIMIT 1); ``` ### Multiple workers can be launched independently to perform 128-PBS computations. @@ -37,11 +37,13 @@ WHERE tenant_id = 1; # Run a single instance of the worker DATABASE_URL=postgresql://postgres:postgres@localhost:5432/coprocessor \ cargo run --release -- \ ---tenant-api-key "a1503fb6-d79b-4e9e-826d-44cf262f3e05" \ --pg-listen-channels "event_pbs_computations" "event_ciphertext_computed" \ --pg-notify-channel "event_pbs_computed" \ ``` +Notes: +- `host_chain_id` is read directly from `pbs_computations`/`ciphertext_digest` rows. + ## Testing - Using `Postgres` docker image @@ -59,4 +61,3 @@ COPROCESSOR_TEST_LOCALHOST_RESET=1 cargo test --release -- --nocapture # Then, on every run COPROCESSOR_TEST_LOCALHOST=1 cargo test --release ``` - diff --git a/coprocessor/fhevm-engine/sns-worker/src/aws_upload.rs b/coprocessor/fhevm-engine/sns-worker/src/aws_upload.rs index ba6376aba4..d0f2878987 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/aws_upload.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/aws_upload.rs @@ -119,10 +119,9 @@ async fn run_uploader_loop( UploadJob::DatabaseLock(item) => { if let Err(err) = sqlx::query!( "SELECT * FROM ciphertext_digest - WHERE handle = $2 AND tenant_id = $1 AND + WHERE handle = $1 AND (ciphertext128 IS NULL OR ciphertext IS NULL) FOR UPDATE SKIP LOCKED", - item.tenant_id, item.handle ) .fetch_one(trx.as_mut()) @@ -234,7 +233,6 @@ async fn upload_ciphertexts( info!( handle = handle_as_hex, len = ?ByteSize::b(ct128_bytes.len() as u64), - tenant_id = task.tenant_id, "Uploading ct128" ); @@ -286,7 +284,6 @@ async fn upload_ciphertexts( info!( handle = handle_as_hex, len = ?ByteSize::b(ct64_compressed.len() as u64), - tenant_id = task.tenant_id, "Uploading ct64", ); @@ -400,7 +397,7 @@ async fn fetch_pending_uploads( limit: i64, ) -> Result, ExecutionError> { let rows = sqlx::query!( - "SELECT tenant_id, handle, ciphertext, ciphertext128, ciphertext128_format, transaction_id + "SELECT handle, ciphertext, ciphertext128, ciphertext128_format, transaction_id, host_chain_id, key_id FROM ciphertext_digest WHERE ciphertext IS NULL OR ciphertext128 IS NULL FOR UPDATE SKIP LOCKED @@ -423,8 +420,7 @@ async fn fetch_pending_uploads( // Fetch missing ciphertext if ciphertext_digest.is_none() { if let Ok(row) = sqlx::query!( - "SELECT ciphertext FROM ciphertexts WHERE tenant_id = $1 AND handle = $2;", - row.tenant_id, + "SELECT ciphertext FROM ciphertexts WHERE handle = $1;", handle ) .fetch_optional(db_pool) @@ -441,8 +437,7 @@ async fn fetch_pending_uploads( // Fetch missing ciphertext128 if ciphertext128_digest.is_none() { if let Ok(row) = sqlx::query!( - "SELECT ciphertext FROM ciphertexts128 WHERE tenant_id = $1 AND handle = $2;", - row.tenant_id, + "SELECT ciphertext FROM ciphertexts128 WHERE handle = $1;", handle ) .fetch_optional(db_pool) @@ -484,7 +479,8 @@ async fn fetch_pending_uploads( if !ct64_compressed.is_empty() || !is_ct128_empty { let item = HandleItem { - tenant_id: row.tenant_id, + host_chain_id: row.host_chain_id, + key_id: row.key_id, handle: handle.clone(), ct64_compressed, ct128: Arc::new(ct128), diff --git a/coprocessor/fhevm-engine/sns-worker/src/bin/sns_worker.rs b/coprocessor/fhevm-engine/sns-worker/src/bin/sns_worker.rs index 89662b7408..5d99cefdb6 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/bin/sns_worker.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/bin/sns_worker.rs @@ -19,7 +19,6 @@ fn construct_config() -> Config { let db_url = args.database_url.clone().unwrap_or_default(); Config { - tenant_api_key: args.tenant_api_key, service_name: args.service_name, metrics: SNSMetricsConfig { addr: args.metrics_addr, diff --git a/coprocessor/fhevm-engine/sns-worker/src/bin/utils/daemon_cli.rs b/coprocessor/fhevm-engine/sns-worker/src/bin/utils/daemon_cli.rs index 58ce0145bc..afede09b11 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/bin/utils/daemon_cli.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/bin/utils/daemon_cli.rs @@ -11,10 +11,6 @@ use tracing::Level; #[derive(Parser, Debug, Clone)] #[command(version, about, long_about = None)] pub struct Args { - /// Tenant API key - #[arg(long)] - pub tenant_api_key: String, - /// Work items batch size #[arg(long, default_value_t = 4)] pub work_items_batch_size: u32, diff --git a/coprocessor/fhevm-engine/sns-worker/src/executor.rs b/coprocessor/fhevm-engine/sns-worker/src/executor.rs index d81be6a651..37ed9052a5 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/executor.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/executor.rs @@ -1,5 +1,5 @@ use crate::aws_upload::check_is_ready; -use crate::keyset::fetch_keyset; +use crate::keyset::fetch_latest_keyset; use crate::metrics::SNS_LATENCY_OP_HISTOGRAM; use crate::metrics::TASK_EXECUTE_FAILURE_COUNTER; use crate::metrics::TASK_EXECUTE_SUCCESS_COUNTER; @@ -13,6 +13,7 @@ use crate::SchedulePolicy; use crate::UploadJob; use crate::{Config, ExecutionError}; use aws_sdk_s3::Client; +use fhevm_engine_common::db_keys::DbKeyId; use fhevm_engine_common::healthz_server::{HealthCheckService, HealthStatus, Version}; use fhevm_engine_common::pg_pool::PostgresPoolManager; use fhevm_engine_common::pg_pool::ServiceError; @@ -142,7 +143,7 @@ impl SwitchNSquashService { } pub async fn run(&self, pool_mngr: &PostgresPoolManager) { - let keys_cache: Arc>> = Arc::new(RwLock::new( + let keys_cache: Arc>> = Arc::new(RwLock::new( lru::LruCache::new(NonZeroUsize::new(10).unwrap()), )); @@ -174,19 +175,10 @@ impl SwitchNSquashService { async fn get_keyset( pool: PgPool, - keys_cache: Arc>>, - tenant_api_key: &String, -) -> Result, ExecutionError> { + keys_cache: Arc>>, +) -> Result, ExecutionError> { let _t = telemetry::tracer("fetch_keyset", &None); - { - let mut cache = keys_cache.write().await; - if let Some(keys) = cache.get(tenant_api_key) { - info!(tenant_api_key = tenant_api_key, "Keyset found in cache"); - return Ok(Some(keys.clone())); - } - } - let keys: Option = fetch_keyset(&keys_cache, &pool, tenant_api_key).await?; - Ok(keys) + fetch_latest_keyset(&keys_cache, &pool).await } /// Executes the worker logic for the SnS task. @@ -196,12 +188,11 @@ pub(crate) async fn run_loop( pool: PgPool, token: CancellationToken, last_active_at: Arc>, - keys_cache: Arc>>, + keys_cache: Arc>>, events_tx: InternalEvents, ) -> Result<(), ExecutionError> { update_last_active(last_active_at.clone()).await; - let tenant_api_key = &conf.tenant_api_key; let mut listener = PgListener::connect_with(&pool).await?; info!("Connected to PostgresDB"); @@ -209,7 +200,7 @@ pub(crate) async fn run_loop( .listen_all(conf.db.listen_channels.iter().map(|v| v.as_str())) .await?; - let mut keys = None; + let mut keys: Option<(DbKeyId, KeySet)> = None; let mut gc_ticker = interval(conf.db.cleanup_interval); let mut gc_timestamp = SystemTime::now(); let mut polling_ticker = interval(Duration::from_secs(conf.db.polling_interval.into())); @@ -218,27 +209,31 @@ pub(crate) async fn run_loop( // Continue looping until the service is cancelled or a critical error occurs update_last_active(last_active_at.clone()).await; - let Some(keys) = keys.as_ref() else { - keys = get_keyset(pool.clone(), keys_cache.clone(), tenant_api_key).await?; - if keys.is_some() { - info!(tenant_api_key = tenant_api_key, "Fetched keyset"); + let latest_keys = get_keyset(pool.clone(), keys_cache.clone()).await?; + if let Some((key_id, keyset)) = latest_keys { + let key_changed = keys + .as_ref() + .map(|(current_key_id, _)| current_key_id != &key_id) + .unwrap_or(true); + if key_changed { + info!(key_id = hex::encode(&key_id), "Fetched keyset"); // Notify that the keys are loaded if let Some(events_tx) = &events_tx { let _ = events_tx.try_send("event_keys_loaded"); } - } else { - warn!( - tenant_api_key = tenant_api_key, - "No keys available, retrying in 5 seconds" - ); - tokio::time::sleep(Duration::from_secs(5)).await; } - + keys = Some((key_id, keyset)); + } else { + warn!("No keys available, retrying in 5 seconds"); + tokio::time::sleep(Duration::from_secs(5)).await; if token.is_cancelled() { return Ok(()); } continue; - }; + } + + // keys is guaranteed by the branch above; panic here if that invariant ever regresses. + let (_, keys) = keys.as_ref().expect("keyset should be available"); let (maybe_remaining, _tasks_processed) = fetch_and_execute_sns_tasks(&pool, &tx, keys, &conf, &token) @@ -316,11 +311,10 @@ pub async fn garbage_collect(pool: &PgPool, limit: u32) -> Result<(), ExecutionE let rows_affected: u64 = sqlx::query!( " WITH uploaded_ct128 AS ( - SELECT c.tenant_id, c.handle + SELECT c.handle FROM ciphertexts128 c JOIN ciphertext_digest d - ON d.tenant_id = c.tenant_id - AND d.handle = c.handle + ON d.handle = c.handle WHERE d.ciphertext128 IS NOT NULL FOR UPDATE OF c SKIP LOCKED LIMIT $1 @@ -328,8 +322,7 @@ pub async fn garbage_collect(pool: &PgPool, limit: u32) -> Result<(), ExecutionE DELETE FROM ciphertexts128 c USING uploaded_ct128 r - WHERE c.tenant_id = r.tenant_id - AND c.handle = r.handle; + WHERE c.handle = r.handle; ", limit as i32 ) @@ -375,7 +368,7 @@ async fn fetch_and_execute_sns_tasks( let mut maybe_remaining = false; let tasks_processed; - if let Some(mut tasks) = query_sns_tasks(trx, conf.db.batch_limit, order).await? { + if let Some(mut tasks) = query_sns_tasks(trx, conf.db.batch_limit, order, &keys.key_id).await? { maybe_remaining = conf.db.batch_limit as usize == tasks.len(); tasks_processed = tasks.len(); @@ -423,6 +416,7 @@ pub async fn query_sns_tasks( db_txn: &mut Transaction<'_, Postgres>, limit: u32, order: Order, + key_id: &DbKeyId, ) -> Result>, ExecutionError> { let start_time = SystemTime::now(); @@ -460,13 +454,16 @@ pub async fn query_sns_tasks( let tasks = records .into_iter() .map(|record| { - let tenant_id: i32 = record.try_get("tenant_id")?; + let host_chain_id: i64 = record.try_get("host_chain_id")?; let handle: Vec = record.try_get("handle")?; let ciphertext: Vec = record.try_get("ciphertext")?; let transaction_id: Option> = record.try_get("transaction_id")?; Ok(HandleItem { - tenant_id, + // TODO: During key rotation, ensure all coprocessors pin the same key_id for a batch + // (e.g., via gateway coordination) to keep ciphertext_digest consistent. + key_id: key_id.clone(), + host_chain_id, handle: handle.clone(), ct64_compressed: Arc::new(ciphertext), ct128: Arc::new(BigCiphertext::default()), // to be computed @@ -644,12 +641,10 @@ async fn update_ciphertext128( let res = sqlx::query!( " INSERT INTO ciphertexts128 ( - tenant_id, handle, ciphertext ) - VALUES ($1, $2, $3)", - task.tenant_id, + VALUES ($1, $2)", task.handle, ciphertext128, ) diff --git a/coprocessor/fhevm-engine/sns-worker/src/keyset.rs b/coprocessor/fhevm-engine/sns-worker/src/keyset.rs index 4907fee7bd..b7d23468e6 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/keyset.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/keyset.rs @@ -1,4 +1,7 @@ -use fhevm_engine_common::{db_keys::read_keys_from_large_object, utils::safe_deserialize_sns_key}; +use fhevm_engine_common::{ + db_keys::{read_keys_from_large_object, DbKeyId}, + utils::safe_deserialize_sns_key, +}; use sqlx::{PgPool, Row}; use std::sync::Arc; use tokio::sync::RwLock; @@ -8,47 +11,52 @@ use crate::{ExecutionError, KeySet}; const SKS_KEY_WITH_NOISE_SQUASHING_SIZE: usize = 1_150 * 1_000_000; // ~1.1 GB -/// Retrieve the keyset from the database -pub(crate) async fn fetch_keyset( - cache: &Arc>>, - pool: &PgPool, - tenant_api_key: &String, -) -> Result, ExecutionError> { - let mut cache = cache.write().await; - if let Some(keys) = cache.get(tenant_api_key) { - info!(tenant_api_key, "Cache hit"); - return Ok(Some(keys.clone())); - } +async fn fetch_latest_key_id(pool: &PgPool) -> Result, ExecutionError> { + let record = sqlx::query( + "SELECT key_id, sequence_number FROM keys ORDER BY sequence_number DESC LIMIT 1", + ) + .fetch_optional(pool) + .await?; - info!(tenant_api_key, "Cache miss"); + if let Some(record) = record { + let key_id: DbKeyId = record.try_get("key_id")?; + let sequence_number: i64 = record.try_get("sequence_number")?; + Ok(Some((key_id, sequence_number))) + } else { + Ok(None) + } +} - let Some((client_key, server_key)) = fetch_keys(pool, tenant_api_key).await? else { +pub(crate) async fn fetch_latest_keyset( + cache: &Arc>>, + pool: &PgPool, +) -> Result, ExecutionError> { + let Some((key_id, _sequence_number)) = fetch_latest_key_id(pool).await? else { return Ok(None); }; - let key_set: KeySet = KeySet { - client_key, - server_key, - }; - - cache.push(tenant_api_key.clone(), key_set.clone()); - Ok(Some(key_set)) + let keyset = fetch_keyset_by_id(cache, pool, &key_id).await?; + Ok(keyset.map(|keys| (key_id, keys))) } -/// Retrieve both the ClientKey and ServerKey from the tenants table -/// -/// The ServerKey is stored in a large object (LOB) in the database. -/// ServerKey must be generated with enable_noise_squashing option. -/// -/// The ClientKey is stored in a bytea column and is optional. It's used only -/// for decrypting on testing. -pub async fn fetch_keys( +async fn fetch_keyset_by_id( + cache: &Arc>>, pool: &PgPool, - tenant_api_key: &String, -) -> anyhow::Result, crate::ServerKey)>> { + key_id: &DbKeyId, +) -> Result, ExecutionError> { + { + let mut cache = cache.write().await; + if let Some(keys) = cache.get(key_id) { + info!(key_id = hex::encode(key_id), "Cache hit"); + return Ok(Some(keys.clone())); + } + } + + info!(key_id = hex::encode(key_id), "Cache miss"); + let blob = read_keys_from_large_object( pool, - tenant_api_key, + key_id.clone(), "sns_pk", SKS_KEY_WITH_NOISE_SQUASHING_SIZE, ) @@ -75,24 +83,29 @@ pub async fn fetch_keys( }; // Optionally retrieve the ClientKey for testing purposes - let client_key = fetch_client_key(pool, tenant_api_key).await?; - Ok(Some((client_key, server_key))) + let client_key = fetch_client_key(pool, key_id).await?; + + let key_set = KeySet { + key_id: key_id.clone(), + client_key, + server_key, + }; + + let mut cache = cache.write().await; + cache.put(key_id.clone(), key_set.clone()); + Ok(Some(key_set)) } pub async fn fetch_client_key( pool: &PgPool, - tenant_api_key: &String, + key_id: &DbKeyId, ) -> anyhow::Result> { - if let Ok(keys) = sqlx::query( - " - SELECT cks_key FROM tenants - WHERE tenant_api_key = $1::uuid - ", - ) - .bind(tenant_api_key) - .fetch_one(pool) - .await - { + let keys = sqlx::query("SELECT cks_key FROM keys WHERE key_id = $1") + .bind(key_id) + .fetch_optional(pool) + .await?; + + if let Some(keys) = keys { if let Ok(cks) = keys.try_get::, _>(0) { if !cks.is_empty() { info!(bytes_len = cks.len(), "Retrieved cks"); diff --git a/coprocessor/fhevm-engine/sns-worker/src/lib.rs b/coprocessor/fhevm-engine/sns-worker/src/lib.rs index abbfb678bc..222a04faa0 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/lib.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/lib.rs @@ -19,6 +19,7 @@ use std::{ use aws_config::{retry::RetryConfig, timeout::TimeoutConfig, BehaviorVersion}; use aws_sdk_s3::{config::Builder, Client}; use fhevm_engine_common::{ + db_keys::DbKeyId, healthz_server::{self}, metrics_server, pg_pool::{PostgresPoolManager, ServiceError}, @@ -57,6 +58,7 @@ type ServerKey = tfhe::ServerKey; #[derive(Clone)] pub struct KeySet { + pub key_id: DbKeyId, /// Optional ClientKey for decrypting on testing pub client_key: Option, pub server_key: ServerKey, @@ -110,7 +112,6 @@ pub struct HealthCheckConfig { #[derive(Clone)] pub struct Config { - pub tenant_api_key: String, pub service_name: String, pub db: DBConfig, pub s3: S3Config, @@ -221,7 +222,8 @@ impl std::fmt::Display for Ciphertext128Format { #[derive(Clone)] pub struct HandleItem { - pub tenant_id: i32, + pub host_chain_id: i64, + pub key_id: DbKeyId, pub handle: Vec, /// Compressed 64-bit ciphertext @@ -248,9 +250,10 @@ impl HandleItem { db_txn: &mut Transaction<'_, Postgres>, ) -> Result<(), ExecutionError> { sqlx::query!( - "INSERT INTO ciphertext_digest (tenant_id, handle, transaction_id) - VALUES ($1, $2, $3) ON CONFLICT DO NOTHING", - self.tenant_id, + "INSERT INTO ciphertext_digest (host_chain_id, key_id, handle, transaction_id) + VALUES ($1, $2, $3, $4) ON CONFLICT DO NOTHING", + self.host_chain_id, + &self.key_id, self.handle, self.transaction_id, ) diff --git a/coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs b/coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs index 606d6c2aee..b6b3845516 100644 --- a/coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs +++ b/coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs @@ -6,6 +6,7 @@ use crate::{ }; use anyhow::{anyhow, Ok}; use aws_config::BehaviorVersion; +use fhevm_engine_common::db_keys::DbKeyId; use fhevm_engine_common::utils::{to_hex, DatabaseURL}; use serde::{Deserialize, Serialize}; use serial_test::serial; @@ -29,7 +30,6 @@ use tokio::{sync::mpsc, time::timeout}; use tracing::{info, Level}; const LISTEN_CHANNEL: &str = "sns_worker_chan"; -const TENANT_API_KEY: &str = "a1503fb6-d79b-4e9e-826d-44cf262f3e05"; static TRACING_INIT: OnceLock<()> = OnceLock::new(); @@ -133,23 +133,14 @@ async fn test_decryptable( if first_fhe_computation { // insert into ciphertexts insert_ciphertext64(pool, handle, ciphertext).await?; - insert_into_pbs_computations(pool, handle).await?; + insert_into_pbs_computations(pool, test_env.host_chain_id, handle).await?; } else { // insert into pbs_computations - insert_into_pbs_computations(pool, handle).await?; + insert_into_pbs_computations(pool, test_env.host_chain_id, handle).await?; insert_ciphertext64(pool, handle, ciphertext).await?; } - let tenant_id = get_tenant_id_from_db(pool, TENANT_API_KEY).await; - - assert_ciphertext128( - test_env, - tenant_id, - with_compression, - handle, - expected_result, - ) - .await?; + assert_ciphertext128(test_env, with_compression, handle, expected_result).await?; Ok(()) } @@ -174,7 +165,7 @@ async fn run_batch_computations( info!(batch_size, "Inserting ciphertexts ..."); let mut handles = Vec::new(); - let tenant_id = get_tenant_id_from_db(pool, TENANT_API_KEY).await; + let host_chain_id = test_env.host_chain_id; for i in 0..batch_size { let mut handle = base_handle.to_owned(); @@ -182,8 +173,8 @@ async fn run_batch_computations( // However the ciphertext64 will be the same handle[0] = (i >> 8) as u8; handle[1] = (i & 0xFF) as u8; - test_harness::db_utils::insert_ciphertext64(pool, tenant_id, &handle, ciphertext).await?; - test_harness::db_utils::insert_into_pbs_computations(pool, tenant_id, &handle).await?; + test_harness::db_utils::insert_ciphertext64(pool, &handle, ciphertext).await?; + test_harness::db_utils::insert_into_pbs_computations(pool, host_chain_id, &handle).await?; handles.push(handle); } @@ -204,14 +195,7 @@ async fn run_batch_computations( let test_env = test_env.clone(); let handle = handle.clone(); set.spawn(async move { - assert_ciphertext128( - &test_env, - tenant_id, - with_compression, - &handle, - expected_cleartext, - ) - .await + assert_ciphertext128(&test_env, with_compression, &handle, expected_cleartext).await }); } @@ -247,25 +231,31 @@ async fn test_lifo_mode() { const HANDLES_COUNT: usize = 30; const BATCH_SIZE: usize = 4; + let key_id: DbKeyId = vec![0u8; 32]; + let host_chain_id: i64 = 1; for i in 0..HANDLES_COUNT { // insert into ciphertexts test_harness::db_utils::insert_ciphertext64( &pool, - 1, &Vec::from([i as u8; 32]), &Vec::from([i as u8; 32]), ) .await .unwrap(); - test_harness::db_utils::insert_into_pbs_computations(&pool, 1, &Vec::from([i as u8; 32])) - .await - .unwrap(); + test_harness::db_utils::insert_into_pbs_computations( + &pool, + host_chain_id, + &Vec::from([i as u8; 32]), + ) + .await + .unwrap(); } let mut trx = pool.begin().await.unwrap(); - if let Result::Ok(Some(tasks)) = query_sns_tasks(&mut trx, BATCH_SIZE as u32, Order::Desc).await + if let Result::Ok(Some(tasks)) = + query_sns_tasks(&mut trx, BATCH_SIZE as u32, Order::Desc, &key_id).await { assert!( tasks.len() == BATCH_SIZE, @@ -288,7 +278,8 @@ async fn test_lifo_mode() { } let mut trx = pool.begin().await.unwrap(); - if let Result::Ok(Some(tasks)) = query_sns_tasks(&mut trx, BATCH_SIZE as u32, Order::Asc).await + if let Result::Ok(Some(tasks)) = + query_sns_tasks(&mut trx, BATCH_SIZE as u32, Order::Asc, &key_id).await { assert!( tasks.len() == BATCH_SIZE, @@ -331,17 +322,17 @@ async fn test_garbage_collect() { clean_up(&pool).await.unwrap(); - let tenant_id = 1; + let host_chain_id: i64 = 1; + let key_id: Vec = vec![0u8; 32]; for i in 0..HANDLES_COUNT { // insert into ciphertexts let mut handle = [0u8; 32]; handle[..4].copy_from_slice(&i.to_le_bytes()); let _ = sqlx::query!( - "INSERT INTO ciphertexts128(tenant_id, handle, ciphertext) - VALUES ($1, $2, $3) + "INSERT INTO ciphertexts128(handle, ciphertext) + VALUES ($1, $2) ON CONFLICT DO NOTHING;", - tenant_id, &handle, &[i as u8; 32], ) @@ -350,10 +341,11 @@ async fn test_garbage_collect() { .expect("insert into ciphertexts"); let _ = sqlx::query!( - "INSERT INTO ciphertext_digest(tenant_id, handle, ciphertext, ciphertext128 ) - VALUES ($1, $2, $3, $4) + "INSERT INTO ciphertext_digest(host_chain_id, key_id, handle, ciphertext, ciphertext128 ) + VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING;", - tenant_id, + host_chain_id, + &key_id, &handle, &[i as u8; 32], &[i as u8; 32], @@ -412,6 +404,8 @@ async fn test_garbage_collect() { struct TestEnvironment { pub pool: sqlx::PgPool, pub client_key: Option, + pub key_id: DbKeyId, + pub host_chain_id: i64, pub db_instance: DBInstance, pub s3_instance: Option>, // If None, the global LocalStack is used pub s3_client: aws_sdk_s3::Client, @@ -448,7 +442,9 @@ async fn setup(enable_compression: bool) -> anyhow::Result { let token = db_instance.parent_token.child_token(); let config: Config = conf.clone(); - let client_key: Option = fetch_client_key(&pool, &TENANT_API_KEY.to_owned()).await?; + let key_id = fetch_latest_key_id(&pool).await; + let host_chain_id = fetch_host_chain_id(&pool).await; + let client_key: Option = fetch_client_key(&pool, &key_id).await?; let (events_tx, mut events_rx) = mpsc::channel::<&'static str>(10); tokio::spawn(async move { @@ -468,6 +464,8 @@ async fn setup(enable_compression: bool) -> anyhow::Result { Ok(TestEnvironment { pool, client_key, + key_id, + host_chain_id, db_instance, s3_instance, s3_client, @@ -561,15 +559,18 @@ fn read_test_file(filename: &str) -> TestFile { serde_json::from_slice(&buffer).expect("Failed to deserialize") } -async fn get_tenant_id_from_db(pool: &sqlx::PgPool, tenant_api_key: &str) -> i32 { - let tenant_id: i32 = - sqlx::query_scalar("SELECT tenant_id FROM tenants WHERE tenant_api_key = $1::uuid") - .bind(tenant_api_key) - .fetch_one(pool) - .await - .expect("tenant_id"); +async fn fetch_latest_key_id(pool: &sqlx::PgPool) -> DbKeyId { + sqlx::query_scalar("SELECT key_id FROM keys ORDER BY sequence_number DESC LIMIT 1") + .fetch_one(pool) + .await + .expect("key_id") +} - tenant_id +async fn fetch_host_chain_id(pool: &sqlx::PgPool) -> i64 { + sqlx::query_scalar("SELECT chain_id FROM host_chains ORDER BY chain_id DESC LIMIT 1") + .fetch_one(pool) + .await + .expect("host_chain_id") } async fn insert_ciphertext64( @@ -577,8 +578,7 @@ async fn insert_ciphertext64( handle: &Vec, ciphertext: &Vec, ) -> anyhow::Result<()> { - let tenant_id = get_tenant_id_from_db(pool, TENANT_API_KEY).await; - test_harness::db_utils::insert_ciphertext64(pool, tenant_id, handle, ciphertext).await?; + test_harness::db_utils::insert_ciphertext64(pool, handle, ciphertext).await?; // Notify sns_worker sqlx::query("SELECT pg_notify($1, '')") @@ -591,10 +591,10 @@ async fn insert_ciphertext64( async fn insert_into_pbs_computations( pool: &sqlx::PgPool, + host_chain_id: i64, handle: &Vec, ) -> Result<(), anyhow::Error> { - let tenant_id = get_tenant_id_from_db(pool, TENANT_API_KEY).await; - test_harness::db_utils::insert_into_pbs_computations(pool, tenant_id, handle).await?; + test_harness::db_utils::insert_into_pbs_computations(pool, host_chain_id, handle).await?; // Notify sns_worker sqlx::query("SELECT pg_notify($1, '')") @@ -624,14 +624,13 @@ async fn clean_up(pool: &sqlx::PgPool) -> anyhow::Result<()> { /// It also checks that the ciphertext is uploaded to S3 if the feature is enabled. async fn assert_ciphertext128( test_env: &TestEnvironment, - tenant_id: i32, with_compression: bool, handle: &Vec, expected_value: i64, ) -> anyhow::Result<()> { let pool = &test_env.pool; let client_key = &test_env.client_key; - let ct = test_harness::db_utils::wait_for_ciphertext(pool, tenant_id, handle, 100).await?; + let ct = test_harness::db_utils::wait_for_ciphertext(pool, handle, 100).await?; info!("Ciphertext len: {:?}", ct.len()); @@ -753,7 +752,6 @@ fn build_test_config(url: DatabaseURL, enable_compression: bool) -> Config { .unwrap_or(SchedulePolicy::RayonParallel); Config { - tenant_api_key: TENANT_API_KEY.to_string(), db: DBConfig { url, listen_channels: vec![LISTEN_CHANNEL.to_string()], diff --git a/coprocessor/fhevm-engine/test-harness/src/db_utils.rs b/coprocessor/fhevm-engine/test-harness/src/db_utils.rs index 3b9ec2e43e..78b1d54af9 100644 --- a/coprocessor/fhevm-engine/test-harness/src/db_utils.rs +++ b/coprocessor/fhevm-engine/test-harness/src/db_utils.rs @@ -55,12 +55,14 @@ pub async fn insert_ciphertext64( pub async fn insert_into_pbs_computations( pool: &sqlx::PgPool, + host_chain_id: i64, handle: &Vec, ) -> Result<(), anyhow::Error> { let _ = query!( - "INSERT INTO pbs_computations(handle) VALUES($1) + "INSERT INTO pbs_computations(handle, host_chain_id) VALUES($1, $2) ON CONFLICT DO NOTHING;", handle, + host_chain_id, ) .execute(pool) .await @@ -99,14 +101,12 @@ pub async fn insert_ciphertext_digest( // Poll database until ciphertext128 of the specified handle is available pub async fn wait_for_ciphertext( pool: &sqlx::PgPool, - tenant_id: i32, handle: &Vec, retries: u64, ) -> anyhow::Result> { for retry in 0..retries { let record = sqlx::query!( - "SELECT ciphertext FROM ciphertexts128 WHERE tenant_id = $1 AND handle = $2", - tenant_id, + "SELECT ciphertext FROM ciphertexts128 WHERE handle = $1", handle ) .fetch_one(pool) diff --git a/coprocessor/fhevm-engine/test-harness/src/instance.rs b/coprocessor/fhevm-engine/test-harness/src/instance.rs index 1d9430e088..bdc0e47089 100644 --- a/coprocessor/fhevm-engine/test-harness/src/instance.rs +++ b/coprocessor/fhevm-engine/test-harness/src/instance.rs @@ -63,13 +63,13 @@ async fn setup_test_app_existing_localhost( if with_reset { info!("Resetting local database at {db_url}"); - let admin_db_url = db_url.to_string().replace("coprocessor", "postgres"); + let admin_db_url = db_url.as_str().replace("coprocessor", "postgres"); create_database(&admin_db_url, db_url.as_str(), mode).await?; } info!("Using existing local database at {db_url}"); - let _ = get_sns_pk_size(&sqlx::PgPool::connect(db_url.as_str()).await?, 12345).await; + let _ = get_sns_pk_size(&sqlx::PgPool::connect(db_url.as_str()).await?).await; Ok(DBInstance { _container: None, @@ -160,14 +160,18 @@ async fn create_database( Ok(()) } -pub async fn get_sns_pk_size(pool: &sqlx::PgPool, chain_id: i64) -> Result { - let row = sqlx::query("SELECT sns_pk FROM tenants WHERE chain_id = $1") - .bind(chain_id) - .fetch_one(pool) +pub async fn get_sns_pk_size(pool: &sqlx::PgPool) -> Result { + let row = sqlx::query("SELECT sns_pk FROM keys ORDER BY sequence_number DESC LIMIT 1") + .fetch_optional(pool) .await?; + let Some(row) = row else { + info!("No sns_pk found in keys"); + return Ok(0); + }; + let oid: Oid = row.try_get(0)?; - info!(oid = ?oid, chain_id, "Found sns_pk oid"); + info!(oid = ?oid, "Found sns_pk oid"); let row = sqlx::query_scalar( "SELECT COALESCE(SUM(octet_length(data))::bigint, 0) FROM pg_largeobject WHERE loid = $1", ) diff --git a/test-suite/fhevm/docker-compose/coprocessor-docker-compose.yml b/test-suite/fhevm/docker-compose/coprocessor-docker-compose.yml index d2a0102747..96264ad657 100644 --- a/test-suite/fhevm/docker-compose/coprocessor-docker-compose.yml +++ b/test-suite/fhevm/docker-compose/coprocessor-docker-compose.yml @@ -188,7 +188,6 @@ services: command: - sns_worker - --database-url=${DATABASE_URL} - - --tenant-api-key=${TENANT_API_KEY} - --pg-listen-channels - event_pbs_computations - event_ciphertext_computed