Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions data/catalog/_data_main/process_data/process_utils_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def get_one_hot_encoding(in_arr):
output: m (ndarray): one-hot encoded matrix
d (dict): also returns a dictionary original_val -> column in encoded matrix
"""
valid_types = (int, np.int64, float, np.float64)

for k in in_arr:
if (
str(type(k)) != "<type 'numpy.float64'>"
and type(k) != int
and type(k) != np.int64
):
# if (
# str(type(k)) != "<type 'numpy.float64'>"
# and type(k) != int
# and type(k) != np.int64
# ):
if not isinstance(k, valid_types):
print(str(type(k)))
print("************* ERROR: Input arr does not have integer types")
return None
Expand Down
4 changes: 4 additions & 0 deletions experiments/experimental_setup.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ recourse_methods:
hyperparams:
roar:
hyperparams:
larr:
hyperparams:
alpha: 0.5
beta: 1.0
rbr:
hyperparams:
train_data: None
Expand Down
37 changes: 35 additions & 2 deletions experiments/results.csv
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ greedy,twomoon,linear,0.0,3.767193501591493e-09,7.308982136233745e-18,2.21002338
greedy,twomoon,linear,0.0,3.5754390226294694e-08,7.772752307703987e-16,2.6186437573905152e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,2.6068826278202728e-08,5.69812657066172e-16,2.3758703449061613e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,1.6624034954171307e-08,1.4464775997720767e-16,1.0110418280362412e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,3.810764737099959e-08,8.018428030433007e-16,2.520794462057552e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,3.810764737099959e-08,8.0184280304330075e-16,2.520794462057552e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,2.0360846092470908e-08,3.2369424982559667e-16,1.7809716035266376e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,3.3704222579533656e-08,7.973010507410352e-16,2.755990613501069e-08,0.0,0.0,,,
greedy,twomoon,linear,0.0,4.011398219150309e-08,8.634995116366633e-16,2.5485330068519832e-08,0.0,0.0,,,
Expand Down Expand Up @@ -1414,7 +1414,7 @@ claproar,twomoon,linear,0.0,2.4025606226718565e-08,3.495247200454482e-16,1.75314
claproar,twomoon,linear,0.0,2.2577137270829453e-08,3.667492287085482e-16,1.876806277056886e-08,0.0,0.0,,,
claproar,twomoon,linear,0.0,1.6213392628472434e-08,1.431988503656087e-16,1.0531753025233572e-08,0.0,0.0,,,
claproar,twomoon,linear,0.0,3.312241828035134e-08,7.202236947752843e-16,2.5826099814274528e-08,0.0,0.0,,,
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877535e-16,2.330451231991049e-08,0.0,0.0,,,
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877539e-16,2.3304512319910493e-08,0.0,0.0,,,
cfvae,adult,linear,14.0,11.148035683882576,10.427443757409122,1.0,2.0,14.0,1.0,1.0,0.0021751707477960735
cfvae,adult,linear,11.0,8.964258306384831,8.36082031186287,1.0,2.0,11.0,,,
cfvae,adult,linear,11.0,8.378064919937332,7.676155956565286,1.0,1.0,11.0,,,
Expand Down Expand Up @@ -1774,6 +1774,39 @@ roar,boston_housing,linear,12.0,10.45135350677032,10.479898040589251,1.069328410
roar,boston_housing,linear,12.0,9.480605434904518,8.787099818607752,0.9835587124161372,0.0,12.0,,,
roar,boston_housing,linear,12.0,9.171005273011977,8.215278014423253,0.951259922770738,0.0,12.0,,,
roar,boston_housing,linear,11.0,11.091221900347287,12.334272797155895,1.1652295128925485,0.0,11.0,,,
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,1.0,1.0,0.0838862200000001
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,,,
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
larr,compass,linear,3.0,4.28278229715521,7.211095016205217,2.28278229715521,2.0,2.0,,,
larr,credit,linear,9.0,6.001184878655714,6.000000574598035,1.0,2.0,9.0,1.0,1.0,1.2996949600000005
larr,credit,linear,9.0,4.655533630429869,4.148313540135076,1.0,2.0,9.0,,,
larr,credit,linear,6.0,6.0,6.0,1.0,2.0,6.0,,,
larr,credit,linear,7.0,4.244995992685195,4.027799497132385,1.0,2.0,7.0,,,
larr,credit,linear,7.0,5.539559014267185,5.251564915609792,1.0,1.0,7.0,,,
larr,mortgage,linear,,,,,,,,0.0,0.0972256399999999
larr,twomoon,linear,1.0,5.073372651230509,25.73911005825368,5.073372651230509,0.0,0.0,1.0,1.0,0.1430016799999997
larr,twomoon,linear,1.0,5.340802822613361,28.52417479003484,5.340802822613361,0.0,0.0,,,
larr,twomoon,linear,1.0,4.7617186457016425,22.673964460822685,4.7617186457016425,0.0,0.0,,,
larr,twomoon,linear,1.0,5.373180659396579,28.871070398513453,5.373180659396579,0.0,0.0,,,
larr,twomoon,linear,1.0,5.341437020192945,28.53094944068769,5.341437020192945,0.0,0.0,,,
larr,breast_cancer,linear,16.0,4.187701215673475,1.4833264815255618,0.6244470524510248,0.0,16.0,1.0,1.0,0.1024079199999995
larr,breast_cancer,linear,16.0,3.658850183852719,0.9913839687825852,0.424483163311366,0.0,16.0,,,
larr,breast_cancer,linear,15.0,3.2486284417348057,0.9069764153088704,0.4321567211338811,0.0,15.0,,,
larr,breast_cancer,linear,15.0,3.3811028406893766,1.0604258308017511,0.5690298507462687,0.0,15.0,,,
larr,breast_cancer,linear,18.0,9.765581080932463,6.492191714227761,0.9852233676975946,0.0,18.0,,,
larr,boston_housing,linear,6.0,2.421538282818797,1.6839497609854843,1.0,0.0,6.0,0.96,1.0,0.0626189199999998
larr,boston_housing,linear,6.0,2.900309991515271,1.9805343016374413,1.0,0.0,6.0,,,
larr,boston_housing,linear,5.0,3.515425322251533,2.913253315859378,1.0,0.0,5.0,,,
larr,boston_housing,linear,5.0,3.2491099548815363,2.661170630523269,1.0,0.0,5.0,,,
larr,boston_housing,linear,5.0,1.9458024889093768,0.8761855842637001,0.575589193332056,0.0,5.0,,,
larr,adult,linear,5.0,3.7114155251141554,3.316157711473906,1.0,1.0,5.0,0.76,1.0,4.95049264
larr,adult,linear,12.0,11.248401826484018,10.970379266487354,1.0,1.0,12.0,,,
larr,adult,linear,9.0,7.621907854199746,6.930067436815111,1.0,1.0,9.0,,,
larr,adult,linear,8.0,7.094977168949772,6.5998882425303895,1.0,1.0,8.0,,,
larr,adult,linear,8.0,6.738812785388127,6.326666249661184,1.0,1.0,1.0,,,
larr,german,linear,1.0,0.26702982282381427,0.07130492627731765,0.26702982282381427,0.0,0.0,0.8,0.4,0.04217012000000011
larr,german,linear,1.0,0.8058765269065699,0.6494369766189955,0.8058765269065699,0.0,0.0,,,
rbr,twomoon,mlp,2.0,0.7286989127039623,0.2674033401335881,0.3951900744198769,0.0,1.0,1.0,1.0,4.879103590000001
rbr,twomoon,mlp,2.0,0.6894799992967063,0.2397964433977159,0.3771830935407586,0.0,1.0,,,
rbr,twomoon,mlp,2.0,0.1409640647837522,0.0189301471125527,0.13754436657573,0.0,1.0,,,
Expand Down
6 changes: 5 additions & 1 deletion experiments/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def initialize_recourse_method(
return Probe(mlmodel, hyperparams)
elif method == "roar":
return Roar(mlmodel, hyperparams)
elif method == "larr":
return Larr(mlmodel, hyperparams)
elif method == "rbr":
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -207,7 +209,7 @@ def create_parser():
Default: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr", "larr"].
-n, --number_of_samples: Specifies the number of instances per dataset.
Default: 20.
-s, --train_split: Specifies the split of the available data used for training.
Expand Down Expand Up @@ -308,6 +310,7 @@ def create_parser():
"cfrl",
"probe",
"roar",
"larr",
"rbr",
],
help="Recourse methods for experiment",
Expand Down Expand Up @@ -392,6 +395,7 @@ def create_parser():
"cfrl",
"probe",
"roar",
"larr",
"rbr",
]
sklearn_methods = ["feature_tweak", "focus", "mace"]
Expand Down
1 change: 1 addition & 0 deletions methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
Gravitational,
Greedy,
GrowingSpheres,
Larr,
Probe,
Revise,
Roar,
Expand Down
1 change: 1 addition & 0 deletions methods/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .gravitational import Gravitational
from .greedy import Greedy
from .growing_spheres import GrowingSpheres
from .larr import Larr
from .mace import MACE
from .probe import Probe
from .rbr import RBR
Expand Down
8 changes: 6 additions & 2 deletions methods/catalog/focus/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ def _filter_hinge_loss(n_class, mask_vector, features, sigma, temperature, model
filtered_input = tf.boolean_mask(features, mask_vector)

# if sigma or temperature are not scalars
if type(sigma) != float or type(sigma) != int:
if not isinstance(
sigma, (float, int)
): # type(sigma) != float or type(sigma) != int:
sigma = tf.boolean_mask(sigma, mask_vector)
if type(temperature) != float or type(temperature) != int:
if not isinstance(
temperature, (float, int)
): # type(temperature) != float or type(temperature) != int:
temperature = tf.boolean_mask(temperature, mask_vector)

# compute loss
Expand Down
3 changes: 3 additions & 0 deletions methods/catalog/larr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# flake8: noqa

from .model import Larr
Loading