From b67d4ab4f7bd35e98f694ec73f17d7f5d798dace Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Wed, 22 Jan 2025 11:37:08 +0100 Subject: [PATCH] fix(shortint): add ciphertext_modulus_after_packing_ks to compression parameters --- .../algorithms/lwe_packing_keyswitch.rs | 29 +++++----- tfhe/src/integer/gpu/client_key/radix.rs | 2 +- .../parameters/list_compression.rs | 57 ++++++++++++++++++- .../compressed_server_keys.rs | 2 +- .../shortint/list_compression/compression.rs | 37 ++++++++---- .../shortint/list_compression/server_keys.rs | 7 ++- .../shortint/parameters/list_compression.rs | 3 + .../parameters/v0_10/list_compression.rs | 3 +- 8 files changed, 104 insertions(+), 36 deletions(-) diff --git a/tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs b/tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs index 2eaf4e4b4b..5df4530a69 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_packing_keyswitch.rs @@ -135,12 +135,11 @@ pub fn keyswitch_lwe_ciphertext_into_glwe_ciphertext, +} + +impl Upgrade for CompressionParametersV0 { + type Error = Infallible; + + fn upgrade(self) -> Result { + let Self { + br_level, + br_base_log, + packing_ks_level, + packing_ks_base_log, + packing_ks_polynomial_size, + packing_ks_glwe_dimension, + lwe_per_glwe, + storage_log_modulus, + packing_ks_key_noise_distribution, + } = self; + + Ok(CompressionParameters { + br_level, + br_base_log, + packing_ks_level, + packing_ks_base_log, + packing_ks_polynomial_size, + packing_ks_glwe_dimension, + lwe_per_glwe, + storage_log_modulus, + packing_ks_key_noise_distribution, + ciphertext_modulus_after_packing_ks: CiphertextModulus::new_native(), + }) + } +} #[derive(VersionsDispatch)] pub enum CompressionParametersVersions { - V0(CompressionParameters), + V0(CompressionParametersV0), + V1(CompressionParameters), } diff --git a/tfhe/src/shortint/list_compression/compressed_server_keys.rs b/tfhe/src/shortint/list_compression/compressed_server_keys.rs index f93770827c..41a059f131 100644 --- a/tfhe/src/shortint/list_compression/compressed_server_keys.rs +++ b/tfhe/src/shortint/list_compression/compressed_server_keys.rs @@ -103,7 +103,7 @@ impl ClientKey { params.packing_ks_base_log, params.packing_ks_level, params.packing_ks_key_noise_distribution, - self.parameters.ciphertext_modulus(), + params.ciphertext_modulus_after_packing_ks, &mut engine.seeder, ) }); diff --git a/tfhe/src/shortint/list_compression/compression.rs b/tfhe/src/shortint/list_compression/compression.rs index 8e8a6134e3..9ae4cea4bd 100644 --- a/tfhe/src/shortint/list_compression/compression.rs +++ b/tfhe/src/shortint/list_compression/compression.rs @@ -2,8 +2,9 @@ use super::{CompressionKey, DecompressionKey}; use crate::core_crypto::prelude::compressed_modulus_switched_glwe_ciphertext::CompressedModulusSwitchedGlweCiphertext; use crate::core_crypto::prelude::{ extract_lwe_sample_from_glwe_ciphertext, - par_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext, CiphertextCount, GlweCiphertext, - LweCiphertext, LweCiphertextCount, LweCiphertextList, MonomialDegree, + par_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext, CiphertextCount, + CiphertextModulus, GlweCiphertext, LweCiphertext, LweCiphertextCount, LweCiphertextList, + MonomialDegree, }; use crate::shortint::ciphertext::CompressedCiphertextList; use crate::shortint::engine::ShortintEngine; @@ -11,7 +12,7 @@ use crate::shortint::parameters::{CarryModulus, MessageModulus, NoiseLevel}; use crate::shortint::server_key::{ apply_programmable_bootstrap, generate_lookup_table_with_encoding, unchecked_scalar_mul_assign, }; -use crate::shortint::{Ciphertext, CiphertextModulus, MaxNoiseLevel}; +use crate::shortint::{Ciphertext, MaxNoiseLevel}; use rayon::iter::ParallelIterator; use rayon::slice::ParallelSlice; @@ -25,7 +26,9 @@ impl CompressionKey { let lwe_pksk = &self.packing_key_switching_key; let polynomial_size = lwe_pksk.output_polynomial_size(); - let ciphertext_modulus = lwe_pksk.ciphertext_modulus(); + + let out_ciphertext_modulus = lwe_pksk.ciphertext_modulus(); + let glwe_size = lwe_pksk.output_glwe_size(); let lwe_size = lwe_pksk.input_key_lwe_dimension().to_lwe_size(); @@ -43,6 +46,7 @@ impl CompressionKey { let message_modulus = first_ct.message_modulus; let carry_modulus = first_ct.carry_modulus; let pbs_order = first_ct.pbs_order; + let in_ciphertext_modulus = first_ct.ct.ciphertext_modulus(); assert!( message_modulus.0 <= carry_modulus.0, @@ -86,6 +90,12 @@ impl CompressionKey { "All ciphertexts do not have the same pbs order" ); + assert_eq!( + in_ciphertext_modulus, + ct.ct.ciphertext_modulus(), + "All ciphertexts do not have the same ciphertext modulus" + ); + let mut ct = ct.clone(); let max_noise_level = MaxNoiseLevel::new((ct.noise_level() * message_modulus.0).get()); @@ -94,12 +104,12 @@ impl CompressionKey { list.extend(ct.ct.as_ref()); } - let list = LweCiphertextList::from_container(list, lwe_size, ciphertext_modulus); + let list = LweCiphertextList::from_container(list, lwe_size, in_ciphertext_modulus); let bodies_count = LweCiphertextCount(ct_list.len()); let mut out = - GlweCiphertext::new(0, glwe_size, polynomial_size, ciphertext_modulus); + GlweCiphertext::new(0, glwe_size, polynomial_size, out_ciphertext_modulus); par_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphertext( lwe_pksk, &list, &mut out, @@ -120,7 +130,7 @@ impl CompressionKey { pbs_order, lwe_per_glwe, count, - ciphertext_modulus, + ciphertext_modulus: out_ciphertext_modulus, } } } @@ -147,6 +157,9 @@ impl DecompressionKey { ))); } + let in_ciphertext_modulus = packed.ciphertext_modulus; + let out_ciphertext_modulus = CiphertextModulus::new_native(); + let encryption_cleartext_modulus = packed.message_modulus.0 * packed.carry_modulus.0; // We multiply by message_modulus during compression so the actual modulus for the // compression is smaller @@ -157,7 +170,7 @@ impl DecompressionKey { let decompression_rescale = generate_lookup_table_with_encoding( self.out_glwe_size(), self.out_polynomial_size(), - packed.ciphertext_modulus, + out_ciphertext_modulus, // Input moduli are the effective compression ones effective_compression_message_modulus, effective_compression_carry_modulus, @@ -172,7 +185,6 @@ impl DecompressionKey { ); let polynomial_size = packed.modulus_switched_glwe_ciphertext_list[0].polynomial_size(); - let ciphertext_modulus = packed.ciphertext_modulus; let glwe_dimension = packed.modulus_switched_glwe_ciphertext_list[0].glwe_dimension(); let lwe_per_glwe = packed.lwe_per_glwe.0; @@ -187,7 +199,7 @@ impl DecompressionKey { let monomial_degree = MonomialDegree(index % lwe_per_glwe); - let mut intermediate_lwe = LweCiphertext::new(0, lwe_size, ciphertext_modulus); + let mut intermediate_lwe = LweCiphertext::new(0, lwe_size, in_ciphertext_modulus); extract_lwe_sample_from_glwe_ciphertext( &packed_glwe, @@ -198,7 +210,7 @@ impl DecompressionKey { let mut output_br = LweCiphertext::new( 0, self.blind_rotate_key.output_lwe_dimension().to_lwe_size(), - ciphertext_modulus, + out_ciphertext_modulus, ); ShortintEngine::with_thread_local_mut(|engine| { @@ -231,6 +243,7 @@ impl DecompressionKey { #[cfg(test)] mod test { use super::*; + use crate::shortint::list_compression::CompressionPrivateKeys; use crate::shortint::parameters::list_compression::COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64; use crate::shortint::{gen_keys, ClientKey}; @@ -241,7 +254,7 @@ mod test { // Generate the client key and the server key: let (cks, _sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); - let private_compression_key: crate::shortint::list_compression::CompressionPrivateKeys = + let private_compression_key: CompressionPrivateKeys = cks.new_compression_private_key(COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64); let (compression_key, decompression_key) = diff --git a/tfhe/src/shortint/list_compression/server_keys.rs b/tfhe/src/shortint/list_compression/server_keys.rs index 44d15ad500..9fba686673 100644 --- a/tfhe/src/shortint/list_compression/server_keys.rs +++ b/tfhe/src/shortint/list_compression/server_keys.rs @@ -65,7 +65,7 @@ impl ClientKey { params.packing_ks_base_log, params.packing_ks_level, params.packing_ks_key_noise_distribution, - self.parameters.ciphertext_modulus(), + params.ciphertext_modulus_after_packing_ks, &mut engine.encryption_generator, ) }); @@ -116,6 +116,7 @@ pub struct CompressionConformanceParameters { pub packing_ks_base_log: DecompositionBaseLog, pub packing_ks_polynomial_size: PolynomialSize, pub packing_ks_glwe_dimension: GlweDimension, + pub ciphertext_modulus_after_packing_ks: CiphertextModulus, pub lwe_per_glwe: LweCiphertextCount, pub storage_log_modulus: CiphertextModulusLog, pub uncompressed_polynomial_size: PolynomialSize, @@ -132,6 +133,8 @@ impl From<(PBSParameters, CompressionParameters)> for CompressionConformancePara packing_ks_base_log: compression_params.packing_ks_base_log, packing_ks_polynomial_size: compression_params.packing_ks_polynomial_size, packing_ks_glwe_dimension: compression_params.packing_ks_glwe_dimension, + ciphertext_modulus_after_packing_ks: compression_params + .ciphertext_modulus_after_packing_ks, lwe_per_glwe: compression_params.lwe_per_glwe, storage_log_modulus: compression_params.storage_log_modulus, uncompressed_polynomial_size: pbs_params.polynomial_size(), @@ -159,7 +162,7 @@ impl ParameterSetConformant for CompressionKey { .to_equivalent_lwe_dimension(parameter_set.uncompressed_polynomial_size), output_glwe_size: parameter_set.packing_ks_glwe_dimension.to_glwe_size(), output_polynomial_size: parameter_set.packing_ks_polynomial_size, - ciphertext_modulus: parameter_set.cipherext_modulus, + ciphertext_modulus: parameter_set.ciphertext_modulus_after_packing_ks, }; packing_key_switching_key.is_conformant(¶ms) diff --git a/tfhe/src/shortint/parameters/list_compression.rs b/tfhe/src/shortint/parameters/list_compression.rs index 6e34507b11..d575009028 100644 --- a/tfhe/src/shortint/parameters/list_compression.rs +++ b/tfhe/src/shortint/parameters/list_compression.rs @@ -1,5 +1,6 @@ use tfhe_versionable::Versionize; +use super::CiphertextModulus; use crate::core_crypto::prelude::{CiphertextModulusLog, LweCiphertextCount}; use crate::shortint::backward_compatibility::parameters::list_compression::CompressionParametersVersions; use crate::shortint::parameters::{ @@ -20,6 +21,7 @@ pub struct CompressionParameters { pub lwe_per_glwe: LweCiphertextCount, pub storage_log_modulus: CiphertextModulusLog, pub packing_ks_key_noise_distribution: DynamicDistribution, + pub ciphertext_modulus_after_packing_ks: CiphertextModulus, } pub const COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: CompressionParameters = @@ -36,4 +38,5 @@ pub const V0_11_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: CompressionPa lwe_per_glwe: LweCiphertextCount(256), storage_log_modulus: CiphertextModulusLog(12), packing_ks_key_noise_distribution: DynamicDistribution::new_t_uniform(43), + ciphertext_modulus_after_packing_ks: CiphertextModulus::new_native(), }; diff --git a/tfhe/src/shortint/parameters/v0_10/list_compression.rs b/tfhe/src/shortint/parameters/v0_10/list_compression.rs index 7f47eecb64..1ff009229a 100644 --- a/tfhe/src/shortint/parameters/v0_10/list_compression.rs +++ b/tfhe/src/shortint/parameters/v0_10/list_compression.rs @@ -1,4 +1,4 @@ -use crate::core_crypto::prelude::{CiphertextModulusLog, LweCiphertextCount}; +use crate::core_crypto::prelude::{CiphertextModulus, CiphertextModulusLog, LweCiphertextCount}; use crate::shortint::parameters::{ CompressionParameters, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, PolynomialSize, @@ -15,4 +15,5 @@ pub const V0_10_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64: CompressionPa lwe_per_glwe: LweCiphertextCount(256), storage_log_modulus: CiphertextModulusLog(12), packing_ks_key_noise_distribution: DynamicDistribution::new_t_uniform(42), + ciphertext_modulus_after_packing_ks: CiphertextModulus::new_native(), };