Skip to content

Commit a6f797d

Browse files
authored
Merge pull request #19 from OmerRonen/pnas_figs
Pnas figs
2 parents 09e750f + 02098d2 commit a6f797d

File tree

3 files changed

+258
-0
lines changed

3 files changed

+258
-0
lines changed
1.06 MB
Loading
1.53 MB
Loading

notebooks/figs/pnas_boosting.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import logging
2+
import matplotlib as mpl
3+
import numpy as np
4+
import os
5+
import dvu
6+
import pickle
7+
from matplotlib import pyplot as plt
8+
from os.path import join as oj
9+
from os.path import dirname
10+
import matplotlib.gridspec as gridspec
11+
12+
from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor
13+
from sklearn.model_selection import train_test_split
14+
from sklearn.metrics import roc_auc_score, r2_score
15+
16+
from imodels import FIGSRegressor, FIGSClassifier, get_clean_dataset
17+
18+
# plt.rcParams['text.usetex'] = True
19+
20+
# from config.figs.datasets import DATASETS_REGRESSION, DATASETS_CLASSIFICATION
21+
DATASETS_CLASSIFICATION = [
22+
# classification datasets from original random forests paper
23+
# page 9: https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf
24+
# ("sonar", "sonar", "pmlb"),
25+
# ("heart", "heart", 'imodels'),
26+
# ("breast-cancer", "breast_cancer", 'imodels'), # this is the wrong breast-cancer dataset (https://new.openml.org/search?type=data&sort=runs&id=13&status=active)
27+
# ("haberman", "haberman", 'imodels'),
28+
# ("ionosphere", "ionosphere", 'pmlb'),
29+
("diabetes", "diabetes", "pmlb"),
30+
# ("liver", "8", "openml"), # note: we omit this dataset bc it's label was found to be incorrect (see caveat here: https://archive.ics.uci.edu/ml/datasets/liver+disorders#:~:text=The%207th%20field%20(selector)%20has%20been%20widely%20misinterpreted%20in%20the%20past%20as%20a%20dependent%20variable%20representing%20presence%20or%20absence%20of%20a%20liver%20disorder.)
31+
# ("credit-g", "credit_g", 'imodels'), # like german-credit, but more feats
32+
("german-credit", "german", "pmlb"),
33+
34+
# clinical-decision rules
35+
# ("iai-pecarn", "iai_pecarn.csv", "imodels"),
36+
37+
# popular classification datasets used in rule-based modeling / fairness
38+
# page 7: http://proceedings.mlr.press/v97/wang19a/wang19a.pdf
39+
("juvenile", "juvenile_clean", 'imodels'),
40+
("recidivism", "compas_two_year_clean", 'imodels'),
41+
("credit", "credit_card_clean", 'imodels'),
42+
("readmission", 'readmission_clean', 'imodels'), # v big
43+
]
44+
45+
DATASETS_REGRESSION = [
46+
# leo-breiman paper random forest uses some UCI datasets as well
47+
# pg 23: https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf
48+
('friedman1', 'friedman1', 'synthetic'),
49+
('friedman2', 'friedman2', 'synthetic'),
50+
('friedman3', 'friedman3', 'synthetic'),
51+
("diabetes-regr", "diabetes", 'sklearn'),
52+
('abalone', '183', 'openml'),
53+
("echo-months", "1199_BNG_echoMonths", 'pmlb'),
54+
("satellite-image", "294_satellite_image", 'pmlb'),
55+
("california-housing", "california_housing", 'sklearn'), # this replaced boston-housing due to ethical issues
56+
("breast-tumor", "1201_BNG_breastTumor", 'pmlb'), # this one is v big (100k examples)
57+
58+
]
59+
LOGGER = logging.getLogger(__name__)
60+
logging.basicConfig(level=logging.INFO)
61+
62+
dvu.set_style()
63+
mpl.rcParams['figure.dpi'] = 250
64+
65+
cb2 = '#66ccff'
66+
cb = '#1f77b4'
67+
cr = '#cc0000'
68+
cp = '#cc3399'
69+
cy = '#d8b365'
70+
cg = '#5ab4ac'
71+
72+
DIR_FIGS = oj(dirname(os.path.realpath(__file__)), 'figures')
73+
DSET_METADATA = {'sonar': (208, 60), 'heart': (270, 15), 'breast-cancer': (277, 17), 'haberman': (306, 3),
74+
'ionosphere': (351, 34), 'diabetes': (768, 8), 'german-credit': (1000, 20), 'juvenile': (3640, 286),
75+
'recidivism': (6172, 20), 'credit': (30000, 33), 'readmission': (101763, 150), 'friedman1': (200, 10),
76+
'friedman2': (200, 4), 'friedman3': (200, 4), 'abalone': (4177, 8), 'diabetes-regr': (442, 10),
77+
'california-housing': (20640, 8), 'satellite-image': (6435, 36), 'echo-months': (17496, 9),
78+
'breast-tumor': (116640, 9), "vo_pati": (100, 100), "radchenko_james": (300, 50),
79+
'tbi-pecarn': (42428, 121), 'csi-pecarn': (3313, 36), 'iai-pecarn': (12044, 58),
80+
}
81+
82+
COLORS = {
83+
'FIGS': 'black',
84+
'CART': 'orange', # cp,
85+
'Rulefit': 'green',
86+
'C45': cb,
87+
'CART_(MSE)': 'orange',
88+
'CART_(MAE)': cg,
89+
'FIGS_(Reweighted)': cg,
90+
'FIGS_(Include_Linear)': cb,
91+
'GBDT-1': cp,
92+
'GBDT-2': "green",
93+
'GBDT-3': cy,
94+
'Dist-GB-FIGS': cg,
95+
'Dist-RF-FIGS': cp,
96+
'Dist-RF-FIGS-3': 'green',
97+
'RandomForest': 'gray',
98+
'GBDT': 'black',
99+
'BFIGS': 'green',
100+
'TAO': cb,
101+
}
102+
def tune_boosting(X, y, budget, is_classification=True):
103+
gb_model = GradientBoostingClassifier if is_classification else GradientBoostingRegressor
104+
metric = roc_auc_score if is_classification else r2_score
105+
# split data into train and test
106+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
107+
models_scores = {}
108+
models = {}
109+
for n_trees in range(1, int(budget / 2)):
110+
max_depth = max(int(np.floor(np.log2(budget / n_trees))), 1)
111+
LOGGER.info(f"tuning model with {n_trees} trees and max depth {max_depth}")
112+
113+
model = gb_model(n_estimators=n_trees + 1, max_depth=max_depth)
114+
model.fit(X_train, y_train)
115+
models_scores[n_trees] = metric(y_test, model.predict(X_test))
116+
models[n_trees] = model
117+
# fit the best model on all the data
118+
n_trees_best = max(models_scores, key=models_scores.get)
119+
max_depth = int(np.ceil(np.log2(n_trees_best + 1)))
120+
model_best = gb_model(n_estimators=n_trees_best + 1, max_depth=max_depth)
121+
model_best.fit(X, y)
122+
return model_best
123+
124+
125+
def figs_vs_boosting(X, y, budget,depth, n_seeds=10, only_boosting=False):
126+
is_classification = len(np.unique(y)) == 2
127+
metric = roc_auc_score if is_classification else r2_score
128+
n_estimators = budget // (np.sum([2**i for i in range(depth)]))
129+
130+
scores = {"figs": [], "boosting": []}
131+
132+
for _ in range(n_seeds):
133+
# split to train and test
134+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
135+
136+
if n_estimators > 0:
137+
gb_model = GradientBoostingClassifier if is_classification else GradientBoostingRegressor
138+
gb = gb_model(n_estimators=n_estimators, max_depth=depth, learning_rate=1)
139+
gb.fit(X_train, y_train)
140+
preds = gb.predict_proba(X_test)[:, 1] if is_classification else gb.predict(X_test)
141+
gb_score = metric(y_test, preds)
142+
scores["boosting"].append(gb_score)
143+
else:
144+
scores["boosting"].append(np.nan)
145+
146+
if only_boosting:
147+
continue
148+
149+
figs_model = FIGSClassifier if is_classification else FIGSRegressor
150+
figs = figs_model(max_rules=budget)
151+
figs.fit(X_train, y_train)
152+
preds_figs = figs.predict_proba(X_test)[:, 1] if is_classification else figs.predict(X_test)
153+
figs_score = metric(y_test, preds_figs)
154+
scores["figs"].append(figs_score)
155+
156+
157+
return scores
158+
159+
160+
def analyze_datasets(datasets, fig_name=None, reg=False):
161+
n_cols = 3
162+
n_rows = int(np.ceil(len(datasets) / n_cols))
163+
budgets = np.arange(5, 21)
164+
165+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
166+
plt.subplots_adjust(wspace=0.2, hspace=0.5)
167+
168+
n_seeds = 20
169+
for i, d in enumerate(datasets):
170+
if isinstance(d, str):
171+
dset_name = d
172+
elif isinstance(d, tuple):
173+
dset_name = d[0]
174+
row = i // n_cols
175+
col = i % n_cols
176+
ax = axes[row, col]
177+
# X, y, feat_names = get_clean_dataset(d[1], data_source=d[2])
178+
f_name = f"figs_vs_boosting_{dset_name}_cls.pkl" if not reg else f"figs_vs_boosting_{dset_name}_reg.pkl"
179+
if os.path.exists(f_name):
180+
ds_data = pickle.load(open(f_name, "rb"))
181+
means = ds_data["means"]
182+
std = ds_data["std"]
183+
else:
184+
means = {"figs": [], "boosting d1": [], "boosting d2": [], "boosting d3":[]}
185+
std = {"figs": [], "boosting d1": [], "boosting d2": [], "boosting d3":[]}
186+
for budget in budgets:
187+
scores = figs_vs_boosting(X, y, budget=budget, n_seeds=n_seeds, depth=1)
188+
means["figs"].append(np.mean(scores["figs"]))
189+
means["boosting d1"].append(np.mean(scores["boosting"]))
190+
std["figs"].append(np.std(scores["figs"]) / np.sqrt(n_seeds))
191+
std["boosting d1"].append(np.std(scores["boosting"]) / np.sqrt(n_seeds))
192+
for d in [2,3]:
193+
scores = figs_vs_boosting(X, y, budget=budget, n_seeds=n_seeds, depth=d)
194+
is_na = np.isnan(scores["boosting"]).sum() > 0
195+
if is_na:
196+
# set mean and std to nan
197+
means[f"boosting d{d}"].append(np.nan)
198+
std[f"boosting d{d}"].append(np.nan)
199+
continue
200+
means[f"boosting d{d}"].append(np.nanmean(scores["boosting"]))
201+
std[f"boosting d{d}"].append(np.nanstd(scores["boosting"]) / np.sqrt(n_seeds))
202+
# make plot with error bars vs budget
203+
ds_data = {"means": means, "std": std, "budgets": budgets}
204+
with open(f_name, "wb") as f:
205+
pickle.dump(ds_data, f)
206+
ax.errorbar(budgets, means["figs"], yerr=std["figs"], color=COLORS["FIGS"], elinewidth = 3, fmt='o')
207+
if reg:
208+
ax.plot(budgets, means["figs"], label="FIGS", color=COLORS["FIGS"], linewidth = 3)
209+
else:
210+
ax.plot(budgets, means["figs"], color=COLORS["FIGS"], linewidth=3)
211+
212+
for depth in [1,2,3]:
213+
ax.errorbar(budgets, means[f"boosting d{depth}"], yerr=std[f"boosting d{depth}"], color=COLORS[f"GBDT-{depth}"],
214+
alpha=0.5, elinewidth = 3, fmt='o')
215+
if depth == 1 and reg:
216+
ax.plot(budgets, means[f"boosting d{depth}"], color=COLORS[f"GBDT-{depth}"],
217+
alpha=0.5, linewidth = 3)
218+
219+
220+
else:
221+
ax.plot(budgets, means[f"boosting d{depth}"], label=f"GB (max_depth = {depth})", color=COLORS[f"GBDT-{depth}"],
222+
alpha=0.5, linewidth = 3)
223+
# add error bars
224+
225+
ax.set_title(dset_name.capitalize().replace('-', ' ') + f' ($n={DSET_METADATA.get(dset_name, (-1))[0]}$)',
226+
fontsize=16)
227+
ylab = "AUC" if not reg else r'$R^2$'
228+
ax.set_xlabel("number of splits", fontsize=14)
229+
if i % n_cols == 0:
230+
ax.set_ylabel(ylab, fontsize=14)
231+
if row == 0 and col == n_cols-1:
232+
# ax.legend()
233+
dvu.line_legend(fontsize=14, adjust_text_labels=False, ax=ax, xoffset_spacing=0.05)
234+
if not reg:
235+
ax.text(0.89, 0.79, "FIGS", color=COLORS["FIGS"], fontsize=14,
236+
transform=ax.transAxes)
237+
else:
238+
depth = 1
239+
ax.text(0.89, 0.68, f"GB (max_depth = {depth})", color=COLORS[f"GBDT-{depth}"], fontsize=14,
240+
transform=ax.transAxes)
241+
242+
txt = "Classification" if not reg else "Regression"
243+
fig.text(0.02, 0.5, txt, fontsize=20, ha='center', va='center', rotation='vertical')
244+
fig.subplots_adjust(left=0.15)
245+
plt.tight_layout(pad=5, h_pad=5, w_pad=5)
246+
plt.savefig(f"{fig_name}_new.png", dpi=300)
247+
248+
249+
def main():
250+
# for depth in [1,2,3]:
251+
cls_reverse = DATASETS_CLASSIFICATION[::-1]
252+
reg_reverse = DATASETS_REGRESSION[::-1]
253+
analyze_datasets(cls_reverse, fig_name=f"figs_vs_boosting_classification_reformatted", reg=False)
254+
analyze_datasets(reg_reverse, fig_name=f"figs_vs_boosting_regression_reformatted" , reg=True)
255+
256+
257+
if __name__ == '__main__':
258+
main()

0 commit comments

Comments
 (0)