2222from models .api import MLModel
2323from models .catalog import ModelCatalog
2424from models .predict_factuals import predict_negative_instances
25- from tools .logging import log
25+ from tools .logging_tools import log
2626
2727RANDOM_SEED = 54321
2828
@@ -120,7 +120,7 @@ def initialize_recourse_method(
120120 elif method == "cchvae" :
121121 hyperparams ["data_name" ] = data_name
122122 hyperparams ["vae_params" ]["layers" ] = [
123- sum (mlmodel .get_mutable_mask ())
123+ sum (dataset .get_mutable_mask ())
124124 ] + hyperparams ["vae_params" ]["layers" ]
125125 return CCHVAE (mlmodel , hyperparams )
126126 elif "cem" in method :
@@ -136,30 +136,30 @@ def initialize_recourse_method(
136136 hyperparams ["data_name" ] = data_name
137137 # variable input layer dimension is first time here available
138138 hyperparams ["vae_params" ]["layers" ] = [
139- sum (mlmodel .get_mutable_mask ())
139+ sum (dataset .get_mutable_mask ())
140140 ] + hyperparams ["vae_params" ]["layers" ]
141- return CRUD (mlmodel , hyperparams )
141+ return CRUD (data , mlmodel , hyperparams )
142142 elif method == "dice" :
143- return Dice (mlmodel , hyperparams )
143+ return Dice (data , mlmodel , hyperparams )
144144 elif "face" in method :
145- return Face (mlmodel , hyperparams )
145+ return Face (data , mlmodel , hyperparams )
146146 elif method == "feature_tweak" :
147- return FeatureTweak (mlmodel )
147+ return FeatureTweak (data , mlmodel )
148148 elif method == "focus" :
149- return FOCUS (mlmodel )
149+ return FOCUS (data , mlmodel )
150150 elif method == "gravitational" :
151- return Gravitational (mlmodel , hyperparams )
151+ return Gravitational (data , mlmodel , hyperparams )
152152 elif method == "greedy" :
153- return Greedy (mlmodel , hyperparams )
153+ return Greedy (data , mlmodel , hyperparams )
154154 elif method == "gs" :
155- return GrowingSpheres (mlmodel )
155+ return GrowingSpheres (data , mlmodel )
156156 elif method == "mace" :
157- return MACE (mlmodel )
157+ return MACE (data , mlmodel )
158158 elif method == "revise" :
159159 hyperparams ["data_name" ] = data_name
160160 # variable input layer dimension is first time here available
161161 hyperparams ["vae_params" ]["layers" ] = [
162- sum (mlmodel .get_mutable_mask ())
162+ sum (dataset .get_mutable_mask ())
163163 ] + hyperparams ["vae_params" ]["layers" ]
164164 return Revise (mlmodel , data , hyperparams )
165165 elif "wachter" in method :
@@ -389,7 +389,12 @@ def create_parser():
389389 # face_knn requires datasets with immutable features.
390390 if exists_already or (
391391 "face" in method_name
392- and (data_name == "mortgage" or data_name == "twomoon" )
392+ and (
393+ data_name == "mortgage"
394+ or data_name == "twomoon"
395+ or data_name == "boston_housing"
396+ or data_name == "breast_cancer"
397+ )
393398 ):
394399 continue
395400
@@ -428,9 +433,7 @@ def create_parser():
428433 )
429434
430435 factuals_sess = factuals_sess .reset_index (drop = True )
431- benchmark = Benchmark (
432- mlmodel_sess , recourse_method_sess , factuals_sess
433- )
436+ benchmark = Benchmark (recourse_method_sess , factuals_sess )
434437 evaluation_measures = [
435438 evaluation_catalog .YNN (
436439 benchmark .mlmodel , {"y" : 5 , "cf_label" : 1 }
@@ -463,17 +466,15 @@ def create_parser():
463466 method_name , mlmodel , dataset , data_name , model_name , setup
464467 )
465468
466- benchmark = Benchmark (mlmodel , recourse_method , factuals )
469+ benchmark = Benchmark (recourse_method , factuals )
467470 evaluation_measures = [
468471 evaluation_catalog .YNN (
469- benchmark . mlmodel , {"y" : 5 , "cf_label" : 1 }
472+ mlmodel , dataset , {"y" : 5 , "cf_label" : 1 }
470473 ),
471- evaluation_catalog .Distance (benchmark . mlmodel ),
474+ evaluation_catalog .Distance (mlmodel , dataset ),
472475 evaluation_catalog .SuccessRate (),
473- evaluation_catalog .Redundancy (
474- benchmark .mlmodel , {"cf_label" : 1 }
475- ),
476- evaluation_catalog .ConstraintViolation (benchmark .mlmodel ),
476+ evaluation_catalog .Redundancy (mlmodel , {"cf_label" : 1 }),
477+ evaluation_catalog .ConstraintViolation (mlmodel , dataset ),
477478 evaluation_catalog .AvgTime ({"time" : benchmark .timer }),
478479 ]
479480 df_benchmark = benchmark .run_benchmark (evaluation_measures )
0 commit comments