-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathanalysis_node.py
More file actions
448 lines (383 loc) · 17.7 KB
/
analysis_node.py
File metadata and controls
448 lines (383 loc) · 17.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-strict
import logging
import time
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import torch
from numpy.typing import NDArray
from privacy_guard.analysis.base_analysis_input import BaseAnalysisInput
from privacy_guard.analysis.base_analysis_node import BaseAnalysisNode
from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput
from privacy_guard.analysis.mia.mia_results import MIAResults
from tqdm import tqdm
logger: logging.Logger = logging.getLogger(__name__)
TimerStats = dict[str, float]
@dataclass
class AnalysisNodeOutput(BaseAnalysisOutput):
"""
A dataclass to encapsulate the outputs of AnalysisNode.
Attributes:
eps (float): Epsilon value at the TPR=1% UB or LB threshold depending on attack settings.
eps_lb (float): Lower bound of epsilon.
eps_fpr_max_ub (float): Maximum upper bound of epsilon for FPR.
eps_fpr_lb (List[float]): List of lower bound epsilon values for various FPRs.
eps_fpr_ub (List[float]): List of upper bound epsilon values for various FPRs.
eps_tpr_lb (List[float]): List of lower bound epsilon values for various TPRs.
eps_tpr_ub (List[float]): List of upper bound epsilon values for various TPRs.
eps_max_lb (List[float]): List of lower bound epsilon values when taking max epsilon over TPR and FPR thresholds (and TNR,FNR is use_tnr_fnr set to true).
eps_max_ub (List[float]): List of upper bound epsilon values when taking max epsilon over TPR and FPR thresholds (and TNR,FNR is use_tnr_fnr set to true).
eps_cp (float): Empirical epsilon computed using the Clopper-Pearson CI method.
accuracy (float): Mean accuracy of the attack.
accuracy_ci (List[float]): Confidence interval for the accuracy, represented as [lower_bound, upper_bound].
auc (float): Mean area under the curve (AUC) of the attack.
auc_ci (List[float]): Confidence interval for the AUC, represented as [lower_bound, upper_bound].
data_size (dict[str, int]): Size of the training, test dataset and bootstrap sample size.
tpr_target (float): Target TPR used for computing epsilon.
"""
# Empirical epsilons and AUC
eps: float
eps_lb: float
auc: float
auc_ci: List[float]
# Other results
eps_fpr_max_ub: float
eps_fpr_lb: List[float]
eps_fpr_ub: List[float]
eps_tpr_lb: List[float]
eps_tpr_ub: List[float]
eps_max_lb: List[float]
eps_max_ub: List[float]
eps_cp: float
# Accuracy
accuracy: float
accuracy_ci: List[float]
# Dataset sizes
data_size: dict[str, int]
# TPR target and index (only set when custom tpr_target is provided)
tpr_target: Optional[float]
tpr_idx: Optional[int]
class AnalysisNode(BaseAnalysisNode):
"""
AnalysisNode class for PrivacyGuard, which computes a general set of output metrics
required to evaluate the performance of a privacy attack.
Calculates privacy eval metrics by computing epsilons at the upper-bound of the
95% confidence interval, using a score threshold such that adversary has ~1% TPR.
args:
analysis_input: AnalysisInput object containing the training (members) and testing (non-members) dataframes
delta: delta parameter in (epsilon, delta)-differential privacy, close to 0
n_users_for_eval: number of users to use for computing epsilon with Clopper-Pearson method
num_bootstrap_resampling_times: number of times to resample the training and testing data for computing bootstrap confidence intervals
use_upper_bound: boolean for whether to compute epsilon at the upper-bound of CI
cap_eps: boolean for whether to cap large epsilon values to log(size of scores)
use_fnr_tnr: boolean for whether to use FNR and TNR in addition to FPR and TPR error thresholds in eps_max_array computation.
show_progress: boolean for whether to show tqdm progress bar
with_timer: boolean for whether to show timer for analysis node
tpr_target: Optional target TPR for computing epsilon. If None (default), uses legacy 1% intervals.
If specified, uses fine-grained intervals determined by tpr_threshold_width.
tpr_threshold_width: Width (step size) between TPR thresholds for fine-grained mode.
Only used when tpr_target is specified. Default 0.0025 (0.25%).
Start is always fixed at 0.01. num_thresholds = int((1.0 - 0.01) / width) + 1.
"""
def __init__(
self,
analysis_input: BaseAnalysisInput,
delta: float,
n_users_for_eval: int,
use_upper_bound: bool = True,
num_bootstrap_resampling_times: int = 1000,
cap_eps: bool = True,
use_fnr_tnr: bool = False,
show_progress: bool = False,
with_timer: bool = False,
tpr_target: Optional[float] = None,
tpr_threshold_width: float = 0.0025,
) -> None:
self._delta = delta
self._n_users_for_eval = n_users_for_eval
self._num_bootstrap_resampling_times = num_bootstrap_resampling_times
self._show_progress = show_progress
self._with_timer = with_timer
self._cap_eps = cap_eps
self._use_fnr_tnr = use_fnr_tnr
self._use_upper_bound = use_upper_bound
self._tpr_target = tpr_target
self._tpr_threshold_width = tpr_threshold_width
self._num_thresholds: int
self._timer_stats: TimerStats = {}
if self._n_users_for_eval < 0:
raise ValueError(
'Input to AnalysisNode "n_users_for_eval" must be nonnegative'
)
if self._tpr_target is not None:
assert isinstance(self._tpr_target, float)
if self._tpr_target < 0.01 or self._tpr_target > 1.0:
raise ValueError(
'Input to AnalysisNode "tpr_target" must be between 0.01 and 1.0'
)
if self._tpr_threshold_width <= 0:
raise ValueError(
'Input to AnalysisNode "tpr_threshold_width" must be positive'
)
if not np.isclose(
0.99 / self._tpr_threshold_width,
round(0.99 / self._tpr_threshold_width),
):
raise ValueError(
'Input to AnalysisNode "tpr_threshold_width" must evenly divide 0.99. '
"Valid examples: 0.001, 0.0025, 0.003, 0.005, 0.01"
)
# Determine num_thresholds based on tpr_target
if self._tpr_target is None:
# Legacy: 1% intervals (0.01 to 1.0, 100 thresholds)
self._num_thresholds = 100
self._tpr_threshold_width = 0.01
else:
self._tpr_threshold_width = tpr_threshold_width
self._num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1
self._error_thresholds: NDArray[np.floating] = np.linspace(
0.01, 1.0, self._num_thresholds
)
super().__init__(analysis_input=analysis_input)
@staticmethod
def get_tpr_index(
tpr_target: float | None,
tpr_threshold_width: float = 0.0025,
) -> int:
"""
Convert TPR target to array index in the error_thresholds grid.
Uses np.isclose to handle floating-point precision issues with np.linspace.
Raises ValueError if tpr_target does not align with the threshold grid.
Args:
tpr_target: Target TPR value. If None, returns 0 (legacy behavior).
tpr_threshold_width: Width between TPR thresholds.
Returns:
Index into the error_thresholds array.
"""
if tpr_target is None:
return 0
num_thresholds = int((1.0 - 0.01) / tpr_threshold_width) + 1
error_thresholds = np.linspace(0.01, 1.0, num_thresholds)
matches = np.where(np.isclose(error_thresholds, tpr_target))[0]
if len(matches) > 0:
return int(matches[0])
raise ValueError(
f"tpr_target={tpr_target} does not align with the error_thresholds array. "
f"Nearest value is {error_thresholds[np.argmin(np.abs(error_thresholds - tpr_target))]}. "
f"Adjust tpr_target and tpr_threshold_width so that tpr_target falls on the threshold grid."
)
def _calculate_one_off_eps(self) -> float:
df_train_user = self.analysis_input.df_train_user
df_test_user = self.analysis_input.df_test_user
score_train = df_train_user["score"]
score_test = df_test_user["score"]
num_users_for_eps_cp_eval = min(
self._n_users_for_eval,
score_train.shape[0] // 2,
score_test.shape[0] // 2,
)
assert num_users_for_eps_cp_eval > 0 and num_users_for_eps_cp_eval < min(
score_train.shape[0], score_test.shape[0]
)
loss_train = torch.from_numpy(
score_train[:num_users_for_eps_cp_eval].to_numpy()
)
loss_test = torch.from_numpy(score_test[:num_users_for_eps_cp_eval].to_numpy())
results = MIAResults(loss_train, loss_test)
# compute one-off accuracy & AUC & CI for epsilon
_, _, eps_cp = results.compute_acc_auc_ci_epsilon(self._delta)
return eps_cp
@staticmethod
# pyrefly: ignore [bad-specialization]
def _compute_ci(array: NDArray[float], axis: int = 0) -> tuple[NDArray, NDArray]:
"""Compute confidence intervals (used for eps, auc, accuracy)"""
# Sort along the specified axis
sorted_array = np.sort(array, axis=axis)
axis_length = sorted_array.shape[axis]
lower_idx = max(int(0.025 * axis_length) - 1, 0)
upper_idx = int(0.975 * axis_length)
# Index into the sorted array at the percentile positions
lower_bound = np.take(sorted_array, lower_idx, axis=axis)
upper_bound = np.take(sorted_array, upper_idx, axis=axis)
# Ensure return is arrays
if np.isscalar(lower_bound):
lower_bound = np.array([lower_bound])
upper_bound = np.array([upper_bound])
return lower_bound, upper_bound
@staticmethod
def _compute_bootstrap_sample_indexes(
num_users: int,
sample_size: int,
) -> list[int]:
"""
Compute bootstrap indexes by random sampling with replacement for the given sample size from a range [0, num_users)
Args:
num_users (int): number of users for indexes 0..num_users-1
sample_size (int): number of samples among the user indexes
Returns:
A list of indexes (with duplicates)
"""
# pyrefly: ignore [bad-return]
return np.random.randint(0, num_users, sample_size)
def run_analysis(self) -> BaseAnalysisOutput:
"""
Computes analysis outputs based on the input dataframes.
Overrides "BaseAnalysisNode::run_analysis"
First, makes loss_train and loss_test and computes one off metrics like psilon confidence intervals.
Then, uses "make_metrics_array" to build lists of
metrics, each computed from random subsets of
loss_train and loss_test. The length of these lists
is determined by self._num_bootstrap_resampling_times
These metrics are combined into the output of this analysis,
and returned from the call.
Returns:
AnalysisNodeOutput dataclass with fields:
"eps": epsilon at TPR=1% UB threshold if use_upper_bound is True, else epsilon at TPR=1% LB threshold
"eps_fpr_max_lb", "eps_fpr_lb", "eps_fpr_ub": epsilon at various false positive rates
"eps_tpr_lb", "eps_tpr_ub": epsilon at various true positive rates
"eps_max_lb", "eps_max_ub": max of epsilon at various true positive rates and false positive rates
"eps_cp": epsilon calculated via Clopper-Pearson confidence interval
"accuracy", "accuracy_ci": accuracy values
"auc", "auc_ci": area under ROC curve values
"data_size": dictionary with keys "train_size", "test_size", "bootstrap_size"
"""
df_train_user = self.analysis_input.df_train_user
df_test_user = self.analysis_input.df_test_user
score_train = df_train_user["score"]
score_test = df_test_user["score"]
logger.info(
f"Train/Test unique users: {score_train.shape[0]}/{score_test.shape[0]}"
)
train_size, test_size = score_train.shape[0], score_test.shape[0]
eps_cp = self._calculate_one_off_eps()
logger.info(f"Epsilon CP: {eps_cp}")
train_size, test_size = score_train.shape[0], score_test.shape[0]
bootstrap_sample_size = min(train_size, test_size)
with self.timer("make_metrics_array"):
metrics_array = self._make_metrics_array()
accuracy = np.array([run[0] for run in metrics_array])
auc = np.array([run[1] for run in metrics_array])
eps_fpr = np.array([run[2] for run in metrics_array])
eps_tpr = np.array([run[3] for run in metrics_array])
eps_max = np.array([run[4] for run in metrics_array])
# get CI bounds with 95% confidence
accuracy_lb, accuracy_ub = self._compute_ci(accuracy)
auc_lb, auc_ub = self._compute_ci(auc)
eps_fpr_lb, eps_fpr_ub = self._compute_ci(eps_fpr)
eps_tpr_lb, eps_tpr_ub = self._compute_ci(eps_tpr)
eps_max_lb, eps_max_ub = self._compute_ci(eps_max)
accuracy_mean = accuracy.mean()
auc_mean = auc.mean()
eps_tpr_boundary = eps_tpr_ub if self._use_upper_bound else eps_tpr_lb
tpr_idx = AnalysisNode.get_tpr_index(
self._tpr_target, self._tpr_threshold_width
)
outputs = AnalysisNodeOutput(
eps=eps_tpr_boundary[tpr_idx], # epsilon at specified TPR threshold
eps_lb=eps_tpr_lb[tpr_idx],
eps_fpr_max_ub=np.nanmax(eps_fpr_ub),
eps_fpr_lb=list(eps_fpr_lb),
eps_fpr_ub=list(eps_fpr_ub),
eps_tpr_lb=list(eps_tpr_lb),
eps_tpr_ub=list(eps_tpr_ub),
eps_max_lb=list(eps_max_lb),
eps_max_ub=list(eps_max_ub),
eps_cp=eps_cp,
accuracy=accuracy_mean,
accuracy_ci=[accuracy_lb[0], accuracy_ub[0]],
auc=auc_mean,
auc_ci=[auc_lb[0], auc_ub[0]],
data_size={
"train_size": train_size,
"test_size": test_size,
"bootstrap_size": bootstrap_sample_size,
},
tpr_target=self._tpr_target,
tpr_idx=tpr_idx,
)
if self._with_timer:
logger.info(f"Timer stats: {self.get_timer_stats()}")
return outputs
def _make_metrics_array(
self,
) -> list[
tuple[
np.float64, np.float64, list[np.float64], list[np.float64], list[np.float64]
]
]:
"""
Make list of tuples metrics at error thresholds, each of which contains the
accuracy, AUC, and epsilon values for a given number of samples
The tuples are randomly generated from permutations of subsets of loss_train
and loss_test.
Args:
N: Length of sublist of loss_train and loss_test to pass into MIAResults
Returns:
List[Tuple] with elements
(accuracy,
auc_value,
eps_fpr_array,
eps_tpr_array,
eps_max_array)
"""
score_train = self.analysis_input.df_train_user["score"]
score_test = self.analysis_input.df_test_user["score"]
# error thresholds set equally spaced at 1% intervals
loss_train = torch.from_numpy(score_train.to_numpy())
loss_test = torch.from_numpy(score_test.to_numpy())
train_size, test_size = score_train.shape[0], score_test.shape[0]
bootstrap_sample_size = min(train_size, test_size)
metrics_array = [
MIAResults(
loss_train[
self._compute_bootstrap_sample_indexes(
train_size, bootstrap_sample_size
)
],
loss_test[
self._compute_bootstrap_sample_indexes(
test_size, bootstrap_sample_size
)
],
).compute_metrics_at_error_threshold(
self._delta,
error_threshold=self._error_thresholds,
cap_eps=self._cap_eps,
use_fnr_tnr=self._use_fnr_tnr,
verbose=False,
)
for _ in tqdm(
range(self._num_bootstrap_resampling_times),
disable=not self._show_progress,
)
]
return metrics_array
@contextmanager
def timer(self, name: str) -> Generator[None, None, None]:
"""
Context manager for timing analysis node
"""
if self._with_timer:
start = time.time()
yield
end = time.time()
self._timer_stats[name] = end - start
else:
yield
def get_timer_stats(self) -> TimerStats:
"""
Get timer stats
"""
return self._timer_stats