@@ -172,6 +172,11 @@ def initialize_recourse_method(
172172 return Probe (mlmodel , hyperparams )
173173 elif method == "roar" :
174174 return Roar (mlmodel , hyperparams )
175+ elif method == "rbr" :
176+ hyperparams ["train_data" ] = data .df_train .drop (columns = ["y" ], axis = 1 )
177+ dev = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
178+ hyperparams ["device" ] = dev
179+ return RBR (mlmodel , hyperparams )
175180 else :
176181 raise ValueError ("Recourse method not known" )
177182
@@ -200,9 +205,9 @@ def create_parser():
200205 Choices: ["mlp", "linear", "forest"].
201206 -r, --recourse_method: Specifies recourse methods for the experiment.
202207 Default: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
203- "focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar"].
208+ "focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr" ].
204209 Choices: ["dice", "ar", "causal_recourse", "cchvae", "cem", "cem_vae", "claproar", "clue", "cruds", "face_knn", "face_epsilon", "feature_tweak",
205- "focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar"].
210+ "focus", "gravitational", "greedy", "gs", "mace", "revise", "wachter", "cfvae", "cfrl", "probe", "roar", "rbr" ].
206211 -n, --number_of_samples: Specifies the number of instances per dataset.
207212 Default: 20.
208213 -s, --train_split: Specifies the split of the available data used for training.
@@ -277,6 +282,7 @@ def create_parser():
277282 "cfrl" ,
278283 "probe" ,
279284 "roar" ,
285+ "rbr" ,
280286 ],
281287 choices = [
282288 "dice" ,
@@ -302,6 +308,7 @@ def create_parser():
302308 "cfrl" ,
303309 "probe" ,
304310 "roar" ,
311+ "rbr" ,
305312 ],
306313 help = "Recourse methods for experiment" ,
307314 )
@@ -385,6 +392,7 @@ def create_parser():
385392 "cfrl" ,
386393 "probe" ,
387394 "roar" ,
395+ "rbr" ,
388396 ]
389397 sklearn_methods = ["feature_tweak" , "focus" , "mace" ]
390398
0 commit comments