Skip to content

Commit df562dd

Browse files
authored
Merge branch 'main' into fix--Fix-REVISE
2 parents b04b372 + 8d5291d commit df562dd

File tree

29 files changed

+268
-432
lines changed

29 files changed

+268
-432
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ This folder houses all datasets and their cached versions. It also contains the
112112

113113
This folder contains the implementation of all evaluation and benchmark metrics used to compare recourse methods in the repository. This includes metrics such as `distance`, `redundancy`, `success rate`, `time`, `violations`, and `y nearest neighbors`.
114114

115+
**WARNING: Current success rate implementation (to be specific, `methods.processing.check_counterfactuals()` embedded in `get_counterfactuals()`) will only consider 0->1 flips and mark 0->0 counterfactuals to nan!**
116+
115117
### Live Site Folder
116118

117119
This folder contains the implementation of the frontend UI interface, which displays results stored in `results.csv` from executing `./experiments/run_experiment.py`.
@@ -158,7 +160,7 @@ benchmark = Benchmark(model, gs, factuals)
158160
evaluation_measures = [
159161
evaluation_catalog.YNN(benchmark.mlmodel, {"y": 5, "cf_label": 1}),
160162
evaluation_catalog.Distance(benchmark.mlmodel),
161-
evaluation_catalog.SuccessRate(),
163+
evaluation_catalog.SuccessRate(), # Will only consider 0->1 flips!!!
162164
evaluation_catalog.Redundancy(benchmark.mlmodel, {"cf_label": 1}),
163165
evaluation_catalog.ConstraintViolation(benchmark.mlmodel),
164166
evaluation_catalog.AvgTime({"time": benchmark.timer}),

_deprecated/test/test_cfmodel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from carla.recourse_methods.catalog.cchvae import CCHVAE
99
from carla.recourse_methods.catalog.cem import CEM
1010
from carla.recourse_methods.catalog.clue import Clue
11-
from carla.recourse_methods.catalog.crud import CRUD
11+
from carla.recourse_methods.catalog.cruds import CRUDS
1212
from carla.recourse_methods.catalog.dice import Dice
1313
from carla.recourse_methods.catalog.face import Face
1414
from carla.recourse_methods.catalog.feature_tweak import FeatureTweak
@@ -367,7 +367,7 @@ def test_cchvae(model_type):
367367

368368

369369
@pytest.mark.parametrize("model_type", testmodel)
370-
def test_crud(model_type):
370+
def test_cruds(model_type):
371371
# Build data and mlmodel
372372
data_name = "adult"
373373
data = OnlineCatalog(data_name)
@@ -384,8 +384,8 @@ def test_crud(model_type):
384384
},
385385
}
386386

387-
crud = CRUD(model, hyperparams)
388-
df_cfs = crud.get_counterfactuals(test_factual)
387+
cruds = CRUDS(model, hyperparams)
388+
df_cfs = cruds.get_counterfactuals(test_factual)
389389

390390
assert test_factual.shape[0] == df_cfs.shape[0]
391391
assert isinstance(df_cfs, pd.DataFrame)
4.45 KB
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

data/catalog/_data_main/process_data/process_adult_data.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ def load_adult_data(load_data_size=None):
103103
for f in data_files:
104104
check_data_file(f)
105105

106+
f = os.path.join(
107+
os.path.dirname(os.path.realpath(__file__)), "..", "raw_data", f
108+
)
109+
106110
for line in open(f):
107111
line = line.strip()
108112
if line == "":
@@ -227,7 +231,6 @@ def load_adult_data_new():
227231
] # sex and race are sensitive feature so we will not use them in classification, we will not consider fnlwght for classification since its computed externally and it highly predictive for the class (for details, see documentation of the adult data)
228232

229233
# adult data comes in two different files, one for training and one for testing, however, we will combine data from both the files
230-
this_files_directory = os.path.dirname(os.path.realpath(__file__))
231234
data_files = ["adult.data", "adult.test"]
232235

233236
y = []
@@ -244,7 +247,9 @@ def load_adult_data_new():
244247

245248
for file_name in data_files:
246249
check_data_file(file_name)
247-
full_file_name = os.path.join(this_files_directory, file_name)
250+
full_file_name = os.path.join(
251+
os.path.dirname(os.path.realpath(__file__)), "..", "raw_data", file_name
252+
)
248253
print(full_file_name)
249254

250255
for line in open(full_file_name):

data/catalog/_data_main/process_data/process_compas_data.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,12 @@ def load_compas_data():
5959
COMPAS_INPUT_FILE = "compas-scores-two-years.csv"
6060
check_data_file(COMPAS_INPUT_FILE)
6161

62+
f = os.path.join(
63+
os.path.dirname(os.path.realpath(__file__)), "..", "raw_data", COMPAS_INPUT_FILE
64+
)
65+
6266
# load the data and get some stats
63-
df = pd.read_csv(COMPAS_INPUT_FILE)
67+
df = pd.read_csv(f)
6468
df = df.dropna(subset=["days_b_screening_arrest"]) # dropping missing vals
6569

6670
# convert to np array
@@ -185,9 +189,10 @@ def load_compas_data_new():
185189
CLASS_FEATURE = "two_year_recid" # the decision variable
186190

187191
file_name = "compas-scores-two-years.csv"
188-
this_files_directory = os.path.dirname(os.path.realpath(__file__))
189-
full_file_name = os.path.join(this_files_directory, file_name)
190192
check_data_file(file_name)
193+
full_file_name = os.path.join(
194+
os.path.dirname(os.path.realpath(__file__)), "..", "raw_data", file_name
195+
)
191196

192197
# load the data and get some stats
193198
df = pd.read_csv(full_file_name)

0 commit comments

Comments
 (0)