2121from ...evaluation .components .component import Component
2222from ...evaluation .data_model .evaluation_dataset import EvaluationDataset
2323from ...evaluation .data_model .evaluation_score import EvaluationScore , PrivacyGrade
24+ from ...evaluation .nearest_neighbors import NearestNeighborSearch
2425from ...observability import get_logger
2526from . import multi_modal_figures as figures
2627
27- faiss_available = False
28- try :
29- import faiss
30-
31- faiss_available = True
32- except (ImportError , ModuleNotFoundError ):
33- pass
34-
35-
3628logger = get_logger (__name__ )
3729
3830
@@ -79,9 +71,6 @@ def from_evaluation_dataset(
7971 evaluation_dataset : EvaluationDataset , config : SafeSynthesizerParameters | None = None
8072 ) -> MembershipInferenceProtection :
8173 """Run the membership inference attack and return the protection score."""
82- if not faiss_available :
83- return MembershipInferenceProtection (score = EvaluationScore ())
84-
8574 score , attack_sum_df , tps_values , fps_values = MembershipInferenceProtection .mia (
8675 df_train = evaluation_dataset .reference ,
8776 df_synth = evaluation_dataset .output ,
@@ -249,7 +238,7 @@ def _compute_mia(
249238 df_train_norm : pd .DataFrame ,
250239 df_test_norm : pd .DataFrame ,
251240 df_synth_norm : pd .DataFrame ,
252- index : faiss . IndexFlatL2 | None , # ty: ignore[possibly-unbound-attribute]
241+ nn_index : NearestNeighborSearch | None ,
253242 run : int ,
254243 text_cnt : int ,
255244 tabular_cnt : int ,
@@ -263,14 +252,14 @@ def _compute_mia(
263252
264253 Builds an attack dataset from a slice of training rows mixed with
265254 test rows, computes nearest-neighbor distances to the synthetic
266- data (text via semantic search, tabular via FAISS L2 ), and
255+ data (text via semantic search, tabular via L2 nearest neighbor ), and
267256 classifies each record as member or non-member.
268257
269258 Args:
270259 df_train_norm: Normalized training dataframe.
271260 df_test_norm: Normalized holdout (test) dataframe.
272261 df_synth_norm: Normalized synthetic dataframe.
273- index : Pre-built FAISS L2 index over the tabular columns of
262+ nn_index : Pre-built NearestNeighborSearch index over the tabular columns of
274263 the synthetic data, or ``None`` if no tabular columns exist.
275264 run: Zero-based run index controlling which training slice to use.
276265 text_cnt: Number of text columns in the dataset.
@@ -334,18 +323,16 @@ def _compute_mia(
334323 attacker_data_tabular = real_data .copy ()
335324 k = 1
336325
337- if index is None :
338- raise RuntimeError ("faiss index not provided for MIA calculation when expected." )
326+ if nn_index is None :
327+ raise RuntimeError ("Nearest neighbor index not provided for MIA calculation when expected." )
339328
340- # This usage matches documentation despite type annotation for
341- # IndexFlatL2.search, possibly related to swig handling that ty is
342- # not aware of. Similar for other calls for faiss indexes.
343- dists , indices = index .search (
344- np .float32 (np .ascontiguousarray (np .array (attacker_data_tabular ))),
345- len (df_synth_norm ),
346- ) # ty: ignore[missing-argument]
329+ # Use nearest neighbor search (torch GPU or sklearn CPU fallback) for distance calculation
330+ dists , indices = nn_index .kneighbors (
331+ np .ascontiguousarray (np .array (attacker_data_tabular )).astype (np .float32 ),
332+ n_neighbors = int (k ),
333+ )
347334 # Scale the Euclidean distance to [0,1]
348- dists = np . sqrt ( dists )
335+ # NearestNeighborSearch.kneighbors() returns L2 distance directly, not squared
349336 max_dist = np .amax (dists )
350337 if max_dist > 0 :
351338 dist_scaled = dists / max_dist
@@ -555,15 +542,14 @@ def mia(
555542 df_train_norm , df_test_norm , df_synth_norm = MembershipInferenceProtection ._normalize_onehot (
556543 df_train_use , df_test , df_synth
557544 )
558- # Create the faiss index on the synthetic tabular data
559- dim = df_synth_norm .shape [1 ]
560- index = faiss .IndexFlatL2 (dim ) # ty: ignore[possibly-unbound-attribute]
561- index .add (np .float32 (np .ascontiguousarray (np .array (df_synth_norm )))) # ty: ignore[missing-argument]
545+ # Create nearest neighbor index on the synthetic tabular data (torch GPU or sklearn CPU fallback)
546+ nn_index = NearestNeighborSearch (n_neighbors = len (df_synth_norm ))
547+ nn_index .fit (np .ascontiguousarray (np .array (df_synth_norm )).astype (np .float32 ))
562548 else :
563549 df_train_norm = pd .DataFrame ()
564550 df_test_norm = pd .DataFrame ()
565551 df_synth_norm = pd .DataFrame ()
566- index = None
552+ nn_index = None
567553
568554 # Create embeddings for text fields and combine the normalized tabular and the
569555 # new text embeddings into one dataframe.
@@ -588,7 +574,7 @@ def mia(
588574 df_train_norm ,
589575 df_test_norm ,
590576 df_synth_norm ,
591- index ,
577+ nn_index ,
592578 i ,
593579 text_cnt ,
594580 tabular_cnt ,
0 commit comments