4141
4242
4343class AttributeInferenceProtection (Component ):
44+ """Attribute Inference Protection privacy metric.
45+
46+ Simulates an attribute inference attack: given quasi-identifier columns,
47+ can an adversary use synthetic nearest-neighbors to predict the remaining
48+ attributes of a training record? A higher score indicates better
49+ protection (lower prediction accuracy).
50+
51+ See Also:
52+ https://arxiv.org/abs/2501.03941 -- Synthetic Data Privacy Metrics.
53+ """
54+
4455 name : str = Field (default = "Attribute Inference Protection" )
45- col_accuracy_df : pd .DataFrame | None = Field (default = None )
56+ col_accuracy_df : pd .DataFrame | None = Field (
57+ default = None , description = "Per-column prediction risk scores and grades."
58+ )
4659
4760 model_config = ConfigDict (arbitrary_types_allowed = True )
4861
4962 @cached_property
5063 def jinja_context (self ) -> dict [str , str ]:
64+ """Template context with the attribute-inference bar chart figure."""
5165 d = super ().jinja_context
5266 d ["anchor_link" ] = "#aia"
5367 if self .col_accuracy_df is not None and not self .col_accuracy_df .empty :
@@ -62,6 +76,7 @@ def jinja_context(self) -> dict[str, str]:
6276 def from_evaluation_dataset (
6377 evaluation_dataset : EvaluationDataset , config : SafeSynthesizerParameters | None = None
6478 ) -> AttributeInferenceProtection :
79+ """Run the attribute inference attack and return the protection score."""
6580 if not faiss_available :
6681 logger .info ("FAISS is not available, skipping Attribute Inference Attack." )
6782 return AttributeInferenceProtection (score = EvaluationScore ())
@@ -180,7 +195,7 @@ def _is_really_categorical(column: str) -> bool:
180195
181196 @staticmethod
182197 def _divide_tabular_text (df : pd .DataFrame , text_fields : list ) -> tuple [pd .DataFrame , pd .DataFrame ]:
183- """Takes a dataframe and divides it into two dataframes, one with the text fields and one with the tabular fields """
198+ """Split a dataframe into tabular-only and text-only subsets. """
184199 tabular_fields = []
185200 for col in df .columns :
186201 if col not in text_fields :
@@ -192,9 +207,7 @@ def _divide_tabular_text(df: pd.DataFrame, text_fields: list) -> tuple[pd.DataFr
192207
193208 @staticmethod
194209 def _embed_text (df : pd .DataFrame , embedder ) -> pd .DataFrame :
195- """Takes a dataframe of text fields, finds the embeddings for each
196- and then averages the embeddings into one embedding and returns a dataframe with just that
197- """
210+ """Embed each text column and average into a single embedding per row."""
198211 embeddings = {}
199212 for col in df .columns :
200213 data = df [col ].to_list ()
@@ -267,7 +280,7 @@ def _get_synth_nn(
267280 if len (text_columns ) == 0 :
268281 # Create the faiss index on the synthetic data
269282 dim = df_synth_norm .shape [1 ]
270- index = faiss .IndexFlatL2 (dim ) # ty: ignore[unresolved-attribute, possibly-unbound-attribute]
283+ index = faiss .IndexFlatL2 (dim ) # ty: ignore[possibly-unbound-attribute]
271284
272285 # This usage matches documentation. Specifying n= and x= parameters as
273286 # the type annotation for IndexFlatL2.add suggests seems unnecessary, possibly related
@@ -288,15 +301,15 @@ def _get_synth_nn(
288301 df_train_embeddings = AttributeInferenceProtection ._embed_text (df_train_text , embedder )
289302 df_synth_embeddings = AttributeInferenceProtection ._embed_text (df_synth_text , embedder )
290303 hits = util .semantic_search (
291- np .array (list (df_train_embeddings ["embedding" ])),
292- np .array (list (df_synth_embeddings ["embedding" ])),
304+ np .array (list (df_train_embeddings ["embedding" ])), # ty: ignore[invalid-argument-type]
305+ np .array (list (df_synth_embeddings ["embedding" ])), # ty: ignore[invalid-argument-type]
293306 top_k = k ,
294307 )
295308 synth_rows = pd .DataFrame ()
296309 for i in range (k ):
297310 corpus_id = hits [0 ][i ]["corpus_id" ]
298311 synth_rows = pd .concat (
299- [synth_rows , pd .DataFrame ([df_synth .iloc [corpus_id ]])],
312+ [synth_rows , pd .DataFrame ([df_synth .iloc [int ( corpus_id ) ]])],
300313 ignore_index = True ,
301314 )
302315
@@ -310,8 +323,8 @@ def _get_synth_nn(
310323 df_synth_embeddings = AttributeInferenceProtection ._embed_text (df_synth_text , embedder )
311324 search_synth_k = min (1000 , len (df_synth_embeddings ))
312325 hits = util .semantic_search (
313- np .array (list (df_train_embeddings ["embedding" ])),
314- np .array (list (df_synth_embeddings ["embedding" ])),
326+ np .array (list (df_train_embeddings ["embedding" ])), # ty: ignore[invalid-argument-type]
327+ np .array (list (df_synth_embeddings ["embedding" ])), # ty: ignore[invalid-argument-type]
315328 top_k = search_synth_k ,
316329 )
317330 synth_NN = pd .DataFrame ()
@@ -324,12 +337,12 @@ def _get_synth_nn(
324337 dist = 1 - sim
325338 text_dist [i ] = dist
326339 corpus_ids .append (corpus_id )
327- synth_NN = pd .concat ([synth_NN , pd .DataFrame ([df_synth_norm .iloc [corpus_id ]])], ignore_index = True )
340+ synth_NN = pd .concat ([synth_NN , pd .DataFrame ([df_synth_norm .iloc [int ( corpus_id ) ]])], ignore_index = True )
328341
329342 # Now get the tabular similarity for these 1000 NN
330343
331344 dim = synth_NN .shape [1 ]
332- index = faiss .IndexFlatL2 (dim ) # ty: ignore[unresolved-attribute, possibly-unbound-attribute]
345+ index = faiss .IndexFlatL2 (dim ) # ty: ignore[possibly-unbound-attribute]
333346 index .add (np .float32 (np .ascontiguousarray (np .array (synth_NN )))) # ty: ignore[missing-argument]
334347 dists , indexes = index .search (np .float32 (np .ascontiguousarray (np .array (df_train_norm ))), search_synth_k ) # ty: ignore[missing-argument]
335348 # Scale the Euclidean distance to [0,1]
@@ -372,6 +385,20 @@ def _aia(
372385 df_synth : pd .DataFrame ,
373386 quasi_identifier_count : int ,
374387 ) -> tuple [EvaluationScore , pd .DataFrame | None ]:
388+ """Core attribute inference attack implementation.
389+
390+ Iterates over random quasi-identifier subsets, finds nearest
391+ synthetic neighbors, and measures attribute prediction accuracy
392+ weighted by column entropy.
393+
394+ Args:
395+ df_train: Training dataframe.
396+ df_synth: Synthetic dataframe.
397+ quasi_identifier_count: Number of columns to use as quasi-identifiers.
398+
399+ Returns:
400+ Tuple of (overall protection score, per-column accuracy dataframe).
401+ """
375402 ias = EvaluationScore (grade = PrivacyGrade .UNAVAILABLE )
376403 col_accuracy_df = None
377404 if quasi_identifier_count is None :
@@ -408,7 +435,7 @@ def _aia(
408435 nominal_columns = list (df_train .select_dtypes (include = ["object" , "category" , "bool" ]).columns )
409436 numeric_columns = [column for column in df_train .columns if column not in nominal_columns ]
410437
411- # Now seperate out the text columns from the nominal
438+ # Now separate out the text columns from the nominal
412439
413440 text_columns = []
414441 for col in nominal_columns :
@@ -531,7 +558,7 @@ def _aia(
531558 # Lat/lon values inspired this. Text must be dist .35 or less
532559 for column in predict_columns :
533560 synth_val = synth_values [column ]
534- train_val = train_row_all .iloc [0 ][column ]
561+ train_val = train_row_all .iloc [0 ][column ] # ty: ignore[invalid-argument-type]
535562
536563 if pd .isna (train_val ):
537564 continue
0 commit comments