Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions privacy_guard/attacks/lira_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,20 @@
# limitations under the License.

# pyre-strict
import logging
from typing import Tuple, Union

import pandas as pd
from pandas import Series
from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput
from privacy_guard.analysis.mia.aggregate_analysis_input import (
AggregateAnalysisInput,
AggregationType,
)
from privacy_guard.attacks.base_attack import BaseAttack
from scipy.stats import norm

logger: logging.Logger = logging.getLogger(__name__)


class LiraAttack(BaseAttack):
"""
Expand Down Expand Up @@ -162,7 +164,7 @@ def _get_std_dev(self) -> Tuple[Union[float, Series], Union[float, Series]]:
raise ValueError(f"{self.std_dev_type} is not a valid std_dev type.")
return std_in, std_out

def run_attack(self) -> BaseAnalysisInput:
def run_attack(self) -> AggregateAnalysisInput:
"""
Run lira attack on the shadows and original models.

Expand Down Expand Up @@ -207,6 +209,15 @@ def run_attack(self) -> BaseAnalysisInput:
self.df_test_merge.score_orig, self.df_test_merge.score_mean, std_out
)

logger.info(
f"before NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
)
self.df_train_merge = self.df_train_merge.dropna(subset=["score"])
self.df_test_merge = self.df_test_merge.dropna(subset=["score"])
logger.info(
f"after NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
)

if not (self.online_attack or self.offline_shadows_evals_in):
# this corresponds to the case of offline shadows evals on the hold out test set
self.df_train_merge["score"] = -self.df_train_merge["score"]
Expand Down
94 changes: 94 additions & 0 deletions privacy_guard/attacks/tests/test_lira_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import unittest

import numpy as np
import pandas as pd
from privacy_guard.analysis.mia.aggregate_analysis_input import (
AggregateAnalysisInput,
Expand Down Expand Up @@ -327,3 +328,96 @@ def test_get_std_dev_invalid_type(self) -> None:
attack._get_std_dev()

self.assertIn("is not a valid std_dev type", str(context.exception))

def test_run_attack_drops_nan_rows_in_train(self) -> None:
"""Test that run_attack drops rows with NaN values in df_train_merge after logpdf computation."""
# Setup: create training data with NaN in score_orig so logpdf produces NaN
df_train_with_nan = self.df_train_merge.copy()
df_train_with_nan.loc["0", "score_orig"] = np.nan
df_train_with_nan.loc["2", "score_orig"] = np.nan

attack = LiraAttack(
df_train_merge=df_train_with_nan,
df_test_merge=self.df_train_merge,
row_aggregation=AggregationType.MAX,
use_fixed_variance=True,
user_id_key=self.user_id_key,
online_attack=True,
)

# Execute
analysis_input = attack.run_attack()

# Assert: 2 NaN rows dropped from train, test unchanged
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
assert isinstance(analysis_input, AggregateAnalysisInput)
self.assertEqual(len(analysis_input.df_train_merge), 3)
self.assertEqual(len(analysis_input.df_test_merge), 5)

def test_run_attack_drops_nan_rows_in_test(self) -> None:
"""Test that run_attack drops rows with NaN values in df_test_merge after logpdf computation."""
# Setup: create test data with NaN in score_orig so logpdf produces NaN
df_test_with_nan = self.df_train_merge.copy()
df_test_with_nan.loc["1", "score_orig"] = np.nan

attack = LiraAttack(
df_train_merge=self.df_train_merge,
df_test_merge=df_test_with_nan,
row_aggregation=AggregationType.MAX,
use_fixed_variance=True,
user_id_key=self.user_id_key,
)

# Execute
analysis_input = attack.run_attack()

# Assert: train unchanged, 1 NaN row dropped from test
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
assert isinstance(analysis_input, AggregateAnalysisInput)
self.assertEqual(len(analysis_input.df_train_merge), 5)
self.assertEqual(len(analysis_input.df_test_merge), 4)

def test_run_attack_drops_nan_rows_online_attack(self) -> None:
"""Test that run_attack drops NaN rows for online attack mode."""
# Setup: create data with NaN in score_mean_in to produce NaN in logpdf
df_train_with_nan = self.df_train_merge.copy()
df_train_with_nan.loc["0", "score_mean_in"] = np.nan
df_train_with_nan.loc["3", "score_mean_out"] = np.nan

attack = LiraAttack(
df_train_merge=df_train_with_nan,
df_test_merge=self.df_train_merge,
row_aggregation=AggregationType.MAX,
use_fixed_variance=True,
online_attack=True,
user_id_key=self.user_id_key,
)

# Execute
analysis_input = attack.run_attack()

# Assert: 2 NaN rows dropped from train, test unchanged
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
assert isinstance(analysis_input, AggregateAnalysisInput)
self.assertEqual(len(analysis_input.df_train_merge), 3)
self.assertEqual(len(analysis_input.df_test_merge), 5)

def test_run_attack_no_nan_preserves_all_rows(self) -> None:
"""Test that run_attack preserves all rows when no NaN values are present."""
# Setup: use clean data (no NaN)
attack = LiraAttack(
df_train_merge=self.df_train_merge,
df_test_merge=self.df_train_merge,
row_aggregation=AggregationType.MAX,
use_fixed_variance=True,
user_id_key=self.user_id_key,
)

# Execute
analysis_input = attack.run_attack()

# Assert: all rows preserved
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
assert isinstance(analysis_input, AggregateAnalysisInput)
self.assertEqual(len(analysis_input.df_train_merge), 5)
self.assertEqual(len(analysis_input.df_test_merge), 5)
Loading