Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion coprocessor/fhevm-engine/sns-worker/src/aws_upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ async fn try_resubmit(
for task in jobs {
select! {
_ = tasks.send(task.clone()) => {
info!(handle = to_hex(task.handle()), "resubmitted");
info!(handle = to_hex(task.handle()), "upload-task, resubmitted");
},
_ = token.cancelled() => {
return Ok(());
Expand Down
184 changes: 162 additions & 22 deletions coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ async fn test_fhe_ciphertext128_with_compression() {
#[serial(db)]
async fn test_batch_execution() {
const WITH_COMPRESSION: bool = true;
let test_env = setup(WITH_COMPRESSION).await.expect("valid setup");
let mut test_env = setup(WITH_COMPRESSION).await.expect("valid setup");
let tf: TestFile = read_test_file("ciphertext64.json");

let batch_size = std::env::var("BATCH_SIZE")
Expand All @@ -88,7 +88,7 @@ async fn test_batch_execution() {
info!("Batch size: {}", batch_size);

run_batch_computations(
&test_env,
&mut test_env,
&tf.handle,
batch_size,
&tf.ciphertext64.clone(),
Expand All @@ -99,6 +99,49 @@ async fn test_batch_execution() {
.expect("run_batch_computations should succeed");
}

/// Tests batch execution of SnS computations with S3 retry logic.
/// Inserts a batch of identical ciphertext64 entries with unique handles,
/// pauses the LocalStack instance to simulate S3 unavailability,
/// triggers the SNS worker to convert them,
/// then unpauses LocalStack and verifies that all ct128 are uploaded
#[tokio::test]
#[serial(db)]
async fn test_batch_execution_with_s3_retry() {
const WITH_COMPRESSION: bool = true;
let mut test_env = setup(WITH_COMPRESSION).await.expect("valid setup");
let tf: TestFile = read_test_file("ciphertext64.bin");

// Pause LocalStack to simulate S3 unavailability
test_env
.s3_instance
.as_ref()
.unwrap()
.container
.pause()
.await
.expect("Pause LocalStack");

info!("Paused LocalStack to test S3 retry logic");

let batch_size = std::env::var("BATCH_SIZE")
.ok()
.and_then(|v| v.parse::<u16>().ok())
.unwrap_or(100);

info!("Batch size: {}", batch_size);

run_batch_and_unpause_s3(
&mut test_env,
&tf.handle,
batch_size,
&tf.ciphertext64.clone(),
tf.cleartext,
WITH_COMPRESSION,
)
.await
.expect("run_batch_and_unpause_s3 should succeed");
}

#[tokio::test]
#[ignore = "disabled in CI"]
async fn test_fhe_ciphertext128_no_compression() {
Expand Down Expand Up @@ -148,23 +191,24 @@ async fn test_decryptable(
with_compression,
handle,
expected_result,
true,
)
.await?;

Ok(())
}

async fn run_batch_computations(
test_env: &TestEnvironment,
test_env: &mut TestEnvironment,
base_handle: &[u8],
batch_size: u16,
ciphertext: &Vec<u8>,
expected_cleartext: i64,
with_compression: bool,
) -> anyhow::Result<()> {
let pool = &test_env.pool;
let bucket128 = &test_env.conf.s3.bucket_ct128;
let bucket64 = &test_env.conf.s3.bucket_ct64;
let bucket128 = &test_env.conf.s3.bucket_ct128.clone();
let bucket64 = &test_env.conf.s3.bucket_ct64.clone();

clean_up(pool).await?;

Expand Down Expand Up @@ -210,6 +254,7 @@ async fn run_batch_computations(
with_compression,
&handle,
expected_cleartext,
true,
)
.await
});
Expand All @@ -229,6 +274,90 @@ async fn run_batch_computations(
anyhow::Result::<()>::Ok(())
}

/// Runs batch computations while LocalStack S3 is paused, then unpauses it
async fn run_batch_and_unpause_s3(
test_env: &mut TestEnvironment,
base_handle: &[u8],
batch_size: u16,
ciphertext: &Vec<u8>,
expected_cleartext: i64,
with_compression: bool,
) -> anyhow::Result<()> {
let pool = &test_env.pool;
let bucket128 = &test_env.conf.s3.bucket_ct128.clone();
let bucket64 = &test_env.conf.s3.bucket_ct64.clone();
info!(batch_size, "Inserting ciphertexts ...");

let mut handles = Vec::new();
let tenant_id = get_tenant_id_from_db(pool, TENANT_API_KEY).await;
for i in 0..batch_size {
let mut handle = base_handle.to_owned();

// Modify first two bytes of the handle to make it unique
// However the ciphertext64 will be the same
handle[0] = (i >> 8) as u8;
handle[1] = (i & 0xFF) as u8;
test_harness::db_utils::insert_ciphertext64(pool, tenant_id, &handle, ciphertext).await?;
test_harness::db_utils::insert_into_pbs_computations(pool, tenant_id, &handle).await?;
handles.push(handle);
}

info!(batch_size, "Inserted batch");

// Send notification only after the batch was fully inserted
sqlx::query("SELECT pg_notify($1, '')")
.bind(LISTEN_CHANNEL)
.execute(pool)
.await?;

info!("Sent pg_notify to SnS worker");

// Wait a bit to ensure that the SNS worker has converted all ciphertexts
tokio::time::sleep(Duration::from_secs(10)).await;

info!("Unpause LocalStack ...");
test_env
.s3_instance
.as_ref()
.unwrap()
.container
.unpause()
.await
.expect("LocalStack unpaused");

info!("LocalStack unpaused");

let start = std::time::Instant::now();
let mut set = tokio::task::JoinSet::new();
for handle in handles.iter() {
let test_env = test_env.clone();
let handle = handle.clone();
set.spawn(async move {
assert_ciphertext128(
&test_env,
tenant_id,
with_compression,
&handle,
expected_cleartext,
true,
)
.await
});
}

while let Some(res) = set.join_next().await {
res??;
}

let elapsed = start.elapsed();
info!(elapsed = ?elapsed, batch_size, "Batch execution completed");

// Assert that all ciphertext128 objects are uploaded to S3
assert_ciphertext_s3_object_count(test_env, bucket128, batch_size as i64).await;
assert_ciphertext_s3_object_count(test_env, bucket64, batch_size as i64).await;

anyhow::Result::<()>::Ok(())
}
#[tokio::test]
#[serial(db)]
#[cfg(not(feature = "gpu"))]
Expand Down Expand Up @@ -450,6 +579,8 @@ async fn setup(enable_compression: bool) -> anyhow::Result<TestEnvironment> {

let client_key: Option<ClientKey> = fetch_client_key(&pool, &TENANT_API_KEY.to_owned()).await?;

clean_up(&pool.clone()).await?;

let (events_tx, mut events_rx) = mpsc::channel::<&'static str>(10);
tokio::spawn(async move {
crate::run_all(config, token, Some(events_tx))
Expand Down Expand Up @@ -486,7 +617,8 @@ async fn setup_localstack(
if std::env::var("TEST_GLOBAL_LOCALSTACK").unwrap_or("0".to_string()) == "1" {
(None, LOCALSTACK_PORT)
} else {
let localstack_instance = Arc::new(test_harness::localstack::start_localstack().await?);
let localstack_instance: Arc<LocalstackContainer> =
Arc::new(test_harness::localstack::start_localstack().await?);
let host_port = localstack_instance.host_port;
(Some(localstack_instance), host_port)
};
Expand Down Expand Up @@ -609,7 +741,12 @@ async fn insert_into_pbs_computations(
async fn clean_up(pool: &sqlx::PgPool) -> anyhow::Result<()> {
truncate_tables(
pool,
vec!["pbs_computations", "ciphertexts", "ciphertext_digest"],
vec![
"pbs_computations",
"ciphertexts",
"ciphertexts128",
"ciphertext_digest",
],
)
.await?;

Expand All @@ -628,10 +765,11 @@ async fn assert_ciphertext128(
with_compression: bool,
handle: &Vec<u8>,
expected_value: i64,
enable_s3_assert: bool,
) -> anyhow::Result<()> {
let pool = &test_env.pool;
let client_key = &test_env.client_key;
let ct = test_harness::db_utils::wait_for_ciphertext(pool, tenant_id, handle, 100).await?;
let ct = test_harness::db_utils::wait_for_ciphertext(pool, tenant_id, handle, 1000).await?;

info!("Ciphertext len: {:?}", ct.len());

Expand Down Expand Up @@ -678,16 +816,18 @@ async fn assert_ciphertext128(

#[cfg(feature = "test_s3_use_handle_as_key")]
{
info!("Asserting ciphertext uploaded to S3");

assert_ciphertext_uploaded(
test_env,
&test_env.conf.s3.bucket_ct128,
handle,
Some(ct.len() as i64),
)
.await;
assert_ciphertext_uploaded(test_env, &test_env.conf.s3.bucket_ct64, handle, None).await;
if enable_s3_assert {
info!("Asserting ciphertext uploaded to S3");

assert_ciphertext_uploaded(
test_env,
&test_env.conf.s3.bucket_ct128,
handle,
Some(ct.len() as i64),
)
.await;
assert_ciphertext_uploaded(test_env, &test_env.conf.s3.bucket_ct64, handle, None).await;
}
}

Ok(())
Expand Down Expand Up @@ -759,7 +899,7 @@ fn build_test_config(url: DatabaseURL, enable_compression: bool) -> Config {
listen_channels: vec![LISTEN_CHANNEL.to_string()],
notify_channel: "fhevm".to_string(),
batch_limit,
gc_batch_limit: 0,
gc_batch_limit: 0, // Disable automatic garbage collection in tests
polling_interval: 60000,
cleanup_interval: Duration::from_hours(10),
max_connections: 5,
Expand All @@ -771,11 +911,11 @@ fn build_test_config(url: DatabaseURL, enable_compression: bool) -> Config {
bucket_ct64: "ct64".to_owned(),
max_concurrent_uploads: 2000,
retry_policy: S3RetryPolicy {
max_retries_per_upload: 100,
max_retries_per_upload: 1,
max_backoff: Duration::from_secs(10),
max_retries_timeout: Duration::from_secs(120),
max_retries_timeout: Duration::from_secs(3),
recheck_duration: Duration::from_secs(2),
regular_recheck_duration: Duration::from_secs(120),
regular_recheck_duration: Duration::from_secs(3),
},
},
service_name: "".to_owned(),
Expand Down
7 changes: 6 additions & 1 deletion coprocessor/fhevm-engine/test-harness/src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ async fn setup_test_app_existing_localhost(

if with_reset {
info!("Resetting local database at {db_url}");
let admin_db_url = db_url.to_string().replace("coprocessor", "postgres");
let admin_db_url = db_url
.as_str()
.to_string()
.replace("coprocessor", "postgres");

info!("Admin DB URL: {admin_db_url}");
create_database(&admin_db_url, db_url.as_str(), mode).await?;
}

Expand Down
3 changes: 2 additions & 1 deletion coprocessor/fhevm-engine/test-harness/src/localstack.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::net::TcpListener;
use std::{net::TcpListener, time::Duration};

use alloy::signers::k256::pkcs8::EncodePrivateKey;
use aws_config::BehaviorVersion;
Expand Down Expand Up @@ -28,6 +28,7 @@ pub async fn start_localstack() -> anyhow::Result<LocalstackContainer> {
.with_exposed_port(LOCALSTACK_PORT.into())
.with_wait_for(WaitFor::message_on_stdout("Ready."))
.with_mapped_port(host_port, LOCALSTACK_PORT.into())
.with_startup_timeout(Duration::from_hours(2))
.start()
.await?;
Ok(LocalstackContainer {
Expand Down
Loading