diff --git a/coprocessor/fhevm-engine/Cargo.lock b/coprocessor/fhevm-engine/Cargo.lock index c2bed048f2..60f10f1655 100644 --- a/coprocessor/fhevm-engine/Cargo.lock +++ b/coprocessor/fhevm-engine/Cargo.lock @@ -9049,6 +9049,7 @@ dependencies = [ "clap", "fhevm-engine-common", "hex", + "humantime", "lru 0.13.0", "rand 0.9.2", "serial_test", diff --git a/coprocessor/fhevm-engine/zkproof-worker/Cargo.toml b/coprocessor/fhevm-engine/zkproof-worker/Cargo.toml index 34c4360c06..03a7120689 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/Cargo.toml +++ b/coprocessor/fhevm-engine/zkproof-worker/Cargo.toml @@ -22,6 +22,7 @@ bincode = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } +humantime = { workspace = true } # local dependencies fhevm-engine-common = { path = "../fhevm-engine-common" } diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/bin/zkproof_worker.rs b/coprocessor/fhevm-engine/zkproof-worker/src/bin/zkproof_worker.rs index 6c26290cd2..0498ecf879 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/bin/zkproof_worker.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/bin/zkproof_worker.rs @@ -1,7 +1,8 @@ use clap::{command, Parser}; use fhevm_engine_common::healthz_server::HttpServer; use fhevm_engine_common::telemetry; -use std::sync::Arc; +use humantime::parse_duration; +use std::{sync::Arc, time::Duration}; use tokio::{join, task}; use tokio_util::sync::CancellationToken; use tracing::{error, info, Level}; @@ -26,6 +27,15 @@ pub struct Args { #[arg(long, default_value_t = 5)] pub pg_pool_connections: u32, + /// Postgres acquire timeout + /// A longer timeout could affect the healthz/liveness updates + #[arg(long, default_value = "15s", value_parser = parse_duration)] + pub pg_timeout: Duration, + + /// Postgres diagnostics: enable auto_explain extension + #[arg(long, value_parser = parse_duration)] + pub pg_auto_explain_with_min_duration: Option, + /// Postgres database url. If unspecified DATABASE_URL environment variable /// is used #[arg(long)] @@ -60,6 +70,8 @@ async fn main() { let args = parse_args(); tracing_subscriber::fmt() .json() + .with_current_span(true) + .with_span_list(false) .with_level(true) .with_max_level(args.log_level) .init(); @@ -76,6 +88,8 @@ async fn main() { pg_pool_connections: args.pg_pool_connections, pg_polling_interval: args.pg_polling_interval, worker_thread_count: args.worker_thread_count, + pg_timeout: args.pg_timeout, + pg_auto_explain_with_min_duration: args.pg_auto_explain_with_min_duration, }; if let Err(err) = telemetry::setup_otlp(&args.service_name) { @@ -84,7 +98,11 @@ async fn main() { } let cancel_token = CancellationToken::new(); - let service = ZkProofService::create(conf, cancel_token.child_token()).await; + let Some(service) = ZkProofService::create(conf, cancel_token.child_token()).await else { + error!("Failed to create zkproof service"); + std::process::exit(1); + }; + let service = Arc::new(service); let http_server = HttpServer::new( diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/lib.rs b/coprocessor/fhevm-engine/zkproof-worker/src/lib.rs index 9b84aceb48..5e5927a0e0 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/lib.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/lib.rs @@ -4,9 +4,9 @@ pub mod auxiliary; mod tests; pub mod verifier; -use std::io; +use std::{io, time::Duration}; -use fhevm_engine_common::types::FhevmError; +use fhevm_engine_common::{pg_pool::ServiceError, types::FhevmError}; use thiserror::Error; /// The highest index of an input is 254, @@ -55,6 +55,17 @@ pub enum ExecutionError { TooManyInputs(usize), } +impl From for ServiceError { + fn from(err: ExecutionError) -> Self { + match err { + ExecutionError::DbError(e) => ServiceError::Database(e), + + // collapse everything else into InternalError + other => ServiceError::InternalError(other.to_string()), + } + } +} + #[derive(Default, Debug, Clone)] pub struct Config { pub database_url: String, @@ -62,6 +73,8 @@ pub struct Config { pub notify_database_channel: String, pub pg_pool_connections: u32, pub pg_polling_interval: u32, + pub pg_timeout: Duration, + pub pg_auto_explain_with_min_duration: Option, pub worker_thread_count: u32, } diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs b/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs index 901b138ce9..ba4ca3c769 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/tests/mod.rs @@ -8,7 +8,8 @@ mod utils; #[tokio::test] #[serial(db)] async fn test_verify_proof() { - let (pool, _instance) = utils::setup().await.expect("valid setup"); + let (pool_mngr, _instance) = utils::setup().await.expect("valid setup"); + let pool = pool_mngr.pool(); // Generate Valid ZkPok let aux: (crate::auxiliary::ZkData, [u8; 92]) = @@ -42,7 +43,8 @@ async fn test_verify_proof() { #[tokio::test] #[serial(db)] async fn test_verify_empty_input_list() { - let (pool, _instance) = utils::setup().await.expect("valid setup"); + let (pool_mngr, _instance) = utils::setup().await.expect("valid setup"); + let pool = pool_mngr.pool(); let aux: (crate::auxiliary::ZkData, [u8; 92]) = utils::aux_fixture(ACL_CONTRACT_ADDR.to_owned()); @@ -61,7 +63,8 @@ async fn test_verify_empty_input_list() { #[tokio::test] #[serial(db)] async fn test_max_input_index() { - let (db, _instance) = utils::setup().await.expect("valid setup"); + let (pool_mngr, _instance) = utils::setup().await.expect("valid setup"); + let pool = pool_mngr.pool(); let aux: (crate::auxiliary::ZkData, [u8; 92]) = utils::aux_fixture(ACL_CONTRACT_ADDR.to_owned()); @@ -70,11 +73,11 @@ async fn test_max_input_index() { let inputs = vec![utils::ZkInput::U8(1); MAX_INPUT_INDEX as usize + 2]; assert!(!utils::is_valid( - &db, + &pool, utils::insert_proof( - &db, + &pool, 101, - &utils::generate_zk_pok_with_inputs(&db, &aux.1, &inputs).await, + &utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await, &aux.0 ) .await @@ -87,11 +90,11 @@ async fn test_max_input_index() { // Test with highest number of inputs - 255 let inputs = vec![utils::ZkInput::U64(2); MAX_INPUT_INDEX as usize + 1]; assert!(utils::is_valid( - &db, + &pool, utils::insert_proof( - &db, + &pool, 102, - &utils::generate_zk_pok_with_inputs(&db, &aux.1, &inputs).await, + &utils::generate_zk_pok_with_inputs(&pool, &aux.1, &inputs).await, &aux.0 ) .await diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs b/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs index 373c7fa274..57fc6e54e5 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/tests/utils.rs @@ -1,3 +1,4 @@ +use fhevm_engine_common::pg_pool::PostgresPoolManager; use fhevm_engine_common::{tenant_keys, utils::safe_serialize}; use std::sync::Arc; use std::time::{Duration, SystemTime}; @@ -7,7 +8,7 @@ use tokio::time::sleep; use crate::auxiliary::ZkData; -pub async fn setup() -> anyhow::Result<(sqlx::PgPool, DBInstance)> { +pub async fn setup() -> anyhow::Result<(PostgresPoolManager, DBInstance)> { let _ = tracing_subscriber::fmt().json().with_level(true).try_init(); let test_instance = test_harness::instance::setup_test_db(ImportMode::WithKeysNoSns) .await @@ -20,29 +21,39 @@ pub async fn setup() -> anyhow::Result<(sqlx::PgPool, DBInstance)> { pg_pool_connections: 10, pg_polling_interval: 60, worker_thread_count: 1, + pg_timeout: Duration::from_secs(15), + pg_auto_explain_with_min_duration: None, }; - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(10) - .connect(&conf.database_url) - .await?; + let pool_mngr = PostgresPoolManager::connect_pool( + test_instance.parent_token.child_token(), + conf.database_url.as_str(), + conf.pg_timeout, + conf.pg_pool_connections, + Duration::from_secs(2), + conf.pg_auto_explain_with_min_duration, + ) + .await + .unwrap(); + + let pmngr = pool_mngr.clone(); sqlx::query("TRUNCATE TABLE verify_proofs") - .execute(&pool) + .execute(&pmngr.pool()) .await .unwrap(); let last_active_at = Arc::new(RwLock::new(SystemTime::now())); - let db_pool = pool.clone(); + tokio::spawn(async move { - crate::verifier::execute_verify_proofs_loop(db_pool, conf.clone(), last_active_at.clone()) + crate::verifier::execute_verify_proofs_loop(pmngr, conf.clone(), last_active_at.clone()) .await .unwrap(); }); sleep(Duration::from_secs(2)).await; - Ok((pool, test_instance)) + Ok((pool_mngr, test_instance)) } /// Checks if the proof is valid by querying the database continuously. diff --git a/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs b/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs index f64f13a0e2..8b264d13b6 100644 --- a/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs +++ b/coprocessor/fhevm-engine/zkproof-worker/src/verifier.rs @@ -1,4 +1,5 @@ use alloy_primitives::Address; +use fhevm_engine_common::pg_pool::{PostgresPoolManager, ServiceError}; use fhevm_engine_common::telemetry; use fhevm_engine_common::tenant_keys::TfheTenantKeys; use fhevm_engine_common::tenant_keys::{self, FetchTenantKeyResult}; @@ -10,7 +11,6 @@ use hex::encode; use lru::LruCache; use sha3::Digest; use sha3::Keccak256; -use sqlx::postgres::PgPoolOptions; use sqlx::{postgres::PgListener, PgPool, Row}; use sqlx::{Postgres, Transaction}; use std::num::NonZero; @@ -43,9 +43,8 @@ pub(crate) struct Ciphertext { } pub struct ZkProofService { - pool: PgPool, + pool_mngr: PostgresPoolManager, conf: Config, - _cancel_token: CancellationToken, // Timestamp of the last moment the service was active last_active_at: Arc>, @@ -53,7 +52,7 @@ pub struct ZkProofService { impl HealthCheckService for ZkProofService { async fn health_check(&self) -> HealthStatus { let mut status = HealthStatus::default(); - status.set_db_connected(&self.pool).await; + status.set_db_connected(&self.pool_mngr.pool()).await; status } @@ -79,31 +78,37 @@ impl HealthCheckService for ZkProofService { } impl ZkProofService { - pub async fn create(conf: Config, cancel_token: CancellationToken) -> ZkProofService { + pub async fn create(conf: Config, token: CancellationToken) -> Option { // Each worker needs at least 3 pg connections - let pool_connections = + let max_pool_connections = std::cmp::max(conf.pg_pool_connections, 3 * conf.worker_thread_count); let t = telemetry::tracer("init_service"); let _s = t.child_span("pg_connect"); - // DB Connection pool is shared amongst all workers - let pool = PgPoolOptions::new() - .max_connections(pool_connections) - .connect(&conf.database_url) - .await - .expect("valid db pool"); + let Some(pool_mngr) = PostgresPoolManager::connect_pool( + token.child_token(), + conf.database_url.as_str(), + conf.pg_timeout, + max_pool_connections, + Duration::from_secs(2), + conf.pg_auto_explain_with_min_duration, + ) + .await + else { + error!("Service was cancelled during Postgres pool initialization"); + return None; + }; - ZkProofService { - pool, + Some(ZkProofService { + pool_mngr, conf, - _cancel_token: cancel_token, last_active_at: Arc::new(RwLock::new(SystemTime::UNIX_EPOCH)), - } + }) } pub async fn run(&self) -> Result<(), ExecutionError> { execute_verify_proofs_loop( - self.pool.clone(), + self.pool_mngr.clone(), self.conf.clone(), self.last_active_at.clone(), ) @@ -113,7 +118,7 @@ impl ZkProofService { /// Executes the main loop for handling verify_proofs requests inserted in the /// database pub async fn execute_verify_proofs_loop( - pool: PgPool, + pool_mngr: PostgresPoolManager, conf: Config, last_active_at: Arc>, ) -> Result<(), ExecutionError> { @@ -129,20 +134,27 @@ pub async fn execute_verify_proofs_loop( telemetry::attribute(&mut s, "count", conf.worker_thread_count.to_string()); let mut task_set = JoinSet::new(); - for _ in 0..conf.worker_thread_count { + for index in 0..conf.worker_thread_count { let conf = conf.clone(); let tenant_key_cache = tenant_key_cache.clone(); - let pool = pool.clone(); let last_active_at = last_active_at.clone(); // Spawn a ZK-proof worker // All workers compete for zk-proof tasks queued in the 'verify_proof' table. - task_set.spawn(async move { - if let Err(err) = execute_worker(&conf, &pool, &tenant_key_cache, last_active_at).await - { - error!(error = %err, "executor failed with"); + let op = move |pool: PgPool, ct: CancellationToken| { + let tenant_key_cache = tenant_key_cache.clone(); + let last_active_at = last_active_at.clone(); + let conf = conf.clone(); + async move { + execute_worker(conf, pool, ct, tenant_key_cache, last_active_at) + .await + .map_err(ServiceError::from) } - }); + }; + + pool_mngr + .spawn_join_set_with_db_retry(op, &mut task_set, format!("worker_{}", index).as_str()) + .await; } telemetry::end_span(s); @@ -158,47 +170,46 @@ pub async fn execute_verify_proofs_loop( } async fn execute_worker( - conf: &Config, - pool: &sqlx::Pool, - tenant_key_cache: &Arc>>, + conf: Config, + pool: sqlx::Pool, + token: CancellationToken, + tenant_key_cache: Arc>>, last_active_at: Arc>, ) -> Result<(), ExecutionError> { - let mut listener = PgListener::connect_with(pool).await?; + update_last_active(last_active_at.clone()).await; + + let mut listener = PgListener::connect_with(&pool).await?; listener.listen(&conf.listen_database_channel).await?; let mut idle_event = interval(Duration::from_secs(conf.pg_polling_interval as u64)); loop { - if let Ok(mut value) = last_active_at.try_write() { - *value = SystemTime::now(); - } + update_last_active(last_active_at.clone()).await; - if let Err(e) = execute_verify_proof_routine(pool, tenant_key_cache, conf).await { - error!(target: "zkpok", error = %e, "Execution error"); - } else { - let count = get_remaining_tasks(pool).await?; - if get_remaining_tasks(pool).await? > 0 { - info!(target: "zkpok", {count}, "ZkPok tasks available"); - continue; - } + execute_verify_proof_routine(&pool, &tenant_key_cache, &conf).await?; + let count = get_remaining_tasks(&pool).await?; + if count > 0 { + info!({ count }, "zkproof requests available"); + continue; } select! { res = listener.try_recv() => { + let res = res?; match res { - Ok(None) => { - error!(target: "zkpok", "DB connection err"); - return Err(ExecutionError::LostDbConnection) - }, - Ok(_) => info!(target: "zkpok", "Received notification"), - Err(err) => { - error!(target: "zkpok", error = %err, "DB connection error"); - return Err(ExecutionError::LostDbConnection) + Some(notification) => info!( src = %notification.process_id(), "Received notification"), + None => { + error!("Connection lost"); + continue; }, }; }, _ = idle_event.tick() => { - debug!(target: "zkpok", "Polling timeout, rechecking for tasks"); + debug!("Polling timeout, rechecking for requests"); + }, + _ = token.cancelled() => { + info!("Cancellation requested, stopping worker"); + return Ok(()); } } } @@ -516,3 +527,8 @@ pub(crate) async fn insert_ciphertexts( .await?; Ok(()) } + +async fn update_last_active(last_active_at: Arc>) { + let mut value = last_active_at.write().await; + *value = SystemTime::now(); +}