Skip to content

Commit 235d105

Browse files
du1204meta-codesync[bot]
authored andcommitted
Update target TPR of prod AIA (#102)
Summary: Pull Request resolved: #102 This diff adds support for configurable TPR (True Positive Rate) targets in MIA. **Key Changes:** - Added `tpr_target` parameter: Allows users to specify a custom TPR threshold (default behavior remains 1% if not specified) - Added `tpr_threshold_width` parameter: Controls the granularity of TPR thresholds (default: 0.25%) - Added input validation for `tpr_threshold_width` (must be positive and evenly divide 99%, e.g., 0.1%, 0.25%, 0.3%, 0.5%, etc.) - Added unit tests for the changes in this diff - `test_get_tpr_index_none_target` to verify legacy behavior (Returns index 0 when `tpr_target=None`) - `test_get_tpr_index_with_target` to verify `get_tpr_index` returns the correct index - `test_tpr_threshold_width_positive_validation` to verify ValueError is raised when `tpr_threshold_width <= 0` - `test_tpr_threshold_width_divisibility_validation` to verify ValueError is raised when `tpr_threshold_width` doesn't evenly divide 0.99 **Backward Compatibility:** - Fully backward compatible - existing configs without `tpr_target` continue to use the original 1% TPR with 100 thresholds Reviewed By: lucamelis Differential Revision: D92439701 fbshipit-source-id: 644a452ff11e61e1523310e7b32bf7d6ef95ca5f
1 parent 146108f commit 235d105

File tree

4 files changed

+208
-10
lines changed

4 files changed

+208
-10
lines changed

privacy_guard/analysis/mia/analysis_node.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections.abc import Generator
1717
from contextlib import contextmanager
1818
from dataclasses import dataclass
19-
from typing import List
19+
from typing import List, Optional
2020

2121
import numpy as np
2222
import torch
@@ -53,6 +53,7 @@ class AnalysisNodeOutput(BaseAnalysisOutput):
5353
auc (float): Mean area under the curve (AUC) of the attack.
5454
auc_ci (List[float]): Confidence interval for the AUC, represented as [lower_bound, upper_bound].
5555
data_size (dict[str, int]): Size of the training, test dataset and bootstrap sample size.
56+
tpr_target (float): Target TPR used for computing epsilon.
5657
"""
5758

5859
# Empirical epsilons
@@ -73,6 +74,9 @@ class AnalysisNodeOutput(BaseAnalysisOutput):
7374
auc_ci: List[float]
7475
# Dataset sizes
7576
data_size: dict[str, int]
77+
# TPR target and index (only set when custom tpr_target is provided)
78+
tpr_target: Optional[float]
79+
tpr_idx: Optional[int]
7680

7781

7882
class AnalysisNode(BaseAnalysisNode):
@@ -93,6 +97,11 @@ class AnalysisNode(BaseAnalysisNode):
9397
use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation.
9498
show_progress: boolean for whether to show tqdm progress bar
9599
with_timer: boolean for whether to show timer for analysis node
100+
tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals.
101+
If specified, uses fine-grained intervals determined by tpr_threshold_width.
102+
tpr_threshold_width: Width (step size) between TPR thresholds for fine-grained mode.
103+
Only used when tpr_target is specified. Default 0.0025 (0.25%).
104+
Start is always fixed at 0.01. num_thresholds = int((1.0 - 0.01) / width) + 1.
96105
"""
97106

98107
def __init__(
@@ -106,6 +115,8 @@ def __init__(
106115
use_fnr_tnr: bool = False,
107116
show_progress: bool = False,
108117
with_timer: bool = False,
118+
tpr_target: Optional[float] = None,
119+
tpr_threshold_width: float = 0.0025,
109120
) -> None:
110121
self._delta = delta
111122
self._n_users_for_eval = n_users_for_eval
@@ -117,15 +128,59 @@ def __init__(
117128

118129
self._use_upper_bound = use_upper_bound
119130

131+
self._tpr_target = tpr_target
132+
self._tpr_threshold_width = tpr_threshold_width
133+
self._num_thresholds: int
134+
120135
self._timer_stats: TimerStats = {}
121136

122137
if self._n_users_for_eval < 0:
123138
raise ValueError(
124139
'Input to AnalysisNode "n_users_for_eval" must be nonnegative'
125140
)
126141

142+
if self._tpr_target is not None:
143+
assert isinstance(self._tpr_target, float)
144+
if self._tpr_target < 0.01 or self._tpr_target > 1.0:
145+
raise ValueError(
146+
'Input to AnalysisNode "tpr_target" must be between 0.01 and 1.0'
147+
)
148+
149+
if self._tpr_threshold_width <= 0:
150+
raise ValueError(
151+
'Input to AnalysisNode "tpr_threshold_width" must be positive'
152+
)
153+
154+
if not np.isclose(
155+
0.99 / self._tpr_threshold_width,
156+
round(0.99 / self._tpr_threshold_width),
157+
):
158+
raise ValueError(
159+
'Input to AnalysisNode "tpr_threshold_width" must evenly divide 0.99. '
160+
"Valid examples: 0.001, 0.0025, 0.003, 0.005, 0.01"
161+
)
162+
163+
# Determine num_thresholds based on tpr_target
164+
if self._tpr_target is None:
165+
# Legacy: 1% intervals (0.01 to 1.0, 100 thresholds)
166+
self._num_thresholds = 100
167+
self._tpr_threshold_width = 0.01
168+
else:
169+
self._tpr_threshold_width = tpr_threshold_width
170+
self._num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1
171+
172+
self._error_thresholds: NDArray[np.floating] = np.linspace(
173+
0.01, 1.0, self._num_thresholds
174+
)
175+
127176
super().__init__(analysis_input=analysis_input)
128177

178+
def _get_tpr_index(self) -> int:
179+
"""Convert TPR target to array index."""
180+
if self._tpr_target is None:
181+
return 0 # Legacy behavior: TPR = 1% is at index 0
182+
return int(np.where(self._error_thresholds == self._tpr_target)[0][0])
183+
129184
def _calculate_one_off_eps(self) -> float:
130185
df_train_user = self.analysis_input.df_train_user
131186
df_test_user = self.analysis_input.df_test_user
@@ -253,9 +308,10 @@ def run_analysis(self) -> BaseAnalysisOutput:
253308

254309
eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb
255310

311+
tpr_idx = self._get_tpr_index()
256312
outputs = AnalysisNodeOutput(
257-
eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold
258-
eps_lb=eps_tpr_lb[0],
313+
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
314+
eps_lb=eps_tpr_lb[tpr_idx],
259315
eps_fpr_max_ub=np.nanmax(eps_fpr_ub),
260316
eps_fpr_lb=list(eps_fpr_lb),
261317
eps_fpr_ub=list(eps_fpr_ub),
@@ -273,6 +329,8 @@ def run_analysis(self) -> BaseAnalysisOutput:
273329
"test_size": test_size,
274330
"bootstrap_size": bootstrap_sample_size,
275331
},
332+
tpr_target=self._tpr_target,
333+
tpr_idx=tpr_idx,
276334
)
277335

278336
if self._with_timer:
@@ -313,8 +371,6 @@ def _make_metrics_array(
313371

314372
bootstrap_sample_size = min(train_size, test_size)
315373

316-
error_thresholds = np.linspace(0.01, 1, 100)
317-
318374
metrics_array = [
319375
MIAResults(
320376
loss_train[
@@ -329,7 +385,7 @@ def _make_metrics_array(
329385
],
330386
).compute_metrics_at_error_threshold(
331387
self._delta,
332-
error_threshold=error_thresholds,
388+
error_threshold=self._error_thresholds,
333389
cap_eps=self._cap_eps,
334390
use_fnr_tnr=self._use_fnr_tnr,
335391
verbose=False,

privacy_guard/analysis/mia/parallel_analysis_node.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import tempfile
1717
from concurrent.futures import ProcessPoolExecutor
18+
from typing import Optional
1819

1920
import numpy as np
2021
import torch
@@ -41,6 +42,8 @@ class ParallelAnalysisNode(AnalysisNode):
4142
use_upper_bound: boolean for whether to compute epsilon at the upper-bound of CI
4243
use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation.
4344
with_timer: boolean for whether to show timer for analysis node
45+
tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals.
46+
tpr_threshold_width: Width between TPR thresholds for fine-grained mode. Default 0.0025.
4447
"""
4548

4649
def __init__(
@@ -53,6 +56,8 @@ def __init__(
5356
num_bootstrap_resampling_times: int = 1000,
5457
use_fnr_tnr: bool = False,
5558
with_timer: bool = False,
59+
tpr_target: Optional[float] = None,
60+
tpr_threshold_width: float = 0.0025,
5661
) -> None:
5762
super().__init__(
5863
analysis_input=analysis_input,
@@ -62,6 +67,8 @@ def __init__(
6267
num_bootstrap_resampling_times=num_bootstrap_resampling_times,
6368
use_fnr_tnr=use_fnr_tnr,
6469
with_timer=with_timer,
70+
tpr_target=tpr_target,
71+
tpr_threshold_width=tpr_threshold_width,
6572
)
6673
self._eps_computation_tasks_num = eps_computation_tasks_num
6774

@@ -87,7 +94,6 @@ def _compute_metrics_array(
8794
loss_test: torch.Tensor = torch.load(test_filename, weights_only=True)
8895
train_size, test_size = loss_train.shape[0], loss_test.shape[0]
8996
bootstrap_sample_size = min(train_size, test_size)
90-
error_thresholds = np.linspace(0.01, 1, 100)
9197
metrics_results = []
9298

9399
try:
@@ -107,7 +113,7 @@ def _compute_metrics_array(
107113

108114
metrics_result = mia_results.compute_metrics_at_error_threshold(
109115
self._delta,
110-
error_threshold=error_thresholds,
116+
error_threshold=self._error_thresholds,
111117
use_fnr_tnr=self._use_fnr_tnr,
112118
verbose=False,
113119
)
@@ -221,9 +227,10 @@ def run_analysis(self) -> AnalysisNodeOutput:
221227

222228
eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb
223229

230+
tpr_idx = self._get_tpr_index()
224231
outputs = AnalysisNodeOutput(
225-
eps=eps_tpr_boundary[0], # epsilon at TPR=1% UB threshold
226-
eps_lb=eps_tpr_lb[0],
232+
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
233+
eps_lb=eps_tpr_lb[tpr_idx],
227234
eps_fpr_max_ub=np.nanmax(eps_fpr_ub),
228235
eps_fpr_lb=list(eps_fpr_lb),
229236
eps_fpr_ub=list(eps_fpr_ub),
@@ -243,6 +250,8 @@ def run_analysis(self) -> AnalysisNodeOutput:
243250
"test_size": test_size,
244251
"bootstrap_size": bootstrap_sample_size,
245252
},
253+
tpr_target=self._tpr_target,
254+
tpr_idx=tpr_idx,
246255
)
247256

248257
if self._with_timer:

privacy_guard/analysis/tests/test_analysis_node.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,3 +496,117 @@ def test_use_fnr_tnr_parameter_comparison(self) -> None:
496496
outputs_false["accuracy"], outputs_true["accuracy"], places=10
497497
)
498498
self.assertAlmostEqual(outputs_false["auc"], outputs_true["auc"], places=10)
499+
500+
def test_get_tpr_index_none_target(self) -> None:
501+
"""Test that _get_tpr_index returns 0 when tpr_target is None (legacy behavior)."""
502+
analysis_node = AnalysisNode(
503+
analysis_input=self.analysis_input,
504+
delta=0.000001,
505+
n_users_for_eval=100,
506+
num_bootstrap_resampling_times=10,
507+
tpr_target=None,
508+
)
509+
self.assertEqual(analysis_node._get_tpr_index(), 0)
510+
511+
def test_get_tpr_index_with_target(self) -> None:
512+
"""Test that _get_tpr_index returns correct index that points to tpr_target."""
513+
# Create error_thresholds array to get actual values
514+
num_thresholds = int((1.0 - 0.01) / 0.0025) + 1
515+
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)
516+
517+
# Test with actual values from the array at various indices
518+
test_indices = [0, 6, 36, 196, num_thresholds - 1]
519+
520+
for idx in test_indices:
521+
tpr_target = error_thresholds[idx]
522+
analysis_node = AnalysisNode(
523+
analysis_input=self.analysis_input,
524+
delta=0.000001,
525+
n_users_for_eval=100,
526+
num_bootstrap_resampling_times=10,
527+
tpr_target=tpr_target,
528+
tpr_threshold_width=0.0025,
529+
)
530+
tpr_idx = analysis_node._get_tpr_index()
531+
self.assertEqual(
532+
tpr_idx,
533+
idx,
534+
msg=f"tpr_target={tpr_target}: expected index {idx}, got {tpr_idx}",
535+
)
536+
537+
def test_tpr_threshold_width_positive_validation(self) -> None:
538+
"""Test that tpr_threshold_width must be positive."""
539+
with self.assertRaisesRegex(ValueError, "must be positive"):
540+
AnalysisNode(
541+
analysis_input=self.analysis_input,
542+
delta=0.000001,
543+
n_users_for_eval=100,
544+
num_bootstrap_resampling_times=10,
545+
tpr_threshold_width=0,
546+
)
547+
548+
with self.assertRaisesRegex(ValueError, "must be positive"):
549+
AnalysisNode(
550+
analysis_input=self.analysis_input,
551+
delta=0.000001,
552+
n_users_for_eval=100,
553+
num_bootstrap_resampling_times=10,
554+
tpr_threshold_width=-0.01,
555+
)
556+
557+
def test_tpr_threshold_width_divisibility_validation(self) -> None:
558+
"""Test that tpr_threshold_width must evenly divide 0.99."""
559+
with self.assertRaisesRegex(ValueError, "must evenly divide 0.99"):
560+
AnalysisNode(
561+
analysis_input=self.analysis_input,
562+
delta=0.000001,
563+
n_users_for_eval=100,
564+
num_bootstrap_resampling_times=10,
565+
tpr_threshold_width=0.02,
566+
)
567+
568+
def test_tpr_target_range_validation(self) -> None:
569+
"""Test that tpr_target must be between 0.01 and 1.0."""
570+
with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"):
571+
AnalysisNode(
572+
analysis_input=self.analysis_input,
573+
delta=0.000001,
574+
n_users_for_eval=100,
575+
num_bootstrap_resampling_times=10,
576+
tpr_target=0.005,
577+
)
578+
579+
with self.assertRaisesRegex(ValueError, "must be between 0.01 and 1.0"):
580+
AnalysisNode(
581+
analysis_input=self.analysis_input,
582+
delta=0.000001,
583+
n_users_for_eval=100,
584+
num_bootstrap_resampling_times=10,
585+
tpr_target=1.5,
586+
)
587+
588+
def test_error_thresholds_array_creation(self) -> None:
589+
"""Test that _error_thresholds array is correctly created."""
590+
# Legacy mode: 100 thresholds
591+
analysis_node_legacy = AnalysisNode(
592+
analysis_input=self.analysis_input,
593+
delta=0.000001,
594+
n_users_for_eval=100,
595+
num_bootstrap_resampling_times=10,
596+
tpr_target=None,
597+
)
598+
self.assertEqual(len(analysis_node_legacy._error_thresholds), 100)
599+
600+
# Fine-grained mode
601+
analysis_node_fine = AnalysisNode(
602+
analysis_input=self.analysis_input,
603+
delta=0.000001,
604+
n_users_for_eval=100,
605+
num_bootstrap_resampling_times=10,
606+
tpr_target=0.01,
607+
tpr_threshold_width=0.0025,
608+
)
609+
expected_num_thresholds = int(0.99 / 0.0025) + 1
610+
self.assertEqual(
611+
len(analysis_node_fine._error_thresholds), expected_num_thresholds
612+
)

privacy_guard/analysis/tests/test_parallel_analysis_node.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,22 @@ def test_use_fnr_tnr_parameter(self) -> None:
308308
self.assertGreater(
309309
len(outputs_false["eps_tpr_ub"]), len(outputs_true["eps_tpr_ub"])
310310
)
311+
312+
def test_tpr_target_parameter(self) -> None:
313+
"""Test that tpr_target parameter works correctly in ParallelAnalysisNode."""
314+
parallel_node = ParallelAnalysisNode(
315+
analysis_input=self.analysis_input,
316+
delta=0.000001,
317+
n_users_for_eval=100,
318+
num_bootstrap_resampling_times=10,
319+
eps_computation_tasks_num=2,
320+
tpr_target=0.025,
321+
tpr_threshold_width=0.0025,
322+
)
323+
# Verify _get_tpr_index returns correct index
324+
tpr_idx = parallel_node._get_tpr_index()
325+
self.assertAlmostEqual(
326+
parallel_node._error_thresholds[tpr_idx],
327+
0.025,
328+
places=10,
329+
)

0 commit comments

Comments
 (0)