Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
0eff424
init commit. Added code from Originial repo to here, slightly modifie…
HashirA123 Oct 5, 2025
bce578f
clean up, WIP
HashirA123 Oct 5, 2025
9e1f0b8
added RBR to init
HashirA123 Oct 8, 2025
8508444
Seperated work done in This branch
HashirA123 Oct 21, 2025
76d6bfd
Added rough outline for the reproduce
HashirA123 Oct 24, 2025
e580b7a
Made the reproduce a test function
HashirA123 Oct 24, 2025
3318741
init commit. Added code from Originial repo to here, slightly modifie…
HashirA123 Oct 5, 2025
d2c9067
clean up, WIP
HashirA123 Oct 5, 2025
024b3d9
added RBR to init
HashirA123 Oct 8, 2025
e4674b1
Seperated work done in This branch
HashirA123 Oct 21, 2025
42b03af
Added rough outline for the reproduce
HashirA123 Oct 24, 2025
e19a431
Made the reproduce a test function
HashirA123 Oct 24, 2025
03deb15
Merge branch 'RBR-model' of https://github.com/HashirA123/recourse_be…
HashirA123 Oct 31, 2025
6c6123b
Getting RBR method to run
HashirA123 Nov 2, 2025
bf46302
Got Reproduce working
HashirA123 Nov 5, 2025
5e5e458
Reproduce file is functioning
HashirA123 Nov 7, 2025
680da7d
Merge remote-tracking branch 'origin/main' into RBR-model
HashirA123 Nov 9, 2025
617f5fa
ran run_experiment for RBR
HashirA123 Nov 9, 2025
e816c9e
resolved merge conflicts
HashirA123 Nov 12, 2025
3ce0d6b
fixed issue regarding cuda-cpu mismatch on Probe
HashirA123 Nov 13, 2025
deb2fa4
fixed cuda-cpu mismatch error for RBR
HashirA123 Nov 13, 2025
993d8b2
ran pre-commit hooks
HashirA123 Nov 13, 2025
9ac893d
reverted breaking changes to the requirements-dev.txt
HashirA123 Nov 13, 2025
8a60036
fixed validity calculations
HashirA123 Nov 20, 2025
21ea4ca
updated var names
HashirA123 Nov 20, 2025
e4962ff
Reran run_experiments.py
HashirA123 Nov 21, 2025
5b6b9ed
ran precommit hooks
HashirA123 Nov 21, 2025
bccf19e
resolved merge conflicts
HashirA123 Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions experiments/experimental_setup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,7 @@ recourse_methods:
hyperparams:
roar:
hyperparams:
rbr:
hyperparams:
train_data: None
device: "cpu"
422 changes: 227 additions & 195 deletions experiments/results.csv

Large diffs are not rendered by default.

9 changes: 8 additions & 1 deletion experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ def initialize_recourse_method(
return Probe(mlmodel, hyperparams)
elif method == "roar":
return Roar(mlmodel, hyperparams)
elif method == "rbr":
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hyperparams["device"] = dev
return RBR(mlmodel, hyperparams)
else:
raise ValueError("Recourse method not known")

Expand Down Expand Up @@ -199,7 +204,7 @@ def create_parser():
-r, --recourse_method: Specifies recourse methods for the experiment.
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe"].
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe", "rbr"].
-n, --number_of_samples: Specifies the number of instances per dataset.
Default: 20.
-s, --train_split: Specifies the split of the available data used for training.
Expand Down Expand Up @@ -292,6 +297,7 @@ def create_parser():
"cfvae",
"probe",
"roar",
"rbr",
],
help="Recourse methods for experiment",
)
Expand Down Expand Up @@ -374,6 +380,7 @@ def create_parser():
"cfvae",
"probe",
"roar",
"rbr",
]
sklearn_methods = ["feature_tweak", "focus", "mace"]

Expand Down
1 change: 1 addition & 0 deletions methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
CRUD,
FOCUS,
MACE,
RBR,
ActionableRecourse,
CausalRecourse,
ClaPROAR,
Expand Down
1 change: 1 addition & 0 deletions methods/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .growing_spheres import GrowingSpheres
from .mace import MACE
from .probe import Probe
from .rbr import RBR
from .revise import Revise
from .roar import Roar
from .wachter import Wachter
13 changes: 8 additions & 5 deletions methods/catalog/probe/library/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,15 @@ def perturb_sample(x, n_samples, sigma2):
return X + eps


def reparametrization_trick(mu, sigma2, n_samples):
def reparametrization_trick(mu, sigma2, device, n_samples):
# var = torch.eye(mu.shape[1]) * sigma2
std = torch.sqrt(sigma2)
std = torch.sqrt(sigma2).to(device)
epsilon = MultivariateNormal(
loc=torch.zeros(mu.shape[1]), covariance_matrix=torch.eye(mu.shape[1])
)
epsilon = epsilon.sample((n_samples,)) # standard Gaussian random noise
ones = torch.ones_like(epsilon)
epsilon = epsilon.to(device)
ones = torch.ones_like(epsilon).to(device)
random_samples = mu.reshape(-1) * ones + std * epsilon

return random_samples
Expand Down Expand Up @@ -176,7 +177,9 @@ def probe_recourse(
costs = []
ces = []

random_samples = reparametrization_trick(x_new, noise_variance, n_samples=1000)
random_samples = reparametrization_trick(
x_new, noise_variance, device, n_samples=1000
)
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)

while (f_x_new <= DECISION_THRESHOLD) or (
Expand Down Expand Up @@ -226,7 +229,7 @@ def probe_recourse(
optimizer.step()

random_samples = reparametrization_trick(
x_new, noise_variance, n_samples=10000
x_new, noise_variance, device, n_samples=10000
)
invalidation_rate = compute_invalidation_rate(torch_model, random_samples)

Expand Down
3 changes: 3 additions & 0 deletions methods/catalog/rbr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .model import RBR
3 changes: 3 additions & 0 deletions methods/catalog/rbr/library/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .rbr_loss import robust_bayesian_recourse
Loading