Skip to content

Commit 54c8c5e

Browse files
enzodimariaagnesLeroy
authored andcommitted
chore(gpu): no crash with aes benches if oom error
1 parent 164fc26 commit 54c8c5e

File tree

2 files changed

+94
-65
lines changed

2 files changed

+94
-65
lines changed

tfhe-benchmark/benches/integer/aes.rs

Lines changed: 47 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ pub mod cuda {
33
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
44
use benchmark::utilities::{write_to_json, OperatorType};
55
use criterion::{black_box, Criterion};
6-
use tfhe::core_crypto::gpu::CudaStreams;
6+
use tfhe::core_crypto::gpu::{check_valid_cuda_malloc, CudaStreams};
77
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
88
use tfhe::integer::gpu::CudaServerKey;
99
use tfhe::integer::keycache::KEY_CACHE;
@@ -102,37 +102,52 @@ pub mod cuda {
102102
let sks = CudaServerKey::new(&cpu_cks, &streams);
103103
let cks = RadixClientKey::from((cpu_cks, 1));
104104

105-
let ct_key = cks.encrypt_u128_for_aes_ctr(key);
106-
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
107-
108-
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
109-
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
110-
111-
let round_keys = sks.key_expansion(&d_key, &streams);
112-
113-
println!("{bench_id}");
114-
bench_group.bench_function(&bench_id, |b| {
115-
b.iter(|| {
116-
black_box(sks.aes_encrypt(
117-
&d_iv,
118-
&round_keys,
119-
0,
120-
NUM_AES_INPUTS,
121-
SBOX_PARALLELISM,
122-
&streams,
123-
));
124-
})
125-
});
126-
127-
write_to_json::<u64, _>(
128-
&bench_id,
129-
atomic_param,
130-
param.name(),
131-
"aes_encryption",
132-
&OperatorType::Atomic,
133-
aes_op_bit_size,
134-
vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize],
135-
);
105+
//
106+
// Memory checks
107+
//
108+
let gpu_index = streams.gpu_indexes[0];
109+
110+
let key_expansion_size = sks.get_key_expansion_size_on_gpu(&streams);
111+
let aes_encrypt_size =
112+
sks.get_aes_encrypt_size_on_gpu(NUM_AES_INPUTS, SBOX_PARALLELISM, &streams);
113+
114+
if check_valid_cuda_malloc(key_expansion_size, gpu_index)
115+
&& check_valid_cuda_malloc(aes_encrypt_size, gpu_index)
116+
{
117+
let ct_key = cks.encrypt_u128_for_aes_ctr(key);
118+
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
119+
120+
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
121+
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
122+
123+
let round_keys = sks.key_expansion(&d_key, &streams);
124+
125+
println!("{bench_id}");
126+
bench_group.bench_function(&bench_id, |b| {
127+
b.iter(|| {
128+
black_box(sks.aes_encrypt(
129+
&d_iv,
130+
&round_keys,
131+
0,
132+
NUM_AES_INPUTS,
133+
SBOX_PARALLELISM,
134+
&streams,
135+
));
136+
})
137+
});
138+
139+
write_to_json::<u64, _>(
140+
&bench_id,
141+
atomic_param,
142+
param.name(),
143+
"aes_encryption",
144+
&OperatorType::Atomic,
145+
aes_op_bit_size,
146+
vec![atomic_param.message_modulus().0.ilog2(); aes_op_bit_size as usize],
147+
);
148+
} else {
149+
println!("{} skipped: Not enough memory in GPU", bench_id);
150+
}
136151
}
137152

138153
bench_group.finish();

tfhe-benchmark/benches/integer/aes256.rs

Lines changed: 47 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ pub mod cuda {
33
use benchmark::params_aliases::BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
44
use benchmark::utilities::{write_to_json, OperatorType};
55
use criterion::{black_box, Criterion};
6-
use tfhe::core_crypto::gpu::CudaStreams;
6+
use tfhe::core_crypto::gpu::{check_valid_cuda_malloc, CudaStreams};
77
use tfhe::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
88
use tfhe::integer::gpu::CudaServerKey;
99
use tfhe::integer::keycache::KEY_CACHE;
@@ -106,38 +106,52 @@ pub mod cuda {
106106
let sks = CudaServerKey::new(&cpu_cks, &streams);
107107
let cks = RadixClientKey::from((cpu_cks, 1));
108108

109-
let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo);
110-
111-
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
112-
113-
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
114-
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
115-
116-
let round_keys = sks.key_expansion_256(&d_key, &streams);
117-
118-
println!("{bench_id}");
119-
bench_group.bench_function(&bench_id, |b| {
120-
b.iter(|| {
121-
black_box(sks.aes_256_encrypt(
122-
&d_iv,
123-
&round_keys,
124-
0,
125-
NUM_AES_INPUTS,
126-
SBOX_PARALLELISM,
127-
&streams,
128-
));
129-
})
130-
});
131-
132-
write_to_json::<u64, _>(
133-
&bench_id,
134-
atomic_param,
135-
param.name(),
136-
"aes_256_encryption",
137-
&OperatorType::Atomic,
138-
aes_block_op_bit_size,
139-
vec![atomic_param.message_modulus().0.ilog2(); aes_block_op_bit_size as usize],
140-
);
109+
//
110+
// Memory checks
111+
//
112+
let gpu_index = streams.gpu_indexes[0];
113+
114+
let key_expansion_size = sks.get_key_expansion_256_size_on_gpu(&streams);
115+
let aes_encrypt_size =
116+
sks.get_aes_encrypt_size_on_gpu(NUM_AES_INPUTS, SBOX_PARALLELISM, &streams);
117+
118+
if check_valid_cuda_malloc(key_expansion_size, gpu_index)
119+
&& check_valid_cuda_malloc(aes_encrypt_size, gpu_index)
120+
{
121+
let ct_key = cks.encrypt_2u128_for_aes_ctr_256(key_hi, key_lo);
122+
let ct_iv = cks.encrypt_u128_for_aes_ctr(iv);
123+
124+
let d_key = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_key, &streams);
125+
let d_iv = CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_iv, &streams);
126+
127+
let round_keys = sks.key_expansion_256(&d_key, &streams);
128+
129+
println!("{bench_id}");
130+
bench_group.bench_function(&bench_id, |b| {
131+
b.iter(|| {
132+
black_box(sks.aes_256_encrypt(
133+
&d_iv,
134+
&round_keys,
135+
0,
136+
NUM_AES_INPUTS,
137+
SBOX_PARALLELISM,
138+
&streams,
139+
));
140+
})
141+
});
142+
143+
write_to_json::<u64, _>(
144+
&bench_id,
145+
atomic_param,
146+
param.name(),
147+
"aes_256_encryption",
148+
&OperatorType::Atomic,
149+
aes_block_op_bit_size,
150+
vec![atomic_param.message_modulus().0.ilog2(); aes_block_op_bit_size as usize],
151+
);
152+
} else {
153+
println!("{} skipped: Not enough memory in GPU", bench_id);
154+
}
141155
}
142156

143157
bench_group.finish();

0 commit comments

Comments
 (0)