Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 83 additions & 202 deletions kms-connector/Cargo.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions kms-connector/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ tfhe = "=1.4.0-alpha.3"
# External dependencies #
#####################################################################
actix-web = "=4.11.0"
alloy = { version = "=1.0.38", default-features = false, features = [
alloy = { version = "=1.1.2", default-features = false, features = [
"essentials",
"json-rpc",
"provider-debug-api",
"provider-ws",
"reqwest-rustls-tls",
"signer-aws",
"std",
Expand Down
1 change: 0 additions & 1 deletion kms-connector/crates/tx-sender/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ tracing.workspace = true
tracing-opentelemetry.workspace = true

[dev-dependencies]
alloy = { workspace = true, features = ["json-rpc"] }
bc2wrap.workspace = true
connector-utils = { workspace = true, features = ["tests"] }
rstest.workspace = true
Expand Down
171 changes: 34 additions & 137 deletions kms-connector/crates/tx-sender/src/core/tx_sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@ use crate::{
};
use alloy::{
hex,
providers::{
PendingTransactionError, Provider, RootProvider, ext::DebugApi, fillers::TxFiller,
},
providers::{PendingTransactionError, Provider, ext::DebugApi},
rpc::types::{
TransactionReceipt, TransactionRequest,
trace::geth::{CallConfig, GethDebugTracingOptions},
Expand All @@ -18,8 +16,7 @@ use alloy::{
};
use anyhow::anyhow;
use connector_utils::{
conn::{WalletGatewayProviderFillers, connect_to_db, connect_to_gateway_with_wallet},
provider::NonceManagedProvider,
conn::{WalletGatewayProvider, connect_to_db, connect_to_gateway_with_wallet},
tasks::spawn_with_limit,
types::{
CrsgenResponse, KeygenResponse, KmsResponse, KmsResponseKind, PrepKeygenResponse,
Expand All @@ -39,31 +36,29 @@ use tracing::{debug, error, info, warn};
use tracing_opentelemetry::OpenTelemetrySpanExt;

/// Struct sending stored KMS Core's responses to the Gateway.
pub struct TransactionSender<L, F, P>
pub struct TransactionSender<L, P>
where
F: TxFiller,
P: Provider,
{
/// The entity used to collect stored KMS Core's responses.
response_picker: L,

/// The entity responsible to send transaction to the Gateway.
inner: TransactionSenderInner<F, P>,
inner: TransactionSenderInner<P>,

/// The database pool for where the KMS Core's responses are stored.
db_pool: Pool<Postgres>,
}

impl<L, F, P> TransactionSender<L, F, P>
impl<L, P> TransactionSender<L, P>
where
L: KmsResponsePicker,
F: TxFiller + 'static,
P: Provider + Clone + 'static,
{
/// Creates a new `TransactionSender` instance.
pub fn new(
response_picker: L,
inner: TransactionSenderInner<F, P>,
inner: TransactionSenderInner<P>,
db_pool: Pool<Postgres>,
) -> Self {
Self {
Expand Down Expand Up @@ -116,7 +111,7 @@ where
/// Handles a response coming from the KMS Core.
#[tracing::instrument(skip(inner, db_pool, cancel_token), fields(response = %response.kind))]
async fn forward_response(
inner: TransactionSenderInner<F, P>,
inner: TransactionSenderInner<P>,
db_pool: Pool<Postgres>,
response: KmsResponse,
cancel_token: CancellationToken,
Expand All @@ -135,7 +130,7 @@ where
}
}

impl TransactionSender<DbKmsResponsePicker, WalletGatewayProviderFillers, RootProvider> {
impl TransactionSender<DbKmsResponsePicker, WalletGatewayProvider> {
/// Creates a new `TransactionSender` instance from a valid `Config`.
pub async fn from_config(config: Config) -> anyhow::Result<(Self, State)> {
let db_pool = connect_to_db(&config.database_url, config.database_pool_size).await?;
Expand Down Expand Up @@ -168,14 +163,13 @@ impl TransactionSender<DbKmsResponsePicker, WalletGatewayProviderFillers, RootPr
}

/// The internal struct used to send transaction to the Gateway.
pub struct TransactionSenderInner<F, P>
pub struct TransactionSenderInner<P>
where
F: TxFiller,
P: Provider,
{
provider: NonceManagedProvider<F, P>,
decryption_contract: DecryptionInstance<NonceManagedProvider<F, P>>,
kms_generation_contract: KMSGenerationInstance<NonceManagedProvider<F, P>>,
provider: P,
decryption_contract: DecryptionInstance<P>,
kms_generation_contract: KMSGenerationInstance<P>,
config: TransactionSenderInnerConfig,
}

Expand All @@ -187,15 +181,14 @@ pub struct TransactionSenderInnerConfig {
pub gas_multiplier_percent: usize,
}

impl<F, P> TransactionSenderInner<F, P>
impl<P> TransactionSenderInner<P>
where
F: TxFiller,
P: Provider,
{
pub fn new(
provider: NonceManagedProvider<F, P>,
decryption_contract: DecryptionInstance<NonceManagedProvider<F, P>>,
kms_generation_contract: KMSGenerationInstance<NonceManagedProvider<F, P>>,
provider: P,
decryption_contract: DecryptionInstance<P>,
kms_generation_contract: KMSGenerationInstance<P>,
inner_config: TransactionSenderInnerConfig,
) -> Self {
Self {
Expand Down Expand Up @@ -419,9 +412,8 @@ where
}
}

impl<F, P> Clone for TransactionSenderInner<F, P>
impl<P> Clone for TransactionSenderInner<P>
where
F: TxFiller,
P: Provider + Clone,
{
fn clone(&self) -> Self {
Expand Down Expand Up @@ -481,20 +473,11 @@ impl From<PendingTransactionError> for Error {
mod tests {
use super::*;
use alloy::{
network::{Ethereum, IntoWallet, Network, TransactionBuilder},
primitives::Address,
providers::{
Identity, ProviderBuilder, SendableTx,
fillers::{FillProvider, FillerControlFlow},
mock::Asserter,
},
providers::{Identity, ProviderBuilder, fillers::FillProvider, mock::Asserter},
rpc::{json_rpc::ErrorPayload, types::trace::geth::GethTrace},
transports::TransportResult,
};
use connector_utils::{
config::KmsWallet,
tests::rand::{rand_signature, rand_u256},
};
use connector_utils::tests::rand::{rand_signature, rand_u256};
use serde::de::DeserializeOwned;
use serde_json::value::RawValue;
use std::fs::File;
Expand All @@ -504,15 +487,9 @@ mod tests {
async fn test_send_tx_out_of_gas() -> anyhow::Result<()> {
// Create a mocked `alloy::Provider`
let asserter = Asserter::new();
let mock_provider = NonceManagedProvider::new(
FillProvider::new(
ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone()),
MockFiller {},
),
Address::default(),
);
let mock_provider = ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone());

// Used to mock all RPC responses of transaction sending operation
let test_data_dir = test_data_dir();
Expand Down Expand Up @@ -555,10 +532,7 @@ mod tests {
#[tokio::test]
async fn test_disable_reverted_tx_tracing() {
let asserter = Asserter::new();
let mock_provider = NonceManagedProvider::new(
ProviderBuilder::new().connect_mocked_client(asserter.clone()),
Address::default(),
);
let mock_provider = ProviderBuilder::new().connect_mocked_client(asserter.clone());
let inner_sender = TransactionSenderInner::new(
mock_provider.clone(),
DecryptionInstance::new(Address::default(), mock_provider.clone()),
Expand Down Expand Up @@ -586,15 +560,9 @@ mod tests {
async fn test_error_decryption_not_requested() -> anyhow::Result<()> {
// Create a mocked `alloy::Provider`
let asserter = Asserter::new();
let mock_provider = NonceManagedProvider::new(
FillProvider::new(
ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone()),
MockFiller {},
),
Address::default(),
);
let mock_provider = ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone());

// Used to mock all RPC responses of transaction sending operation
let estimate_gas: usize = 21000;
Expand Down Expand Up @@ -640,15 +608,9 @@ mod tests {
async fn test_error_not_kms_tx_sender() -> anyhow::Result<()> {
// Create a mocked `alloy::Provider`
let asserter = Asserter::new();
let mock_provider = NonceManagedProvider::new(
FillProvider::new(
ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone()),
MockFiller {},
),
Address::default(),
);
let mock_provider = ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone());

// Used to mock all RPC responses of transaction sending operation
let estimate_gas: usize = 21000;
Expand Down Expand Up @@ -694,14 +656,11 @@ mod tests {
async fn test_error_not_kms_signer() -> anyhow::Result<()> {
// Create a mocked `alloy::Provider`
let asserter = Asserter::new();
let mock_provider = NonceManagedProvider::new(
FillProvider::new(
ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone()),
Identity,
),
Address::default(),
let mock_provider = FillProvider::new(
ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone()),
Identity,
);

// Used to mock all RPC responses of transaction sending operation
Expand Down Expand Up @@ -750,66 +709,4 @@ mod tests {
fn test_data_dir() -> String {
format!("{}/tests/data/tx_out_of_gas", env!("CARGO_MANIFEST_DIR"))
}

/// A filler that mocks gas estimation and signing of the transactions
#[derive(Clone, Debug)]
struct MockFiller;

impl TxFiller<Ethereum> for MockFiller {
type Fillable = ();

fn status(&self, tx: &<Ethereum as Network>::TransactionRequest) -> FillerControlFlow {
if tx.from().is_none() {
return FillerControlFlow::Ready;
}

match tx.complete_preferred() {
Ok(_) => FillerControlFlow::Ready,
Err(e) => FillerControlFlow::Missing(vec![("Wallet", e)]),
}
}

fn fill_sync(&self, _tx: &mut SendableTx<Ethereum>) {}

async fn prepare<P>(
&self,
_provider: &P,
_tx: &<Ethereum as Network>::TransactionRequest,
) -> TransportResult<Self::Fillable>
where
P: Provider<Ethereum>,
{
Ok(())
}

async fn fill(
&self,
_fillable: Self::Fillable,
tx: SendableTx<Ethereum>,
) -> TransportResult<SendableTx<Ethereum>> {
let mut builder = match tx {
SendableTx::Builder(builder) => builder,
_ => return Ok(tx),
};

let chain_id = 54321;
let wallet = KmsWallet::from_private_key_str(
"0x3f45b129a7fd099146e9fe63851a71646231f7743c712695f3b2d2bf0e41c774",
Some(chain_id),
)
.unwrap()
.into_wallet();
builder.set_gas_limit(21000);
builder.set_max_fee_per_gas(10);
builder.set_max_priority_fee_per_gas(10);
builder.set_chain_id(chain_id);
builder.set_nonce(0);
let envelope = builder
.build(&wallet)
.await
.map_err(RpcError::local_usage)?;

Ok(SendableTx::Envelope(envelope))
}
}
}
1 change: 1 addition & 0 deletions kms-connector/crates/utils/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ testcontainers = { workspace = true, optional = true }
toml = { workspace = true, optional = true }

[dev-dependencies]
alloy = { workspace = true, features = ["provider-anvil-node"] }
serial_test.workspace = true
tempfile.workspace = true
toml.workspace = true
Expand Down
14 changes: 11 additions & 3 deletions kms-connector/crates/utils/src/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use alloy::{
WalletFiller,
},
},
transports::http::reqwest::Url,
transports::http::reqwest::{self, Url},
};
use anyhow::anyhow;
use sqlx::{Pool, Postgres, postgres::PgPoolOptions};
Expand All @@ -24,6 +24,9 @@ pub const CONNECTION_RETRY_NUMBER: usize = 5;
/// The delay between two connection attempts.
pub const CONNECTION_RETRY_DELAY: Duration = Duration::from_secs(2);

/// Timeout for requests to the Gateway's RPC node.
const REQUEST_TIMEOUT: Duration = Duration::from_mins(1);

/// Tries to establish the connection with Postgres database.
pub async fn connect_to_db(db_url: &str, db_pool_size: u32) -> anyhow::Result<Pool<Postgres>> {
for i in 1..=CONNECTION_RETRY_NUMBER {
Expand Down Expand Up @@ -55,7 +58,8 @@ type DefaultFillers = JoinFill<
pub type GatewayProvider = FillProvider<JoinFill<DefaultFillers, ChainIdFiller>, RootProvider>;

/// The default `alloy::Provider` used to interact with the Gateway using a wallet.
pub type WalletGatewayProvider = NonceManagedProvider<WalletGatewayProviderFillers, RootProvider>;
pub type WalletGatewayProvider =
NonceManagedProvider<FillProvider<WalletGatewayProviderFillers, RootProvider>>;
pub type WalletGatewayProviderFillers = JoinFill<
JoinFill<JoinFill<Identity, ChainIdFiller>, FillersWithoutNonceManagement>,
WalletFiller<EthereumWallet>,
Expand Down Expand Up @@ -107,7 +111,11 @@ where

let gateway_url =
Url::from_str(gateway_url).map_err(|e| anyhow!("Invalid Gateway URL: {e}"))?;
let provider = provider_builder_new().connect_http(gateway_url);
let client = reqwest::ClientBuilder::new()
.timeout(REQUEST_TIMEOUT)
.build()?;
let provider = provider_builder_new().connect_reqwest(client, gateway_url);

info!("Connected to Gateway's RPC node successfully");
Ok(provider)
}
Expand Down
Loading