|
| 1 | +import logging |
| 2 | +import matplotlib as mpl |
| 3 | +import numpy as np |
| 4 | +import os |
| 5 | +import dvu |
| 6 | +import pickle |
| 7 | +from matplotlib import pyplot as plt |
| 8 | +from os.path import join as oj |
| 9 | +from os.path import dirname |
| 10 | +import matplotlib.gridspec as gridspec |
| 11 | + |
| 12 | +from sklearn.ensemble import GradientBoostingClassifier, GradientBoostingRegressor |
| 13 | +from sklearn.model_selection import train_test_split |
| 14 | +from sklearn.metrics import roc_auc_score, r2_score |
| 15 | + |
| 16 | +from imodels import FIGSRegressor, FIGSClassifier, get_clean_dataset |
| 17 | + |
| 18 | +# plt.rcParams['text.usetex'] = True |
| 19 | + |
| 20 | +# from config.figs.datasets import DATASETS_REGRESSION, DATASETS_CLASSIFICATION |
| 21 | +DATASETS_CLASSIFICATION = [ |
| 22 | + # classification datasets from original random forests paper |
| 23 | + # page 9: https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf |
| 24 | + # ("sonar", "sonar", "pmlb"), |
| 25 | + # ("heart", "heart", 'imodels'), |
| 26 | + # ("breast-cancer", "breast_cancer", 'imodels'), # this is the wrong breast-cancer dataset (https://new.openml.org/search?type=data&sort=runs&id=13&status=active) |
| 27 | + # ("haberman", "haberman", 'imodels'), |
| 28 | + # ("ionosphere", "ionosphere", 'pmlb'), |
| 29 | + ("diabetes", "diabetes", "pmlb"), |
| 30 | + # ("liver", "8", "openml"), # note: we omit this dataset bc it's label was found to be incorrect (see caveat here: https://archive.ics.uci.edu/ml/datasets/liver+disorders#:~:text=The%207th%20field%20(selector)%20has%20been%20widely%20misinterpreted%20in%20the%20past%20as%20a%20dependent%20variable%20representing%20presence%20or%20absence%20of%20a%20liver%20disorder.) |
| 31 | + # ("credit-g", "credit_g", 'imodels'), # like german-credit, but more feats |
| 32 | + ("german-credit", "german", "pmlb"), |
| 33 | + |
| 34 | + # clinical-decision rules |
| 35 | + # ("iai-pecarn", "iai_pecarn.csv", "imodels"), |
| 36 | + |
| 37 | + # popular classification datasets used in rule-based modeling / fairness |
| 38 | + # page 7: http://proceedings.mlr.press/v97/wang19a/wang19a.pdf |
| 39 | + ("juvenile", "juvenile_clean", 'imodels'), |
| 40 | + ("recidivism", "compas_two_year_clean", 'imodels'), |
| 41 | + ("credit", "credit_card_clean", 'imodels'), |
| 42 | + ("readmission", 'readmission_clean', 'imodels'), # v big |
| 43 | +] |
| 44 | + |
| 45 | +DATASETS_REGRESSION = [ |
| 46 | + # leo-breiman paper random forest uses some UCI datasets as well |
| 47 | + # pg 23: https://www.stat.berkeley.edu/~breiman/randomforest2001.pdf |
| 48 | + ('friedman1', 'friedman1', 'synthetic'), |
| 49 | + ('friedman2', 'friedman2', 'synthetic'), |
| 50 | + ('friedman3', 'friedman3', 'synthetic'), |
| 51 | + ("diabetes-regr", "diabetes", 'sklearn'), |
| 52 | + ('abalone', '183', 'openml'), |
| 53 | + ("echo-months", "1199_BNG_echoMonths", 'pmlb'), |
| 54 | + ("satellite-image", "294_satellite_image", 'pmlb'), |
| 55 | + ("california-housing", "california_housing", 'sklearn'), # this replaced boston-housing due to ethical issues |
| 56 | + ("breast-tumor", "1201_BNG_breastTumor", 'pmlb'), # this one is v big (100k examples) |
| 57 | + |
| 58 | +] |
| 59 | +LOGGER = logging.getLogger(__name__) |
| 60 | +logging.basicConfig(level=logging.INFO) |
| 61 | + |
| 62 | +dvu.set_style() |
| 63 | +mpl.rcParams['figure.dpi'] = 250 |
| 64 | + |
| 65 | +cb2 = '#66ccff' |
| 66 | +cb = '#1f77b4' |
| 67 | +cr = '#cc0000' |
| 68 | +cp = '#cc3399' |
| 69 | +cy = '#d8b365' |
| 70 | +cg = '#5ab4ac' |
| 71 | + |
| 72 | +DIR_FIGS = oj(dirname(os.path.realpath(__file__)), 'figures') |
| 73 | +DSET_METADATA = {'sonar': (208, 60), 'heart': (270, 15), 'breast-cancer': (277, 17), 'haberman': (306, 3), |
| 74 | + 'ionosphere': (351, 34), 'diabetes': (768, 8), 'german-credit': (1000, 20), 'juvenile': (3640, 286), |
| 75 | + 'recidivism': (6172, 20), 'credit': (30000, 33), 'readmission': (101763, 150), 'friedman1': (200, 10), |
| 76 | + 'friedman2': (200, 4), 'friedman3': (200, 4), 'abalone': (4177, 8), 'diabetes-regr': (442, 10), |
| 77 | + 'california-housing': (20640, 8), 'satellite-image': (6435, 36), 'echo-months': (17496, 9), |
| 78 | + 'breast-tumor': (116640, 9), "vo_pati": (100, 100), "radchenko_james": (300, 50), |
| 79 | + 'tbi-pecarn': (42428, 121), 'csi-pecarn': (3313, 36), 'iai-pecarn': (12044, 58), |
| 80 | + } |
| 81 | + |
| 82 | +COLORS = { |
| 83 | + 'FIGS': 'black', |
| 84 | + 'CART': 'orange', # cp, |
| 85 | + 'Rulefit': 'green', |
| 86 | + 'C45': cb, |
| 87 | + 'CART_(MSE)': 'orange', |
| 88 | + 'CART_(MAE)': cg, |
| 89 | + 'FIGS_(Reweighted)': cg, |
| 90 | + 'FIGS_(Include_Linear)': cb, |
| 91 | + 'GBDT-1': cp, |
| 92 | + 'GBDT-2': "green", |
| 93 | + 'GBDT-3': cy, |
| 94 | + 'Dist-GB-FIGS': cg, |
| 95 | + 'Dist-RF-FIGS': cp, |
| 96 | + 'Dist-RF-FIGS-3': 'green', |
| 97 | + 'RandomForest': 'gray', |
| 98 | + 'GBDT': 'black', |
| 99 | + 'BFIGS': 'green', |
| 100 | + 'TAO': cb, |
| 101 | +} |
| 102 | +def tune_boosting(X, y, budget, is_classification=True): |
| 103 | + gb_model = GradientBoostingClassifier if is_classification else GradientBoostingRegressor |
| 104 | + metric = roc_auc_score if is_classification else r2_score |
| 105 | + # split data into train and test |
| 106 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
| 107 | + models_scores = {} |
| 108 | + models = {} |
| 109 | + for n_trees in range(1, int(budget / 2)): |
| 110 | + max_depth = max(int(np.floor(np.log2(budget / n_trees))), 1) |
| 111 | + LOGGER.info(f"tuning model with {n_trees} trees and max depth {max_depth}") |
| 112 | + |
| 113 | + model = gb_model(n_estimators=n_trees + 1, max_depth=max_depth) |
| 114 | + model.fit(X_train, y_train) |
| 115 | + models_scores[n_trees] = metric(y_test, model.predict(X_test)) |
| 116 | + models[n_trees] = model |
| 117 | + # fit the best model on all the data |
| 118 | + n_trees_best = max(models_scores, key=models_scores.get) |
| 119 | + max_depth = int(np.ceil(np.log2(n_trees_best + 1))) |
| 120 | + model_best = gb_model(n_estimators=n_trees_best + 1, max_depth=max_depth) |
| 121 | + model_best.fit(X, y) |
| 122 | + return model_best |
| 123 | + |
| 124 | + |
| 125 | +def figs_vs_boosting(X, y, budget,depth, n_seeds=10, only_boosting=False): |
| 126 | + is_classification = len(np.unique(y)) == 2 |
| 127 | + metric = roc_auc_score if is_classification else r2_score |
| 128 | + n_estimators = budget // (np.sum([2**i for i in range(depth)])) |
| 129 | + |
| 130 | + scores = {"figs": [], "boosting": []} |
| 131 | + |
| 132 | + for _ in range(n_seeds): |
| 133 | + # split to train and test |
| 134 | + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2) |
| 135 | + |
| 136 | + if n_estimators > 0: |
| 137 | + gb_model = GradientBoostingClassifier if is_classification else GradientBoostingRegressor |
| 138 | + gb = gb_model(n_estimators=n_estimators, max_depth=depth, learning_rate=1) |
| 139 | + gb.fit(X_train, y_train) |
| 140 | + preds = gb.predict_proba(X_test)[:, 1] if is_classification else gb.predict(X_test) |
| 141 | + gb_score = metric(y_test, preds) |
| 142 | + scores["boosting"].append(gb_score) |
| 143 | + else: |
| 144 | + scores["boosting"].append(np.nan) |
| 145 | + |
| 146 | + if only_boosting: |
| 147 | + continue |
| 148 | + |
| 149 | + figs_model = FIGSClassifier if is_classification else FIGSRegressor |
| 150 | + figs = figs_model(max_rules=budget) |
| 151 | + figs.fit(X_train, y_train) |
| 152 | + preds_figs = figs.predict_proba(X_test)[:, 1] if is_classification else figs.predict(X_test) |
| 153 | + figs_score = metric(y_test, preds_figs) |
| 154 | + scores["figs"].append(figs_score) |
| 155 | + |
| 156 | + |
| 157 | + return scores |
| 158 | + |
| 159 | + |
| 160 | +def analyze_datasets(datasets, fig_name=None, reg=False): |
| 161 | + n_cols = 3 |
| 162 | + n_rows = int(np.ceil(len(datasets) / n_cols)) |
| 163 | + budgets = np.arange(5, 21) |
| 164 | + |
| 165 | + fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows)) |
| 166 | + plt.subplots_adjust(wspace=0.2, hspace=0.5) |
| 167 | + |
| 168 | + n_seeds = 20 |
| 169 | + for i, d in enumerate(datasets): |
| 170 | + if isinstance(d, str): |
| 171 | + dset_name = d |
| 172 | + elif isinstance(d, tuple): |
| 173 | + dset_name = d[0] |
| 174 | + row = i // n_cols |
| 175 | + col = i % n_cols |
| 176 | + ax = axes[row, col] |
| 177 | + # X, y, feat_names = get_clean_dataset(d[1], data_source=d[2]) |
| 178 | + f_name = f"figs_vs_boosting_{dset_name}_cls.pkl" if not reg else f"figs_vs_boosting_{dset_name}_reg.pkl" |
| 179 | + if os.path.exists(f_name): |
| 180 | + ds_data = pickle.load(open(f_name, "rb")) |
| 181 | + means = ds_data["means"] |
| 182 | + std = ds_data["std"] |
| 183 | + else: |
| 184 | + means = {"figs": [], "boosting d1": [], "boosting d2": [], "boosting d3":[]} |
| 185 | + std = {"figs": [], "boosting d1": [], "boosting d2": [], "boosting d3":[]} |
| 186 | + for budget in budgets: |
| 187 | + scores = figs_vs_boosting(X, y, budget=budget, n_seeds=n_seeds, depth=1) |
| 188 | + means["figs"].append(np.mean(scores["figs"])) |
| 189 | + means["boosting d1"].append(np.mean(scores["boosting"])) |
| 190 | + std["figs"].append(np.std(scores["figs"]) / np.sqrt(n_seeds)) |
| 191 | + std["boosting d1"].append(np.std(scores["boosting"]) / np.sqrt(n_seeds)) |
| 192 | + for d in [2,3]: |
| 193 | + scores = figs_vs_boosting(X, y, budget=budget, n_seeds=n_seeds, depth=d) |
| 194 | + is_na = np.isnan(scores["boosting"]).sum() > 0 |
| 195 | + if is_na: |
| 196 | + # set mean and std to nan |
| 197 | + means[f"boosting d{d}"].append(np.nan) |
| 198 | + std[f"boosting d{d}"].append(np.nan) |
| 199 | + continue |
| 200 | + means[f"boosting d{d}"].append(np.nanmean(scores["boosting"])) |
| 201 | + std[f"boosting d{d}"].append(np.nanstd(scores["boosting"]) / np.sqrt(n_seeds)) |
| 202 | + # make plot with error bars vs budget |
| 203 | + ds_data = {"means": means, "std": std, "budgets": budgets} |
| 204 | + with open(f_name, "wb") as f: |
| 205 | + pickle.dump(ds_data, f) |
| 206 | + ax.errorbar(budgets, means["figs"], yerr=std["figs"], color=COLORS["FIGS"], elinewidth = 3, fmt='o') |
| 207 | + if reg: |
| 208 | + ax.plot(budgets, means["figs"], label="FIGS", color=COLORS["FIGS"], linewidth = 3) |
| 209 | + else: |
| 210 | + ax.plot(budgets, means["figs"], color=COLORS["FIGS"], linewidth=3) |
| 211 | + |
| 212 | + for depth in [1,2,3]: |
| 213 | + ax.errorbar(budgets, means[f"boosting d{depth}"], yerr=std[f"boosting d{depth}"], color=COLORS[f"GBDT-{depth}"], |
| 214 | + alpha=0.5, elinewidth = 3, fmt='o') |
| 215 | + if depth == 1 and reg: |
| 216 | + ax.plot(budgets, means[f"boosting d{depth}"], color=COLORS[f"GBDT-{depth}"], |
| 217 | + alpha=0.5, linewidth = 3) |
| 218 | + |
| 219 | + |
| 220 | + else: |
| 221 | + ax.plot(budgets, means[f"boosting d{depth}"], label=f"GB (max_depth = {depth})", color=COLORS[f"GBDT-{depth}"], |
| 222 | + alpha=0.5, linewidth = 3) |
| 223 | + # add error bars |
| 224 | + |
| 225 | + ax.set_title(dset_name.capitalize().replace('-', ' ') + f' ($n={DSET_METADATA.get(dset_name, (-1))[0]}$)', |
| 226 | + fontsize=16) |
| 227 | + ylab = "AUC" if not reg else r'$R^2$' |
| 228 | + ax.set_xlabel("number of splits", fontsize=14) |
| 229 | + if i % n_cols == 0: |
| 230 | + ax.set_ylabel(ylab, fontsize=14) |
| 231 | + if row == 0 and col == n_cols-1: |
| 232 | + # ax.legend() |
| 233 | + dvu.line_legend(fontsize=14, adjust_text_labels=False, ax=ax, xoffset_spacing=0.05) |
| 234 | + if not reg: |
| 235 | + ax.text(0.89, 0.79, "FIGS", color=COLORS["FIGS"], fontsize=14, |
| 236 | + transform=ax.transAxes) |
| 237 | + else: |
| 238 | + depth = 1 |
| 239 | + ax.text(0.89, 0.68, f"GB (max_depth = {depth})", color=COLORS[f"GBDT-{depth}"], fontsize=14, |
| 240 | + transform=ax.transAxes) |
| 241 | + |
| 242 | + txt = "Classification" if not reg else "Regression" |
| 243 | + fig.text(0.02, 0.5, txt, fontsize=20, ha='center', va='center', rotation='vertical') |
| 244 | + fig.subplots_adjust(left=0.15) |
| 245 | + plt.tight_layout(pad=5, h_pad=5, w_pad=5) |
| 246 | + plt.savefig(f"{fig_name}_new.png", dpi=300) |
| 247 | + |
| 248 | + |
| 249 | +def main(): |
| 250 | + # for depth in [1,2,3]: |
| 251 | + cls_reverse = DATASETS_CLASSIFICATION[::-1] |
| 252 | + reg_reverse = DATASETS_REGRESSION[::-1] |
| 253 | + analyze_datasets(cls_reverse, fig_name=f"figs_vs_boosting_classification_reformatted", reg=False) |
| 254 | + analyze_datasets(reg_reverse, fig_name=f"figs_vs_boosting_regression_reformatted" , reg=True) |
| 255 | + |
| 256 | + |
| 257 | +if __name__ == '__main__': |
| 258 | + main() |
0 commit comments