Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,26 +76,6 @@ where
}
}

pub async fn check_decryption_not_already_done(
&self,
decryption_id: U256,
) -> Result<(), ProcessingError> {
let is_decryption_done = self
.decryption_contract
.isDecryptionDone(decryption_id)
.call()
.await
.map_err(|e| ProcessingError::Recoverable(anyhow::Error::from(e)))?;

if is_decryption_done {
return Err(ProcessingError::Irrecoverable(anyhow!(
"Decryption already done on the Gateway"
)));
}

Ok(())
}

#[tracing::instrument(skip_all)]
pub async fn check_ciphertexts_allowed_for_public_decryption(
&self,
Expand Down Expand Up @@ -217,38 +197,23 @@ where
delegator_address: Address,
) -> Result<(), ProcessingError> {
let handle_hex = hex::encode(handle);
let is_delegated_call = acl_contract.isHandleDelegatedForUserDecryption(
delegator_address,
user_address,
contract_address,
handle,
);
let delegator_allowed_call = acl_contract.isAllowed(handle, delegator_address);
let contract_allowed_call = acl_contract.isAllowed(handle, contract_address);

let (is_delegated, delegator_allowed, contract_allowed) = tokio::try_join!(
is_delegated_call.call(),
delegator_allowed_call.call(),
contract_allowed_call.call(),
)
.map_err(|e| ProcessingError::Recoverable(anyhow::Error::from(e)))?;
let is_delegated = acl_contract
.isHandleDelegatedForUserDecryption(
delegator_address,
user_address,
contract_address,
handle,
)
.call()
.await
.map_err(|e| ProcessingError::Recoverable(anyhow::Error::from(e)))?;

if !is_delegated {
return Err(ProcessingError::Recoverable(anyhow!(
"{user_address} is not a delegate of {delegator_address} for contract \
{contract_address} and handle {handle_hex}!",
)));
}
if !delegator_allowed {
return Err(ProcessingError::Recoverable(anyhow!(
"{delegator_address} is not allowed to decrypt {handle_hex}!",
)));
}
if !contract_allowed {
return Err(ProcessingError::Recoverable(anyhow!(
"{contract_address} is not allowed to decrypt {handle_hex}!",
)));
}

Ok(())
}
Expand Down Expand Up @@ -442,65 +407,10 @@ mod tests {
enum ExpectedOutcome {
Ok,
Recoverable,
#[allow(unused)]
Irrecoverable,
}

enum DecryptionReadyMock {
Failure(&'static str),
Success(bool),
}

#[rstest]
#[case::transport_error(
DecryptionReadyMock::Failure("Transport Error"),
ExpectedOutcome::Recoverable
)]
#[case::not_done(DecryptionReadyMock::Success(false), ExpectedOutcome::Ok)]
#[case::already_done(DecryptionReadyMock::Success(true), ExpectedOutcome::Irrecoverable)]
#[tokio::test]
async fn check_decryption_not_already_done(
#[case] mock_response: DecryptionReadyMock,
#[case] expected: ExpectedOutcome,
) {
let asserter = Asserter::new();
let mock_provider = ProviderBuilder::new()
.disable_recommended_fillers()
.connect_mocked_client(asserter.clone());
let acl_contracts_mock = HashMap::from([(
u64::default(),
ACL::new(Address::default(), mock_provider.clone()),
)]);

let config = Config::default();
let s3_service = S3Service::new(&config, mock_provider.clone(), reqwest::Client::new());
let decryption_processor = DecryptionProcessor::new(
&config,
MockContextManager,
mock_provider,
acl_contracts_mock,
s3_service,
);

match mock_response {
DecryptionReadyMock::Failure(msg) => asserter.push_failure_msg(msg),
DecryptionReadyMock::Success(val) => asserter.push_success(&val.abi_encode()),
}

let result = decryption_processor
.check_decryption_not_already_done(U256::ZERO)
.await;

match expected {
ExpectedOutcome::Ok => result.unwrap(),
ExpectedOutcome::Recoverable => {
assert!(matches!(result, Err(ProcessingError::Recoverable(_))))
}
ExpectedOutcome::Irrecoverable => {
assert!(matches!(result, Err(ProcessingError::Irrecoverable(_))))
}
}
}

enum PubDecryptACLMock {
Failure(&'static str),
Success(bool),
Expand Down Expand Up @@ -653,11 +563,7 @@ mod tests {

enum DelegatedUserDecryptACLMock {
Failure(&'static str),
Success {
is_delegated: bool,
delegator_allowed: bool,
contract_allowed: bool,
},
Success { is_delegated: bool },
}

#[rstest]
Expand All @@ -667,22 +573,12 @@ mod tests {
None
)]
#[case::allowed(
DelegatedUserDecryptACLMock::Success { is_delegated: true, delegator_allowed: true, contract_allowed: true },
DelegatedUserDecryptACLMock::Success { is_delegated: true },
ExpectedOutcome::Ok,
None
)]
#[case::delegator_allowed_contract_not_allowed(
DelegatedUserDecryptACLMock::Success { is_delegated: true, delegator_allowed: true, contract_allowed: false },
ExpectedOutcome::Recoverable,
Some("is not allowed to decrypt")
)]
#[case::delegator_not_allowed_contract_allowed(
DelegatedUserDecryptACLMock::Success { is_delegated: true, delegator_allowed: false, contract_allowed: true },
ExpectedOutcome::Recoverable,
Some("is not allowed to decrypt")
)]
#[case::not_delegated(
DelegatedUserDecryptACLMock::Success { is_delegated: false, delegator_allowed: true, contract_allowed: true },
DelegatedUserDecryptACLMock::Success { is_delegated: false },
ExpectedOutcome::Recoverable,
Some("is not a delegate of")
)]
Expand Down Expand Up @@ -725,14 +621,8 @@ mod tests {

match mock_response {
DelegatedUserDecryptACLMock::Failure(msg) => asserter.push_failure_msg(msg),
DelegatedUserDecryptACLMock::Success {
is_delegated,
delegator_allowed,
contract_allowed,
} => {
DelegatedUserDecryptACLMock::Success { is_delegated } => {
asserter.push_success(&is_delegated.abi_encode());
asserter.push_success(&delegator_allowed.abi_encode());
asserter.push_success(&contract_allowed.abi_encode());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ impl<GP: Provider, HP: Provider, C: ContextManager> DbEventProcessor<GP, HP, C>
) -> Result<KmsGrpcRequest, ProcessingError> {
let request = match &event.kind {
GatewayEventKind::PublicDecryption(req) => {
self.decryption_processor
.check_decryption_not_already_done(req.decryptionId)
.await?;
self.decryption_processor
.check_ciphertexts_allowed_for_public_decryption(&req.snsCtMaterials)
.await?;
Expand Down
10 changes: 2 additions & 8 deletions kms-connector/crates/kms-worker/tests/acl.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
mod common;

use crate::common::{create_mock_user_decryption_request_tx, init_kms_worker};
use alloy::{
providers::{ProviderBuilder, mock::Asserter},
sol_types::SolValue,
};
use alloy::providers::{ProviderBuilder, mock::Asserter};
use connector_utils::{
tests::{
db::requests::{
Expand Down Expand Up @@ -45,10 +42,7 @@ async fn test_decryption_acl_failure(#[case] event_type: EventType) -> anyhow::R
.with_tx_hash(tx_hash);
for _ in 0..MAX_DECRYPTION_ATTEMPTS {
match event_type {
EventType::PublicDecryptionRequest => {
// Mocking isDecryptionDone returns false
asserter.push_success(&false.abi_encode());
}
EventType::PublicDecryptionRequest => (),
EventType::UserDecryptionRequest => {
// Mocking `get_transaction_by_hash` call result
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
Expand Down
17 changes: 5 additions & 12 deletions kms-connector/crates/kms-worker/tests/attempt_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,11 @@ async fn test_request_processing(#[case] event_type: EventType) -> anyhow::Resul
.with_sns_ct_materials(vec![sns_ct.clone()])
.with_tx_hash(tx_hash);
for attempt in 0..MAX_DECRYPTION_ATTEMPTS {
match event_type {
EventType::PublicDecryptionRequest => {
// Mocking isDecryptionDone returns false
asserter.push_success(&false.abi_encode());
}
EventType::UserDecryptionRequest => {
// Mocking `get_transaction_by_hash` call result
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
asserter.push_success(&mock_tx);
}
_ => (),
};
if matches!(event_type, EventType::UserDecryptionRequest) {
// Mocking `get_transaction_by_hash` call result
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
asserter.push_success(&mock_tx);
}

// First attempt, the copro URL is not cached yet
if attempt == 0 {
Expand Down
17 changes: 5 additions & 12 deletions kms-connector/crates/kms-worker/tests/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,11 @@ async fn test_decryption_context_not_found(#[case] event_type: EventType) -> any
.with_context_id(unknown_context_id);

for _ in 0..MAX_DECRYPTION_ATTEMPTS {
match event_type {
EventType::PublicDecryptionRequest => {
// Mocking isDecryptionDone returns false
asserter.push_success(&false.abi_encode());
}
EventType::UserDecryptionRequest => {
// Mocking `get_transaction_by_hash` call result
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
asserter.push_success(&mock_tx);
}
_ => panic!("Unexpected event type"),
};
if matches!(event_type, EventType::UserDecryptionRequest) {
// Mocking `get_transaction_by_hash` call result
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
asserter.push_success(&mock_tx);
}
}

let gateway_mock_provider = ProviderBuilder::new()
Expand Down
21 changes: 7 additions & 14 deletions kms-connector/crates/kms-worker/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,20 +70,13 @@ async fn test_processing_request(
let mut insert_options = InsertRequestOptions::new()
.with_already_sent(already_sent)
.with_sns_ct_materials(vec![sns_ct.clone()]);
match event_type {
EventType::PublicDecryptionRequest => {
// Mocking isDecryptionDone returns false
asserter.push_success(&false.abi_encode());
}
EventType::UserDecryptionRequest => {
// Mocking `get_transaction_by_hash` call result
let tx_hash = rand_digest();
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
insert_options = insert_options.with_tx_hash(tx_hash);
asserter.push_success(&mock_tx);
}
_ => (),
};
if matches!(event_type, EventType::UserDecryptionRequest) {
// Mocking `get_transaction_by_hash` call result
let tx_hash = rand_digest();
let mock_tx = create_mock_user_decryption_request_tx(tx_hash, sns_ct.ctHandle)?;
insert_options = insert_options.with_tx_hash(tx_hash);
asserter.push_success(&mock_tx);
}

let get_copro_call_response = Coprocessor {
s3BucketUrl: format!("{}/ct128", test_instance.s3_url()),
Expand Down
Loading