Skip to content

Commit b0032a6

Browse files
authored
Merge pull request #32 from HashirA123/Probe
Probe Implementation
2 parents aaf9232 + 14cef2c commit b0032a6

File tree

13 files changed

+715
-11
lines changed

13 files changed

+715
-11
lines changed

experiments/experimental_setup.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,5 +138,7 @@ recourse_methods:
138138
hyperparams:
139139
loss_type: "BCE"
140140
binary_cat_features: True
141+
probe:
142+
hyperparams:
141143
roar:
142144
hyperparams:

experiments/results.csv

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,34 @@ claproar,twomoon,linear,0.0,2.2577137270829443e-08,3.667492287085482e-16,1.87680
14151415
claproar,twomoon,linear,0.0,1.6213392628472434e-08,1.431988503656087e-16,1.0531753025233572e-08,0.0,0.0,,,
14161416
claproar,twomoon,linear,0.0,3.312241828035134e-08,7.202236947752843e-16,2.5826099814274528e-08,0.0,0.0,,,
14171417
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877535e-16,2.330451231991049e-08,0.0,0.0,,,
1418+
probe,adult,linear,51.0,2.5574682458217044,0.20724095760647737,0.10291039943695068,2.0,51.0,0.0,1.0,11.27795516
1419+
probe,adult,linear,48.0,1.6209024338863478,0.0782517097074313,0.06244194507598877,2.0,51.0,,,
1420+
probe,adult,linear,51.0,6.151970284269187,1.3522899002686382,0.268756240606308,2.0,48.0,,,
1421+
probe,adult,linear,47.0,4.1338594518437715,0.6257325925713797,0.1863815188407898,2.0,44.0,,,
1422+
probe,adult,linear,49.0,6.529014715911816,1.5907548243066623,0.2905552387237549,2.0,27.0,,,
1423+
probe,compass,linear,7.0,1.0493243297463968,0.15758896097912403,0.15593880414962769,5.0,3.0,0.0,1.0,3.9818648800000007
1424+
probe,compass,linear,7.0,1.1968591064214706,0.20529106263232566,0.178464874625206,5.0,3.0,,,
1425+
probe,compass,linear,7.0,0.5305286708631014,0.040413185530297796,0.0809311717748642,5.0,6.0,,,
1426+
probe,compass,linear,7.0,2.068292945623398,0.6289491666450724,0.3217114806175232,5.0,2.0,,,
1427+
probe,compass,linear,6.0,0.11717507370600574,0.002380937481501453,0.024646831676363945,5.0,7.0,,,
1428+
probe,credit,linear,20.0,0.7460355385040021,0.03599241548631901,0.05706929787993431,5.0,18.0,0.2666666666666667,1.0,4.403598066666665
1429+
probe,credit,linear,18.0,0.2715397661723585,0.005983706032008009,0.028109369799494743,5.0,20.0,,,
1430+
probe,credit,linear,18.0,0.2347176402264215,0.004717934816094677,0.025857295840978622,5.0,20.0,,,
1431+
probe,german,linear,4.0,0.34145124321402176,0.02916159435279554,0.08767480622319615,2.0,2.0,0.19999999999999996,1.0,3.3755328400000026
1432+
probe,german,linear,4.0,0.3343678011115066,0.027964667531387127,0.08589213144253283,2.0,2.0,,,
1433+
probe,german,linear,4.0,0.267133860467632,0.017855097803688207,0.06914474771303289,2.0,2.0,,,
1434+
probe,german,linear,4.0,0.3044240432907651,0.023182781515892512,0.07841214362312765,2.0,2.0,,,
1435+
probe,german,linear,4.0,0.2473267066867472,0.015308234751386422,0.06424275391242085,2.0,3.0,,,
1436+
probe,mortgage,linear,2.0,2.596893806724318,3.3737499230649304,1.3286230641796841,0.0,0.0,0.0,1.0,7.078433219999999
1437+
probe,mortgage,linear,2.0,2.7662955305674433,3.8262065210593192,1.3854972163486525,0.0,0.0,,,
1438+
probe,mortgage,linear,2.0,2.9121862424407814,4.241009224725797,1.473339416403974,0.0,0.0,,,
1439+
probe,mortgage,linear,2.0,2.904149937359671,4.217532586486467,1.4677139768545209,0.0,0.0,,,
1440+
probe,mortgage,linear,2.0,3.148416728152398,4.960202167161096,1.6185830196025777,0.0,0.0,,,
1441+
probe,boston_housing,linear,12.0,1.3520136007298742,0.16837960058355547,0.13424224549151464,0.0,9.0,0.0,1.0,3.862365679999999
1442+
probe,boston_housing,linear,11.0,1.007249276669714,0.09299927052223386,0.09963742976493017,0.0,9.0,,,
1443+
probe,boston_housing,linear,13.0,1.3172484222248864,0.15939193355054326,0.12997258850868726,0.0,9.0,,,
1444+
probe,boston_housing,linear,12.0,1.2995529090837012,0.15490400031856688,0.12785010134920172,0.0,9.0,,,
1445+
probe,boston_housing,linear,12.0,1.0059221718260045,0.09273649552678494,0.09945072011744727,0.0,9.0,,,
14181446
roar,adult,linear,5.0,9.618374680646278,18.55040773192179,2.03277587890625,1.0,5.0,0.05999999999999994,1.0,1.0735000799999999
14191447
roar,adult,linear,5.0,8.90917690170632,15.907446315902012,1.8774079084396362,1.0,5.0,,,
14201448
roar,adult,linear,5.0,14.060748848458747,39.60930977247996,2.9490909576416016,1.0,5.0,,,

experiments/run_experiment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,8 @@ def initialize_recourse_method(
164164
return Revise(mlmodel, data, hyperparams)
165165
elif "wachter" in method:
166166
return Wachter(mlmodel, hyperparams)
167+
elif method == "probe":
168+
return Probe(mlmodel, hyperparams)
167169
elif method == "roar":
168170
return Roar(mlmodel, hyperparams)
169171
else:
@@ -195,7 +197,7 @@ def create_parser():
195197
-r, --recourse_method: Specifies recourse methods for the experiment.
196198
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
197199
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
198-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "roar"].
200+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "roar", "probe"].
199201
-n, --number_of_samples: Specifies the number of instances per dataset.
200202
Default: 20.
201203
-s, --train_split: Specifies the split of the available data used for training.
@@ -284,6 +286,7 @@ def create_parser():
284286
"mace",
285287
"revise",
286288
"wachter",
289+
"probe",
287290
"roar",
288291
],
289292
help="Recourse methods for experiment",
@@ -364,6 +367,7 @@ def create_parser():
364367
"gravitational",
365368
"wachter",
366369
"revise",
370+
"probe",
367371
"roar",
368372
]
369373
sklearn_methods = ["feature_tweak", "focus", "mace"]

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Gravitational,
1818
Greedy,
1919
GrowingSpheres,
20+
Probe,
2021
Revise,
2122
Roar,
2223
Wachter,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .greedy import Greedy
1515
from .growing_spheres import GrowingSpheres
1616
from .mace import MACE
17+
from .probe import Probe
1718
from .revise import Revise
1819
from .roar import Roar
1920
from .wachter import Wachter

methods/catalog/probe/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# flake8: noqa
2+
3+
from .model import Probe
4+
from .reproduce import test_probe
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .probe import probe_recourse
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import datetime
2+
from typing import List, Optional
3+
4+
import numpy as np
5+
import torch
6+
import torch.distributions.normal as normal_distribution
7+
import torch.optim as optim
8+
from torch import nn
9+
from torch.autograd import Variable
10+
from torch.distributions.multivariate_normal import MultivariateNormal
11+
12+
from methods.processing import reconstruct_encoding_constraints
13+
14+
"""
15+
This file contains the implementation of the Probe method, along with required helper functions
16+
"""
17+
18+
DECISION_THRESHOLD = 0.5
19+
20+
# Mean and variance for rectified normal distribution:
21+
# see in here : http://journal-sfds.fr/article/view/669
22+
23+
24+
def compute_jacobian(inputs, output):
25+
"""
26+
:param inputs: Batch X Size (e.g. Depth X Width X Height)
27+
:param output: Batch X Classes
28+
:return: jacobian: Batch X Classes X Size
29+
"""
30+
assert inputs.requires_grad
31+
grad = gradient(output, inputs)
32+
return grad
33+
34+
35+
def gradient(y, x, grad_outputs=None):
36+
"""Compute dy/dx @ grad_outputs"""
37+
if grad_outputs is None:
38+
grad_outputs = torch.tensor(1, device=y.device)
39+
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
40+
return grad
41+
42+
43+
def compute_invalidation_rate_closed(torch_model, x, sigma2):
44+
# Compute input into CDF
45+
prob = torch_model(x)
46+
logit_x = torch.log(prob[0][1] / prob[0][0])
47+
Sigma2 = sigma2 * torch.eye(x.shape[0])
48+
jacobian_x = compute_jacobian(x, logit_x).reshape(-1)
49+
denom = torch.sqrt(sigma2) * torch.norm(jacobian_x, 2)
50+
arg = logit_x / denom
51+
52+
# Evaluate Gaussian cdf
53+
normal = normal_distribution.Normal(loc=0.0, scale=1.0)
54+
normal_cdf = normal.cdf(arg)
55+
56+
# Get invalidation rate
57+
ir = 1 - normal_cdf
58+
59+
return ir
60+
61+
62+
def perturb_sample(x, n_samples, sigma2):
63+
# stack copies of this sample, i.e. n rows of x.
64+
X = x.repeat(n_samples, 1)
65+
# sample normal distributed values
66+
Sigma = torch.eye(x.shape[1]) * sigma2
67+
eps = MultivariateNormal(
68+
loc=torch.zeros(x.shape[1]), covariance_matrix=Sigma
69+
).sample((n_samples,))
70+
71+
return X + eps
72+
73+
74+
def reparametrization_trick(mu, sigma2, n_samples):
75+
# var = torch.eye(mu.shape[1]) * sigma2
76+
std = torch.sqrt(sigma2)
77+
epsilon = MultivariateNormal(
78+
loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1])
79+
)
80+
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
81+
ones = torch.ones_like(epsilon)
82+
random_samples = mu.reshape(-1) * ones + std * epsilon
83+
84+
return random_samples
85+
86+
87+
def compute_invalidation_rate(torch_model, random_samples):
88+
yhat = torch_model(random_samples)[:, 1]
89+
hat = (yhat > 0.5).float()
90+
ir = 1 - torch.mean(hat, 0)
91+
return ir
92+
93+
94+
def probe_recourse(
95+
torch_model,
96+
x: np.ndarray,
97+
cat_feature_indices: List[int],
98+
binary_cat_features: bool = True,
99+
feature_costs: Optional[List[float]] = None,
100+
lr: float = 0.07,
101+
lambda_param: float = 5,
102+
y_target: List[int] = [0.45, 0.55],
103+
n_iter: int = 500,
104+
t_max_min: float = 1.0,
105+
norm: int = 1,
106+
clamp: bool = False,
107+
loss_type: str = "MSE",
108+
invalidation_target: float = 0.45,
109+
inval_target_eps: float = 0.005,
110+
noise_variance: float = 0.01,
111+
) -> np.ndarray:
112+
"""
113+
Generates counterfactual example according to Wachter et.al for input instance x
114+
115+
Parameters
116+
----------
117+
torch_model: black-box-model to discover
118+
x: factual to explain
119+
cat_feature_indices: list of positions of categorical features in x
120+
binary_cat_features: If true, the encoding of x is done by drop_if_binary
121+
feature_costs: List with costs per feature
122+
lr: learning rate for gradient descent
123+
lambda_param: weight factor for feature_cost
124+
y_target: List of one-hot-encoded target class
125+
n_iter: maximum number of iteration
126+
t_max_min: maximum time of search
127+
norm: L-norm to calculate cost
128+
clamp: If true, feature values will be clamped to (0, 1)
129+
loss_type: String for loss function (MSE or BCE)
130+
Invalidation_target: target invalidation rate
131+
inval_target_eps: epsilon for invalidation rate
132+
noise_variance: variance of the normal distribution for sampling
133+
134+
Returns
135+
-------
136+
Counterfactual example as np.ndarray
137+
"""
138+
# device = "cpu" # for simplicity and to avoid Runtime error.
139+
device = "cuda" if torch.cuda.is_available() else "cpu"
140+
141+
torch_model = torch_model.to(device)
142+
# returns counterfactual instance
143+
torch.manual_seed(0)
144+
noise_variance = torch.tensor(noise_variance)
145+
146+
# if feature_costs is not None:
147+
# feature_costs = torch.from_numpy(feature_costs).float().to(device)
148+
149+
# print("x:", x)
150+
151+
x = torch.from_numpy(x).float().to(device)
152+
y_target = torch.tensor(y_target).float().to(device)
153+
lamb = torch.tensor(lambda_param).float().to(device)
154+
# x_new is used for gradient search in optimizing process
155+
x_new = Variable(x.clone(), requires_grad=True)
156+
# x_new_enc is a copy of x_new with reconstructed encoding constraints of x_new
157+
# such that categorical data is either 0 or 1
158+
159+
# x_new_enc = reconstruct_encoding_constraints( #TODO: check if this is needed here, i believe that the encoding is done in the model prediction
160+
# x_new, cat_feature_indices, binary_cat_features
161+
# )
162+
163+
optimizer = optim.Adam([x_new], lr, amsgrad=True)
164+
softmax = nn.Softmax()
165+
166+
if loss_type == "MSE":
167+
loss_fn = torch.nn.MSELoss()
168+
f_x_new = softmax(torch_model(x_new))[:, 1]
169+
else:
170+
loss_fn = torch.nn.BCELoss()
171+
f_x_new = torch_model(x_new)[:, 1]
172+
173+
t0 = datetime.datetime.now()
174+
t_max = datetime.timedelta(minutes=t_max_min)
175+
176+
costs = []
177+
ces = []
178+
179+
random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
180+
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
181+
182+
while (f_x_new <= DECISION_THRESHOLD) or (
183+
invalidation_rate > invalidation_target + inval_target_eps
184+
):
185+
# it = 0
186+
for it in range(n_iter):
187+
# while invalidation_target >= 0.5 and it < n_iter:
188+
189+
optimizer.zero_grad()
190+
# x_new_enc = reconstruct_encoding_constraints(
191+
# x_new, cat_feature_indices, binary_cat_features
192+
# )
193+
# use x_new_enc for prediction results to ensure constraints
194+
# f_x_new = softmax(torch_model(x_new))[:, 1]
195+
f_x_new_binary = torch_model(x_new).squeeze(axis=0)
196+
197+
cost = (
198+
torch.dist(x_new, x, norm)
199+
# if feature_costs is None
200+
# else torch.norm(feature_costs * (x_new - x), norm)
201+
)
202+
203+
# Compute Invalidation loss
204+
# output_mean, output_std = compute_output_dist_suff_statistics(torch_model, x_new,
205+
# noise_variance=noise_variance)
206+
207+
# normal = normal_distribution.Normal(loc=0.0, scale=1.0)
208+
# ratio = torch.divide(output_mean, output_std)
209+
# normal_cdf = normal.cdf(ratio)
210+
# invalidation_rate = 1 - normal_cdf
211+
212+
# invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
213+
invalidation_rate_c = compute_invalidation_rate_closed(
214+
torch_model, x_new, noise_variance
215+
)
216+
217+
# Compute & update losses
218+
loss_invalidation = invalidation_rate_c - invalidation_target
219+
# Hinge loss
220+
loss_invalidation[loss_invalidation < 0] = 0
221+
222+
loss = (
223+
3 * loss_invalidation + loss_fn(f_x_new_binary, y_target) + lamb * cost
224+
)
225+
loss.backward()
226+
optimizer.step()
227+
228+
random_samples = reparametrization_trick(
229+
x_new, noise_variance, n_samples=10000
230+
)
231+
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
232+
233+
# x_pertub = perturb_sample(x_new, sigma2=noise_variance, n_samples=10000)
234+
# pred = 1 - torch_model(x_pertub)[:, 1]
235+
# invalidation_rate_empirical = torch.mean(pred)
236+
237+
# print('-----------------------------------------')
238+
# print('IR empirical', invalidation_rate_empirical)
239+
# print('IR from loss', invalidation_rate)
240+
# print('IR loss', loss_invalidation)
241+
242+
# clamp potential CF
243+
if clamp:
244+
x_new.clone().clamp_(0, 1)
245+
# it += 1
246+
247+
# x_new_enc = reconstruct_encoding_constraints(
248+
# x_new, cat_feature_indices, binary_cat_features
249+
# )
250+
# f_x_new = torch_model(x_new_enc)[:, 1]
251+
f_x_new = torch_model(x_new)[:, 1]
252+
253+
if (f_x_new > DECISION_THRESHOLD) and (
254+
invalidation_rate < invalidation_target + inval_target_eps
255+
):
256+
print("--------------------------------------")
257+
print("invalidation rate:", invalidation_rate)
258+
# print('emp invalidation rate', invalidation_rate_empirical)
259+
print("cost:", cost)
260+
print("classifier output:", f_x_new_binary)
261+
262+
costs.append(cost)
263+
ces.append(x_new)
264+
265+
break
266+
267+
lamb -= 0.10
268+
269+
if datetime.datetime.now() - t0 > t_max:
270+
print("Timeout")
271+
break
272+
273+
if not ces:
274+
print(
275+
"No Counterfactual Explanation Found at that Target Rate - Try Different Target"
276+
)
277+
return x_new.cpu().detach().numpy().squeeze(axis=0)
278+
else:
279+
print("Counterfactual Explanation Found")
280+
costs = torch.tensor(costs)
281+
min_idx = int(torch.argmin(costs).numpy())
282+
x_new_enc = ces[min_idx]
283+
284+
# print("x_prime ", x_new_enc.cpu().detach().numpy().squeeze(axis=0))
285+
286+
return x_new_enc.cpu().detach().numpy().squeeze(axis=0)

0 commit comments

Comments
 (0)