Skip to content

Commit 7417a2e

Browse files
committed
Ran precommit hooks
1 parent 7a459ef commit 7417a2e

File tree

6 files changed

+12
-18
lines changed

6 files changed

+12
-18
lines changed

data/catalog/_data_main/process_data/process_german_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def load_german_data(modified=False):
1717
# input vars
1818
this_files_directory = os.path.dirname(os.path.realpath(__file__))
19-
if modified == False:
19+
if modified is False:
2020
raw_data_file = os.path.join(
2121
this_files_directory, "..", "raw_data", "german_v1.csv"
2222
)

data/catalog/_data_main/process_data/process_sba_data.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
import os
2-
from random import seed, shuffle
2+
from random import seed
33

4-
import numpy as np
54
import pandas as pd
6-
import process_data.process_utils_data as ut
7-
from sklearn.preprocessing import StandardScaler
85

96
RANDOM_SEED = 54321
107
seed(

data/catalog/loadData.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,7 @@
7474
print(f"[ENV WARNING] process_boston_housing_data not available. Error: {e}")
7575

7676
try:
77-
from data.catalog._data_main.process_data.process_sba_data import (
78-
load_sba_data,
79-
load_sba_data_modified,
80-
)
77+
from data.catalog._data_main.process_data.process_sba_data import load_sba_data
8178
except Exception as e:
8279
print(f"[ENV WARNING] process_sba_data not available. Error: {e}")
8380

experiments/run_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def create_parser():
284284
"mace",
285285
"revise",
286286
"wachter",
287-
'roar',
287+
"roar",
288288
],
289289
help="Recourse methods for experiment",
290290
)

methods/catalog/roar/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def get_counterfactuals(self, factuals: pd.DataFrame) -> pd.DataFrame:
203203

204204
coeffs_neg = (
205205
# self._mlmodel.raw_model.output.weight.cpu().detach()[0].numpy()
206-
self._mlmodel.raw_model.linear.weight.cpu().detach()[0].numpy()
206+
self._mlmodel.raw_model.linear.weight.cpu()
207+
.detach()[0]
208+
.numpy()
207209
)
208210
coeffs_pos = (
209211
self._mlmodel.raw_model.linear.weight.cpu().detach()[1].numpy()

methods/catalog/roar/reproduce.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from pathlib import Path
2-
31
import numpy as np
4-
import pandas as pd
52
import pytest
63

74
from data.catalog.online_catalog import DataCatalog
@@ -10,6 +7,7 @@
107

118
RANDOM_SEED = 54321
129

10+
1311
# Find indices where recourse is needed
1412
def recourse_needed(predict_fn, X, target=1):
1513
return np.where(predict_fn(X) == 1 - target)[0]
@@ -86,17 +84,17 @@ def test_roar(dataset_name, model_type, backend):
8684
m2._test_accuracy()
8785

8886
print("Using %s cost" % args["cost"])
89-
if args["cost"] == "l1":
90-
feature_costs = None
87+
# if args["cost"] == "l1":
88+
# feature_costs = None
9189

9290
coefficients = intercept = None
9391

9492
roar = Roar(mlmodel=m1, hyperparams={}, coeffs=coefficients, intercepts=intercept)
9593

96-
lamb = args["lamb"]
94+
# lamb = args["lamb"]
9795

9896
recourses = []
99-
deltas = []
97+
# deltas = []
10098

10199
factuals = (data._df_test).sample(n=10, random_state=RANDOM_SEED)
102100

0 commit comments

Comments
 (0)