77
88import numpy as np
99import pandas as pd
10- import tensorflow as tf
1110import torch
1211import yaml
13- from tensorflow import Graph , Session
14- from tensorflow .python .keras .backend import set_session
1512
1613import evaluation .catalog as evaluation_catalog
1714from data .api import Data
2724RANDOM_SEED = 54321
2825
2926np .random .seed (RANDOM_SEED )
30- os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "3"
3127seed (
3228 RANDOM_SEED
3329) # set the random seed so that the random permutations can be reproduced again
34- tf .set_random_seed (RANDOM_SEED )
3530torch .manual_seed (RANDOM_SEED )
3631warnings .simplefilter (action = "ignore" , category = FutureWarning )
3732
@@ -67,11 +62,11 @@ def initialize_recourse_method(
6762 data_name : str ,
6863 model_type : str ,
6964 setup : Dict ,
70- sess : Session = None ,
65+ sess = None ,
7166) -> RecourseMethod :
7267 """
7368 Initializes and returns an instance of a recourse method based on the specified recourse method,
74- machine learning model, data, and an optional TensorFlow session.
69+ machine learning model, data, and an optional session parameter .
7570
7671 Parameters
7772 ----------
@@ -81,7 +76,7 @@ def initialize_recourse_method(
8176 data_name (str): The name of the dataset.
8277 model_type (str): The type of machine learning model.
8378 setup (Dict): The experimental setup containing hyperparameters for the recourse methods.
84- sess (Session, optional): Optional TensorFlow session. Defaults to None.
79+ sess (optional): Optional session parameter . Defaults to None.
8580
8681 Returns
8782 -------
@@ -101,13 +96,27 @@ def initialize_recourse_method(
10196 elif method == "ar" :
10297 coeffs , intercepts = None , None
10398 if model_type == "linear" :
104- # get weights and bias of linear layer for negative class 0
105- coeffs_neg = mlmodel .raw_model .layers [0 ].get_weights ()[0 ][:, 0 ]
106- intercepts_neg = np .array (mlmodel .raw_model .layers [0 ].get_weights ()[1 ][0 ])
99+ if hasattr (mlmodel .raw_model , "layers" ):
100+ # Keras-style
101+ coeffs_neg = mlmodel .raw_model .layers [0 ].get_weights ()[0 ][:, 0 ]
102+ intercepts_neg = np .array (
103+ mlmodel .raw_model .layers [0 ].get_weights ()[1 ][0 ]
104+ )
107105
108- # get weights and bias of linear layer for positive class 1
109- coeffs_pos = mlmodel .raw_model .layers [0 ].get_weights ()[0 ][:, 1 ]
110- intercepts_pos = np .array (mlmodel .raw_model .layers [0 ].get_weights ()[1 ][1 ])
106+ coeffs_pos = mlmodel .raw_model .layers [0 ].get_weights ()[0 ][:, 1 ]
107+ intercepts_pos = np .array (
108+ mlmodel .raw_model .layers [0 ].get_weights ()[1 ][1 ]
109+ )
110+ elif hasattr (mlmodel .raw_model , "linear" ):
111+ # PyTorch-style
112+ weights = mlmodel .raw_model .linear .weight .detach ().cpu ().numpy ()
113+ bias = mlmodel .raw_model .linear .bias .detach ().cpu ().numpy ()
114+ coeffs_neg = weights [0 ]
115+ intercepts_neg = np .array (bias [0 ])
116+ coeffs_pos = weights [1 ]
117+ intercepts_pos = np .array (bias [1 ])
118+ else :
119+ raise ValueError ("Unsupported linear model for AR coefficients." )
111120
112121 coeffs = - (coeffs_neg - coeffs_pos )
113122 intercepts = - (intercepts_neg - intercepts_pos )
@@ -175,7 +184,7 @@ def initialize_recourse_method(
175184 elif method == "larr" :
176185 return Larr (mlmodel , hyperparams )
177186 elif method == "rbr" :
178- hyperparams ["train_data" ] = data .df_train .drop (columns = ["y" ], axis = 1 )
187+ hyperparams ["train_data" ] = data .df_train .drop (columns = ["y" ])
179188 dev = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
180189 hyperparams ["device" ] = dev
181190 return RBR (mlmodel , hyperparams )
@@ -420,8 +429,7 @@ def _append_to_csv(path: str, df: pd.DataFrame):
420429 ]
421430 )
422431
423- session_models = ["cem" , "cem_vae" , "greedy" ]
424- torch_methods = [
432+ pytorch_methods = [
425433 "cchvae" ,
426434 "claproar" ,
427435 "clue" ,
@@ -437,14 +445,27 @@ def _append_to_csv(path: str, df: pd.DataFrame):
437445 "rbr" ,
438446 ]
439447 sklearn_methods = ["feature_tweak" , "focus" , "mace" ]
448+ common_methods = ["ar" , "dice" , "face" , "face_knn" , "face_epsilon" , "gs" ]
449+ disabled_methods = {
450+ "cem" , # tensorflow
451+ "cem_vae" , # tensorflow
452+ "greedy" , # tensorflow
453+ "causal_recourse" , # causal
454+ }
440455
441456 for method_name in args .recourse_method :
442- if method_name in torch_methods :
457+ if method_name in disabled_methods :
458+ log .info ("Skipping disabled recourse method: {}" .format (method_name ))
459+ continue
460+ if method_name in pytorch_methods :
443461 backend = "pytorch"
444462 elif method_name in sklearn_methods :
445463 backend = "sklearn"
464+ elif method_name in common_methods :
465+ backend = "pytorch" # pytorch by default
446466 else :
447- backend = "tensorflow"
467+ log .warning ("Skipping unknown recourse method: {}" .format (method_name ))
468+ continue
448469 log .info ("Recourse method: {}" .format (method_name ))
449470 for data_name in args .dataset :
450471 for model_name in args .type :
@@ -467,91 +488,43 @@ def _append_to_csv(path: str, df: pd.DataFrame):
467488 and (data_name == "mortgage" or data_name == "twomoon" )
468489 ):
469490 continue
491+ # feature_tweak requires forest model
492+ if method_name == "feature_tweak" and model_name != "forest" :
493+ log .info (
494+ "Skipping feature_tweak for non-forest model: {}" .format (
495+ model_name
496+ )
497+ )
498+ continue
470499
471500 dataset = DataCatalog (data_name , model_name , args .train_split )
472501
473- if method_name in session_models :
474- graph = Graph ()
475- ann_sess = Session ()
476- session_graph = tf .get_default_graph ()
477- init = tf .global_variables_initializer ()
478- ann_sess .run (init )
479- with graph .as_default ():
480- with session_graph .as_default ():
481- set_session (ann_sess )
482- mlmodel_sess = ModelCatalog (dataset , model_name , backend )
483-
484- factuals_sess = predict_negative_instances (
485- mlmodel_sess , dataset
486- )
487-
488- recourse_method_sess = initialize_recourse_method (
489- method_name ,
490- mlmodel_sess ,
491- dataset ,
492- data_name ,
493- model_name ,
494- setup ,
495- sess = ann_sess ,
496- )
497- factuals_len = len (factuals_sess )
498- if factuals_len == 0 :
499- continue
500- elif factuals_len > args .number_of_samples :
501- factuals_sess = factuals_sess .sample (
502- n = args .number_of_samples , random_state = RANDOM_SEED
503- )
504-
505- factuals_sess = factuals_sess .reset_index (drop = True )
506- benchmark = Benchmark (
507- mlmodel_sess , recourse_method_sess , factuals_sess
508- )
509- evaluation_measures = [
510- evaluation_catalog .YNN (
511- benchmark .mlmodel , {"y" : 5 , "cf_label" : 1 }
512- ),
513- evaluation_catalog .Distance (benchmark .mlmodel ),
514- evaluation_catalog .SuccessRate (),
515- evaluation_catalog .Redundancy (
516- benchmark .mlmodel , {"cf_label" : 1 }
517- ),
518- evaluation_catalog .ConstraintViolation (
519- benchmark .mlmodel
520- ),
521- evaluation_catalog .AvgTime ({"time" : benchmark .timer }),
522- ]
523- df_benchmark = benchmark .run_benchmark (evaluation_measures )
524- else :
525- mlmodel = ModelCatalog (dataset , model_name , backend )
526- factuals = predict_negative_instances (mlmodel , dataset )
527-
528- factuals_len = len (factuals )
529- if factuals_len == 0 :
530- continue
531- elif factuals_len > args .number_of_samples :
532- factuals = factuals .sample (
533- n = args .number_of_samples , random_state = RANDOM_SEED
534- )
502+ mlmodel = ModelCatalog (dataset , model_name , backend )
503+ factuals = predict_negative_instances (mlmodel , dataset )
535504
536- factuals = factuals .reset_index (drop = True )
537- recourse_method = initialize_recourse_method (
538- method_name , mlmodel , dataset , data_name , model_name , setup
505+ factuals_len = len (factuals )
506+ if factuals_len == 0 :
507+ continue
508+ elif factuals_len > args .number_of_samples :
509+ factuals = factuals .sample (
510+ n = args .number_of_samples , random_state = RANDOM_SEED
539511 )
540512
541- benchmark = Benchmark (mlmodel , recourse_method , factuals )
542- evaluation_measures = [
543- evaluation_catalog .YNN (
544- benchmark .mlmodel , {"y" : 5 , "cf_label" : 1 }
545- ),
546- evaluation_catalog .Distance (benchmark .mlmodel ),
547- evaluation_catalog .SuccessRate (),
548- evaluation_catalog .Redundancy (
549- benchmark .mlmodel , {"cf_label" : 1 }
550- ),
551- evaluation_catalog .ConstraintViolation (benchmark .mlmodel ),
552- evaluation_catalog .AvgTime ({"time" : benchmark .timer }),
553- ]
554- df_benchmark = benchmark .run_benchmark (evaluation_measures )
513+ factuals = factuals .reset_index (drop = True )
514+ recourse_method = initialize_recourse_method (
515+ method_name , mlmodel , dataset , data_name , model_name , setup
516+ )
517+
518+ benchmark = Benchmark (mlmodel , recourse_method , factuals )
519+ evaluation_measures = [
520+ evaluation_catalog .YNN (benchmark .mlmodel , {"y" : 5 , "cf_label" : 1 }),
521+ evaluation_catalog .Distance (benchmark .mlmodel ),
522+ evaluation_catalog .SuccessRate (),
523+ evaluation_catalog .Redundancy (benchmark .mlmodel , {"cf_label" : 1 }),
524+ evaluation_catalog .ConstraintViolation (benchmark .mlmodel ),
525+ evaluation_catalog .AvgTime ({"time" : benchmark .timer }),
526+ ]
527+ df_benchmark = benchmark .run_benchmark (evaluation_measures )
555528
556529 df_benchmark ["Recourse_Method" ] = method_name
557530 df_benchmark ["Dataset" ] = data_name
0 commit comments