Skip to content

Commit 693cc76

Browse files
committed
fix: experiments impacted by structure changes
1 parent 1abcc89 commit 693cc76

File tree

1 file changed

+25
-24
lines changed

1 file changed

+25
-24
lines changed

experiments/run_experiment.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from models.api import MLModel
2323
from models.catalog import ModelCatalog
2424
from models.predict_factuals import predict_negative_instances
25-
from tools.logging import log
25+
from tools.logging_tools import log
2626

2727
RANDOM_SEED = 54321
2828

@@ -120,7 +120,7 @@ def initialize_recourse_method(
120120
elif method == "cchvae":
121121
hyperparams["data_name"] = data_name
122122
hyperparams["vae_params"]["layers"] = [
123-
sum(mlmodel.get_mutable_mask())
123+
sum(dataset.get_mutable_mask())
124124
] + hyperparams["vae_params"]["layers"]
125125
return CCHVAE(mlmodel, hyperparams)
126126
elif "cem" in method:
@@ -136,30 +136,30 @@ def initialize_recourse_method(
136136
hyperparams["data_name"] = data_name
137137
# variable input layer dimension is first time here available
138138
hyperparams["vae_params"]["layers"] = [
139-
sum(mlmodel.get_mutable_mask())
139+
sum(dataset.get_mutable_mask())
140140
] + hyperparams["vae_params"]["layers"]
141-
return CRUD(mlmodel, hyperparams)
141+
return CRUD(data, mlmodel, hyperparams)
142142
elif method == "dice":
143-
return Dice(mlmodel, hyperparams)
143+
return Dice(data, mlmodel, hyperparams)
144144
elif "face" in method:
145-
return Face(mlmodel, hyperparams)
145+
return Face(data, mlmodel, hyperparams)
146146
elif method == "feature_tweak":
147-
return FeatureTweak(mlmodel)
147+
return FeatureTweak(data, mlmodel)
148148
elif method == "focus":
149-
return FOCUS(mlmodel)
149+
return FOCUS(data, mlmodel)
150150
elif method == "gravitational":
151-
return Gravitational(mlmodel, hyperparams)
151+
return Gravitational(data, mlmodel, hyperparams)
152152
elif method == "greedy":
153-
return Greedy(mlmodel, hyperparams)
153+
return Greedy(data, mlmodel, hyperparams)
154154
elif method == "gs":
155-
return GrowingSpheres(mlmodel)
155+
return GrowingSpheres(data, mlmodel)
156156
elif method == "mace":
157-
return MACE(mlmodel)
157+
return MACE(data, mlmodel)
158158
elif method == "revise":
159159
hyperparams["data_name"] = data_name
160160
# variable input layer dimension is first time here available
161161
hyperparams["vae_params"]["layers"] = [
162-
sum(mlmodel.get_mutable_mask())
162+
sum(dataset.get_mutable_mask())
163163
] + hyperparams["vae_params"]["layers"]
164164
return Revise(mlmodel, data, hyperparams)
165165
elif "wachter" in method:
@@ -389,7 +389,12 @@ def create_parser():
389389
# face_knn requires datasets with immutable features.
390390
if exists_already or (
391391
"face" in method_name
392-
and (data_name == "mortgage" or data_name == "twomoon")
392+
and (
393+
data_name == "mortgage"
394+
or data_name == "twomoon"
395+
or data_name == "boston_housing"
396+
or data_name == "breast_cancer"
397+
)
393398
):
394399
continue
395400

@@ -428,9 +433,7 @@ def create_parser():
428433
)
429434

430435
factuals_sess = factuals_sess.reset_index(drop=True)
431-
benchmark = Benchmark(
432-
mlmodel_sess, recourse_method_sess, factuals_sess
433-
)
436+
benchmark = Benchmark(recourse_method_sess, factuals_sess)
434437
evaluation_measures = [
435438
evaluation_catalog.YNN(
436439
benchmark.mlmodel, {"y": 5, "cf_label": 1}
@@ -463,17 +466,15 @@ def create_parser():
463466
method_name, mlmodel, dataset, data_name, model_name, setup
464467
)
465468

466-
benchmark = Benchmark(mlmodel, recourse_method, factuals)
469+
benchmark = Benchmark(recourse_method, factuals)
467470
evaluation_measures = [
468471
evaluation_catalog.YNN(
469-
benchmark.mlmodel, {"y": 5, "cf_label": 1}
472+
mlmodel, dataset, {"y": 5, "cf_label": 1}
470473
),
471-
evaluation_catalog.Distance(benchmark.mlmodel),
474+
evaluation_catalog.Distance(mlmodel, dataset),
472475
evaluation_catalog.SuccessRate(),
473-
evaluation_catalog.Redundancy(
474-
benchmark.mlmodel, {"cf_label": 1}
475-
),
476-
evaluation_catalog.ConstraintViolation(benchmark.mlmodel),
476+
evaluation_catalog.Redundancy(mlmodel, {"cf_label": 1}),
477+
evaluation_catalog.ConstraintViolation(mlmodel, dataset),
477478
evaluation_catalog.AvgTime({"time": benchmark.timer}),
478479
]
479480
df_benchmark = benchmark.run_benchmark(evaluation_measures)

0 commit comments

Comments
 (0)