Skip to content

Commit 3f0db2b

Browse files
authored
Merge pull request #34 from HashirA123/LARR-method
feat: Larr method
2 parents 7140449 + 287534e commit 3f0db2b

File tree

17 files changed

+3121
-13
lines changed

17 files changed

+3121
-13
lines changed

data/catalog/_data_main/process_data/process_utils_data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def get_one_hot_encoding(in_arr):
2222
output: m (ndarray): one-hot encoded matrix
2323
d (dict): also returns a dictionary original_val -> column in encoded matrix
2424
"""
25+
valid_types = (int, np.int64, float, np.float64)
2526

2627
for k in in_arr:
27-
if (
28-
str(type(k)) != "<type 'numpy.float64'>"
29-
and type(k) != int
30-
and type(k) != np.int64
31-
):
28+
# if (
29+
# str(type(k)) != "<type 'numpy.float64'>"
30+
# and type(k) != int
31+
# and type(k) != np.int64
32+
# ):
33+
if not isinstance(k, valid_types):
3234
print(str(type(k)))
3335
print("************* ERROR: Input arr does not have integer types")
3436
return None

experiments/experimental_setup.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ recourse_methods:
159159
hyperparams:
160160
roar:
161161
hyperparams:
162+
larr:
163+
hyperparams:
164+
alpha: 0.5
165+
beta: 1.0
162166
rbr:
163167
hyperparams:
164168
train_data: None

experiments/results.csv

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -896,7 +896,7 @@ greedy,twomoon,linear,0.0,3.767193501591493e-09,7.308982136233745e-18,2.21002338
896896
greedy,twomoon,linear,0.0,3.5754390226294694e-08,7.772752307703987e-16,2.6186437573905152e-08,0.0,0.0,,,
897897
greedy,twomoon,linear,0.0,2.6068826278202728e-08,5.69812657066172e-16,2.3758703449061613e-08,0.0,0.0,,,
898898
greedy,twomoon,linear,0.0,1.6624034954171307e-08,1.4464775997720767e-16,1.0110418280362412e-08,0.0,0.0,,,
899-
greedy,twomoon,linear,0.0,3.810764737099959e-08,8.018428030433007e-16,2.520794462057552e-08,0.0,0.0,,,
899+
greedy,twomoon,linear,0.0,3.810764737099959e-08,8.0184280304330075e-16,2.520794462057552e-08,0.0,0.0,,,
900900
greedy,twomoon,linear,0.0,2.0360846092470908e-08,3.2369424982559667e-16,1.7809716035266376e-08,0.0,0.0,,,
901901
greedy,twomoon,linear,0.0,3.3704222579533656e-08,7.973010507410352e-16,2.755990613501069e-08,0.0,0.0,,,
902902
greedy,twomoon,linear,0.0,4.011398219150309e-08,8.634995116366633e-16,2.5485330068519832e-08,0.0,0.0,,,
@@ -1414,7 +1414,7 @@ claproar,twomoon,linear,0.0,2.4025606226718565e-08,3.495247200454482e-16,1.75314
14141414
claproar,twomoon,linear,0.0,2.2577137270829453e-08,3.667492287085482e-16,1.876806277056886e-08,0.0,0.0,,,
14151415
claproar,twomoon,linear,0.0,1.6213392628472434e-08,1.431988503656087e-16,1.0531753025233572e-08,0.0,0.0,,,
14161416
claproar,twomoon,linear,0.0,3.312241828035134e-08,7.202236947752843e-16,2.5826099814274528e-08,0.0,0.0,,,
1417-
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877535e-16,2.330451231991049e-08,0.0,0.0,,,
1417+
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877539e-16,2.3304512319910493e-08,0.0,0.0,,,
14181418
cfvae,adult,linear,14.0,11.148035683882576,10.427443757409122,1.0,2.0,14.0,1.0,1.0,0.0021751707477960735
14191419
cfvae,adult,linear,11.0,8.964258306384831,8.36082031186287,1.0,2.0,11.0,,,
14201420
cfvae,adult,linear,11.0,8.378064919937332,7.676155956565286,1.0,1.0,11.0,,,
@@ -1774,6 +1774,39 @@ roar,boston_housing,linear,12.0,10.45135350677032,10.479898040589251,1.069328410
17741774
roar,boston_housing,linear,12.0,9.480605434904518,8.787099818607752,0.9835587124161372,0.0,12.0,,,
17751775
roar,boston_housing,linear,12.0,9.171005273011977,8.215278014423253,0.951259922770738,0.0,12.0,,,
17761776
roar,boston_housing,linear,11.0,11.091221900347287,12.334272797155895,1.1652295128925485,0.0,11.0,,,
1777+
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,1.0,1.0,0.0838862200000001
1778+
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,,,
1779+
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
1780+
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
1781+
larr,compass,linear,3.0,4.28278229715521,7.211095016205217,2.28278229715521,2.0,2.0,,,
1782+
larr,credit,linear,9.0,6.001184878655714,6.000000574598035,1.0,2.0,9.0,1.0,1.0,1.2996949600000005
1783+
larr,credit,linear,9.0,4.655533630429869,4.148313540135076,1.0,2.0,9.0,,,
1784+
larr,credit,linear,6.0,6.0,6.0,1.0,2.0,6.0,,,
1785+
larr,credit,linear,7.0,4.244995992685195,4.027799497132385,1.0,2.0,7.0,,,
1786+
larr,credit,linear,7.0,5.539559014267185,5.251564915609792,1.0,1.0,7.0,,,
1787+
larr,mortgage,linear,,,,,,,,0.0,0.0972256399999999
1788+
larr,twomoon,linear,1.0,5.073372651230509,25.73911005825368,5.073372651230509,0.0,0.0,1.0,1.0,0.1430016799999997
1789+
larr,twomoon,linear,1.0,5.340802822613361,28.52417479003484,5.340802822613361,0.0,0.0,,,
1790+
larr,twomoon,linear,1.0,4.7617186457016425,22.673964460822685,4.7617186457016425,0.0,0.0,,,
1791+
larr,twomoon,linear,1.0,5.373180659396579,28.871070398513453,5.373180659396579,0.0,0.0,,,
1792+
larr,twomoon,linear,1.0,5.341437020192945,28.53094944068769,5.341437020192945,0.0,0.0,,,
1793+
larr,breast_cancer,linear,16.0,4.187701215673475,1.4833264815255618,0.6244470524510248,0.0,16.0,1.0,1.0,0.1024079199999995
1794+
larr,breast_cancer,linear,16.0,3.658850183852719,0.9913839687825852,0.424483163311366,0.0,16.0,,,
1795+
larr,breast_cancer,linear,15.0,3.2486284417348057,0.9069764153088704,0.4321567211338811,0.0,15.0,,,
1796+
larr,breast_cancer,linear,15.0,3.3811028406893766,1.0604258308017511,0.5690298507462687,0.0,15.0,,,
1797+
larr,breast_cancer,linear,18.0,9.765581080932463,6.492191714227761,0.9852233676975946,0.0,18.0,,,
1798+
larr,boston_housing,linear,6.0,2.421538282818797,1.6839497609854843,1.0,0.0,6.0,0.96,1.0,0.0626189199999998
1799+
larr,boston_housing,linear,6.0,2.900309991515271,1.9805343016374413,1.0,0.0,6.0,,,
1800+
larr,boston_housing,linear,5.0,3.515425322251533,2.913253315859378,1.0,0.0,5.0,,,
1801+
larr,boston_housing,linear,5.0,3.2491099548815363,2.661170630523269,1.0,0.0,5.0,,,
1802+
larr,boston_housing,linear,5.0,1.9458024889093768,0.8761855842637001,0.575589193332056,0.0,5.0,,,
1803+
larr,adult,linear,5.0,3.7114155251141554,3.316157711473906,1.0,1.0,5.0,0.76,1.0,4.95049264
1804+
larr,adult,linear,12.0,11.248401826484018,10.970379266487354,1.0,1.0,12.0,,,
1805+
larr,adult,linear,9.0,7.621907854199746,6.930067436815111,1.0,1.0,9.0,,,
1806+
larr,adult,linear,8.0,7.094977168949772,6.5998882425303895,1.0,1.0,8.0,,,
1807+
larr,adult,linear,8.0,6.738812785388127,6.326666249661184,1.0,1.0,1.0,,,
1808+
larr,german,linear,1.0,0.26702982282381427,0.07130492627731765,0.26702982282381427,0.0,0.0,0.8,0.4,0.04217012000000011
1809+
larr,german,linear,1.0,0.8058765269065699,0.6494369766189955,0.8058765269065699,0.0,0.0,,,
17771810
rbr,twomoon,mlp,2.0,0.7286989127039623,0.2674033401335881,0.3951900744198769,0.0,1.0,1.0,1.0,4.879103590000001
17781811
rbr,twomoon,mlp,2.0,0.6894799992967063,0.2397964433977159,0.3771830935407586,0.0,1.0,,,
17791812
rbr,twomoon,mlp,2.0,0.1409640647837522,0.0189301471125527,0.13754436657573,0.0,1.0,,,

experiments/run_experiment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def initialize_recourse_method(
172172
return Probe(mlmodel, hyperparams)
173173
elif method == "roar":
174174
return Roar(mlmodel, hyperparams)
175+
elif method == "larr":
176+
return Larr(mlmodel, hyperparams)
175177
elif method == "rbr":
176178
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
177179
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -207,7 +209,7 @@ def create_parser():
207209
Default: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
208210
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
209211
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
210-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
212+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr", "larr"].
211213
-n, --number_of_samples: Specifies the number of instances per dataset.
212214
Default: 20.
213215
-s, --train_split: Specifies the split of the available data used for training.
@@ -308,6 +310,7 @@ def create_parser():
308310
"cfrl",
309311
"probe",
310312
"roar",
313+
"larr",
311314
"rbr",
312315
],
313316
help="Recourse methods for experiment",
@@ -392,6 +395,7 @@ def create_parser():
392395
"cfrl",
393396
"probe",
394397
"roar",
398+
"larr",
395399
"rbr",
396400
]
397401
sklearn_methods = ["feature_tweak", "focus", "mace"]

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Gravitational,
2121
Greedy,
2222
GrowingSpheres,
23+
Larr,
2324
Probe,
2425
Revise,
2526
Roar,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .gravitational import Gravitational
1616
from .greedy import Greedy
1717
from .growing_spheres import GrowingSpheres
18+
from .larr import Larr
1819
from .mace import MACE
1920
from .probe import Probe
2021
from .rbr import RBR

methods/catalog/focus/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ def _filter_hinge_loss(n_class, mask_vector, features, sigma, temperature, model
4646
filtered_input = tf.boolean_mask(features, mask_vector)
4747

4848
# if sigma or temperature are not scalars
49-
if type(sigma) != float or type(sigma) != int:
49+
if not isinstance(
50+
sigma, (float, int)
51+
): # type(sigma) != float or type(sigma) != int:
5052
sigma = tf.boolean_mask(sigma, mask_vector)
51-
if type(temperature) != float or type(temperature) != int:
53+
if not isinstance(
54+
temperature, (float, int)
55+
): # type(temperature) != float or type(temperature) != int:
5256
temperature = tf.boolean_mask(temperature, mask_vector)
5357

5458
# compute loss

methods/catalog/larr/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .model import Larr

0 commit comments

Comments
 (0)