@@ -3,7 +3,7 @@ pub mod cuda {
33 use benchmark:: params_aliases:: BENCH_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128 ;
44 use benchmark:: utilities:: { write_to_json, OperatorType } ;
55 use criterion:: { black_box, Criterion } ;
6- use tfhe:: core_crypto:: gpu:: CudaStreams ;
6+ use tfhe:: core_crypto:: gpu:: { check_valid_cuda_malloc , CudaStreams } ;
77 use tfhe:: integer:: gpu:: ciphertext:: CudaUnsignedRadixCiphertext ;
88 use tfhe:: integer:: gpu:: CudaServerKey ;
99 use tfhe:: integer:: keycache:: KEY_CACHE ;
@@ -106,38 +106,52 @@ pub mod cuda {
106106 let sks = CudaServerKey :: new ( & cpu_cks, & streams) ;
107107 let cks = RadixClientKey :: from ( ( cpu_cks, 1 ) ) ;
108108
109- let ct_key = cks. encrypt_2u128_for_aes_ctr_256 ( key_hi, key_lo) ;
110-
111- let ct_iv = cks. encrypt_u128_for_aes_ctr ( iv) ;
112-
113- let d_key = CudaUnsignedRadixCiphertext :: from_radix_ciphertext ( & ct_key, & streams) ;
114- let d_iv = CudaUnsignedRadixCiphertext :: from_radix_ciphertext ( & ct_iv, & streams) ;
115-
116- let round_keys = sks. key_expansion_256 ( & d_key, & streams) ;
117-
118- println ! ( "{bench_id}" ) ;
119- bench_group. bench_function ( & bench_id, |b| {
120- b. iter ( || {
121- black_box ( sks. aes_256_encrypt (
122- & d_iv,
123- & round_keys,
124- 0 ,
125- NUM_AES_INPUTS ,
126- SBOX_PARALLELISM ,
127- & streams,
128- ) ) ;
129- } )
130- } ) ;
131-
132- write_to_json :: < u64 , _ > (
133- & bench_id,
134- atomic_param,
135- param. name ( ) ,
136- "aes_256_encryption" ,
137- & OperatorType :: Atomic ,
138- aes_block_op_bit_size,
139- vec ! [ atomic_param. message_modulus( ) . 0 . ilog2( ) ; aes_block_op_bit_size as usize ] ,
140- ) ;
109+ //
110+ // Memory checks
111+ //
112+ let gpu_index = streams. gpu_indexes [ 0 ] ;
113+
114+ let key_expansion_size = sks. get_key_expansion_256_size_on_gpu ( & streams) ;
115+ let aes_encrypt_size =
116+ sks. get_aes_encrypt_size_on_gpu ( NUM_AES_INPUTS , SBOX_PARALLELISM , & streams) ;
117+
118+ if check_valid_cuda_malloc ( key_expansion_size, gpu_index)
119+ && check_valid_cuda_malloc ( aes_encrypt_size, gpu_index)
120+ {
121+ let ct_key = cks. encrypt_2u128_for_aes_ctr_256 ( key_hi, key_lo) ;
122+ let ct_iv = cks. encrypt_u128_for_aes_ctr ( iv) ;
123+
124+ let d_key = CudaUnsignedRadixCiphertext :: from_radix_ciphertext ( & ct_key, & streams) ;
125+ let d_iv = CudaUnsignedRadixCiphertext :: from_radix_ciphertext ( & ct_iv, & streams) ;
126+
127+ let round_keys = sks. key_expansion_256 ( & d_key, & streams) ;
128+
129+ println ! ( "{bench_id}" ) ;
130+ bench_group. bench_function ( & bench_id, |b| {
131+ b. iter ( || {
132+ black_box ( sks. aes_256_encrypt (
133+ & d_iv,
134+ & round_keys,
135+ 0 ,
136+ NUM_AES_INPUTS ,
137+ SBOX_PARALLELISM ,
138+ & streams,
139+ ) ) ;
140+ } )
141+ } ) ;
142+
143+ write_to_json :: < u64 , _ > (
144+ & bench_id,
145+ atomic_param,
146+ param. name ( ) ,
147+ "aes_256_encryption" ,
148+ & OperatorType :: Atomic ,
149+ aes_block_op_bit_size,
150+ vec ! [ atomic_param. message_modulus( ) . 0 . ilog2( ) ; aes_block_op_bit_size as usize ] ,
151+ ) ;
152+ } else {
153+ println ! ( "{} skipped: Not enough memory in GPU" , bench_id) ;
154+ }
141155 }
142156
143157 bench_group. finish ( ) ;
0 commit comments