Skip to content

Commit b24c0b8

Browse files
author
Chenghao Tan
committed
Merge remote-tracking branch 'origin/main' into fix--Fix-SuccessRate
2 parents 6ba1063 + 3f0db2b commit b24c0b8

File tree

17 files changed

+3151
-11
lines changed

17 files changed

+3151
-11
lines changed

data/catalog/_data_main/process_data/process_utils_data.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def get_one_hot_encoding(in_arr):
2222
output: m (ndarray): one-hot encoded matrix
2323
d (dict): also returns a dictionary original_val -> column in encoded matrix
2424
"""
25+
valid_types = (int, np.int64, float, np.float64)
2526

2627
for k in in_arr:
27-
if (
28-
str(type(k)) != "<type 'numpy.float64'>"
29-
and type(k) != int
30-
and type(k) != np.int64
31-
):
28+
# if (
29+
# str(type(k)) != "<type 'numpy.float64'>"
30+
# and type(k) != int
31+
# and type(k) != np.int64
32+
# ):
33+
if not isinstance(k, valid_types):
3234
print(str(type(k)))
3335
print("************* ERROR: Input arr does not have integer types")
3436
return None

experiments/experimental_setup.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ recourse_methods:
159159
hyperparams:
160160
roar:
161161
hyperparams:
162+
larr:
163+
hyperparams:
164+
alpha: 0.5
165+
beta: 1.0
162166
rbr:
163167
hyperparams:
164168
train_data: None

experiments/results.csv

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,68 @@ claproar,boston_housing,linear,,,,,,,,0.0,0.0001886077574454248
15241524
claproar,mortgage,linear,,,,,,,,0.0,0.0002759364084340632
15251525
claproar,twomoon,linear,,,,,,,,0.0,0.00016150964656844736
15261526
claproar,breast_cancer,linear,,,,,,,,0.0,0.00025387420319020746
1527+
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,1.0,1.0,0.0838862200000001
1528+
larr,compass,linear,5.0,6.28278229715521,9.211095016205215,2.28278229715521,4.0,4.0,,,
1529+
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
1530+
larr,compass,linear,4.0,5.28278229715521,8.211095016205217,2.28278229715521,3.0,3.0,,,
1531+
larr,compass,linear,3.0,4.28278229715521,7.211095016205217,2.28278229715521,2.0,2.0,,,
1532+
larr,credit,linear,9.0,6.001184878655714,6.000000574598035,1.0,2.0,9.0,1.0,1.0,1.2996949600000005
1533+
larr,credit,linear,9.0,4.655533630429869,4.148313540135076,1.0,2.0,9.0,,,
1534+
larr,credit,linear,6.0,6.0,6.0,1.0,2.0,6.0,,,
1535+
larr,credit,linear,7.0,4.244995992685195,4.027799497132385,1.0,2.0,7.0,,,
1536+
larr,credit,linear,7.0,5.539559014267185,5.251564915609792,1.0,1.0,7.0,,,
1537+
larr,mortgage,linear,,,,,,,,0.0,0.0972256399999999
1538+
larr,twomoon,linear,1.0,5.073372651230509,25.73911005825368,5.073372651230509,0.0,0.0,1.0,1.0,0.1430016799999997
1539+
larr,twomoon,linear,1.0,5.340802822613361,28.52417479003484,5.340802822613361,0.0,0.0,,,
1540+
larr,twomoon,linear,1.0,4.7617186457016425,22.673964460822685,4.7617186457016425,0.0,0.0,,,
1541+
larr,twomoon,linear,1.0,5.373180659396579,28.871070398513453,5.373180659396579,0.0,0.0,,,
1542+
larr,twomoon,linear,1.0,5.341437020192945,28.53094944068769,5.341437020192945,0.0,0.0,,,
1543+
larr,breast_cancer,linear,16.0,4.187701215673475,1.4833264815255618,0.6244470524510248,0.0,16.0,1.0,1.0,0.1024079199999995
1544+
larr,breast_cancer,linear,16.0,3.658850183852719,0.9913839687825852,0.424483163311366,0.0,16.0,,,
1545+
larr,breast_cancer,linear,15.0,3.2486284417348057,0.9069764153088704,0.4321567211338811,0.0,15.0,,,
1546+
larr,breast_cancer,linear,15.0,3.3811028406893766,1.0604258308017511,0.5690298507462687,0.0,15.0,,,
1547+
larr,breast_cancer,linear,18.0,9.765581080932463,6.492191714227761,0.9852233676975946,0.0,18.0,,,
1548+
larr,boston_housing,linear,6.0,2.421538282818797,1.6839497609854843,1.0,0.0,6.0,0.96,1.0,0.0626189199999998
1549+
larr,boston_housing,linear,6.0,2.900309991515271,1.9805343016374413,1.0,0.0,6.0,,,
1550+
larr,boston_housing,linear,5.0,3.515425322251533,2.913253315859378,1.0,0.0,5.0,,,
1551+
larr,boston_housing,linear,5.0,3.2491099548815363,2.661170630523269,1.0,0.0,5.0,,,
1552+
larr,boston_housing,linear,5.0,1.9458024889093768,0.8761855842637001,0.575589193332056,0.0,5.0,,,
1553+
larr,adult,linear,5.0,3.7114155251141554,3.316157711473906,1.0,1.0,5.0,0.76,1.0,4.95049264
1554+
larr,adult,linear,12.0,11.248401826484018,10.970379266487354,1.0,1.0,12.0,,,
1555+
larr,adult,linear,9.0,7.621907854199746,6.930067436815111,1.0,1.0,9.0,,,
1556+
larr,adult,linear,8.0,7.094977168949772,6.5998882425303895,1.0,1.0,8.0,,,
1557+
larr,adult,linear,8.0,6.738812785388127,6.326666249661184,1.0,1.0,1.0,,,
1558+
larr,german,linear,1.0,0.26702982282381427,0.07130492627731765,0.26702982282381427,0.0,0.0,0.8,0.4,0.04217012000000011
1559+
larr,german,linear,1.0,0.8058765269065699,0.6494369766189955,0.8058765269065699,0.0,0.0,,,
1560+
rbr,twomoon,mlp,2.0,0.7286989127039623,0.2674033401335881,0.3951900744198769,0.0,1.0,1.0,1.0,4.879103590000001
1561+
rbr,twomoon,mlp,2.0,0.6894799992967063,0.2397964433977159,0.3771830935407586,0.0,1.0,,,
1562+
rbr,twomoon,mlp,2.0,0.1409640647837522,0.0189301471125527,0.13754436657573,0.0,1.0,,,
1563+
rbr,twomoon,mlp,2.0,0.3849629894626554,0.1410744506732003,0.3754790343940074,0.0,1.0,,,
1564+
rbr,twomoon,mlp,2.0,1.002957846445284,0.8269444367229775,0.9039601127519769,0.0,1.0,,,
1565+
rbr,twomoon,mlp,2.0,0.1446472630709976,0.0109513530407096,0.0879751120364685,0.0,1.0,,,
1566+
rbr,twomoon,mlp,2.0,0.2456580453177645,0.0308932669559418,0.1417938498564787,0.0,1.0,,,
1567+
rbr,twomoon,mlp,2.0,0.3707915824245534,0.0714714848971661,0.2223300984259902,0.0,1.0,,,
1568+
rbr,twomoon,mlp,2.0,0.373523153433866,0.127161158046786,0.3561743639721993,0.0,1.0,,,
1569+
rbr,twomoon,mlp,2.0,0.6861088172443294,0.265889329731744,0.4665790522798936,0.0,1.0,,,
1570+
rbr,compass,mlp,7.0,0.8117256915000708,0.1751537979559963,0.2588201761245727,5.0,5.0,0.4,1.0,45.98423136
1571+
rbr,compass,mlp,7.0,1.000752511777376,0.4067629195629747,0.479181706905365,5.0,6.0,,,
1572+
rbr,compass,mlp,7.0,0.7570992925841558,0.396842765160404,0.6263000965118408,5.0,6.0,,,
1573+
rbr,compass,mlp,7.0,0.7593743661418557,0.1561166400465182,0.2451530694961547,5.0,5.0,,,
1574+
rbr,compass,mlp,7.0,2.095355091420444,1.0422519967563235,0.5292404294013977,5.0,4.0,,,
1575+
rbr,adult,mlp,,,,,,,,0.0,0.0905930000000001
1576+
rbr,credit,mlp,20.0,0.965264297888805,0.0794397720991071,0.1038587912917137,5.0,20.0,0.0,1.0,59.347741940000006
1577+
rbr,credit,mlp,20.0,0.566989452251398,0.0259825727896292,0.0628717467188835,5.0,20.0,,,
1578+
rbr,credit,mlp,20.0,1.1195734536322235,0.0959842938489572,0.1067807227373123,5.0,19.0,,,
1579+
rbr,credit,mlp,20.0,2.84648185800286,0.7871216235889688,0.3118891716003418,5.0,15.0,,,
1580+
rbr,credit,mlp,20.0,4.035248189802385,1.4515882305083423,0.3944755792617798,5.0,15.0,,,
1581+
rbr,german,mlp,7.0,0.4157270057639756,0.0748703424622665,0.263138996789961,5.0,6.0,0.3199999999999999,1.0,18.5585004
1582+
rbr,german,mlp,7.0,0.915372353215671,0.5424925760678092,0.7276125465120588,5.0,6.0,,,
1583+
rbr,german,mlp,7.0,0.5011241306168176,0.0910087649095372,0.2593848833329897,5.0,5.0,,,
1584+
rbr,german,mlp,7.0,0.4318845116361215,0.081656268455909,0.2743761845484836,5.0,6.0,,,
1585+
rbr,german,mlp,7.0,0.8469910197915712,0.2895495604187955,0.5005260705947876,5.0,6.0,,,
1586+
rbr,boston_housing,mlp,13.0,1.4352224063896322,0.3736564812519745,0.4282479539830633,0.0,10.0,0.3199999999999999,1.0,12.221076720000005
1587+
rbr,boston_housing,mlp,13.0,1.3238735117244025,0.2947727664453581,0.3508007670036611,0.0,9.0,,,
1588+
rbr,boston_housing,mlp,13.0,1.3824252398802377,0.3149165118214244,0.4218325881247825,0.0,10.0,,,
1589+
rbr,boston_housing,mlp,13.0,0.7209683333367833,0.0688352053038206,0.153095543384552,0.0,12.0,,,
1590+
rbr,boston_housing,mlp,13.0,1.0697691839818662,0.1521799871864383,0.2351901829242706,0.0,10.0,,,
1591+
rbr,breast_cancer,mlp,,,,,,,,0.0,0.006507019999999741

experiments/run_experiment.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ def initialize_recourse_method(
172172
return Probe(mlmodel, hyperparams)
173173
elif method == "roar":
174174
return Roar(mlmodel, hyperparams)
175+
elif method == "larr":
176+
return Larr(mlmodel, hyperparams)
175177
elif method == "rbr":
176178
hyperparams["train_data"] = data.df_train.drop(columns=["y"], axis=1)
177179
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -207,7 +209,7 @@ def create_parser():
207209
Default: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
208210
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
209211
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
210-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr"].
212+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr", "larr"].
211213
-n, --number_of_samples: Specifies the number of instances per dataset.
212214
Default: 20.
213215
-s, --train_split: Specifies the split of the available data used for training.
@@ -308,6 +310,7 @@ def create_parser():
308310
"cfrl",
309311
"probe",
310312
"roar",
313+
"larr",
311314
"rbr",
312315
],
313316
help="Recourse methods for experiment",
@@ -427,6 +430,7 @@ def _append_to_csv(path: str, df: pd.DataFrame):
427430
"cfrl",
428431
"probe",
429432
"roar",
433+
"larr",
430434
"rbr",
431435
]
432436
sklearn_methods = ["feature_tweak", "focus", "mace"]

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Gravitational,
2121
Greedy,
2222
GrowingSpheres,
23+
Larr,
2324
Probe,
2425
Revise,
2526
Roar,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .gravitational import Gravitational
1616
from .greedy import Greedy
1717
from .growing_spheres import GrowingSpheres
18+
from .larr import Larr
1819
from .mace import MACE
1920
from .probe import Probe
2021
from .rbr import RBR

methods/catalog/focus/model.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,13 @@ def _filter_hinge_loss(n_class, mask_vector, features, sigma, temperature, model
4646
filtered_input = tf.boolean_mask(features, mask_vector)
4747

4848
# if sigma or temperature are not scalars
49-
if type(sigma) != float or type(sigma) != int:
49+
if not isinstance(
50+
sigma, (float, int)
51+
): # type(sigma) != float or type(sigma) != int:
5052
sigma = tf.boolean_mask(sigma, mask_vector)
51-
if type(temperature) != float or type(temperature) != int:
53+
if not isinstance(
54+
temperature, (float, int)
55+
): # type(temperature) != float or type(temperature) != int:
5256
temperature = tf.boolean_mask(temperature, mask_vector)
5357

5458
# compute loss

methods/catalog/larr/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .model import Larr

0 commit comments

Comments
 (0)