Skip to content

Commit 36c72aa

Browse files
authored
Merge pull request #28 from Chenghao-Tan/feat-CFVAE-Support
feat: Add support for CFVAE
2 parents b0032a6 + d7dcfc1 commit 36c72aa

23 files changed

+1959
-4
lines changed

experiments/experimental_setup.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ recourse_methods:
138138
hyperparams:
139139
loss_type: "BCE"
140140
binary_cat_features: True
141+
cfvae:
142+
hyperparams:
143+
encoded_size: 10
144+
train: True
141145
probe:
142146
hyperparams:
143147
roar:

experiments/results.csv

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,114 @@ claproar,twomoon,linear,0.0,2.2577137270829443e-08,3.667492287085482e-16,1.87680
14151415
claproar,twomoon,linear,0.0,1.6213392628472434e-08,1.431988503656087e-16,1.0531753025233572e-08,0.0,0.0,,,
14161416
claproar,twomoon,linear,0.0,3.312241828035134e-08,7.202236947752843e-16,2.5826099814274528e-08,0.0,0.0,,,
14171417
claproar,twomoon,linear,0.0,3.2146713291325575e-08,6.212848124877535e-16,2.330451231991049e-08,0.0,0.0,,,
1418+
cfvae,adult,linear,14.0,11.148035683882576,10.427443757409122,1.0,2.0,14.0,1.0,1.0,0.0021751707477960735
1419+
cfvae,adult,linear,11.0,8.964258306384831,8.36082031186287,1.0,2.0,11.0,,,
1420+
cfvae,adult,linear,11.0,8.378064919937332,7.676155956565286,1.0,1.0,11.0,,,
1421+
cfvae,adult,linear,14.0,11.836404305897002,11.264988777815269,1.0,2.0,14.0,,,
1422+
cfvae,adult,linear,11.0,9.228735055988782,8.518062168544052,1.0,1.0,11.0,,,
1423+
cfvae,adult,linear,12.0,10.07007323509604,9.41834116641392,1.0,1.0,12.0,,,
1424+
cfvae,adult,linear,14.0,11.811383219396687,11.299482430992908,1.0,2.0,14.0,,,
1425+
cfvae,adult,linear,12.0,10.189344386038648,9.53576674349818,1.0,1.0,12.0,,,
1426+
cfvae,adult,linear,11.0,9.126694239662251,8.426433680414107,1.0,1.0,11.0,,,
1427+
cfvae,adult,linear,12.0,9.96024178609456,9.357729667247114,1.0,1.0,12.0,,,
1428+
cfvae,adult,linear,8.0,5.600834699939474,5.190816680907651,1.0,2.0,8.0,,,
1429+
cfvae,adult,linear,8.0,6.093453736661165,5.460893313561802,1.0,1.0,8.0,,,
1430+
cfvae,adult,linear,12.0,9.079412889525498,8.373943611876138,1.0,2.0,12.0,,,
1431+
cfvae,adult,linear,13.0,11.154091261517644,10.445198736457584,1.0,1.0,13.0,,,
1432+
cfvae,adult,linear,11.0,8.506157757643075,7.806209370052858,1.0,1.0,11.0,,,
1433+
cfvae,adult,linear,11.0,9.085598110832716,8.401100406564549,1.0,1.0,11.0,,,
1434+
cfvae,adult,linear,9.0,6.7249886615735575,6.192066859738652,1.0,2.0,9.0,,,
1435+
cfvae,adult,linear,11.0,9.253756023279806,8.537237089327121,1.0,1.0,11.0,,,
1436+
cfvae,adult,linear,13.0,11.330347207698779,10.593045613777402,1.0,2.0,13.0,,,
1437+
cfvae,adult,linear,8.0,5.987881335064834,5.391319099820862,1.0,1.0,8.0,,,
1438+
cfvae,compass,linear,4.0,3.024829186499119,3.000616488502208,1.0,2.0,4.0,1.0,1.0,0.0015247433970216665
1439+
cfvae,compass,linear,4.0,3.024829186499119,3.000616488502208,1.0,2.0,4.0,,,
1440+
cfvae,compass,linear,5.0,4.0244677290320405,4.000598669763985,1.0,3.0,5.0,,,
1441+
cfvae,compass,linear,5.0,4.025518359616399,4.000651186677512,1.0,3.0,5.0,,,
1442+
cfvae,compass,linear,6.0,5.025733323767781,5.000662203952136,1.0,4.0,6.0,,,
1443+
cfvae,compass,linear,6.0,5.00054485890034,5.000000296871222,1.0,4.0,6.0,,,
1444+
cfvae,compass,linear,4.0,3.0277654431564245,3.0007709198336725,1.0,2.0,4.0,,,
1445+
cfvae,compass,linear,6.0,5.053101077087615,5.002819724387865,1.0,4.0,6.0,,,
1446+
cfvae,compass,linear,5.0,4.001863296878965,4.000003471875258,1.0,3.0,5.0,,,
1447+
cfvae,compass,linear,5.0,4.028168947975102,4.000793489630024,1.0,3.0,5.0,,,
1448+
cfvae,compass,linear,4.0,3.079887599831349,3.0063820286068137,1.0,2.0,4.0,,,
1449+
cfvae,compass,linear,5.0,4.026247303932904,4.000688920963746,1.0,4.0,5.0,,,
1450+
cfvae,compass,linear,4.0,3.025780810043216,3.000664650166484,1.0,3.0,4.0,,,
1451+
cfvae,compass,linear,5.0,4.027201469595495,4.000739919948154,1.0,3.0,5.0,,,
1452+
cfvae,compass,linear,6.0,5.315943886771014,5.0998205395879745,1.0,4.0,6.0,,,
1453+
cfvae,compass,linear,6.0,5.237026503211574,5.0561815632247065,1.0,4.0,6.0,,,
1454+
cfvae,compass,linear,5.0,4.000912980911763,4.000000833534145,1.0,3.0,5.0,,,
1455+
cfvae,compass,linear,4.0,3.025780810043216,3.000664650166484,1.0,3.0,4.0,,,
1456+
cfvae,compass,linear,6.0,5.00054485890034,5.000000296871222,1.0,4.0,6.0,,,
1457+
cfvae,compass,linear,5.0,4.000043332380684,4.0000000018776944,1.0,4.0,5.0,,,
1458+
cfvae,credit,linear,11.0,3.8635048107397414,3.4543789350664933,1.0,2.0,12.0,1.0,1.0,0.0014838634000625462
1459+
cfvae,credit,linear,14.0,6.198726399398852,6.011102742569847,1.0,2.0,15.0,,,
1460+
cfvae,credit,linear,12.0,4.350875405698145,4.044041966345174,1.0,2.0,13.0,,,
1461+
cfvae,credit,linear,12.0,4.263478694309617,3.783341978995004,1.0,2.0,12.0,,,
1462+
cfvae,credit,linear,14.0,7.624633209525639,7.143803839765062,1.0,3.0,15.0,,,
1463+
cfvae,credit,linear,14.0,7.592809723475144,7.183050512047454,1.0,3.0,15.0,,,
1464+
cfvae,credit,linear,13.0,5.954872258443237,5.788788752716654,1.0,3.0,14.0,,,
1465+
cfvae,credit,linear,13.0,8.69606156849192,7.989534436531932,1.0,4.0,13.0,,,
1466+
cfvae,credit,linear,12.0,4.404955609766425,4.052602125029948,1.0,1.0,13.0,,,
1467+
cfvae,credit,linear,13.0,5.392869866462722,4.6604467313237175,1.0,2.0,13.0,,,
1468+
cfvae,credit,linear,13.0,5.634321083100907,5.254773997724432,1.0,3.0,14.0,,,
1469+
cfvae,credit,linear,12.0,5.5930198212291815,5.1411887012072,1.0,2.0,13.0,,,
1470+
cfvae,credit,linear,13.0,6.145782473042407,6.00508153133561,1.0,3.0,14.0,,,
1471+
cfvae,credit,linear,13.0,5.636465045002524,5.3428791774655515,1.0,2.0,14.0,,,
1472+
cfvae,credit,linear,14.0,6.93476258607812,6.659525192391085,1.0,3.0,15.0,,,
1473+
cfvae,credit,linear,11.0,4.258967514070494,3.7955233873850296,1.0,2.0,11.0,,,
1474+
cfvae,credit,linear,13.0,4.907895225057001,4.304941588158154,1.0,3.0,13.0,,,
1475+
cfvae,credit,linear,13.0,5.774344388452266,5.310840807995662,1.0,2.0,14.0,,,
1476+
cfvae,credit,linear,12.0,4.5209345588455445,4.1246171300212735,1.0,2.0,13.0,,,
1477+
cfvae,credit,linear,12.0,5.248803936646749,5.025442029200411,1.0,2.0,13.0,,,
1478+
cfvae,german,linear,,,,,,,,0.0,0.0013751629972830414
1479+
cfvae,mortgage,linear,,,,,,,,0.0,0.0013101477990858255
1480+
cfvae,twomoon,linear,2.0,1.278267372685709,0.8852150719661562,0.8238379552834283,0.0,0.0,0.0,1.0,0.0012534039502497762
1481+
cfvae,twomoon,linear,2.0,1.0350506053108428,0.542389349836534,0.5755100940207559,0.0,0.0,,,
1482+
cfvae,twomoon,linear,2.0,1.1099842384583596,0.7478669832308561,0.8117356473052192,0.0,0.0,,,
1483+
cfvae,twomoon,linear,2.0,1.053769507472274,0.6778692544696351,0.7745278104028399,0.0,0.0,,,
1484+
cfvae,twomoon,linear,2.0,1.0022305781983267,0.5230489365334128,0.6031345779506758,0.0,0.0,,,
1485+
cfvae,twomoon,linear,2.0,1.0665961662581838,0.5803080001612014,0.6091080733113688,0.0,0.0,,,
1486+
cfvae,twomoon,linear,2.0,1.1887700319568832,0.7066622746718962,0.6005160938817034,0.0,0.0,,,
1487+
cfvae,twomoon,linear,2.0,1.1920020013422483,0.7822199633346685,0.7854549234765197,0.0,0.0,,,
1488+
cfvae,twomoon,linear,2.0,0.451678944831882,0.1219330822957706,0.325654670891798,0.0,0.0,,,
1489+
cfvae,twomoon,linear,2.0,0.6595415909171616,0.2297017324759225,0.4078866626084505,0.0,0.0,,,
1490+
cfvae,twomoon,linear,2.0,1.2080345012158409,0.7470094481737078,0.6971187277976112,0.0,0.0,,,
1491+
cfvae,twomoon,linear,2.0,1.1750422859735123,0.733865281461683,0.735005200682417,0.0,0.0,,,
1492+
cfvae,twomoon,linear,2.0,0.6748372314018426,0.24140993510219194,0.4202053617993757,0.0,0.0,,,
1493+
cfvae,twomoon,linear,2.0,0.6457621748157715,0.29229754857498663,0.5275676225825146,0.0,0.0,,,
1494+
cfvae,twomoon,linear,2.0,1.2785728574438275,0.8446127885742218,0.7559879982875899,0.0,0.0,,,
1495+
cfvae,twomoon,linear,2.0,1.3329917697261102,0.9520751817475874,0.8448798144486733,0.0,0.0,,,
1496+
cfvae,twomoon,linear,2.0,0.7240719348770486,0.4052162985947956,0.6295020542201003,0.0,0.0,,,
1497+
cfvae,twomoon,linear,2.0,0.5760552931649328,0.19222788647149427,0.4027186141427455,0.0,0.0,,,
1498+
cfvae,twomoon,linear,2.0,1.2584016952337005,0.8655615000073452,0.8212608469970173,0.0,0.0,,,
1499+
cfvae,twomoon,linear,2.0,0.7697960412900238,0.31617957552542925,0.4846141249926379,0.0,0.0,,,
1500+
cfvae,breast_cancer,linear,30.0,7.08162388785715,2.9664198324413507,0.746176963944082,0.0,30.0,0.0,0.35,0.0013887853478081525
1501+
cfvae,breast_cancer,linear,30.0,6.664591540974017,3.403015792209841,0.8122290516618058,0.0,30.0,,,
1502+
cfvae,breast_cancer,linear,30.0,6.375213263705285,2.7283774457322907,0.7883810740931736,0.0,29.0,,,
1503+
cfvae,breast_cancer,linear,30.0,6.844435437402433,2.98689311672136,0.7775364279024392,0.0,30.0,,,
1504+
cfvae,breast_cancer,linear,30.0,5.726400578799684,2.1362629079163016,0.6787077764440472,0.0,30.0,,,
1505+
cfvae,breast_cancer,linear,30.0,5.597933558560285,2.001940612780337,0.6823853925042678,0.0,27.0,,,
1506+
cfvae,breast_cancer,linear,30.0,6.764009958434911,2.685179403513465,0.7363181661195569,0.0,27.0,,,
1507+
cfvae,boston_housing,linear,11.0,1.3756296956601155,0.2503766222685556,0.3084429971715237,0.0,6.0,0.6631578947368422,0.95,0.0013642499456182125
1508+
cfvae,boston_housing,linear,13.0,2.689280792994406,1.3429195876823918,1.0,0.0,9.0,,,
1509+
cfvae,boston_housing,linear,12.0,1.7124826888009486,0.3750786076067511,0.32121214365738177,0.0,7.0,,,
1510+
cfvae,boston_housing,linear,13.0,3.1184136916805447,1.6207198196856158,1.0,0.0,10.0,,,
1511+
cfvae,boston_housing,linear,13.0,2.793209294868518,1.4829030023015737,1.0,0.0,7.0,,,
1512+
cfvae,boston_housing,linear,12.0,2.8325242342940915,1.0518447593204854,0.541595995426178,0.0,7.0,,,
1513+
cfvae,boston_housing,linear,13.0,2.438999505456368,1.2431634415230215,1.0,0.0,10.0,,,
1514+
cfvae,boston_housing,linear,12.0,1.6577324851327695,0.4570327895920604,0.4648141604015206,0.0,10.0,,,
1515+
cfvae,boston_housing,linear,13.0,2.3588741779213,1.2372385521401563,1.0,0.0,11.0,,,
1516+
cfvae,boston_housing,linear,13.0,2.718044087600396,1.4084335108554928,1.0,0.0,8.0,,,
1517+
cfvae,boston_housing,linear,12.0,1.8132503315960813,0.7066282844438531,0.6871339156993509,0.0,9.0,,,
1518+
cfvae,boston_housing,linear,12.0,1.7952120795048256,0.4758684272687951,0.4647828977653359,0.0,9.0,,,
1519+
cfvae,boston_housing,linear,13.0,2.68523783144457,1.3536098527317284,1.0,0.0,8.0,,,
1520+
cfvae,boston_housing,linear,12.0,2.6634623119522995,0.9450720482092088,0.5413446724414825,0.0,7.0,,,
1521+
cfvae,boston_housing,linear,12.0,2.7303082781093764,0.9808185383635876,0.541551798582077,0.0,7.0,,,
1522+
cfvae,boston_housing,linear,13.0,2.128286996362456,1.1681508890811234,1.0,0.0,7.0,,,
1523+
cfvae,boston_housing,linear,11.0,1.1642179021592085,0.2220489567468797,0.3127339951535488,0.0,7.0,,,
1524+
cfvae,boston_housing,linear,12.0,1.0237826877700886,0.15038662312775705,0.24312054669597186,0.0,10.0,,,
1525+
cfvae,boston_housing,linear,13.0,2.354798590209251,1.2301028537404477,1.0,0.0,11.0,,,
14181526
probe,adult,linear,51.0,2.5574682458217044,0.20724095760647737,0.10291039943695068,2.0,51.0,0.0,1.0,11.27795516
14191527
probe,adult,linear,48.0,1.6209024338863478,0.0782517097074313,0.06244194507598877,2.0,51.0,,,
14201528
probe,adult,linear,51.0,6.151970284269187,1.3522899002686382,0.268756240606308,2.0,48.0,,,

experiments/run_experiment.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,10 @@ def initialize_recourse_method(
162162
sum(mlmodel.get_mutable_mask())
163163
] + hyperparams["vae_params"]["layers"]
164164
return Revise(mlmodel, data, hyperparams)
165-
elif "wachter" in method:
165+
elif method == "wachter":
166166
return Wachter(mlmodel, hyperparams)
167+
elif method == "cfvae":
168+
return CFVAE(mlmodel, hyperparams)
167169
elif method == "probe":
168170
return Probe(mlmodel, hyperparams)
169171
elif method == "roar":
@@ -197,7 +199,7 @@ def create_parser():
197199
-r, --recourse_method: Specifies recourse methods for the experiment.
198200
Default: ["dice", "cchvae", "cem", "cem_vae", "clue", "cruds", "face_knn", "face_epsilon", "gs", "mace", "revise", "wachter"].
199201
Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
200-
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "roar", "probe"].
202+
"focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "roar", "probe"].
201203
-n, --number_of_samples: Specifies the number of instances per dataset.
202204
Default: 20.
203205
-s, --train_split: Specifies the split of the available data used for training.
@@ -264,6 +266,7 @@ def create_parser():
264266
"gs",
265267
"revise",
266268
"wachter",
269+
"cfvae",
267270
"roar",
268271
],
269272
choices=[
@@ -286,6 +289,7 @@ def create_parser():
286289
"mace",
287290
"revise",
288291
"wachter",
292+
"cfvae",
289293
"probe",
290294
"roar",
291295
],
@@ -367,6 +371,7 @@ def create_parser():
367371
"gravitational",
368372
"wachter",
369373
"revise",
374+
"cfvae",
370375
"probe",
371376
"roar",
372377
]

methods/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .catalog import (
55
CCHVAE,
66
CEM,
7+
CFVAE,
78
CRUD,
89
FOCUS,
910
MACE,

methods/catalog/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .causal_recourse import CausalRecourse
44
from .cchvae import CCHVAE
55
from .cem import CEM
6+
from .cfvae import CFVAE
67
from .claproar import ClaPROAR
78
from .clue import Clue
89
from .crud import CRUD

methods/catalog/cfvae/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# flake8: noqa
2+
3+
from .model import CFVAE

0 commit comments

Comments
 (0)