Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions experiments/experimental_setup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,7 @@ recourse_methods:
hyperparams:
loss_type: "BCE"
binary_cat_features: True
cfvae:
hyperparams:
encoded_size: 10
train: True
7 changes: 6 additions & 1 deletion experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def initialize_recourse_method(
return Revise(mlmodel, data, hyperparams)
elif "wachter" in method:
return Wachter(mlmodel, hyperparams)
elif "cfvae" in method:
return CFVAE(mlmodel, hyperparams)
else:
raise ValueError("Recourse method not known")

Expand Down Expand Up @@ -193,7 +195,7 @@ def create_parser():
-r, --recourse_method: Specifies recourse methods for the experiment.
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter"].
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae"].
-n, --number_of_samples: Specifies the number of instances per dataset.
Default: 20.
-s, --train_split: Specifies the split of the available data used for training.
Expand Down Expand Up @@ -260,6 +262,7 @@ def create_parser():
"gs",
"revise",
"wachter",
"cfvae",
],
choices=[
"dice",
Expand All @@ -281,6 +284,7 @@ def create_parser():
"mace",
"revise",
"wachter",
"cfvae",
],
help="Recourse methods for experiment",
)
Expand Down Expand Up @@ -360,6 +364,7 @@ def create_parser():
"gravitational",
"wachter",
"revise",
"cfvae",
]
sklearn_methods = ["feature_tweak", "focus", "mace"]

Expand Down
1 change: 1 addition & 0 deletions methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
Revise,
Roar,
Wachter,
CFVAE,
)
2 changes: 1 addition & 1 deletion methods/api/recourse_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,4 @@ def get_counterfactuals(self, factuals: pd.DataFrame):
pd.DataFrame
Encoded and normalised counterfactual examples.
"""
pass
return pd.DataFrame()
1 change: 1 addition & 0 deletions methods/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
from .revise import Revise
from .roar import Roar
from .wachter import Wachter
from .cfvae import CFVAE
3 changes: 3 additions & 0 deletions methods/catalog/cfvae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .model import CFVAE
Loading