Skip to content

Commit 7e81bfb

Browse files
committed
fix: replace faiss with torch + sklearn
* Use a torch impl with an sklearn fallback Signed-off-by: Matt Kornfield <mkornfield@nvidia.com>
1 parent 78102fd commit 7e81bfb

12 files changed

Lines changed: 949 additions & 127 deletions

pyproject.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ engine = [
9898
"pandas>=2.1.3, <3",
9999
"plotly",
100100
"ratelimit",
101+
"scikit-learn",
101102
"range_regex>=0.1.0",
102103
"tenacity==9.1.4",
103104
"tiktoken>=0.7.0,<1.0",
@@ -119,7 +120,6 @@ cpu = [
119120
"bitsandbytes==0.49.1",
120121
"flashinfer-python==0.6.1; sys_platform=='linux'",
121122
"flashinfer-cubin==0.6.1; sys_platform=='linux'",
122-
"faiss-cpu==1.13.2",
123123
"gliner",
124124
"kernels>=0.12.1",
125125
"peft",
@@ -141,7 +141,6 @@ cpu = [
141141
cu128 = [
142142
"accelerate",
143143
"bitsandbytes==0.49.1",
144-
"faiss-gpu-cu12==1.13.2; sys_platform == 'linux'",
145144
"flashinfer-python==0.6.1; sys_platform == 'linux'",
146145
"flashinfer-cubin==0.6.1; sys_platform == 'linux'",
147146
"flashinfer-jit-cache==0.6.1+cu128; sys_platform == 'linux'",
@@ -225,10 +224,9 @@ constraint-dependencies = ["torch==2.9.1", "regex==2025.07.34", "pandas<3"]
225224
flashinfer-jit-cache = [
226225
{ index = "flashinfer-jit-cache", marker = "sys_platform=='linux'", extra="cu128"},
227226
]
228-
nvidia-cublas-cu12 = [
227+
nvidia-cublas-cu12 = [
229228
{ index = "pytorch-cu128" },
230229
]
231-
232230
nvidia-cuda-cupti-cu12 = [
233231
{ index = "pytorch-cu128" },
234232
]
@@ -297,6 +295,7 @@ name = "nvidia-pypi-public"
297295
url = "https://pypi.nvidia.com"
298296
explicit = true
299297

298+
300299
[build-system]
301300
requires = ["hatchling", "uv-dynamic-versioning"]
302301
build-backend = "hatchling.build"

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ markers =
2222
smollm2: SmolLM2 Hub download tests (used by Makefile for process isolation)
2323
unsloth: Unsloth backend tests (process-isolated from DP tests)
2424
noautouse: Marker to skip autouse fixtures for specific tests
25+
benchmark: Test the performance of the code
2526

2627
# Note: Unit tests (testing single classes/functions with no infrastructure dependencies)
2728
# do not need markers and are the default test type.

src/nemo_safe_synthesizer/evaluation/components/attribute_inference_protection.py

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,9 @@
2727
from ..components.component import Component
2828
from ..data_model.evaluation_dataset import EvaluationDataset
2929
from ..data_model.evaluation_score import EvaluationScore, PrivacyGrade
30+
from ..nearest_neighbors import NearestNeighborSearch
3031
from . import multi_modal_figures as figures
3132

32-
faiss_available = False
33-
try:
34-
import faiss
35-
36-
faiss_available = True
37-
except (ImportError, ModuleNotFoundError):
38-
pass
39-
4033
logger = get_logger(__name__)
4134

4235

@@ -77,10 +70,6 @@ def from_evaluation_dataset(
7770
evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
7871
) -> AttributeInferenceProtection:
7972
"""Run the attribute inference attack and return the protection score."""
80-
if not faiss_available:
81-
logger.info("FAISS is not available, skipping Attribute Inference Attack.")
82-
return AttributeInferenceProtection(score=EvaluationScore())
83-
8473
quasi_identifier_count = config.evaluation.quasi_identifier_count if config else QUASI_IDENTIFIER_COUNT
8574

8675
score, col_accuracy_df = AttributeInferenceProtection._aia(
@@ -276,23 +265,17 @@ def _get_synth_nn(
276265
df_train_use, df_synth_use
277266
)
278267

279-
# If all tabular, just use FAISS to get NN
268+
# If all tabular, use nearest neighbor search (torch CUDA or sklearn CPU fallback)
280269
if len(text_columns) == 0:
281-
# Create the faiss index on the synthetic data
282-
dim = df_synth_norm.shape[1]
283-
index = faiss.IndexFlatL2(dim) # ty: ignore[possibly-unbound-attribute]
284-
285-
# This usage matches documentation. Specifying n= and x= parameters as
286-
# the type annotation for IndexFlatL2.add suggests seems unnecessary, possibly related
287-
# to swig handling that ty is not aware of. Similar for other faiss calls below
288-
# which are using swig-generated code.
289-
index.add(np.float32(np.ascontiguousarray(np.array(df_synth_norm)))) # ty: ignore[missing-argument]
270+
# Create the nearest neighbors index on the synthetic data
271+
nn = NearestNeighborSearch(n_neighbors=k)
272+
nn.fit(np.ascontiguousarray(np.array(df_synth_norm)).astype(np.float32))
290273

291274
# Get nearest neighbors to this attack record
292-
_, indexes = index.search(np.float32(np.ascontiguousarray(np.array(df_train_norm))), k) # ty: ignore[missing-argument]
275+
_, indexes = nn.kneighbors(np.ascontiguousarray(np.array(df_train_norm)).astype(np.float32))
293276
synth_rows = pd.DataFrame()
294-
for index in indexes:
295-
synth_rows = pd.concat([synth_rows, df_synth.iloc[index].copy()])
277+
for idx_row in indexes:
278+
synth_rows = pd.concat([synth_rows, df_synth.iloc[idx_row].copy()])
296279
return synth_rows
297280

298281
# If all text, just use Sentence Transformer to get NN
@@ -339,23 +322,20 @@ def _get_synth_nn(
339322
corpus_ids.append(corpus_id)
340323
synth_NN = pd.concat([synth_NN, pd.DataFrame([df_synth_norm.iloc[int(corpus_id)]])], ignore_index=True)
341324

342-
# Now get the tabular similarity for these 1000 NN
343-
344-
dim = synth_NN.shape[1]
345-
index = faiss.IndexFlatL2(dim) # ty: ignore[possibly-unbound-attribute]
346-
index.add(np.float32(np.ascontiguousarray(np.array(synth_NN)))) # ty: ignore[missing-argument]
347-
dists, indexes = index.search(np.float32(np.ascontiguousarray(np.array(df_train_norm))), search_synth_k) # ty: ignore[missing-argument]
325+
# Now get the tabular similarity for these 1000 NN using nearest neighbor search
326+
nn = NearestNeighborSearch(n_neighbors=search_synth_k)
327+
nn.fit(np.ascontiguousarray(np.array(synth_NN)).astype(np.float32))
328+
dists, indexes = nn.kneighbors(np.ascontiguousarray(np.array(df_train_norm)).astype(np.float32))
348329
# Scale the Euclidean distance to [0,1]
349-
dists = np.sqrt(dists)
350330
max_dist = np.amax(dists)
351331
if max_dist > 0:
352332
dist_scaled = dists / max_dist
353333
else:
354334
dist_scaled = dists
355335
tab_dist = {}
356336
for i in range(search_synth_k):
357-
index = indexes[0][i]
358-
tab_dist[index] = dist_scaled[0][i]
337+
idx = indexes[0][i]
338+
tab_dist[idx] = dist_scaled[0][i]
359339

360340
# Now get the hybrid distance
361341

src/nemo_safe_synthesizer/evaluation/components/membership_inference_protection.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,10 @@
2121
from ...evaluation.components.component import Component
2222
from ...evaluation.data_model.evaluation_dataset import EvaluationDataset
2323
from ...evaluation.data_model.evaluation_score import EvaluationScore, PrivacyGrade
24+
from ...evaluation.nearest_neighbors import NearestNeighborSearch
2425
from ...observability import get_logger
2526
from . import multi_modal_figures as figures
2627

27-
faiss_available = False
28-
try:
29-
import faiss
30-
31-
faiss_available = True
32-
except (ImportError, ModuleNotFoundError):
33-
pass
34-
35-
3628
logger = get_logger(__name__)
3729

3830

@@ -79,9 +71,6 @@ def from_evaluation_dataset(
7971
evaluation_dataset: EvaluationDataset, config: SafeSynthesizerParameters | None = None
8072
) -> MembershipInferenceProtection:
8173
"""Run the membership inference attack and return the protection score."""
82-
if not faiss_available:
83-
return MembershipInferenceProtection(score=EvaluationScore())
84-
8574
score, attack_sum_df, tps_values, fps_values = MembershipInferenceProtection.mia(
8675
df_train=evaluation_dataset.reference,
8776
df_synth=evaluation_dataset.output,
@@ -249,7 +238,7 @@ def _compute_mia(
249238
df_train_norm: pd.DataFrame,
250239
df_test_norm: pd.DataFrame,
251240
df_synth_norm: pd.DataFrame,
252-
index: faiss.IndexFlatL2 | None, # ty: ignore[possibly-unbound-attribute]
241+
nn_index: NearestNeighborSearch | None,
253242
run: int,
254243
text_cnt: int,
255244
tabular_cnt: int,
@@ -263,14 +252,14 @@ def _compute_mia(
263252
264253
Builds an attack dataset from a slice of training rows mixed with
265254
test rows, computes nearest-neighbor distances to the synthetic
266-
data (text via semantic search, tabular via FAISS L2), and
255+
data (text via semantic search, tabular via L2 nearest neighbor), and
267256
classifies each record as member or non-member.
268257
269258
Args:
270259
df_train_norm: Normalized training dataframe.
271260
df_test_norm: Normalized holdout (test) dataframe.
272261
df_synth_norm: Normalized synthetic dataframe.
273-
index: Pre-built FAISS L2 index over the tabular columns of
262+
nn_index: Pre-built NearestNeighborSearch index over the tabular columns of
274263
the synthetic data, or ``None`` if no tabular columns exist.
275264
run: Zero-based run index controlling which training slice to use.
276265
text_cnt: Number of text columns in the dataset.
@@ -334,18 +323,16 @@ def _compute_mia(
334323
attacker_data_tabular = real_data.copy()
335324
k = 1
336325

337-
if index is None:
338-
raise RuntimeError("faiss index not provided for MIA calculation when expected.")
326+
if nn_index is None:
327+
raise RuntimeError("Nearest neighbor index not provided for MIA calculation when expected.")
339328

340-
# This usage matches documentation despite type annotation for
341-
# IndexFlatL2.search, possibly related to swig handling that ty is
342-
# not aware of. Similar for other calls for faiss indexes.
343-
dists, indices = index.search(
344-
np.float32(np.ascontiguousarray(np.array(attacker_data_tabular))),
345-
len(df_synth_norm),
346-
) # ty: ignore[missing-argument]
329+
# Use nearest neighbor search (torch GPU or sklearn CPU fallback) for distance calculation
330+
dists, indices = nn_index.kneighbors(
331+
np.ascontiguousarray(np.array(attacker_data_tabular)).astype(np.float32),
332+
n_neighbors=int(k),
333+
)
347334
# Scale the Euclidean distance to [0,1]
348-
dists = np.sqrt(dists)
335+
# NearestNeighborSearch.kneighbors() returns L2 distance directly, not squared
349336
max_dist = np.amax(dists)
350337
if max_dist > 0:
351338
dist_scaled = dists / max_dist
@@ -555,15 +542,14 @@ def mia(
555542
df_train_norm, df_test_norm, df_synth_norm = MembershipInferenceProtection._normalize_onehot(
556543
df_train_use, df_test, df_synth
557544
)
558-
# Create the faiss index on the synthetic tabular data
559-
dim = df_synth_norm.shape[1]
560-
index = faiss.IndexFlatL2(dim) # ty: ignore[possibly-unbound-attribute]
561-
index.add(np.float32(np.ascontiguousarray(np.array(df_synth_norm)))) # ty: ignore[missing-argument]
545+
# Create nearest neighbor index on the synthetic tabular data (torch GPU or sklearn CPU fallback)
546+
nn_index = NearestNeighborSearch(n_neighbors=len(df_synth_norm))
547+
nn_index.fit(np.ascontiguousarray(np.array(df_synth_norm)).astype(np.float32))
562548
else:
563549
df_train_norm = pd.DataFrame()
564550
df_test_norm = pd.DataFrame()
565551
df_synth_norm = pd.DataFrame()
566-
index = None
552+
nn_index = None
567553

568554
# Create embeddings for text fields and combine the normalized tabular and the
569555
# new text embeddings into one dataframe.
@@ -588,7 +574,7 @@ def mia(
588574
df_train_norm,
589575
df_test_norm,
590576
df_synth_norm,
591-
index,
577+
nn_index,
592578
i,
593579
text_cnt,
594580
tabular_cnt,

0 commit comments

Comments
 (0)