Skip to content

Commit 0509b28

Browse files
authored
feat(coprocessor): enable GPU backend in SnS worker (#1801)
* feat(coprocessor): enable GPU support in SnS worker * chore(coprocessor): logs if gpu or cpu backend is enabled * chore(coprocessor): enable test-harness/gpu feature * chore(coprocessor): update sns_pk in setup_test_user * chore(coprocessor): log if gpu or cpu backend is enabled, tfhe-worker * chore(coprocessor): disable s3 asserts in SnS unit-testing on GPU backend * chore(coprocessor): scope gpu testing to gpu-dependent crates
1 parent 923dd06 commit 0509b28

File tree

8 files changed

+102
-13
lines changed

8 files changed

+102
-13
lines changed

.github/workflows/coprocessor-gpu-tests.yml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,18 @@ jobs:
185185
env:
186186
HARDHAT_NETWORK: hardhat
187187

188-
- name: Run tests on GPU
188+
- name: Run GPU tests for the worker services.
189189
run: |
190-
DATABASE_URL=postgresql://postgres:postgres@localhost:5432/coprocessor cargo test --release --features=gpu -- --test-threads=1
190+
export DATABASE_URL=postgresql://postgres:postgres@localhost:5432/coprocessor
191+
cargo test \
192+
-p tfhe-worker \
193+
-p sns-worker \
194+
--release \
195+
--features=gpu \
196+
-- \
197+
--test-threads=1
191198
working-directory: coprocessor/fhevm-engine
192199

193-
194200
teardown-instance:
195201
name: coprocessor-gpu-tests/teardown
196202
if: ${{ always() && needs.setup-instance.result == 'success' }}

coprocessor/fhevm-engine/fhevm-engine-common/src/utils.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,27 @@ impl FromStr for DatabaseURL {
244244
Ok(Self(s.to_owned()))
245245
}
246246
}
247+
248+
/// Logs whether the GPU backend is enabled or not.
249+
pub fn log_backend() -> bool {
250+
log_backend_impl()
251+
}
252+
253+
#[cfg(feature = "gpu")]
254+
fn log_backend_impl() -> bool {
255+
use tfhe::core_crypto::gpu::{get_number_of_gpus, get_number_of_sms};
256+
let num_gpus = get_number_of_gpus();
257+
let streaming_multiprocessors = get_number_of_sms();
258+
tracing::info!(
259+
num_gpus,
260+
streaming_multiprocessors,
261+
"GPU feature is enabled"
262+
);
263+
true
264+
}
265+
266+
#[cfg(not(feature = "gpu"))]
267+
fn log_backend_impl() -> bool {
268+
tracing::info!("GPU feature is disabled, using CPU backend");
269+
false
270+
}

coprocessor/fhevm-engine/sns-worker/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ name = "sns_worker"
5050
path = "src/bin/sns_worker.rs"
5151

5252
[features]
53+
gpu = ["tfhe/gpu", "fhevm-engine-common/gpu", "test-harness/gpu"]
5354
test_decrypt_128 = []
5455
test_s3_use_handle_as_key = []
5556

coprocessor/fhevm-engine/sns-worker/src/keyset.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ pub(crate) async fn fetch_keyset(
2323
}
2424

2525
info!(tenant_api_key, "Cache miss");
26+
2627
let Some((client_key, server_key)) = fetch_keys(pool, tenant_api_key).await? else {
2728
return Ok(None);
2829
};
30+
2931
let key_set: KeySet = KeySet {
3032
client_key,
3133
server_key,
@@ -45,21 +47,35 @@ pub(crate) async fn fetch_keyset(
4547
pub async fn fetch_keys(
4648
pool: &PgPool,
4749
tenant_api_key: &String,
48-
) -> anyhow::Result<Option<(Option<tfhe::ClientKey>, tfhe::ServerKey)>> {
50+
) -> anyhow::Result<Option<(Option<tfhe::ClientKey>, crate::ServerKey)>> {
4951
let blob = read_keys_from_large_object(
5052
pool,
5153
tenant_api_key,
5254
"sns_pk",
5355
SKS_KEY_WITH_NOISE_SQUASHING_SIZE,
5456
)
5557
.await?;
56-
info!(bytes_len = blob.len(), "Retrieved sns_pk");
58+
info!(
59+
bytes_len = blob.len(),
60+
"Fetched sns_pk/sks_ns bytes from LOB"
61+
);
5762
if blob.is_empty() {
5863
return Ok(None);
5964
}
6065

66+
#[cfg(not(feature = "gpu"))]
6167
let server_key: tfhe::ServerKey = safe_deserialize_sns_key(&blob)?;
6268

69+
#[cfg(feature = "gpu")]
70+
let server_key = {
71+
let compressed_server_key: tfhe::CompressedServerKey = safe_deserialize_sns_key(&blob)?;
72+
info!("Deserialized sns_pk/sks_ns to CompressedServerKey");
73+
74+
let server_key = compressed_server_key.decompress_to_gpu();
75+
info!("Decompressed sns_pk/sks_ns to CudaServerKey");
76+
server_key
77+
};
78+
6379
// Optionally retrieve the ClientKey for testing purposes
6480
let client_key = fetch_client_key(pool, tenant_api_key).await?;
6581
Ok(Some((client_key, server_key)))

coprocessor/fhevm-engine/sns-worker/src/lib.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ use fhevm_engine_common::{
2727
utils::{to_hex, DatabaseURL},
2828
};
2929
use futures::join;
30-
use serde::{Deserialize, Serialize};
3130
use sqlx::{Postgres, Transaction};
3231
use thiserror::Error;
3332
use tokio::{
@@ -51,10 +50,16 @@ pub const UPLOAD_QUEUE_SIZE: usize = 20;
5150
pub const SAFE_SER_LIMIT: u64 = 1024 * 1024 * 66;
5251
pub type InternalEvents = Option<tokio::sync::mpsc::Sender<&'static str>>;
5352

54-
#[derive(Serialize, Deserialize, Clone)]
53+
#[cfg(feature = "gpu")]
54+
type ServerKey = tfhe::CudaServerKey;
55+
#[cfg(not(feature = "gpu"))]
56+
type ServerKey = tfhe::ServerKey;
57+
58+
#[derive(Clone)]
5559
pub struct KeySet {
56-
pub server_key: tfhe::ServerKey,
60+
/// Optional ClientKey for decrypting on testing
5761
pub client_key: Option<tfhe::ClientKey>,
62+
pub server_key: ServerKey,
5863
}
5964

6065
#[derive(Clone)]
@@ -117,7 +122,7 @@ pub struct Config {
117122
pub pg_auto_explain_with_min_duration: Option<Duration>,
118123
}
119124

120-
#[derive(Debug, Clone, Copy, Default)]
125+
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
121126
pub enum SchedulePolicy {
122127
Sequential,
123128
#[default]
@@ -501,7 +506,8 @@ pub async fn run_all(
501506
mpsc::channel::<UploadJob>(10 * config.s3.max_concurrent_uploads as usize);
502507

503508
let rayon_threads = rayon::current_num_threads();
504-
info!(config = %config, rayon_threads, "Starting SNS worker");
509+
let gpu_enabled = fhevm_engine_common::utils::log_backend();
510+
info!(gpu_enabled, rayon_threads, config = %config, "Starting SNS worker");
505511

506512
if !config.service_name.is_empty() {
507513
if let Err(err) = telemetry::setup_otlp(&config.service_name) {

coprocessor/fhevm-engine/sns-worker/src/tests/mod.rs

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ async fn run_batch_computations(
231231

232232
#[tokio::test]
233233
#[serial(db)]
234+
#[cfg(not(feature = "gpu"))]
234235
async fn test_lifo_mode() {
235236
init_tracing();
236237

@@ -311,6 +312,7 @@ async fn test_lifo_mode() {
311312

312313
#[tokio::test]
313314
#[serial(db)]
315+
#[cfg(not(feature = "gpu"))]
314316
async fn test_garbage_collect() {
315317
init_tracing();
316318

@@ -433,7 +435,15 @@ async fn setup(enable_compression: bool) -> anyhow::Result<TestEnvironment> {
433435
.await?;
434436

435437
// Set up S3 storage
436-
let (s3_instance, s3_client) = setup_localstack(&conf).await?;
438+
let (s3_instance, s3_client) = if cfg!(feature = "gpu") {
439+
info!("GPU feature is enabled, avoid testing S3-related functionality");
440+
(
441+
None,
442+
aws_sdk_s3::Client::new(&aws_config::load_defaults(BehaviorVersion::latest()).await),
443+
)
444+
} else {
445+
setup_localstack(&conf).await?
446+
};
437447

438448
let token = db_instance.parent_token.child_token();
439449
let config: Config = conf.clone();
@@ -684,6 +694,7 @@ async fn assert_ciphertext128(
684694
}
685695

686696
/// Asserts that ciphertext exists in S3
697+
#[cfg(not(feature = "gpu"))]
687698
async fn assert_ciphertext_uploaded(
688699
test_env: &TestEnvironment,
689700
bucket: &String,
@@ -700,7 +711,18 @@ async fn assert_ciphertext_uploaded(
700711
.await;
701712
}
702713

714+
#[cfg(feature = "gpu")]
715+
async fn assert_ciphertext_uploaded(
716+
_test_env: &TestEnvironment,
717+
_bucket: &String,
718+
_handle: &Vec<u8>,
719+
_expected_ct_len: Option<i64>,
720+
) {
721+
// No-op when GPU feature is enabled
722+
}
723+
703724
/// Asserts that the number of ciphertext128 objects in S3 matches the expected count
725+
#[cfg(not(feature = "gpu"))]
704726
async fn assert_ciphertext_s3_object_count(
705727
test_env: &TestEnvironment,
706728
bucket: &String,
@@ -710,6 +732,15 @@ async fn assert_ciphertext_s3_object_count(
710732
.await;
711733
}
712734

735+
#[cfg(feature = "gpu")]
736+
async fn assert_ciphertext_s3_object_count(
737+
_te: &TestEnvironment,
738+
_bucket: &String,
739+
_expected_count: i64,
740+
) {
741+
// No-op when GPU feature is enabled
742+
}
743+
713744
fn build_test_config(url: DatabaseURL, enable_compression: bool) -> Config {
714745
let batch_limit = std::env::var("BATCH_LIMIT")
715746
.ok()

coprocessor/fhevm-engine/test-harness/src/db_utils.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,9 @@ pub async fn setup_test_user(
138138
pool: &sqlx::PgPool,
139139
with_sns_pk: bool,
140140
) -> Result<(), Box<dyn std::error::Error>> {
141+
let gpu_enabled = cfg!(feature = "gpu");
142+
info!(gpu_enabled, "Setting up test user...");
143+
141144
let (sks, cks, pks, pp, sns_pk) = if !cfg!(feature = "gpu") {
142145
(
143146
"../fhevm-keys/sks",
@@ -152,7 +155,7 @@ pub async fn setup_test_user(
152155
"../fhevm-keys/gpu-cks",
153156
"../fhevm-keys/gpu-pks",
154157
"../fhevm-keys/gpu-pp",
155-
"../fhevm-keys/sns_pk",
158+
"../fhevm-keys/gpu-csks",
156159
)
157160
};
158161
let sks = tokio::fs::read(sks).await.expect("can't read sks key");

coprocessor/fhevm-engine/tfhe-worker/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ pub async fn async_main(
8181
}
8282

8383
if args.run_bg_worker {
84-
info!(target: "async_main", "Initializing background worker");
84+
let gpu_enabled = fhevm_engine_common::utils::log_backend();
85+
info!(target: "async_main", gpu_enabled, "Initializing background worker");
86+
8587
set.spawn(tfhe_worker::run_tfhe_worker(
8688
args.clone(),
8789
health_check.clone(),

0 commit comments

Comments
 (0)