forked from facebookresearch/PrivacyGuard
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlira_attack.py
More file actions
233 lines (205 loc) · 9.2 KB
/
lira_attack.py
File metadata and controls
233 lines (205 loc) · 9.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pyre-strict
import logging
from typing import Tuple, Union
import pandas as pd
from pandas import Series
from privacy_guard.analysis.mia.aggregate_analysis_input import (
AggregateAnalysisInput,
AggregationType,
)
from privacy_guard.attacks.base_attack import BaseAttack
from scipy.stats import norm
logger: logging.Logger = logging.getLogger(__name__)
class LiraAttack(BaseAttack):
"""
This is an implementation of an MIA attack
In the LiRA attack, there is a target model (orig) that contains the users in the
hold_out_train set.
There also is a set of N reference (shadow) models that do not
contain these hold_out_train users (in the offline attack) and a further set of N reference shadow models that do contain these hold_out_train users (in the online attack).
For each user, scores in the shadow tables are aggregated and then combined with the ones in the orig table to generate a final score
"""
def __init__(
self,
df_train_merge: pd.DataFrame,
df_test_merge: pd.DataFrame,
row_aggregation: AggregationType,
user_id_key: str = "user_id",
use_fixed_variance: bool = False,
std_dev_type: str = "global",
online_attack: bool = False,
offline_shadows_evals_in: bool = False,
) -> None:
"""
args:
df_train_merge: training data dataframe
df_test_merge: test data dataframe
has columns "score_orig" from the orig table
also has "score_mean" and "score_std" from the shadow tables
row_aggregation: specifies user aggregation strategy
user_id_key: key corresponding to user id, used for aggregating scores.
used_fixed_variance: whether to use fixed variance or not,
normalizing using the orig scores of the attack.
std_dev_type: specifies the type of standard deviation to be used in the attack calculations.
online_attack: indicates whether the attack is an online attack. It defaults to False, indicating that the attack is offline.
offline_shadows_evals_in: specifies whether the offline shadow evaluations are included in the attack.
It defaults to False, indicating that offline shadow evaluations are not included unless specified otherwise.
Returns:
AnalysisInput has:
start_eval_ds: in sql query, where to start the date range
end_eval_ds: in sql query, where to end the date range
start_eval_ts_hour: int = 0
end_eval_ts_hour: int = 23
users_intersect_eval_window: number of distinct user timestamps allowed in window
apply_hard_cut: whether to apply hard cut
eval_output_table_settings: dict of configuration for output table
"""
self.df_train_merge = df_train_merge
self.df_test_merge = df_test_merge
self.row_aggregation: AggregationType = row_aggregation
self.user_id_key = user_id_key
self.std_dev_type = std_dev_type
self.online_attack = online_attack
self.offline_shadows_evals_in = offline_shadows_evals_in
self.use_fixed_variance = use_fixed_variance
def _get_std_dev(self) -> Tuple[Union[float, Series], Union[float, Series]]:
"""
Get the std dev for the in and out scores.
Returns:
std_in: std dev for the in scores
std_out: std dev for the out scores
"""
std_in, std_out = 0.0, 0.0
match self.std_dev_type:
case "global":
std_in = std_out = (
pd.concat(
[
self.df_train_merge.score_orig,
self.df_test_merge.score_orig,
]
)
).std()
case "shadows_in":
if self.online_attack:
std_in = std_out = pd.concat(
[
self.df_train_merge.score_std_in,
self.df_test_merge.score_std_in,
]
).mean()
else:
# offline case where std dev is computed on the hold out test set
std_in = std_out = pd.concat(
[
self.df_train_merge.score_std,
self.df_test_merge.score_std,
]
).mean()
case "shadows_out":
if self.online_attack:
std_in = std_out = pd.concat(
[
self.df_train_merge.score_std_out,
self.df_test_merge.score_std_out,
]
).mean()
else:
# offline case where std dev is computed on the hold out test set
std_in = std_out = pd.concat(
[
self.df_train_merge.score_std,
self.df_test_merge.score_std,
]
).mean()
case "mix":
if not self.online_attack:
raise ValueError(
"mix std dev type is only supported for online attacks"
)
std_in = pd.concat(
[
self.df_train_merge.score_std_in,
self.df_test_merge.score_std_in,
]
).mean()
std_out = pd.concat(
[
self.df_train_merge.score_std_out,
self.df_test_merge.score_std_out,
]
).mean()
case _:
raise ValueError(f"{self.std_dev_type} is not a valid std_dev type.")
return std_in, std_out
def run_attack(self) -> AggregateAnalysisInput:
"""
Run lira attack on the shadows and original models.
Returns:
AggregateAnalysisInput: input for analysis with train and testing datasets
"""
std_in, std_out = self._get_std_dev()
if self.online_attack:
self.df_train_merge["score"] = norm.logpdf(
self.df_train_merge.score_orig,
self.df_train_merge.score_mean_in,
std_in if self.use_fixed_variance else self.df_train_merge.score_std_in,
) - norm.logpdf(
self.df_train_merge.score_orig,
self.df_train_merge.score_mean_out,
std_out
if self.use_fixed_variance
else self.df_train_merge.score_std_out,
)
self.df_test_merge["score"] = norm.logpdf(
self.df_test_merge.score_orig,
self.df_test_merge.score_mean_in,
std_in if self.use_fixed_variance else self.df_test_merge.score_std_in,
) - norm.logpdf(
self.df_test_merge.score_orig,
self.df_test_merge.score_mean_out,
std_out
if self.use_fixed_variance
else self.df_test_merge.score_std_out,
)
else:
self.df_train_merge["score"] = norm.logpdf(
self.df_train_merge.score_orig,
self.df_train_merge.score_mean,
std_in,
)
self.df_test_merge["score"] = norm.logpdf(
self.df_test_merge.score_orig, self.df_test_merge.score_mean, std_out
)
logger.info(
f"before NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
)
self.df_train_merge = self.df_train_merge.dropna(subset=["score"])
self.df_test_merge = self.df_test_merge.dropna(subset=["score"])
logger.info(
f"after NaN removal for logpdf results: train {self.df_train_merge.shape} and test {self.df_test_merge.shape}"
)
if not (self.online_attack or self.offline_shadows_evals_in):
# this corresponds to the case of offline shadows evals on the hold out test set
self.df_train_merge["score"] = -self.df_train_merge["score"]
self.df_test_merge["score"] = -self.df_test_merge["score"]
analysis_input = AggregateAnalysisInput(
row_aggregation=self.row_aggregation,
df_train_merge=self.df_train_merge,
df_test_merge=self.df_test_merge,
user_id_key=self.user_id_key,
)
return analysis_input