Skip to content

Commit 6ee4f9c

Browse files
lucamelismeta-codesync[bot]
authored andcommitted
Online LiRA NaN filtering and squared_error as a new score type (#112)
Summary: Pull Request resolved: #112 * Filter out NaN values after merging holdout train and test (plus LiRA run attack return arg type change) * Added squared_error Reviewed By: iden-kalemaj Differential Revision: D95680183 fbshipit-source-id: e796317d35f7f574e982b77b144ef48eeb28356f
1 parent bf5a147 commit 6ee4f9c

File tree

2 files changed

+107
-2
lines changed

2 files changed

+107
-2
lines changed

privacy_guard/attacks/lira_attack.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
# limitations under the License.
1414

1515
# pyre-strict
16+
import logging
1617
from typing import Tuple, Union
1718

1819
import pandas as pd
1920
from pandas import Series
20-
from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput
2121
from privacy_guard.analysis.mia.aggregate_analysis_input import (
2222
AggregateAnalysisInput,
2323
AggregationType,
2424
)
2525
from privacy_guard.attacks.base_attack import BaseAttack
2626
from scipy.stats import norm
2727

28+
logger: logging.Logger = logging.getLogger(__name__)
29+
2830

2931
class LiraAttack(BaseAttack):
3032
"""
@@ -162,7 +164,7 @@ def _get_std_dev(self) -> Tuple[Union[float, Series], Union[float, Series]]:
162164
raise ValueError(f"{self.std_dev_type} is not a valid std_dev type.")
163165
return std_in, std_out
164166

165-
def run_attack(self) -> BaseAnalysisInput:
167+
def run_attack(self) -> AggregateAnalysisInput:
166168
"""
167169
Run lira attack on the shadows and original models.
168170
@@ -207,6 +209,15 @@ def run_attack(self) -> BaseAnalysisInput:
207209
self.df_test_merge.score_orig, self.df_test_merge.score_mean, std_out
208210
)
209211

212+
logger.info(
213+
f"before NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
214+
)
215+
self.df_train_merge = self.df_train_merge.dropna(subset=["score"])
216+
self.df_test_merge = self.df_test_merge.dropna(subset=["score"])
217+
logger.info(
218+
f"after NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
219+
)
220+
210221
if not (self.online_attack or self.offline_shadows_evals_in):
211222
# this corresponds to the case of offline shadows evals on the hold out test set
212223
self.df_train_merge["score"] = -self.df_train_merge["score"]

privacy_guard/attacks/tests/test_lira_attack.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import unittest
1919

20+
import numpy as np
2021
import pandas as pd
2122
from privacy_guard.analysis.mia.aggregate_analysis_input import (
2223
AggregateAnalysisInput,
@@ -327,3 +328,96 @@ def test_get_std_dev_invalid_type(self) -> None:
327328
attack._get_std_dev()
328329

329330
self.assertIn("is not a valid std_dev type", str(context.exception))
331+
332+
def test_run_attack_drops_nan_rows_in_train(self) -> None:
333+
"""Test that run_attack drops rows with NaN values in df_train_merge after logpdf computation."""
334+
# Setup: create training data with NaN in score_orig so logpdf produces NaN
335+
df_train_with_nan = self.df_train_merge.copy()
336+
df_train_with_nan.loc["0", "score_orig"] = np.nan
337+
df_train_with_nan.loc["2", "score_orig"] = np.nan
338+
339+
attack = LiraAttack(
340+
df_train_merge=df_train_with_nan,
341+
df_test_merge=self.df_train_merge,
342+
row_aggregation=AggregationType.MAX,
343+
use_fixed_variance=True,
344+
user_id_key=self.user_id_key,
345+
online_attack=True,
346+
)
347+
348+
# Execute
349+
analysis_input = attack.run_attack()
350+
351+
# Assert: 2 NaN rows dropped from train, test unchanged
352+
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
353+
assert isinstance(analysis_input, AggregateAnalysisInput)
354+
self.assertEqual(len(analysis_input.df_train_merge), 3)
355+
self.assertEqual(len(analysis_input.df_test_merge), 5)
356+
357+
def test_run_attack_drops_nan_rows_in_test(self) -> None:
358+
"""Test that run_attack drops rows with NaN values in df_test_merge after logpdf computation."""
359+
# Setup: create test data with NaN in score_orig so logpdf produces NaN
360+
df_test_with_nan = self.df_train_merge.copy()
361+
df_test_with_nan.loc["1", "score_orig"] = np.nan
362+
363+
attack = LiraAttack(
364+
df_train_merge=self.df_train_merge,
365+
df_test_merge=df_test_with_nan,
366+
row_aggregation=AggregationType.MAX,
367+
use_fixed_variance=True,
368+
user_id_key=self.user_id_key,
369+
)
370+
371+
# Execute
372+
analysis_input = attack.run_attack()
373+
374+
# Assert: train unchanged, 1 NaN row dropped from test
375+
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
376+
assert isinstance(analysis_input, AggregateAnalysisInput)
377+
self.assertEqual(len(analysis_input.df_train_merge), 5)
378+
self.assertEqual(len(analysis_input.df_test_merge), 4)
379+
380+
def test_run_attack_drops_nan_rows_online_attack(self) -> None:
381+
"""Test that run_attack drops NaN rows for online attack mode."""
382+
# Setup: create data with NaN in score_mean_in to produce NaN in logpdf
383+
df_train_with_nan = self.df_train_merge.copy()
384+
df_train_with_nan.loc["0", "score_mean_in"] = np.nan
385+
df_train_with_nan.loc["3", "score_mean_out"] = np.nan
386+
387+
attack = LiraAttack(
388+
df_train_merge=df_train_with_nan,
389+
df_test_merge=self.df_train_merge,
390+
row_aggregation=AggregationType.MAX,
391+
use_fixed_variance=True,
392+
online_attack=True,
393+
user_id_key=self.user_id_key,
394+
)
395+
396+
# Execute
397+
analysis_input = attack.run_attack()
398+
399+
# Assert: 2 NaN rows dropped from train, test unchanged
400+
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
401+
assert isinstance(analysis_input, AggregateAnalysisInput)
402+
self.assertEqual(len(analysis_input.df_train_merge), 3)
403+
self.assertEqual(len(analysis_input.df_test_merge), 5)
404+
405+
def test_run_attack_no_nan_preserves_all_rows(self) -> None:
406+
"""Test that run_attack preserves all rows when no NaN values are present."""
407+
# Setup: use clean data (no NaN)
408+
attack = LiraAttack(
409+
df_train_merge=self.df_train_merge,
410+
df_test_merge=self.df_train_merge,
411+
row_aggregation=AggregationType.MAX,
412+
use_fixed_variance=True,
413+
user_id_key=self.user_id_key,
414+
)
415+
416+
# Execute
417+
analysis_input = attack.run_attack()
418+
419+
# Assert: all rows preserved
420+
self.assertIsInstance(analysis_input, AggregateAnalysisInput)
421+
assert isinstance(analysis_input, AggregateAnalysisInput)
422+
self.assertEqual(len(analysis_input.df_train_merge), 5)
423+
self.assertEqual(len(analysis_input.df_test_merge), 5)

0 commit comments

Comments
 (0)