Skip to content

Commit 4b565db

Browse files
authored
Merge pull request charmlab#27 from HashirA123/RBR-model
feat: RBR Implementation
2 parents 36c72aa + bccf19e commit 4b565db

File tree

16 files changed

+4051
-201
lines changed

16 files changed

+4051
-201
lines changed

experiments/experimental_setup.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,7 @@ recourse_methods:
146146
hyperparams:
147147
roar:
148148
hyperparams:
149+
rbr:
150+
hyperparams:
151+
train_data: None
152+
device: "cpu"

experiments/results.csv

Lines changed: 227 additions & 195 deletions
Large diffs are not rendered by default.

experiments/run_experiment.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ def initialize_recourse_method(
170170
return Probe(mlmodel, hyperparams)
171171
elif method == "roar":
172172
return Roar(mlmodel, hyperparams)
173+
elif method == "rbr":
174+
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
175+
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
176+
hyperparams["device"] = dev
177+
return RBR(mlmodel, hyperparams)
173178
else:
174179
raise ValueError("Recourse method not known")
175180

@@ -199,7 +204,7 @@ def create_parser():
199204
-r, --recourse_method: Specifies recourse methods for the experiment.
200205
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
201206
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
202-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe"].
207+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe", "rbr"].
203208
-n, --number_of_samples: Specifies the number of instances per dataset.
204209
Default: 20.
205210
-s, --train_split: Specifies the split of the available data used for training.
@@ -292,6 +297,7 @@ def create_parser():
292297
"cfvae",
293298
"probe",
294299
"roar",
300+
"rbr",
295301
],
296302
help="Recourse methods for experiment",
297303
)
@@ -374,6 +380,7 @@ def create_parser():
374380
"cfvae",
375381
"probe",
376382
"roar",
383+
"rbr",
377384
]
378385
sklearn_methods = ["feature_tweak", "focus", "mace"]
379386

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
CRUD,
99
FOCUS,
1010
MACE,
11+
RBR,
1112
ActionableRecourse,
1213
CausalRecourse,
1314
ClaPROAR,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .growing_spheres import GrowingSpheres
1717
from .mace import MACE
1818
from .probe import Probe
19+
from .rbr import RBR
1920
from .revise import Revise
2021
from .roar import Roar
2122
from .wachter import Wachter

methods/catalog/probe/library/probe.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,15 @@ def perturb_sample(x, n_samples, sigma2):
7171
return X + eps
7272

7373

74-
def reparametrization_trick(mu, sigma2, n_samples):
74+
def reparametrization_trick(mu, sigma2, device, n_samples):
7575
# var = torch.eye(mu.shape[1]) * sigma2
76-
std = torch.sqrt(sigma2)
76+
std = torch.sqrt(sigma2).to(device)
7777
epsilon = MultivariateNormal(
7878
loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1])
7979
)
8080
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
81-
ones = torch.ones_like(epsilon)
81+
epsilon = epsilon.to(device)
82+
ones = torch.ones_like(epsilon).to(device)
8283
random_samples = mu.reshape(-1) * ones + std * epsilon
8384

8485
return random_samples
@@ -176,7 +177,9 @@ def probe_recourse(
176177
costs = []
177178
ces = []
178179

179-
random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
180+
random_samples = reparametrization_trick(
181+
x_new, noise_variance, device, n_samples=1000
182+
)
180183
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
181184

182185
while (f_x_new <= DECISION_THRESHOLD) or (
@@ -226,7 +229,7 @@ def probe_recourse(
226229
optimizer.step()
227230

228231
random_samples = reparametrization_trick(
229-
x_new, noise_variance, n_samples=10000
232+
x_new, noise_variance, device, n_samples=10000
230233
)
231234
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)
232235

methods/catalog/rbr/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .model import RBR
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .rbr_loss import robust_bayesian_recourse

0 commit comments

Comments
 (0)