diff --git a/docs/metrics/metrics.md b/docs/metrics/metrics.md index 5a610edece..c248b2236d 100644 --- a/docs/metrics/metrics.md +++ b/docs/metrics/metrics.md @@ -228,11 +228,11 @@ Metrics for zkproof-worker are to be added in future releases, if/when needed. C - **Alarm**: If the counter is a flat line over a period of time, only for `event_type` `public_decryption_request` and `user_decryption_request`. - **Recommendation**: 0 for more than 1 minute, i.e. `increase(counter{event_type="..."}[1m]) == 0`. -#### Metric Name: `kms_connector_gw_listener_event_received_errors` +#### Metric Name: `kms_connector_gw_listener_event_listening_errors` - **Type**: Counter - **Labels**: - - `event_type`: see [description](#metric-name-kms_connector_gw_listener_event_received_counter) - - **Description**: Counts the number of errors encountered by the GW listener while receiving events. + - `contract`: can be used to filter by contract (decryption, kmsgeneration). + - **Description**: Counts the number of errors encountered by the GW listener while listening for events. - **Alarm**: If the counter increases over a period of time. - **Recommendation**: more than 60 failures in 1 minute, i.e. `sum(increase(counter[1m])) > 60`. diff --git a/kms-connector/.sqlx/query-08628a859406ce5a66dda3b9729847a3846a59691b8dde5f518c2cb40b687742.json b/kms-connector/.sqlx/query-08628a859406ce5a66dda3b9729847a3846a59691b8dde5f518c2cb40b687742.json new file mode 100644 index 0000000000..051611e8b8 --- /dev/null +++ b/kms-connector/.sqlx/query-08628a859406ce5a66dda3b9729847a3846a59691b8dde5f518c2cb40b687742.json @@ -0,0 +1,38 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE last_block_polled SET block_number = $2, updated_at = $3 WHERE event_type = ANY($1::event_type[]) AND (block_number IS NULL OR block_number < $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + { + "Custom": { + "name": "event_type[]", + "kind": { + "Array": { + "Custom": { + "name": "event_type", + "kind": { + "Enum": [ + "PublicDecryptionRequest", + "UserDecryptionRequest", + "PrepKeygenRequest", + "KeygenRequest", + "CrsgenRequest", + "PrssInit", + "KeyReshareSameSet" + ] + } + } + } + } + } + }, + "Int8", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "08628a859406ce5a66dda3b9729847a3846a59691b8dde5f518c2cb40b687742" +} diff --git a/kms-connector/.sqlx/query-7e6031fa69c6a6884bb9b3a7b2e19db0fef204f71ad8dde9bd8bc808b0cb25fd.json b/kms-connector/.sqlx/query-7e6031fa69c6a6884bb9b3a7b2e19db0fef204f71ad8dde9bd8bc808b0cb25fd.json deleted file mode 100644 index 8e8379f459..0000000000 --- a/kms-connector/.sqlx/query-7e6031fa69c6a6884bb9b3a7b2e19db0fef204f71ad8dde9bd8bc808b0cb25fd.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE last_block_polled SET block_number = $2, updated_at = $3 WHERE event_type = $1 AND (block_number IS NULL OR block_number < $2)", - "describe": { - "columns": [], - "parameters": { - "Left": [ - { - "Custom": { - "name": "event_type", - "kind": { - "Enum": [ - "PublicDecryptionRequest", - "UserDecryptionRequest", - "PrepKeygenRequest", - "KeygenRequest", - "CrsgenRequest", - "PrssInit", - "KeyReshareSameSet" - ] - } - } - }, - "Int8", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "7e6031fa69c6a6884bb9b3a7b2e19db0fef204f71ad8dde9bd8bc808b0cb25fd" -} diff --git a/kms-connector/config/gw-listener.toml b/kms-connector/config/gw-listener.toml index 53e85a7650..772370fb48 100644 --- a/kms-connector/config/gw-listener.toml +++ b/kms-connector/config/gw-listener.toml @@ -42,14 +42,22 @@ database_url = "postgres://postgres:postgres@localhost/kms-connector" # ENV: KMS_CONNECTOR_TASK_LIMIT # task_limit = 1000 -# The polling interval for decryption requests (optional, defaults to 1s) +# The polling interval for decryption requests (optional, defaults to 500ms) # ENV: KMS_CONNECTOR_DECRYPTION_POLLING (format: https://docs.rs/humantime/latest/humantime/) -# decryption_polling = "1s" +# decryption_polling = "500ms" # The polling interval for key management requests (optional, defaults to 30s) # ENV: KMS_CONNECTOR_KMS_GENERATION_POLLING (format: https://docs.rs/humantime/latest/humantime/) # key_management_polling = "30s" +# Maximum number of blocks per eth_getLogs RPC call (optional, defaults to 100) +# ENV: KMS_CONNECTOR_GET_LOGS_BATCH_SIZE +# get_logs_batch_size = 100 + +# Maximum consecutive polling errors before the listener stops (optional, defaults to 20) +# ENV: KMS_CONNECTOR_MAX_CONSECUTIVE_POLLING_ERRORS +# max_consecutive_polling_errors = 20 + # Block number to start processing decryption events from (optional, defaults to latest block if not set) # ENV: KMS_CONNECTOR_DECRYPTION_FROM_BLOCK_NUMBER # decryption_from_block_number = 1234 diff --git a/kms-connector/crates/gw-listener/src/core/config.rs b/kms-connector/crates/gw-listener/src/core/config.rs index 5f0499259e..63f4fd0224 100644 --- a/kms-connector/crates/gw-listener/src/core/config.rs +++ b/kms-connector/crates/gw-listener/src/core/config.rs @@ -62,6 +62,14 @@ pub struct Config { #[serde(with = "humantime_serde", default = "default_key_management_polling")] pub key_management_polling: Duration, + /// The maximum number of blocks to fetch per `eth_getLogs` request. + #[serde(default = "default_get_logs_batch_size")] + pub get_logs_batch_size: u64, + + /// Maximum number of consecutive polling errors before stopping the loop. + #[serde(default = "default_max_consecutive_polling_errors")] + pub max_consecutive_polling_errors: usize, + /// Optional block number to start processing decryption events from. pub decryption_from_block_number: Option, /// Optional block number to start processing KMS operation events from. @@ -75,13 +83,21 @@ fn default_service_name() -> String { } fn default_decryption_polling() -> Duration { - Duration::from_secs(1) + Duration::from_millis(500) } fn default_key_management_polling() -> Duration { Duration::from_secs(30) } +fn default_get_logs_batch_size() -> u64 { + 100 +} + +fn default_max_consecutive_polling_errors() -> usize { + 20 +} + // Default implementation for testing purpose impl Default for Config { fn default() -> Self { @@ -101,6 +117,8 @@ impl Default for Config { healthcheck_timeout: default_healthcheck_timeout(), decryption_polling: default_decryption_polling(), key_management_polling: default_key_management_polling(), + get_logs_batch_size: default_get_logs_batch_size(), + max_consecutive_polling_errors: default_max_consecutive_polling_errors(), decryption_from_block_number: None, kms_operation_from_block_number: None, } @@ -125,6 +143,8 @@ mod tests { env::remove_var("KMS_CONNECTOR_DECRYPTION_CONTRACT__ADDRESS"); env::remove_var("KMS_CONNECTOR_KMS_GENERATION_CONTRACT__ADDRESS"); env::remove_var("KMS_CONNECTOR_SERVICE_NAME"); + env::remove_var("KMS_CONNECTOR_GET_LOGS_BATCH_SIZE"); + env::remove_var("KMS_CONNECTOR_MAX_CONSECUTIVE_POLLING_ERRORS"); } } @@ -207,15 +227,27 @@ mod tests { // Set an environment variable to override the file let gateway_chain_id = 77737; let service_name = "kms-connector-override"; + let get_logs_batch_size: u64 = 500; + let max_consecutive_polling_errors = 5; let mut expected_config = example_config.clone(); expected_config.gateway_chain_id = gateway_chain_id; expected_config.service_name = service_name.to_string(); + expected_config.get_logs_batch_size = get_logs_batch_size; + expected_config.max_consecutive_polling_errors = max_consecutive_polling_errors; unsafe { env::set_var( "KMS_CONNECTOR_GATEWAY_CHAIN_ID", gateway_chain_id.to_string(), ); env::set_var("KMS_CONNECTOR_SERVICE_NAME", service_name); + env::set_var( + "KMS_CONNECTOR_GET_LOGS_BATCH_SIZE", + get_logs_batch_size.to_string(), + ); + env::set_var( + "KMS_CONNECTOR_MAX_CONSECUTIVE_POLLING_ERRORS", + max_consecutive_polling_errors.to_string(), + ); } // Load config from both sources diff --git a/kms-connector/crates/gw-listener/src/core/gateway.rs b/kms-connector/crates/gw-listener/src/core/gateway.rs index debb64aea1..e46401e878 100644 --- a/kms-connector/crates/gw-listener/src/core/gateway.rs +++ b/kms-connector/crates/gw-listener/src/core/gateway.rs @@ -1,33 +1,59 @@ use crate::{ - core::{Config, publish::update_last_block_polled, publish_event}, - monitoring::metrics::{EVENT_RECEIVED_COUNTER, EVENT_RECEIVED_ERRORS}, + core::{Config, publish::publish_batch}, + monitoring::metrics::{EVENT_LISTENING_ERRORS, EVENT_RECEIVED_COUNTER}, }; use alloy::{ - contract::{Event, EventPoller}, network::Ethereum, - primitives::LogData, providers::Provider, rpc::types::{Filter, Log}, - sol_types::SolEvent, + sol_types::SolEventInterface, }; use anyhow::anyhow; use connector_utils::{ monitoring::otlp::PropagationContext, - tasks::spawn_with_limit, types::{GatewayEvent, GatewayEventKind, db::EventType}, }; use fhevm_gateway_bindings::{ - decryption::Decryption::{self, DecryptionInstance}, - kms_generation::KMSGeneration::{self, KMSGenerationInstance}, + decryption::Decryption::DecryptionEvents, kms_generation::KMSGeneration::KMSGenerationEvents, }; use sqlx::{Pool, Postgres, Row}; -use std::time::Duration; -use tokio::{select, task::JoinSet, time::timeout}; -use tokio_stream::StreamExt; +use tokio::{select, task::JoinSet}; use tokio_util::sync::CancellationToken; -use tracing::{error, info, warn}; +use tracing::{error, info, info_span, warn}; use tracing_opentelemetry::OpenTelemetrySpanExt; +const DECRYPTION_EVENT_TYPES: [EventType; 2] = [ + EventType::PublicDecryptionRequest, + EventType::UserDecryptionRequest, +]; + +const KMS_GENERATION_EVENT_TYPES: [EventType; 5] = [ + EventType::PrepKeygenRequest, + EventType::KeygenRequest, + EventType::CrsgenRequest, + EventType::PrssInit, + EventType::KeyReshareSameSet, +]; + +/// Identifies which contract is being polled. +/// +/// **Note:** The kms-connector is designed to listen to a specific set of events/contracts, +/// so listening to a new contract/event to monitor requires a code change and a new release. +#[derive(Clone, Copy)] +enum MonitoredContract { + Decryption, + KmsGeneration, +} + +impl std::fmt::Display for MonitoredContract { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MonitoredContract::Decryption => write!(f, "Decryption"), + MonitoredContract::KmsGeneration => write!(f, "KmsGeneration"), + } + } +} + /// Struct monitoring and storing Gateway's events. #[derive(Clone)] pub struct GatewayListener

@@ -37,11 +63,8 @@ where /// The database pool for storing Gateway's events. db_pool: Pool, - /// The Gateway's `Decryption` contract instance which is monitored. - decryption_contract: DecryptionInstance

, - - /// The Gateway's `KMSGeneration` contract instance which is monitored. - kms_generation_contract: KMSGenerationInstance

, + /// The Gateway RPC Provider. + provider: P, /// The configuration of the `GatewayListener`. config: Config, @@ -61,15 +84,9 @@ where config: &Config, cancel_token: CancellationToken, ) -> Self { - let decryption_contract = - Decryption::new(config.decryption_contract.address, provider.clone()); - let kms_generation_contract = - KMSGeneration::new(config.kms_generation_contract.address, provider); - Self { db_pool, - decryption_contract, - kms_generation_contract, + provider, config: config.clone(), cancel_token, } @@ -77,17 +94,11 @@ where /// Starts the `GatewayListener`. /// - /// Spawns and joins the `GatewayListener` event monitoring tasks. + /// Spawns two polling tasks: one for Decryption events and one for KMSGeneration events. pub async fn start(self) { let mut tasks = JoinSet::new(); - - tasks.spawn(self.clone().subscribe(EventType::PublicDecryptionRequest)); - tasks.spawn(self.clone().subscribe(EventType::UserDecryptionRequest)); - tasks.spawn(self.clone().subscribe(EventType::PrepKeygenRequest)); - tasks.spawn(self.clone().subscribe(EventType::KeygenRequest)); - tasks.spawn(self.clone().subscribe(EventType::CrsgenRequest)); - tasks.spawn(self.clone().subscribe(EventType::PrssInit)); - tasks.spawn(self.subscribe(EventType::KeyReshareSameSet)); + tasks.spawn(self.clone().poll_events(MonitoredContract::Decryption)); + tasks.spawn(self.poll_events(MonitoredContract::KmsGeneration)); while let Some(res) = tasks.join_next().await { if let Err(e) = res { @@ -97,233 +108,183 @@ where info!("GatewayListener stopped successfully!"); } - /// Subscribes to a particular set of events. + /// Polls a contract for events using `eth_getLogs`. /// - /// Each event received from the `event_filer` is then published in the DB. - pub async fn subscribe(self, event_type: EventType) { - let polling = match &event_type { - EventType::PublicDecryptionRequest | EventType::UserDecryptionRequest => { - self.config.decryption_polling + /// Cancels all other tasks on failure. + async fn poll_events(self, contract: MonitoredContract) { + select! { + biased; + _ = self.cancel_token.cancelled() => info!("{contract} polling cancelled..."), + result = self.run_poll_loop(contract) => if let Err(e) = result { + error!("{contract} polling failed: {e}"); } - _ => self.config.key_management_polling, - }; + } + self.cancel_token.cancel(); + } - let result = match &event_type { - EventType::PublicDecryptionRequest => { - let filter = self.decryption_contract.PublicDecryptionRequest_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::UserDecryptionRequest => { - let filter = self.decryption_contract.UserDecryptionRequest_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::PrepKeygenRequest => { - let filter = self.kms_generation_contract.PrepKeygenRequest_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::KeygenRequest => { - let filter = self.kms_generation_contract.KeygenRequest_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::CrsgenRequest => { - let filter = self.kms_generation_contract.CrsgenRequest_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::PrssInit => { - let filter = self.kms_generation_contract.PRSSInit_filter(); - self.subscribe_inner(event_type, filter, polling).await - } - EventType::KeyReshareSameSet => { - let filter = self.kms_generation_contract.KeyReshareSameSet_filter(); - self.subscribe_inner(event_type, filter, polling).await - } + /// Polling loop to listen to both [`Decryption`] and [`KMSGeneration`] contracts. + async fn run_poll_loop(&self, contract: MonitoredContract) -> anyhow::Result<()> { + let (contract_address, poll_interval, from_block_config, event_types) = match contract { + MonitoredContract::Decryption => ( + self.config.decryption_contract.address, + self.config.decryption_polling, + self.config.decryption_from_block_number, + DECRYPTION_EVENT_TYPES.as_slice(), + ), + MonitoredContract::KmsGeneration => ( + self.config.kms_generation_contract.address, + self.config.key_management_polling, + self.config.kms_operation_from_block_number, + KMS_GENERATION_EVENT_TYPES.as_slice(), + ), }; - self.cancel_token.cancel(); // Cancel other event subscription tasks - if let Err(e) = result { - error!("{e}"); - } - } + let event_signatures = event_types + .iter() + .map(|e| e.signature_hash()) + .collect::>(); + let base_filter = Filter::new() + .address(contract_address) + .event_signature(event_signatures); - async fn subscribe_inner( - &self, - event_type: EventType, - event_filter: Event<&'_ P, E>, - poll_interval: Duration, - ) -> anyhow::Result<()> - where - E: Into + SolEvent + Send + Sync + 'static, - { - let mut last_block_polled = self.get_last_block_polled(event_type).await?; - let mut event_poller = event_filter - .watch() - .await - .map_err(|e| anyhow!("Failed to subscribe to {event_type} events: {e}"))?; - event_poller.poller = event_poller.poller.with_poll_interval(poll_interval); - info!("✓ Subscribed to {event_type} events"); - - let _ = self - .catchup_past_events::(&mut last_block_polled, event_type) - .await - .inspect_err(|e| warn!("Failed to catch up past {event_type} events: {e}")); + let mut from_block = self.get_start_block(from_block_config, event_types).await?; + info!("Started {contract} polling from block {from_block}"); - select! { - _ = self.process_events(event_type, event_poller, &mut last_block_polled) => (), - _ = self.cancel_token.cancelled() => info!("{event_type} subscription cancelled..."), + let mut ticker = tokio::time::interval(poll_interval); + let max_errors = self.config.max_consecutive_polling_errors; + let mut consecutive_errors: usize = 0; + loop { + ticker.tick().await; + match self + .fetch_and_publish(contract, base_filter.clone(), event_types, from_block) + .await + { + Ok((new_from_block, has_more)) => { + consecutive_errors = 0; + from_block = new_from_block; + if has_more { + ticker.reset_immediately(); + } + } + Err(e) => { + EVENT_LISTENING_ERRORS + .with_label_values(&[contract.to_string().to_lowercase()]) + .inc(); + consecutive_errors = consecutive_errors.saturating_add(1); + warn!("{contract} listening error: {e} ({consecutive_errors}/{max_errors})"); + if consecutive_errors >= max_errors { + anyhow::bail!("Too many consecutive errors for {contract}"); + } + } + } } - - // Use a timeout to ensure we are not preventing the `GatewayListener` from being shutdown - // if the `last_block_polled` update get stuck for some reason. - timeout( - LAST_BLOCK_POLLED_UPDATE_TIMEOUT, - update_last_block_polled(&self.db_pool, event_type, last_block_polled), - ) - .await??; - Ok(()) } - /// Catches events created before the event filter using `eth_getFilterLogs`. - async fn catchup_past_events( + /// Fetches logs for a block range, decodes them, and publishes them in a single transaction. + /// + /// Returns `(new_from_block, has_more_blocks)`. + async fn fetch_and_publish( &self, - last_block_polled: &mut Option, - event_type: EventType, - ) -> anyhow::Result<()> - where - E: Into + SolEvent + Send + Sync + 'static, - { - let catchup_from_block = match last_block_polled { - None => { - info!( - "No previously polled block for {event_type}; skipping catchup of past events." - ); - return Ok(()); - } - Some(block) => *block, - }; + contract: MonitoredContract, + base_filter: Filter, + event_types: &[EventType], + from_block: u64, + ) -> anyhow::Result<(u64, bool)> { + let current_block = self.provider.get_block_number().await?; + + if from_block > current_block { + return Ok((from_block, false)); + } - let contract_address = match event_type { - EventType::PublicDecryptionRequest | EventType::UserDecryptionRequest => { - self.decryption_contract.address() - } - _ => self.kms_generation_contract.address(), - }; + let to_block = std::cmp::min( + from_block.saturating_add(self.config.get_logs_batch_size.saturating_sub(1)), + current_block, + ); - let filter = Filter::new() - .address(*contract_address) - .event_signature(E::SIGNATURE_HASH) - .from_block(catchup_from_block); - let provider = self.decryption_contract.provider(); - - info!("Catching up {event_type} from {catchup_from_block}..."); - let mut event_count = 0; - let event_filter_id = provider.new_filter(&filter).await?; - let past_events = provider - .get_filter_logs(event_filter_id) - .await? - .into_iter() - .map(|log| { - decode_log::(&log).map(|event| { - event_count += 1; - (event, log) - }) - }); + let filter = base_filter.from_block(from_block).to_block(to_block); - for event in past_events { - self.spawn_event_handling(event_type, event, last_block_polled) - .await; - } + let logs = self.provider.get_logs(&filter).await?; + let events = Self::prepare_events(contract, logs)?; + publish_batch(&self.db_pool, events, event_types, to_block).await?; - info!( - "Successfully caught {event_count} {event_type} events from block {catchup_from_block}!" - ); - if let Err(e) = provider.uninstall_filter(event_filter_id).await { - warn!("Failed to uninstall {event_type} event catchup filter: {e}"); - } - Ok(()) + Ok((to_block.saturating_add(1), to_block < current_block)) } - /// Event processing loop. - async fn process_events( - &self, - event_type: EventType, - event_poller: EventPoller, - last_block_polled: &mut Option, - ) where - E: Into + SolEvent + Send + Sync + 'static, - { - let mut events = event_poller.into_stream(); - loop { - info!("Waiting for next {event_type}..."); - match events.next().await { - Some(event) => { - self.spawn_event_handling(event_type, event, last_block_polled) - .await + /// Decodes a log into a `GatewayEventKind`. + fn decode_log(contract: MonitoredContract, log: &Log) -> anyhow::Result { + match contract { + MonitoredContract::Decryption => { + let event = DecryptionEvents::decode_log(&log.inner) + .map_err(|e| anyhow!("Failed to decode Decryption event: {e}"))?; + match event.data { + DecryptionEvents::PublicDecryptionRequest(e) => Ok(e.into()), + DecryptionEvents::UserDecryptionRequest(e) => Ok(e.into()), + _ => Err(anyhow!("Unexpected Decryption event: {log:?}")), + } + } + MonitoredContract::KmsGeneration => { + let event = KMSGenerationEvents::decode_log(&log.inner) + .map_err(|e| anyhow!("Failed to decode KMSGeneration event: {e}"))?; + match event.data { + KMSGenerationEvents::PrepKeygenRequest(e) => Ok(e.into()), + KMSGenerationEvents::KeygenRequest(e) => Ok(e.into()), + KMSGenerationEvents::CrsgenRequest(e) => Ok(e.into()), + KMSGenerationEvents::PRSSInit(e) => Ok(e.into()), + KMSGenerationEvents::KeyReshareSameSet(e) => Ok(e.into()), + _ => Err(anyhow!("Unexpected KMSGeneration event: {log:?}")), } - None => break error!("Alloy Provider was dropped for {event_type}"), } } } - async fn spawn_event_handling( - &self, - event_type: EventType, - event: alloy::sol_types::Result<(E, Log)>, - last_block: &mut Option, - ) where - E: Into + SolEvent + Send + Sync + 'static, - { - match event { - Ok((event, log)) => { - *last_block = log.block_number; - EVENT_RECEIVED_COUNTER - .with_label_values(&[event_type.as_str()]) - .inc(); - - let db = self.db_pool.clone(); - spawn_with_limit(handle_gateway_event(db, event.into(), log)).await; - } - Err(err) => { - error!("Error while listening for {event_type} events: {err}"); - EVENT_RECEIVED_ERRORS - .with_label_values(&[event_type.as_str()]) - .inc(); - } + /// Decodes logs and prepares `GatewayEvent` structs with OTLP context and metrics. + fn prepare_events( + contract: MonitoredContract, + logs: Vec, + ) -> anyhow::Result> { + let mut events = Vec::with_capacity(logs.len()); + for log in logs { + let event_kind = Self::decode_log(contract, &log)?; + EVENT_RECEIVED_COUNTER + .with_label_values(&[EventType::from(&event_kind).as_str()]) + .inc(); + + let span = info_span!("handle_gateway_event", event = %event_kind); + let otlp_ctx = PropagationContext::inject(&span.context()); + events.push(GatewayEvent::new( + event_kind, + log.transaction_hash, + otlp_ctx, + )); } + Ok(events) } - /// Get the last block polled from config or DB. - async fn get_last_block_polled(&self, event_type: EventType) -> anyhow::Result> { - let from_block_number = match event_type { - EventType::PublicDecryptionRequest | EventType::UserDecryptionRequest => { - self.config.decryption_from_block_number - } - _ => self.config.kms_operation_from_block_number, - }; + /// Determines the block to start event listening from. + async fn get_start_block( + &self, + from_block_config: Option, + event_types: &[EventType], + ) -> anyhow::Result { + if let Some(from_block) = from_block_config { + info!("Found configured from_block_number ({from_block}) for polling"); + return Ok(from_block); + } - let last_block_polled = match from_block_number { - // Start polling event from the configured `from_block_number` if set - Some(from_block) => { - info!( - "Found configured `from_block_number` ({from_block}) for {event_type} subscriptions!" - ); - Some(from_block) + let mut min_last_processed_block: Option = None; + for &event_type in event_types { + if let Some(last) = self.get_last_block_polled_from_db(event_type).await? { + min_last_processed_block = match min_last_processed_block { + Some(current) => Some(std::cmp::min(current, last)), + None => Some(last), + }; } - // Start from `last_block_polled` stored in DB + 1 if not configured - None => self - .get_last_block_polled_from_db(event_type) - .await? - .map(|n| n + 1), - }; - - info!( - "Starting {} subscriptions from block {}...", - event_type, - last_block_polled - .map(|b| b.to_string()) - .unwrap_or_else(|| "latest".into()) - ); + } - Ok(last_block_polled) + match min_last_processed_block { + Some(last_block_polled) => Ok(last_block_polled.saturating_add(1)), + None => Ok(self.provider.get_block_number().await?), + } } async fn get_last_block_polled_from_db( @@ -346,88 +307,54 @@ where } } -/// Main function used to trace a single event handling across all Connector's services. -#[tracing::instrument(skip_all, fields(event = %event_kind))] -async fn handle_gateway_event(db_pool: Pool, event_kind: GatewayEventKind, log: Log) { - let event = GatewayEvent::new( - event_kind, - log.transaction_hash, - PropagationContext::inject(&tracing::Span::current().context()), - ); - if let Err(err) = publish_event(&db_pool, event, log.block_number).await { - error!("Failed to publish event: {err}"); - } -} - -fn decode_log(log: &Log) -> alloy::sol_types::Result { - let log_data: &LogData = log.as_ref(); - E::decode_raw_log(log_data.topics().iter().copied(), &log_data.data) -} - -/// The timeout we allow for the listener to store the last block polled in DB. -const LAST_BLOCK_POLLED_UPDATE_TIMEOUT: Duration = Duration::from_mins(5); - #[cfg(test)] mod tests { use super::*; - use alloy::{ - primitives::Address, - providers::{ - Identity, ProviderBuilder, RootProvider, - fillers::{ - BlobGasFiller, ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, - }, - mock::Asserter, - }, - rpc::json_rpc::ErrorPayload, + use alloy::providers::{ + Identity, ProviderBuilder, RootProvider, + fillers::{BlobGasFiller, ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller}, + mock::Asserter, }; + use alloy::rpc::json_rpc::ErrorPayload; use connector_utils::tests::setup::{TestInstance, TestInstanceBuilder}; + use std::time::Duration; #[rstest::rstest] #[timeout(Duration::from_secs(90))] #[tokio::test] - async fn test_reset_filter_stops_listener() { + async fn test_consecutive_get_logs_error_stops_listener() { let (_test_instance, asserter, gw_listener) = test_setup(None).await; - asserter.push_failure(ErrorPayload { - code: -32000, - message: "filter not found".into(), - data: None, - }); + // Initial get_block_number succeeds + asserter.push_success(&100_u64); - gw_listener.subscribe(EventType::KeygenRequest).await; - } + for _ in 0..MAX_CONSECUTIVE_POLLING_ERRORS { + // Loop get_block_number succeeds + asserter.push_success(&101_u64); - #[rstest::rstest] - #[timeout(Duration::from_secs(90))] - #[tokio::test] - async fn test_failed_catchup_does_not_stop_listener() { - let (mut test_instance, asserter, gw_listener) = test_setup(Some(0)).await; - - asserter.push_failure(ErrorPayload { - code: -32002, - message: "request timed out".into(), - data: None, - }); - - let event_type = EventType::KeygenRequest; - tokio::spawn(gw_listener.subscribe(event_type)); - test_instance.wait_for_log("Failed to catch up").await; - test_instance - .wait_for_log(&format!("Waiting for next {event_type}")) - .await; + // get_logs fails + asserter.push_failure(ErrorPayload { + code: -32000, + message: "get logs error".into(), + data: None, + }); + } + + gw_listener.poll_events(MonitoredContract::Decryption).await; } #[rstest::rstest] #[timeout(Duration::from_secs(90))] #[tokio::test] - async fn test_listener_ended_by_end_of_any_task() { + async fn test_listener_ended_by_cancel_token() { let (mut test_instance, _asserter, gw_listener) = test_setup(None).await; - // Will stop because some subscription tasks will not be able to init their event filter - gw_listener.start().await; + gw_listener.cancel_token.cancel(); - test_instance.wait_for_log("Failed to subscribe to").await; + gw_listener.start().await; + test_instance + .wait_for_log("GatewayListener stopped successfully") + .await; } type MockProvider = FillProvider< @@ -438,23 +365,21 @@ mod tests { RootProvider, >; + const MAX_CONSECUTIVE_POLLING_ERRORS: usize = 2; + async fn test_setup( kms_operation_from_block_number: Option, ) -> (TestInstance, Asserter, GatewayListener) { let test_instance = TestInstanceBuilder::db_setup().await.unwrap(); - // Create a mocked `alloy::Provider` let asserter = Asserter::new(); let mock_provider = ProviderBuilder::new().connect_mocked_client(asserter.clone()); - // Used to mock response of `filter.watch()` operation - let mocked_eth_get_filter_changes_result = Address::default(); - asserter.push_success(&mocked_eth_get_filter_changes_result); - let config = Config { decryption_polling: Duration::from_millis(500), key_management_polling: Duration::from_millis(500), kms_operation_from_block_number, + max_consecutive_polling_errors: MAX_CONSECUTIVE_POLLING_ERRORS, ..Default::default() }; let listener = GatewayListener::new( diff --git a/kms-connector/crates/gw-listener/src/core/mod.rs b/kms-connector/crates/gw-listener/src/core/mod.rs index 9947c205c5..09b31fcb6c 100644 --- a/kms-connector/crates/gw-listener/src/core/mod.rs +++ b/kms-connector/crates/gw-listener/src/core/mod.rs @@ -7,4 +7,3 @@ mod publish; pub use config::Config; pub use gateway::GatewayListener; pub use gw_listener::EventListener; -pub use publish::publish_event; diff --git a/kms-connector/crates/gw-listener/src/core/publish.rs b/kms-connector/crates/gw-listener/src/core/publish.rs index 35ce973349..1b9cc30d0e 100644 --- a/kms-connector/crates/gw-listener/src/core/publish.rs +++ b/kms-connector/crates/gw-listener/src/core/publish.rs @@ -14,71 +14,65 @@ use fhevm_gateway_bindings::{ }, }; use sqlx::{ - Pool, Postgres, + PgExecutor, Pool, Postgres, postgres::PgQueryResult, types::chrono::{DateTime, Utc}, }; -use std::time::Duration; -use tracing::{debug, error, info, warn}; - -const INSERTION_RETRY_LIMIT: usize = 10; -const INSERTION_RETRY_INTERVAL: Duration = Duration::from_millis(10); +use tracing::{debug, info, warn}; +/// Inserts all events and updates the last block polled in a single transaction. +/// On failure, the transaction is rolled back automatically. #[tracing::instrument(skip_all)] -pub async fn publish_event( +pub async fn publish_batch( db_pool: &Pool, - event: GatewayEvent, - block_number: Option, + events: Vec, + event_types: &[EventType], + block_number: u64, ) -> anyhow::Result<()> { - for i in 1..=INSERTION_RETRY_LIMIT { - match publish_event_inner(db_pool, event.clone(), block_number).await { - Ok(()) => return Ok(()), - Err(e) => error!("Insertion attempt #{i}/{INSERTION_RETRY_LIMIT} failed: {e}"), - } - if i != INSERTION_RETRY_LIMIT { - tokio::time::sleep(INSERTION_RETRY_INTERVAL).await; - } + let mut tx = db_pool.begin().await?; + for event in events { + publish_event_inner(&mut *tx, event).await?; } - - Err(anyhow::anyhow!( - "Failed to publish {:?} event after {} attempts", - event.kind, - INSERTION_RETRY_LIMIT - )) + update_last_block_polled(&mut *tx, event_types, Some(block_number)).await?; + tx.commit().await?; + Ok(()) } -async fn publish_event_inner( - db_pool: &Pool, +async fn publish_event_inner<'e>( + executor: impl PgExecutor<'e>, event: GatewayEvent, - block_number: Option, ) -> anyhow::Result<()> { - info!(block_number, "Storing {:?} in DB...", event.kind); + info!("Storing {:?} in DB...", event.kind); - let event_type = (&event.kind).into(); let otlp_ctx = event.otlp_context; let tx_hash = event.tx_hash; let created_at = event.created_at; let query_result = match event.kind { GatewayEventKind::PublicDecryption(e) => { - publish_public_decryption(db_pool, e, tx_hash, created_at, otlp_ctx).await + publish_public_decryption(executor, e, tx_hash, created_at, otlp_ctx).await } GatewayEventKind::UserDecryption(e) => { - publish_user_decryption(db_pool, e, tx_hash, created_at, otlp_ctx).await + publish_user_decryption(executor, e, tx_hash, created_at, otlp_ctx).await } GatewayEventKind::PrepKeygen(e) => { - publish_prep_keygen_request(db_pool, e, tx_hash, created_at, otlp_ctx).await + let params_type: ParamsTypeDb = e.paramsType.try_into()?; + publish_prep_keygen_request(executor, e, params_type, tx_hash, created_at, otlp_ctx) + .await } GatewayEventKind::Keygen(e) => { - publish_keygen_request(db_pool, e, tx_hash, created_at, otlp_ctx).await + publish_keygen_request(executor, e, tx_hash, created_at, otlp_ctx).await } GatewayEventKind::Crsgen(e) => { - publish_crsgen_request(db_pool, e, tx_hash, created_at, otlp_ctx).await + let params_type: ParamsTypeDb = e.paramsType.try_into()?; + publish_crsgen_request(executor, e, params_type, tx_hash, created_at, otlp_ctx).await } GatewayEventKind::PrssInit(id) => { - publish_prss_init(db_pool, id, tx_hash, created_at, otlp_ctx).await + publish_prss_init(executor, id, tx_hash, created_at, otlp_ctx).await } GatewayEventKind::KeyReshareSameSet(e) => { - publish_key_reshare_same_set(db_pool, e, tx_hash, created_at, otlp_ctx).await + let params_type: ParamsTypeDb = e.paramsType.try_into()?; + publish_key_reshare_same_set(executor, e, params_type, tx_hash, created_at, otlp_ctx) + .await } } .map_err(|err| anyhow!("Failed to publish event: {err}"))?; @@ -89,12 +83,11 @@ async fn publish_event_inner( warn!("Unexpected query result while publishing event: {query_result:?}"); } - update_last_block_polled(db_pool, event_type, block_number).await?; Ok(()) } -async fn publish_public_decryption( - db_pool: &Pool, +async fn publish_public_decryption<'e>( + executor: impl PgExecutor<'e>, request: PublicDecryptionRequest, tx_hash: Option>, created_at: DateTime, @@ -118,13 +111,13 @@ async fn publish_public_decryption( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_user_decryption( - db_pool: &Pool, +async fn publish_user_decryption<'e>( + executor: impl PgExecutor<'e>, request: UserDecryptionRequest, tx_hash: Option>, created_at: DateTime, @@ -151,19 +144,19 @@ async fn publish_user_decryption( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_prep_keygen_request( - db_pool: &Pool, +async fn publish_prep_keygen_request<'e>( + executor: impl PgExecutor<'e>, request: PrepKeygenRequest, + params_type: ParamsTypeDb, tx_hash: Option>, created_at: DateTime, otlp_ctx: PropagationContext, ) -> anyhow::Result { - let params_type: ParamsTypeDb = request.paramsType.try_into()?; sqlx::query!( "INSERT INTO prep_keygen_requests(\ prep_keygen_id, epoch_id, params_type, tx_hash, created_at, otlp_context\ @@ -176,13 +169,13 @@ async fn publish_prep_keygen_request( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_keygen_request( - db_pool: &Pool, +async fn publish_keygen_request<'e>( + executor: impl PgExecutor<'e>, request: KeygenRequest, tx_hash: Option>, created_at: DateTime, @@ -197,19 +190,19 @@ async fn publish_keygen_request( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_crsgen_request( - db_pool: &Pool, +async fn publish_crsgen_request<'e>( + executor: impl PgExecutor<'e>, request: CrsgenRequest, + params_type: ParamsTypeDb, tx_hash: Option>, created_at: DateTime, otlp_ctx: PropagationContext, ) -> anyhow::Result { - let params_type: ParamsTypeDb = request.paramsType.try_into()?; sqlx::query!( "INSERT INTO crsgen_requests(\ crs_id, max_bit_length, params_type, tx_hash, created_at, otlp_context\ @@ -222,13 +215,13 @@ async fn publish_crsgen_request( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_prss_init( - db_pool: &Pool, +async fn publish_prss_init<'e>( + executor: impl PgExecutor<'e>, id: U256, tx_hash: Option>, created_at: DateTime, @@ -242,19 +235,19 @@ async fn publish_prss_init( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -async fn publish_key_reshare_same_set( - db_pool: &Pool, +async fn publish_key_reshare_same_set<'e>( + executor: impl PgExecutor<'e>, request: KeyReshareSameSet, + params_type: ParamsTypeDb, tx_hash: Option>, created_at: DateTime, otlp_ctx: PropagationContext, ) -> anyhow::Result { - let params_type: ParamsTypeDb = request.paramsType.try_into()?; sqlx::query!( "INSERT INTO key_reshare_same_set(\ prep_keygen_id, key_id, key_reshare_id, params_type, tx_hash, created_at, otlp_context\ @@ -268,41 +261,44 @@ async fn publish_key_reshare_same_set( created_at, bc2wrap::serialize(&otlp_ctx)?, ) - .execute(db_pool) + .execute(executor) .await .map_err(anyhow::Error::from) } -/// Updates the registered last block polled in DB. +/// Updates the registered last block polled in DB for the given event types. #[tracing::instrument(skip_all)] -pub async fn update_last_block_polled( - db_pool: &Pool, - event_type: EventType, +pub async fn update_last_block_polled<'e>( + executor: impl PgExecutor<'e>, + event_types: &[EventType], last_block_polled: Option, ) -> anyhow::Result<()> { info!( last_block_polled, - "Updating last block polled in DB for {event_type}" + "Updating last block polled in DB for {event_types:?}" ); let query_result = sqlx::query!( "UPDATE last_block_polled SET block_number = $2, updated_at = $3 \ - WHERE event_type = $1 AND (block_number IS NULL OR block_number < $2)", - event_type as EventType, + WHERE event_type = ANY($1::event_type[]) AND (block_number IS NULL OR block_number < $2)", + event_types as &[EventType], last_block_polled.map(|n| n as i64), Utc::now(), ) - .execute(db_pool) + .execute(executor) .await?; - if query_result.rows_affected() == 1 { + let rows_affected = query_result.rows_affected(); + if rows_affected > 0 { info!( last_block_polled, - "Last block polled for {event_type} was successfully updated!" + "Last block polled updated for {}/{} event types in {event_types:?}", + rows_affected, + event_types.len() ); } else { debug!( last_block_polled, - "Last block polled for {event_type} was not updated: {query_result:?}" + "Last block polled for {event_types:?} was not updated: {query_result:?}" ); } diff --git a/kms-connector/crates/gw-listener/src/monitoring/metrics.rs b/kms-connector/crates/gw-listener/src/monitoring/metrics.rs index 76d6d31dd8..7de4f7060e 100644 --- a/kms-connector/crates/gw-listener/src/monitoring/metrics.rs +++ b/kms-connector/crates/gw-listener/src/monitoring/metrics.rs @@ -10,11 +10,11 @@ pub static EVENT_RECEIVED_COUNTER: LazyLock = LazyLock::new(|| { .unwrap() }); -pub static EVENT_RECEIVED_ERRORS: LazyLock = LazyLock::new(|| { +pub static EVENT_LISTENING_ERRORS: LazyLock = LazyLock::new(|| { register_int_counter_vec!( - "kms_connector_gw_listener_event_received_errors", - "Number of errors encountered by the GatewayListener while receiving events", - &["event_type"] + "kms_connector_gw_listener_event_listening_errors", + "Number of errors encountered by the GatewayListener while listening for events", + &["contract"] ) .unwrap() }); diff --git a/kms-connector/crates/gw-listener/tests/block_tracking.rs b/kms-connector/crates/gw-listener/tests/block_tracking.rs index 7c8ef2f05c..0d1593759a 100644 --- a/kms-connector/crates/gw-listener/tests/block_tracking.rs +++ b/kms-connector/crates/gw-listener/tests/block_tracking.rs @@ -1,6 +1,6 @@ mod common; -use crate::common::{check_event_in_db, fetch_from_db, mock_event_on_gw, start_test_listener}; +use crate::common::{mock_event_on_gw, poll_db_for_event, start_test_listener}; use connector_utils::{tests::setup::TestInstanceBuilder, types::db::EventType}; use rstest::rstest; use std::time::Duration; @@ -67,12 +67,7 @@ async fn test_block_tracking(event_type: EventType) -> anyhow::Result<()> { start_test_listener(&mut test_instance, cancel_token.clone(), None).await; let (expected_event, _) = mock_event_on_gw(&test_instance, event_type).await?; - test_instance - .wait_for_log("Event successfully stored in DB!") - .await; - - let rows = fetch_from_db(test_instance.db(), event_type).await?; - check_event_in_db(&rows, expected_event)?; + poll_db_for_event(test_instance.db(), event_type, &expected_event).await?; info!("Event successfully stored! Stopping GatewayListener..."); cancel_token.cancel(); gw_listener_task?.await?; @@ -87,12 +82,7 @@ async fn test_block_tracking(event_type: EventType) -> anyhow::Result<()> { let gw_listener_task = start_test_listener(&mut test_instance, cancel_token.clone(), None).await; - test_instance - .wait_for_log("Event successfully stored in DB!") - .await; - - let rows = fetch_from_db(test_instance.db(), event_type).await?; - check_event_in_db(&rows, expected_event)?; + poll_db_for_event(test_instance.db(), event_type, &expected_event).await?; info!("Event successfully stored! Stopping GatewayListener..."); cancel_token.cancel(); diff --git a/kms-connector/crates/gw-listener/tests/catchup.rs b/kms-connector/crates/gw-listener/tests/catchup.rs index 89ad12bfcd..192d9c5459 100644 --- a/kms-connector/crates/gw-listener/tests/catchup.rs +++ b/kms-connector/crates/gw-listener/tests/catchup.rs @@ -1,6 +1,6 @@ mod common; -use crate::common::{check_event_in_db, fetch_from_db, mock_event_on_gw, start_test_listener}; +use crate::common::{mock_event_on_gw, poll_db_for_event, start_test_listener}; use connector_utils::{tests::setup::TestInstanceBuilder, types::db::EventType}; use rstest::rstest; use std::time::Duration; @@ -63,13 +63,11 @@ async fn test_catchup_from_block(event_type: EventType) -> anyhow::Result<()> { // Wait for two more anvil blocks so anvil is fully ready tokio::time::sleep(2 * test_instance.anvil_block_time()).await; - let mut nb_event = 1; let (event1, block1) = mock_event_on_gw(&test_instance, event_type).await?; assert!(block1.is_some()); let event2 = if !matches!(event_type, EventType::PrssInit) { let (event2, block2) = mock_event_on_gw(&test_instance, event_type).await?; assert_ne!(block1, block2); - nb_event += 1; Some(event2) } else { None @@ -79,16 +77,9 @@ async fn test_catchup_from_block(event_type: EventType) -> anyhow::Result<()> { let gw_listener_task = start_test_listener(&mut test_instance, cancel_token.clone(), block1).await; - for _ in 0..nb_event { - test_instance - .wait_for_log("Event successfully stored in DB!") - .await; - } - - let rows = fetch_from_db(test_instance.db(), event_type).await?; - check_event_in_db(&rows, event1)?; - if let Some(event2) = event2 { - check_event_in_db(&rows, event2)?; + poll_db_for_event(test_instance.db(), event_type, &event1).await?; + if let Some(ref event2) = event2 { + poll_db_for_event(test_instance.db(), event_type, event2).await?; } info!("Events successfully stored! Stopping GatewayListener..."); diff --git a/kms-connector/crates/gw-listener/tests/common/mod.rs b/kms-connector/crates/gw-listener/tests/common/mod.rs index e9967ba707..55a3ceb73b 100644 --- a/kms-connector/crates/gw-listener/tests/common/mod.rs +++ b/kms-connector/crates/gw-listener/tests/common/mod.rs @@ -32,7 +32,7 @@ use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use tracing::info; -const NB_EVENT_TYPE: usize = 7; +const NB_POLL_GROUPS: usize = 2; pub async fn start_test_listener( test_instance: &mut TestInstance, @@ -55,9 +55,9 @@ pub async fn start_test_listener( let listener_task = tokio::spawn(gw_listener.start()); - // Wait for all gw-listener event filters to be ready + 2 anvil blocks - for _ in 0..NB_EVENT_TYPE { - test_instance.wait_for_log("Subscribed to ").await; + // Wait for both polling tasks to start + 2 anvil blocks + for _ in 0..NB_POLL_GROUPS { + test_instance.wait_for_log("Started ").await; } tokio::time::sleep(2 * test_instance.anvil_block_time()).await; @@ -191,6 +191,26 @@ pub async fn fetch_from_db(db: &Pool, event_type: EventType) -> sqlx:: sqlx::query(query).fetch_all(db).await } +pub async fn poll_db_for_event( + db: &Pool, + event_type: EventType, + expected_event: &GatewayEventKind, +) -> anyhow::Result<()> { + let timeout = Duration::from_secs(30); + let poll_interval = Duration::from_millis(200); + let start = std::time::Instant::now(); + loop { + let rows = fetch_from_db(db, event_type).await?; + if check_event_in_db(&rows, expected_event.clone()).is_ok() { + return Ok(()); + } + if start.elapsed() > timeout { + anyhow::bail!("Timed out waiting for {event_type} event in DB"); + } + tokio::time::sleep(poll_interval).await; + } +} + pub fn check_event_in_db(rows: &[PgRow], event: GatewayEventKind) -> anyhow::Result<()> { match event { GatewayEventKind::PublicDecryption(e) => { diff --git a/kms-connector/crates/gw-listener/tests/integration_test.rs b/kms-connector/crates/gw-listener/tests/integration_test.rs index c55088fc89..f3520b8798 100644 --- a/kms-connector/crates/gw-listener/tests/integration_test.rs +++ b/kms-connector/crates/gw-listener/tests/integration_test.rs @@ -1,6 +1,6 @@ mod common; -use crate::common::{check_event_in_db, fetch_from_db, mock_event_on_gw, start_test_listener}; +use crate::common::{mock_event_on_gw, poll_db_for_event, start_test_listener}; use connector_utils::tests::setup::TestInstanceBuilder; use connector_utils::types::db::EventType; use rstest::rstest; @@ -64,12 +64,7 @@ async fn test_publish_event(event_type: EventType) -> anyhow::Result<()> { start_test_listener(&mut test_instance, cancel_token.clone(), None).await; let (expected_event, _) = mock_event_on_gw(&test_instance, event_type).await?; - test_instance - .wait_for_log("Event successfully stored in DB!") - .await; - - let rows = fetch_from_db(test_instance.db(), event_type).await?; - check_event_in_db(&rows, expected_event)?; + poll_db_for_event(test_instance.db(), event_type, &expected_event).await?; info!("Event successfully stored! Stopping GatewayListener..."); cancel_token.cancel(); diff --git a/kms-connector/crates/utils/src/types/db.rs b/kms-connector/crates/utils/src/types/db.rs index 50b91e3ec5..032e3e50fd 100644 --- a/kms-connector/crates/utils/src/types/db.rs +++ b/kms-connector/crates/utils/src/types/db.rs @@ -1,8 +1,19 @@ use crate::types::GatewayEventKind; -use alloy::primitives::{Address, U256}; +use alloy::{ + primitives::{Address, B256, U256}, + sol_types::SolEvent, +}; use anyhow::anyhow; use fhevm_gateway_bindings::{ - decryption::Decryption::SnsCiphertextMaterial, kms_generation::IKMSGeneration::KeyDigest, + decryption::Decryption::{ + PublicDecryptionRequest, SnsCiphertextMaterial, UserDecryptionRequest, + }, + kms_generation::{ + IKMSGeneration::KeyDigest, + KMSGeneration::{ + CrsgenRequest, KeyReshareSameSet, KeygenRequest, PRSSInit, PrepKeygenRequest, + }, + }, }; use sqlx::postgres::PgNotification; use std::{fmt::Display, str::FromStr}; @@ -203,6 +214,18 @@ impl EventType { EventType::KeyReshareSameSet => "key_reshare_same_set", } } + + pub fn signature_hash(&self) -> B256 { + match self { + EventType::PublicDecryptionRequest => PublicDecryptionRequest::SIGNATURE_HASH, + EventType::UserDecryptionRequest => UserDecryptionRequest::SIGNATURE_HASH, + EventType::PrepKeygenRequest => PrepKeygenRequest::SIGNATURE_HASH, + EventType::KeygenRequest => KeygenRequest::SIGNATURE_HASH, + EventType::CrsgenRequest => CrsgenRequest::SIGNATURE_HASH, + EventType::PrssInit => PRSSInit::SIGNATURE_HASH, + EventType::KeyReshareSameSet => KeyReshareSameSet::SIGNATURE_HASH, + } + } } // Postgres notifications