|
| 1 | +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. |
| 2 | + |
| 3 | +# pyre-strict |
| 4 | +import logging |
| 5 | +from dataclasses import dataclass |
| 6 | +from typing import List, Tuple |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from privacy_guard.analysis.base_analysis_output import BaseAnalysisOutput |
| 11 | +from privacy_guard.analysis.lia.lia_analysis_input import LIAAnalysisInput |
| 12 | +from privacy_guard.analysis.mia.analysis_node import AnalysisNode |
| 13 | +from privacy_guard.analysis.mia.mia_results import MIAResults |
| 14 | +from tqdm import tqdm |
| 15 | + |
| 16 | +logger: logging.Logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | +TimerStats = dict[str, float] |
| 19 | + |
| 20 | + |
| 21 | +@dataclass |
| 22 | +class LIAAnalysisOutput(BaseAnalysisOutput): |
| 23 | + """ |
| 24 | + A dataclass to encapsulate the outputs of LIAAnalysisNode. |
| 25 | + """ |
| 26 | + |
| 27 | + eps: float # epsilon UB (highest across all error thresholds) |
| 28 | + eps_lb: float # LB associated with UB epsilon |
| 29 | + # Accuracy and AUC |
| 30 | + accuracy: float |
| 31 | + accuracy_ci: List[float] # confidence interval |
| 32 | + auc: float |
| 33 | + auc_ci: List[float] # confidence interval |
| 34 | + error_rate_at_max_eps: float # error rate at which max eps upper bound is achieved |
| 35 | + eps_max_bounds: Tuple[ |
| 36 | + List[float], List[float] |
| 37 | + ] # eps LB and UB at TPR and FPR thresholds (eps=max(eps_tpr, eps_fpr)) |
| 38 | + eps_at_tpr_bounds: Tuple[ |
| 39 | + List[float], List[float] |
| 40 | + ] # eps LB and UB at TPR thresholds |
| 41 | + eps_at_fpr_bounds: Tuple[ |
| 42 | + List[float], List[float] |
| 43 | + ] # eps LB and UB at FPR thresholds |
| 44 | + # Dataset size |
| 45 | + data_size: int |
| 46 | + label_mean: float |
| 47 | + prediction_mean: float |
| 48 | + prediction_y1_generation_mean: float |
| 49 | + |
| 50 | + |
| 51 | +class LIAAnalysisNode(AnalysisNode): |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + analysis_input: LIAAnalysisInput, |
| 55 | + delta: float, |
| 56 | + num_bootstrap_resampling_times: int = 10, |
| 57 | + cap_eps: bool = True, |
| 58 | + show_progress: bool = False, |
| 59 | + with_timer: bool = False, |
| 60 | + power: float = 0.0, |
| 61 | + use_fnr_and_tnr: bool = False, |
| 62 | + ) -> None: |
| 63 | + if power < 0: |
| 64 | + raise ValueError("Power used for score function must be non-negative") |
| 65 | + |
| 66 | + self._delta = delta |
| 67 | + self._num_bootstrap_resampling_times = num_bootstrap_resampling_times |
| 68 | + self._cap_eps = cap_eps |
| 69 | + self._show_progress = show_progress |
| 70 | + self._with_timer = with_timer |
| 71 | + self._analysis_input = analysis_input |
| 72 | + self._timer_stats: dict[str, float] = {} |
| 73 | + self._power = power |
| 74 | + self._use_fnr_and_tnr = use_fnr_and_tnr |
| 75 | + |
| 76 | + def compute_scores(self, i: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| 77 | + """ |
| 78 | + Compute scores for a given game instantiation. |
| 79 | +
|
| 80 | + Args: |
| 81 | + i (int): index of the game instantiation (i.e. the i-th game) |
| 82 | +
|
| 83 | + Returns: |
| 84 | + Tuple[torch.Tensor, torch.Tensor]: scores for samples with training labels and reconstructed labels |
| 85 | + """ |
| 86 | + |
| 87 | + true_bits = self._analysis_input.true_bits[i] |
| 88 | + |
| 89 | + received_labels = self._analysis_input.received_labels[i] |
| 90 | + y1_probs = self._analysis_input.predictions_y1_generation |
| 91 | + predictions = self._analysis_input.predictions |
| 92 | + |
| 93 | + prob_train = np.where(received_labels == 1, predictions, 1 - predictions) |
| 94 | + prob_reconstruct = np.where(received_labels == 1, y1_probs, 1 - y1_probs) |
| 95 | + prob_diff_label = np.where(received_labels == 1, 1 - y1_probs, y1_probs) |
| 96 | + |
| 97 | + scores = (prob_train - prob_reconstruct) * prob_diff_label**self._power |
| 98 | + |
| 99 | + scores_train = torch.tensor(scores[true_bits == 0]) |
| 100 | + scores_test = torch.tensor(scores[true_bits == 1]) |
| 101 | + |
| 102 | + return scores_train, scores_test |
| 103 | + |
| 104 | + def run_analysis(self) -> BaseAnalysisOutput: |
| 105 | + """Run LIA analysis""" |
| 106 | + |
| 107 | + error_thresholds = np.linspace(0.01, 1, 100) |
| 108 | + num_resampling = self._analysis_input.y1.shape[0] |
| 109 | + num_samples = self._analysis_input.y1.shape[1] |
| 110 | + |
| 111 | + # run analysis for each game instance |
| 112 | + all_metrics = [] |
| 113 | + with self.timer("compute all metrics"): |
| 114 | + for i in tqdm(range(num_resampling), disable=not self._show_progress): |
| 115 | + scores_train, scores_test = self.compute_scores(i) |
| 116 | + train_size, test_size = scores_train.shape[0], scores_test.shape[0] |
| 117 | + bootstrap_sample_size = min(train_size, test_size) |
| 118 | + for _ in range(self._num_bootstrap_resampling_times): |
| 119 | + indices_train = AnalysisNode._compute_bootstrap_sample_indexes( |
| 120 | + train_size, bootstrap_sample_size |
| 121 | + ) |
| 122 | + indices_test = AnalysisNode._compute_bootstrap_sample_indexes( |
| 123 | + test_size, bootstrap_sample_size |
| 124 | + ) |
| 125 | + lia_results = MIAResults( |
| 126 | + scores_train=scores_train[indices_train], |
| 127 | + scores_test=scores_test[indices_test], |
| 128 | + ) |
| 129 | + |
| 130 | + # metrics is a tuple: (accuracy, auc_value, eps_fpr_array, eps_tpr_array, eps_max_array) |
| 131 | + metrics = lia_results.compute_metrics_at_error_threshold( |
| 132 | + delta=self._delta, |
| 133 | + error_threshold=error_thresholds, |
| 134 | + cap_eps=self._cap_eps, |
| 135 | + verbose=self._show_progress, |
| 136 | + use_fnr_tnr=self._use_fnr_and_tnr, |
| 137 | + ) |
| 138 | + |
| 139 | + all_metrics.append(metrics) |
| 140 | + |
| 141 | + all_accuracy_values = np.array([run[0] for run in all_metrics]) |
| 142 | + all_auc_values = np.array([run[1] for run in all_metrics]) |
| 143 | + all_eps_fpr_values = np.array([run[2] for run in all_metrics]) |
| 144 | + all_eps_tpr_values = np.array([run[3] for run in all_metrics]) |
| 145 | + all_eps_values = np.array([run[4] for run in all_metrics]) |
| 146 | + |
| 147 | + # Compute upper bounds (95th percentile) for each error_threshold |
| 148 | + eps_lb_per_threshold, eps_ub_per_threshold = self._compute_ci(all_eps_values) |
| 149 | + # Find the maximum eps_ub across all error thresholds |
| 150 | + idx = np.argmax(eps_ub_per_threshold) |
| 151 | + |
| 152 | + error_rate_at_max_eps = error_thresholds[idx] |
| 153 | + |
| 154 | + eps_max_ub = eps_ub_per_threshold[idx] |
| 155 | + eps_lb_at_max_ub = eps_lb_per_threshold[idx] |
| 156 | + |
| 157 | + # Compute lb/ub for accuracy and auc |
| 158 | + accuracy_lb, accuracy_ub = self._compute_ci(np.array(all_accuracy_values)) |
| 159 | + auc_lb, auc_ub = self._compute_ci(np.array(all_auc_values)) |
| 160 | + |
| 161 | + # Compute lb/ub for eps computed using only TPR or only FPR thresholds |
| 162 | + eps_tpr_lb, eps_tpr_ub = self._compute_ci(np.array(all_eps_tpr_values)) |
| 163 | + eps_fpr_lb, eps_fpr_ub = self._compute_ci(np.array(all_eps_fpr_values)) |
| 164 | + |
| 165 | + return LIAAnalysisOutput( |
| 166 | + eps=float(eps_max_ub), |
| 167 | + eps_lb=float(eps_lb_at_max_ub), |
| 168 | + accuracy=np.mean(all_accuracy_values), |
| 169 | + accuracy_ci=[accuracy_lb[0], accuracy_ub[0]], |
| 170 | + auc=np.mean(all_auc_values), |
| 171 | + auc_ci=[auc_lb[0], auc_ub[0]], |
| 172 | + error_rate_at_max_eps=error_rate_at_max_eps, |
| 173 | + eps_max_bounds=(list(eps_lb_per_threshold), list(eps_ub_per_threshold)), |
| 174 | + eps_at_tpr_bounds=(list(eps_tpr_lb), list(eps_tpr_ub)), |
| 175 | + eps_at_fpr_bounds=(list(eps_fpr_lb), list(eps_fpr_ub)), |
| 176 | + data_size=num_samples, |
| 177 | + label_mean=np.mean(self._analysis_input.y0), |
| 178 | + prediction_mean=np.mean(self._analysis_input.predictions), |
| 179 | + prediction_y1_generation_mean=np.mean( |
| 180 | + self._analysis_input.predictions_y1_generation |
| 181 | + ), |
| 182 | + ) |
0 commit comments