diff --git a/privacy_guard/attacks/lira_attack.py b/privacy_guard/attacks/lira_attack.py index aa1ea37..311847e 100644 --- a/privacy_guard/attacks/lira_attack.py +++ b/privacy_guard/attacks/lira_attack.py @@ -13,11 +13,11 @@ # limitations under the License. # pyre-strict +import logging from typing import Tuple, Union import pandas as pd from pandas import Series -from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput from privacy_guard.analysis.mia.aggregate_analysis_input import ( AggregateAnalysisInput, AggregationType, @@ -25,6 +25,8 @@ from privacy_guard.attacks.base_attack import BaseAttack from scipy.stats import norm +logger: logging.Logger = logging.getLogger(__name__) + class LiraAttack(BaseAttack): """ @@ -162,7 +164,7 @@ def _get_std_dev(self) -> Tuple[Union[float, Series], Union[float, Series]]: raise ValueError(f"{self.std_dev_type} is not a valid std_dev type.") return std_in, std_out - def run_attack(self) -> BaseAnalysisInput: + def run_attack(self) -> AggregateAnalysisInput: """ Run lira attack on the shadows and original models. @@ -207,6 +209,15 @@ def run_attack(self) -> BaseAnalysisInput: self.df_test_merge.score_orig, self.df_test_merge.score_mean, std_out ) + logger.info( + f"before NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}" + ) + self.df_train_merge = self.df_train_merge.dropna(subset=["score"]) + self.df_test_merge = self.df_test_merge.dropna(subset=["score"]) + logger.info( + f"after NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}" + ) + if not (self.online_attack or self.offline_shadows_evals_in): # this corresponds to the case of offline shadows evals on the hold out test set self.df_train_merge["score"] = -self.df_train_merge["score"] diff --git a/privacy_guard/attacks/tests/test_lira_attack.py b/privacy_guard/attacks/tests/test_lira_attack.py index f29e5d3..6047685 100644 --- a/privacy_guard/attacks/tests/test_lira_attack.py +++ b/privacy_guard/attacks/tests/test_lira_attack.py @@ -17,6 +17,7 @@ import unittest +import numpy as np import pandas as pd from privacy_guard.analysis.mia.aggregate_analysis_input import ( AggregateAnalysisInput, @@ -327,3 +328,96 @@ def test_get_std_dev_invalid_type(self) -> None: attack._get_std_dev() self.assertIn("is not a valid std_dev type", str(context.exception)) + + def test_run_attack_drops_nan_rows_in_train(self) -> None: + """Test that run_attack drops rows with NaN values in df_train_merge after logpdf computation.""" + # Setup: create training data with NaN in score_orig so logpdf produces NaN + df_train_with_nan = self.df_train_merge.copy() + df_train_with_nan.loc["0", "score_orig"] = np.nan + df_train_with_nan.loc["2", "score_orig"] = np.nan + + attack = LiraAttack( + df_train_merge=df_train_with_nan, + df_test_merge=self.df_train_merge, + row_aggregation=AggregationType.MAX, + use_fixed_variance=True, + user_id_key=self.user_id_key, + online_attack=True, + ) + + # Execute + analysis_input = attack.run_attack() + + # Assert: 2 NaN rows dropped from train, test unchanged + self.assertIsInstance(analysis_input, AggregateAnalysisInput) + assert isinstance(analysis_input, AggregateAnalysisInput) + self.assertEqual(len(analysis_input.df_train_merge), 3) + self.assertEqual(len(analysis_input.df_test_merge), 5) + + def test_run_attack_drops_nan_rows_in_test(self) -> None: + """Test that run_attack drops rows with NaN values in df_test_merge after logpdf computation.""" + # Setup: create test data with NaN in score_orig so logpdf produces NaN + df_test_with_nan = self.df_train_merge.copy() + df_test_with_nan.loc["1", "score_orig"] = np.nan + + attack = LiraAttack( + df_train_merge=self.df_train_merge, + df_test_merge=df_test_with_nan, + row_aggregation=AggregationType.MAX, + use_fixed_variance=True, + user_id_key=self.user_id_key, + ) + + # Execute + analysis_input = attack.run_attack() + + # Assert: train unchanged, 1 NaN row dropped from test + self.assertIsInstance(analysis_input, AggregateAnalysisInput) + assert isinstance(analysis_input, AggregateAnalysisInput) + self.assertEqual(len(analysis_input.df_train_merge), 5) + self.assertEqual(len(analysis_input.df_test_merge), 4) + + def test_run_attack_drops_nan_rows_online_attack(self) -> None: + """Test that run_attack drops NaN rows for online attack mode.""" + # Setup: create data with NaN in score_mean_in to produce NaN in logpdf + df_train_with_nan = self.df_train_merge.copy() + df_train_with_nan.loc["0", "score_mean_in"] = np.nan + df_train_with_nan.loc["3", "score_mean_out"] = np.nan + + attack = LiraAttack( + df_train_merge=df_train_with_nan, + df_test_merge=self.df_train_merge, + row_aggregation=AggregationType.MAX, + use_fixed_variance=True, + online_attack=True, + user_id_key=self.user_id_key, + ) + + # Execute + analysis_input = attack.run_attack() + + # Assert: 2 NaN rows dropped from train, test unchanged + self.assertIsInstance(analysis_input, AggregateAnalysisInput) + assert isinstance(analysis_input, AggregateAnalysisInput) + self.assertEqual(len(analysis_input.df_train_merge), 3) + self.assertEqual(len(analysis_input.df_test_merge), 5) + + def test_run_attack_no_nan_preserves_all_rows(self) -> None: + """Test that run_attack preserves all rows when no NaN values are present.""" + # Setup: use clean data (no NaN) + attack = LiraAttack( + df_train_merge=self.df_train_merge, + df_test_merge=self.df_train_merge, + row_aggregation=AggregationType.MAX, + use_fixed_variance=True, + user_id_key=self.user_id_key, + ) + + # Execute + analysis_input = attack.run_attack() + + # Assert: all rows preserved + self.assertIsInstance(analysis_input, AggregateAnalysisInput) + assert isinstance(analysis_input, AggregateAnalysisInput) + self.assertEqual(len(analysis_input.df_train_merge), 5) + self.assertEqual(len(analysis_input.df_test_merge), 5)