Skip to content

Commit 9d136e1

Browse files
authored
Merge branch 'main' into feat--CFRL-Support
2 parents 6fbc427 + 4b565db commit 9d136e1

File tree

16 files changed

+4053
-202
lines changed

16 files changed

+4053
-202
lines changed

experiments/experimental_setup.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,7 @@ recourse_methods:
159159
hyperparams:
160160
roar:
161161
hyperparams:
162+
rbr:
163+
hyperparams:
164+
train_data: None
165+
device: "cpu"

experiments/results.csv

Lines changed: 227 additions & 195 deletions
Large diffs are not rendered by default.

experiments/run_experiment.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ def initialize_recourse_method(
172172
return Probe(mlmodel, hyperparams)
173173
elif method == "roar":
174174
return Roar(mlmodel, hyperparams)
175+
elif method == "rbr":
176+
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
177+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
178+
hyperparams["device"] = dev
179+
return RBR(mlmodel, hyperparams)
175180
else:
176181
raise ValueError("Recourse method not known")
177182

@@ -200,9 +205,9 @@ def create_parser():
200205
Choices: ["mlp", "linear", "forest"].
201206
-r, --recourse_method: Specifies recourse methods for the experiment.
202207
Default: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
203-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar"].
208+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
204209
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
205-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar"].
210+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
206211
-n, --number_of_samples: Specifies the number of instances per dataset.
207212
Default: 20.
208213
-s, --train_split: Specifies the split of the available data used for training.
@@ -277,6 +282,7 @@ def create_parser():
277282
"cfrl",
278283
"probe",
279284
"roar",
285+
"rbr",
280286
],
281287
choices=[
282288
"dice",
@@ -302,6 +308,7 @@ def create_parser():
302308
"cfrl",
303309
"probe",
304310
"roar",
311+
"rbr",
305312
],
306313
help="Recourse methods for experiment",
307314
)
@@ -385,6 +392,7 @@ def create_parser():
385392
"cfrl",
386393
"probe",
387394
"roar",
395+
"rbr",
388396
]
389397
sklearn_methods = ["feature_tweak", "focus", "mace"]
390398

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
CRUD,
1010
FOCUS,
1111
MACE,
12+
RBR,
1213
ActionableRecourse,
1314
CausalRecourse,
1415
ClaPROAR,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .growing_spheres import GrowingSpheres
1818
from .mace import MACE
1919
from .probe import Probe
20+
from .rbr import RBR
2021
from .revise import Revise
2122
from .roar import Roar
2223
from .wachter import Wachter

methods/catalog/probe/library/probe.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ def perturb_sample(x, n_samples, sigma2):
7171
return X + eps
7272

7373

74-
def reparametrization_trick(mu, sigma2, n_samples):
74+
def reparametrization_trick(mu, sigma2, device, n_samples):
7575
# var = torch.eye(mu.shape[1]) * sigma2
76-
std = torch.sqrt(sigma2)
76+
std = torch.sqrt(sigma2).to(device)
7777
epsilon = MultivariateNormal(
7878
loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1])
7979
)
8080
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
81-
ones = torch.ones_like(epsilon)
81+
epsilon = epsilon.to(device)
82+
ones = torch.ones_like(epsilon).to(device)
8283
random_samples = mu.reshape(-1) * ones + std * epsilon
8384

8485
return random_samples
@@ -176,7 +177,9 @@ def probe_recourse(
176177
costs = []
177178
ces = []
178179

179-
random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
180+
random_samples = reparametrization_trick(
181+
x_new, noise_variance, device, n_samples=1000
182+
)
180183
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
181184

182185
while (f_x_new <= DECISION_THRESHOLD) or (
@@ -226,7 +229,7 @@ def probe_recourse(
226229
optimizer.step()
227230

228231
random_samples = reparametrization_trick(
229-
x_new, noise_variance, n_samples=10000
232+
x_new, noise_variance, device, n_samples=10000
230233
)
231234
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
232235

methods/catalog/rbr/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .model import RBR
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .rbr_loss import robust_bayesian_recourse

0 commit comments

Comments
 (0)