Skip to content

Commit 36ecf3f

Browse files
authored
Merge branch 'main' into feat-CFVAE-Support
2 parents f4af196 + b0032a6 commit 36ecf3f

File tree

13 files changed

+718
-14
lines changed

13 files changed

+718
-14
lines changed

experiments/experimental_setup.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,5 +142,7 @@ recourse_methods:
142142
hyperparams:
143143
encoded_size: 10
144144
train: True
145+
probe:
146+
hyperparams:
145147
roar:
146148
hyperparams:

experiments/results.csv

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,34 @@ cfvae,boston_housing,linear,11.0,1.1642179021592083,0.2220489567468797,0.3127339
15751575
cfvae,boston_housing,linear,12.0,1.0237826877700886,0.15038662312775705,0.24312054669597183,0.0,10.0,,,
15761576
cfvae,boston_housing,linear,13.0,2.354798590209251,1.2301028537404475,1.0,0.0,11.0,,,
15771577
cfvae,boston_housing,linear,12.0,2.148736169232081,1.2422881142210536,1.0,0.0,1.0,,,
1578+
probe,adult,linear,51.0,2.5574682458217044,0.20724095760647737,0.10291039943695068,2.0,51.0,0.0,1.0,11.27795516
1579+
probe,adult,linear,48.0,1.6209024338863478,0.0782517097074313,0.06244194507598877,2.0,51.0,,,
1580+
probe,adult,linear,51.0,6.151970284269187,1.3522899002686382,0.268756240606308,2.0,48.0,,,
1581+
probe,adult,linear,47.0,4.1338594518437715,0.6257325925713797,0.1863815188407898,2.0,44.0,,,
1582+
probe,adult,linear,49.0,6.529014715911816,1.5907548243066623,0.2905552387237549,2.0,27.0,,,
1583+
probe,compass,linear,7.0,1.0493243297463968,0.15758896097912403,0.15593880414962769,5.0,3.0,0.0,1.0,3.9818648800000007
1584+
probe,compass,linear,7.0,1.1968591064214706,0.20529106263232566,0.178464874625206,5.0,3.0,,,
1585+
probe,compass,linear,7.0,0.5305286708631014,0.040413185530297796,0.0809311717748642,5.0,6.0,,,
1586+
probe,compass,linear,7.0,2.068292945623398,0.6289491666450724,0.3217114806175232,5.0,2.0,,,
1587+
probe,compass,linear,6.0,0.11717507370600574,0.002380937481501453,0.024646831676363945,5.0,7.0,,,
1588+
probe,credit,linear,20.0,0.7460355385040021,0.03599241548631901,0.05706929787993431,5.0,18.0,0.2666666666666667,1.0,4.403598066666665
1589+
probe,credit,linear,18.0,0.2715397661723585,0.005983706032008009,0.028109369799494743,5.0,20.0,,,
1590+
probe,credit,linear,18.0,0.2347176402264215,0.004717934816094677,0.025857295840978622,5.0,20.0,,,
1591+
probe,german,linear,4.0,0.34145124321402176,0.02916159435279554,0.08767480622319615,2.0,2.0,0.19999999999999996,1.0,3.3755328400000026
1592+
probe,german,linear,4.0,0.3343678011115066,0.027964667531387127,0.08589213144253283,2.0,2.0,,,
1593+
probe,german,linear,4.0,0.267133860467632,0.017855097803688207,0.06914474771303289,2.0,2.0,,,
1594+
probe,german,linear,4.0,0.3044240432907651,0.023182781515892512,0.07841214362312765,2.0,2.0,,,
1595+
probe,german,linear,4.0,0.2473267066867472,0.015308234751386422,0.06424275391242085,2.0,3.0,,,
1596+
probe,mortgage,linear,2.0,2.596893806724318,3.3737499230649304,1.3286230641796841,0.0,0.0,0.0,1.0,7.078433219999999
1597+
probe,mortgage,linear,2.0,2.7662955305674433,3.8262065210593192,1.3854972163486525,0.0,0.0,,,
1598+
probe,mortgage,linear,2.0,2.9121862424407814,4.241009224725797,1.473339416403974,0.0,0.0,,,
1599+
probe,mortgage,linear,2.0,2.904149937359671,4.217532586486467,1.4677139768545209,0.0,0.0,,,
1600+
probe,mortgage,linear,2.0,3.148416728152398,4.960202167161096,1.6185830196025777,0.0,0.0,,,
1601+
probe,boston_housing,linear,12.0,1.3520136007298742,0.16837960058355547,0.13424224549151464,0.0,9.0,0.0,1.0,3.862365679999999
1602+
probe,boston_housing,linear,11.0,1.007249276669714,0.09299927052223386,0.09963742976493017,0.0,9.0,,,
1603+
probe,boston_housing,linear,13.0,1.3172484222248864,0.15939193355054326,0.12997258850868726,0.0,9.0,,,
1604+
probe,boston_housing,linear,12.0,1.2995529090837012,0.15490400031856688,0.12785010134920172,0.0,9.0,,,
1605+
probe,boston_housing,linear,12.0,1.0059221718260045,0.09273649552678494,0.09945072011744727,0.0,9.0,,,
15781606
roar,adult,linear,5.0,9.618374680646278,18.55040773192179,2.03277587890625,1.0,5.0,0.05999999999999994,1.0,1.0735000799999999
15791607
roar,adult,linear,5.0,8.90917690170632,15.907446315902012,1.8774079084396362,1.0,5.0,,,
15801608
roar,adult,linear,5.0,14.060748848458747,39.60930977247996,2.9490909576416016,1.0,5.0,,,

experiments/run_experiment.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,13 @@ def initialize_recourse_method(
162162
sum(mlmodel.get_mutable_mask())
163163
] + hyperparams["vae_params"]["layers"]
164164
return Revise(mlmodel, data, hyperparams)
165-
elif "wachter" in method:
165+
elif method == "wachter":
166166
return Wachter(mlmodel, hyperparams)
167-
elif "cfvae" in method:
167+
elif method == "cfvae":
168168
return CFVAE(mlmodel, hyperparams)
169-
elif "roar" in method:
169+
elif method == "probe":
170+
return Probe(mlmodel, hyperparams)
171+
elif method == "roar":
170172
return Roar(mlmodel, hyperparams)
171173
else:
172174
raise ValueError("Recourse method not known")
@@ -197,7 +199,7 @@ def create_parser():
197199
-r, --recourse_method: Specifies recourse methods for the experiment.
198200
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
199201
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
200-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar"].
202+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe"].
201203
-n, --number_of_samples: Specifies the number of instances per dataset.
202204
Default: 20.
203205
-s, --train_split: Specifies the split of the available data used for training.
@@ -288,6 +290,7 @@ def create_parser():
288290
"revise",
289291
"wachter",
290292
"cfvae",
293+
"probe",
291294
"roar",
292295
],
293296
help="Recourse methods for experiment",
@@ -369,6 +372,7 @@ def create_parser():
369372
"wachter",
370373
"revise",
371374
"cfvae",
375+
"probe",
372376
"roar",
373377
]
374378
sklearn_methods = ["feature_tweak", "focus", "mace"]

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Gravitational,
1919
Greedy,
2020
GrowingSpheres,
21+
Probe,
2122
Revise,
2223
Roar,
2324
Wachter,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .greedy import Greedy
1616
from .growing_spheres import GrowingSpheres
1717
from .mace import MACE
18+
from .probe import Probe
1819
from .revise import Revise
1920
from .roar import Roar
2021
from .wachter import Wachter

methods/catalog/probe/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# flake8: noqa
2+
3+
from .model import Probe
4+
from .reproduce import test_probe
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .probe import probe_recourse
Lines changed: 286 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,286 @@
1+
import datetime
2+
from typing import List, Optional
3+
4+
import numpy as np
5+
import torch
6+
import torch.distributions.normal as normal_distribution
7+
import torch.optim as optim
8+
from torch import nn
9+
from torch.autograd import Variable
10+
from torch.distributions.multivariate_normal import MultivariateNormal
11+
12+
from methods.processing import reconstruct_encoding_constraints
13+
14+
"""
15+
This file contains the implementation of the Probe method, along with required helper functions
16+
"""
17+
18+
DECISION_THRESHOLD = 0.5
19+
20+
# Mean and variance for rectified normal distribution:
21+
# see in here : http://journal-sfds.fr/article/view/669
22+
23+
24+
def compute_jacobian(inputs, output):
25+
"""
26+
:param inputs: Batch X Size (e.g. Depth X Width X Height)
27+
:param output: Batch X Classes
28+
:return: jacobian: Batch X Classes X Size
29+
"""
30+
assert inputs.requires_grad
31+
grad = gradient(output, inputs)
32+
return grad
33+
34+
35+
def gradient(y, x, grad_outputs=None):
36+
"""Compute dy/dx @ grad_outputs"""
37+
if grad_outputs is None:
38+
grad_outputs = torch.tensor(1, device=y.device)
39+
grad = torch.autograd.grad(y, [x], grad_outputs=grad_outputs, create_graph=True)[0]
40+
return grad
41+
42+
43+
def compute_invalidation_rate_closed(torch_model, x, sigma2):
44+
# Compute input into CDF
45+
prob = torch_model(x)
46+
logit_x = torch.log(prob[0][1] / prob[0][0])
47+
Sigma2 = sigma2 * torch.eye(x.shape[0])
48+
jacobian_x = compute_jacobian(x, logit_x).reshape(-1)
49+
denom = torch.sqrt(sigma2) * torch.norm(jacobian_x, 2)
50+
arg = logit_x / denom
51+
52+
# Evaluate Gaussian cdf
53+
normal = normal_distribution.Normal(loc=0.0, scale=1.0)
54+
normal_cdf = normal.cdf(arg)
55+
56+
# Get invalidation rate
57+
ir = 1 - normal_cdf
58+
59+
return ir
60+
61+
62+
def perturb_sample(x, n_samples, sigma2):
63+
# stack copies of this sample, i.e. n rows of x.
64+
X = x.repeat(n_samples, 1)
65+
# sample normal distributed values
66+
Sigma = torch.eye(x.shape[1]) * sigma2
67+
eps = MultivariateNormal(
68+
loc=torch.zeros(x.shape[1]), covariance_matrix=Sigma
69+
).sample((n_samples,))
70+
71+
return X + eps
72+
73+
74+
def reparametrization_trick(mu, sigma2, n_samples):
75+
# var = torch.eye(mu.shape[1]) * sigma2
76+
std = torch.sqrt(sigma2)
77+
epsilon = MultivariateNormal(
78+
loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1])
79+
)
80+
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
81+
ones = torch.ones_like(epsilon)
82+
random_samples = mu.reshape(-1) * ones + std * epsilon
83+
84+
return random_samples
85+
86+
87+
def compute_invalidation_rate(torch_model, random_samples):
88+
yhat = torch_model(random_samples)[:, 1]
89+
hat = (yhat > 0.5).float()
90+
ir = 1 - torch.mean(hat, 0)
91+
return ir
92+
93+
94+
def probe_recourse(
95+
torch_model,
96+
x: np.ndarray,
97+
cat_feature_indices: List[int],
98+
binary_cat_features: bool = True,
99+
feature_costs: Optional[List[float]] = None,
100+
lr: float = 0.07,
101+
lambda_param: float = 5,
102+
y_target: List[int] = [0.45, 0.55],
103+
n_iter: int = 500,
104+
t_max_min: float = 1.0,
105+
norm: int = 1,
106+
clamp: bool = False,
107+
loss_type: str = "MSE",
108+
invalidation_target: float = 0.45,
109+
inval_target_eps: float = 0.005,
110+
noise_variance: float = 0.01,
111+
) -> np.ndarray:
112+
"""
113+
Generates counterfactual example according to Wachter et.al for input instance x
114+
115+
Parameters
116+
----------
117+
torch_model: black-box-model to discover
118+
x: factual to explain
119+
cat_feature_indices: list of positions of categorical features in x
120+
binary_cat_features: If true, the encoding of x is done by drop_if_binary
121+
feature_costs: List with costs per feature
122+
lr: learning rate for gradient descent
123+
lambda_param: weight factor for feature_cost
124+
y_target: List of one-hot-encoded target class
125+
n_iter: maximum number of iteration
126+
t_max_min: maximum time of search
127+
norm: L-norm to calculate cost
128+
clamp: If true, feature values will be clamped to (0, 1)
129+
loss_type: String for loss function (MSE or BCE)
130+
Invalidation_target: target invalidation rate
131+
inval_target_eps: epsilon for invalidation rate
132+
noise_variance: variance of the normal distribution for sampling
133+
134+
Returns
135+
-------
136+
Counterfactual example as np.ndarray
137+
"""
138+
# device = "cpu" # for simplicity and to avoid Runtime error.
139+
device = "cuda" if torch.cuda.is_available() else "cpu"
140+
141+
torch_model = torch_model.to(device)
142+
# returns counterfactual instance
143+
torch.manual_seed(0)
144+
noise_variance = torch.tensor(noise_variance)
145+
146+
# if feature_costs is not None:
147+
# feature_costs = torch.from_numpy(feature_costs).float().to(device)
148+
149+
# print("x:", x)
150+
151+
x = torch.from_numpy(x).float().to(device)
152+
y_target = torch.tensor(y_target).float().to(device)
153+
lamb = torch.tensor(lambda_param).float().to(device)
154+
# x_new is used for gradient search in optimizing process
155+
x_new = Variable(x.clone(), requires_grad=True)
156+
# x_new_enc is a copy of x_new with reconstructed encoding constraints of x_new
157+
# such that categorical data is either 0 or 1
158+
159+
# x_new_enc = reconstruct_encoding_constraints( #TODO: check if this is needed here, i believe that the encoding is done in the model prediction
160+
# x_new, cat_feature_indices, binary_cat_features
161+
# )
162+
163+
optimizer = optim.Adam([x_new], lr, amsgrad=True)
164+
softmax = nn.Softmax()
165+
166+
if loss_type == "MSE":
167+
loss_fn = torch.nn.MSELoss()
168+
f_x_new = softmax(torch_model(x_new))[:, 1]
169+
else:
170+
loss_fn = torch.nn.BCELoss()
171+
f_x_new = torch_model(x_new)[:, 1]
172+
173+
t0 = datetime.datetime.now()
174+
t_max = datetime.timedelta(minutes=t_max_min)
175+
176+
costs = []
177+
ces = []
178+
179+
random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
180+
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
181+
182+
while (f_x_new <= DECISION_THRESHOLD) or (
183+
invalidation_rate > invalidation_target + inval_target_eps
184+
):
185+
# it = 0
186+
for it in range(n_iter):
187+
# while invalidation_target >= 0.5 and it < n_iter:
188+
189+
optimizer.zero_grad()
190+
# x_new_enc = reconstruct_encoding_constraints(
191+
# x_new, cat_feature_indices, binary_cat_features
192+
# )
193+
# use x_new_enc for prediction results to ensure constraints
194+
# f_x_new = softmax(torch_model(x_new))[:, 1]
195+
f_x_new_binary = torch_model(x_new).squeeze(axis=0)
196+
197+
cost = (
198+
torch.dist(x_new, x, norm)
199+
# if feature_costs is None
200+
# else torch.norm(feature_costs * (x_new - x), norm)
201+
)
202+
203+
# Compute Invalidation loss
204+
# output_mean, output_std = compute_output_dist_suff_statistics(torch_model, x_new,
205+
# noise_variance=noise_variance)
206+
207+
# normal = normal_distribution.Normal(loc=0.0, scale=1.0)
208+
# ratio = torch.divide(output_mean, output_std)
209+
# normal_cdf = normal.cdf(ratio)
210+
# invalidation_rate = 1 - normal_cdf
211+
212+
# invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
213+
invalidation_rate_c = compute_invalidation_rate_closed(
214+
torch_model, x_new, noise_variance
215+
)
216+
217+
# Compute & update losses
218+
loss_invalidation = invalidation_rate_c - invalidation_target
219+
# Hinge loss
220+
loss_invalidation[loss_invalidation < 0] = 0
221+
222+
loss = (
223+
3 * loss_invalidation + loss_fn(f_x_new_binary, y_target) + lamb * cost
224+
)
225+
loss.backward()
226+
optimizer.step()
227+
228+
random_samples = reparametrization_trick(
229+
x_new, noise_variance, n_samples=10000
230+
)
231+
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
232+
233+
# x_pertub = perturb_sample(x_new, sigma2=noise_variance, n_samples=10000)
234+
# pred = 1 - torch_model(x_pertub)[:, 1]
235+
# invalidation_rate_empirical = torch.mean(pred)
236+
237+
# print('-----------------------------------------')
238+
# print('IR empirical', invalidation_rate_empirical)
239+
# print('IR from loss', invalidation_rate)
240+
# print('IR loss', loss_invalidation)
241+
242+
# clamp potential CF
243+
if clamp:
244+
x_new.clone().clamp_(0, 1)
245+
# it += 1
246+
247+
# x_new_enc = reconstruct_encoding_constraints(
248+
# x_new, cat_feature_indices, binary_cat_features
249+
# )
250+
# f_x_new = torch_model(x_new_enc)[:, 1]
251+
f_x_new = torch_model(x_new)[:, 1]
252+
253+
if (f_x_new > DECISION_THRESHOLD) and (
254+
invalidation_rate < invalidation_target + inval_target_eps
255+
):
256+
print("--------------------------------------")
257+
print("invalidation rate:", invalidation_rate)
258+
# print('emp invalidation rate', invalidation_rate_empirical)
259+
print("cost:", cost)
260+
print("classifier output:", f_x_new_binary)
261+
262+
costs.append(cost)
263+
ces.append(x_new)
264+
265+
break
266+
267+
lamb -= 0.10
268+
269+
if datetime.datetime.now() - t0 > t_max:
270+
print("Timeout")
271+
break
272+
273+
if not ces:
274+
print(
275+
"No Counterfactual Explanation Found at that Target Rate - Try Different Target"
276+
)
277+
return x_new.cpu().detach().numpy().squeeze(axis=0)
278+
else:
279+
print("Counterfactual Explanation Found")
280+
costs = torch.tensor(costs)
281+
min_idx = int(torch.argmin(costs).numpy())
282+
x_new_enc = ces[min_idx]
283+
284+
# print("x_prime ", x_new_enc.cpu().detach().numpy().squeeze(axis=0))
285+
286+
return x_new_enc.cpu().detach().numpy().squeeze(axis=0)

0 commit comments

Comments
 (0)