|
17 | 17 |
|
18 | 18 | import unittest |
19 | 19 |
|
| 20 | +import numpy as np |
20 | 21 | import pandas as pd |
21 | 22 | from privacy_guard.analysis.mia.aggregate_analysis_input import ( |
22 | 23 | AggregateAnalysisInput, |
@@ -327,3 +328,96 @@ def test_get_std_dev_invalid_type(self) -> None: |
327 | 328 | attack._get_std_dev() |
328 | 329 |
|
329 | 330 | self.assertIn("is not a valid std_dev type", str(context.exception)) |
| 331 | + |
| 332 | + def test_run_attack_drops_nan_rows_in_train(self) -> None: |
| 333 | + """Test that run_attack drops rows with NaN values in df_train_merge after logpdf computation.""" |
| 334 | + # Setup: create training data with NaN in score_orig so logpdf produces NaN |
| 335 | + df_train_with_nan = self.df_train_merge.copy() |
| 336 | + df_train_with_nan.loc["0", "score_orig"] = np.nan |
| 337 | + df_train_with_nan.loc["2", "score_orig"] = np.nan |
| 338 | + |
| 339 | + attack = LiraAttack( |
| 340 | + df_train_merge=df_train_with_nan, |
| 341 | + df_test_merge=self.df_train_merge, |
| 342 | + row_aggregation=AggregationType.MAX, |
| 343 | + use_fixed_variance=True, |
| 344 | + user_id_key=self.user_id_key, |
| 345 | + online_attack=True, |
| 346 | + ) |
| 347 | + |
| 348 | + # Execute |
| 349 | + analysis_input = attack.run_attack() |
| 350 | + |
| 351 | + # Assert: 2 NaN rows dropped from train, test unchanged |
| 352 | + self.assertIsInstance(analysis_input, AggregateAnalysisInput) |
| 353 | + assert isinstance(analysis_input, AggregateAnalysisInput) |
| 354 | + self.assertEqual(len(analysis_input.df_train_merge), 3) |
| 355 | + self.assertEqual(len(analysis_input.df_test_merge), 5) |
| 356 | + |
| 357 | + def test_run_attack_drops_nan_rows_in_test(self) -> None: |
| 358 | + """Test that run_attack drops rows with NaN values in df_test_merge after logpdf computation.""" |
| 359 | + # Setup: create test data with NaN in score_orig so logpdf produces NaN |
| 360 | + df_test_with_nan = self.df_train_merge.copy() |
| 361 | + df_test_with_nan.loc["1", "score_orig"] = np.nan |
| 362 | + |
| 363 | + attack = LiraAttack( |
| 364 | + df_train_merge=self.df_train_merge, |
| 365 | + df_test_merge=df_test_with_nan, |
| 366 | + row_aggregation=AggregationType.MAX, |
| 367 | + use_fixed_variance=True, |
| 368 | + user_id_key=self.user_id_key, |
| 369 | + ) |
| 370 | + |
| 371 | + # Execute |
| 372 | + analysis_input = attack.run_attack() |
| 373 | + |
| 374 | + # Assert: train unchanged, 1 NaN row dropped from test |
| 375 | + self.assertIsInstance(analysis_input, AggregateAnalysisInput) |
| 376 | + assert isinstance(analysis_input, AggregateAnalysisInput) |
| 377 | + self.assertEqual(len(analysis_input.df_train_merge), 5) |
| 378 | + self.assertEqual(len(analysis_input.df_test_merge), 4) |
| 379 | + |
| 380 | + def test_run_attack_drops_nan_rows_online_attack(self) -> None: |
| 381 | + """Test that run_attack drops NaN rows for online attack mode.""" |
| 382 | + # Setup: create data with NaN in score_mean_in to produce NaN in logpdf |
| 383 | + df_train_with_nan = self.df_train_merge.copy() |
| 384 | + df_train_with_nan.loc["0", "score_mean_in"] = np.nan |
| 385 | + df_train_with_nan.loc["3", "score_mean_out"] = np.nan |
| 386 | + |
| 387 | + attack = LiraAttack( |
| 388 | + df_train_merge=df_train_with_nan, |
| 389 | + df_test_merge=self.df_train_merge, |
| 390 | + row_aggregation=AggregationType.MAX, |
| 391 | + use_fixed_variance=True, |
| 392 | + online_attack=True, |
| 393 | + user_id_key=self.user_id_key, |
| 394 | + ) |
| 395 | + |
| 396 | + # Execute |
| 397 | + analysis_input = attack.run_attack() |
| 398 | + |
| 399 | + # Assert: 2 NaN rows dropped from train, test unchanged |
| 400 | + self.assertIsInstance(analysis_input, AggregateAnalysisInput) |
| 401 | + assert isinstance(analysis_input, AggregateAnalysisInput) |
| 402 | + self.assertEqual(len(analysis_input.df_train_merge), 3) |
| 403 | + self.assertEqual(len(analysis_input.df_test_merge), 5) |
| 404 | + |
| 405 | + def test_run_attack_no_nan_preserves_all_rows(self) -> None: |
| 406 | + """Test that run_attack preserves all rows when no NaN values are present.""" |
| 407 | + # Setup: use clean data (no NaN) |
| 408 | + attack = LiraAttack( |
| 409 | + df_train_merge=self.df_train_merge, |
| 410 | + df_test_merge=self.df_train_merge, |
| 411 | + row_aggregation=AggregationType.MAX, |
| 412 | + use_fixed_variance=True, |
| 413 | + user_id_key=self.user_id_key, |
| 414 | + ) |
| 415 | + |
| 416 | + # Execute |
| 417 | + analysis_input = attack.run_attack() |
| 418 | + |
| 419 | + # Assert: all rows preserved |
| 420 | + self.assertIsInstance(analysis_input, AggregateAnalysisInput) |
| 421 | + assert isinstance(analysis_input, AggregateAnalysisInput) |
| 422 | + self.assertEqual(len(analysis_input.df_train_merge), 5) |
| 423 | + self.assertEqual(len(analysis_input.df_test_merge), 5) |
0 commit comments