|
117 | 117 | " \"rf_sk\",\n", |
118 | 118 | " \"ert\",\n", |
119 | 119 | " \"elastic\",\n", |
| 120 | + " \"sgd\",\n", |
120 | 121 | " \"lm\",\n", |
121 | | - " \"lm_svm\",\n", |
| 122 | + " \"lsvm\",\n", |
122 | 123 | " \"svm\",\n", |
123 | 124 | " \"nn\",\n", |
124 | 125 | " \"knn\",\n", |
|
136 | 137 | "def trial_runner(trial):\n", |
137 | 138 | " seed=42 + int(trial.replicate_num)\n", |
138 | 139 | " max_samples = 1000000000000\n", |
| 140 | + " n_calibration_folds = 4 # 4 uses all cores on the containers\n", |
139 | 141 | "\n", |
140 | 142 | " from interpret.glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor\n", |
141 | 143 | " from xgboost import XGBClassifier, XGBRegressor, XGBRFClassifier, XGBRFRegressor\n", |
142 | 144 | " from lightgbm import LGBMClassifier, LGBMRegressor\n", |
143 | 145 | " from catboost import CatBoostClassifier, CatBoostRegressor\n", |
144 | 146 | " from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, ExtraTreesClassifier, ExtraTreesRegressor\n", |
145 | | - " from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet\n", |
| 147 | + " from sklearn.linear_model import LogisticRegression, LinearRegression, ElasticNet, SGDClassifier, SGDRegressor\n", |
146 | 148 | " from sklearn.svm import LinearSVC, LinearSVR, SVC, SVR\n", |
147 | 149 | " from sklearn.neural_network import MLPClassifier, MLPRegressor\n", |
148 | 150 | " from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor\n", |
|
178 | 180 | " pass # Re-enable stratification if dataset fails from absent class in train/test sets (PMLB)\n", |
179 | 181 | " \n", |
180 | 182 | " fit_params = {}\n", |
181 | | - " fit_params[\"X\"], X_test, fit_params[\"y\"], y_test = train_test_split(X, y, test_size=0.3, stratify=stratification, random_state=seed)\n", |
| 183 | + " fit_params[\"X\"], X_test, fit_params[\"y\"], y_test = train_test_split(X, y, test_size=0.2, stratify=stratification, random_state=seed)\n", |
182 | 184 | " del X\n", |
183 | 185 | "\n", |
184 | 186 | " # Build optional preprocessor for use by methods below\n", |
|
197 | 199 | " rf_sk_params = {}\n", |
198 | 200 | " ert_params = {}\n", |
199 | 201 | " elastic_params = {}\n", |
| 202 | + " sgd_params = {}\n", |
200 | 203 | " lm_params = {}\n", |
201 | | - " lm_svm_params = {}\n", |
| 204 | + " lsvm_params = {}\n", |
202 | 205 | " svm_params = {}\n", |
203 | 206 | " nn_params = {}\n", |
204 | 207 | " knn_params = {}\n", |
|
212 | 215 | " catboost_params[\"verbose\"] = False\n", |
213 | 216 | " rf_xgb_params[\"enable_categorical\"] = True\n", |
214 | 217 | " rf_xgb_params[\"feature_types\"] = [\"c\" if cat else \"q\" for cat in cat_bools]\n", |
215 | | - " rf_sk_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
| 218 | + " rf_sk_params[\"random_state\"] = seed\n", |
216 | 219 | " rf_sk_params[\"n_jobs\"] = -1\n", |
217 | 220 | " ert_params[\"n_jobs\"] = -1\n", |
218 | | - " ert_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
219 | | - " elastic_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
| 221 | + " ert_params[\"random_state\"] = seed\n", |
| 222 | + " elastic_params[\"random_state\"] = seed\n", |
220 | 223 | " # elastic_params[\"selection\"] = 'cyclic' # 'random' # TODO: try both\n", |
| 224 | + " sgd_params[\"random_state\"] = seed\n", |
221 | 225 | " lm_params[\"n_jobs\"] = -1\n", |
222 | | - " lm_svm_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
223 | | - " nn_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
| 226 | + " lsvm_params[\"random_state\"] = seed\n", |
| 227 | + " nn_params[\"random_state\"] = seed\n", |
224 | 228 | " knn_params[\"n_jobs\"] = -1\n", |
225 | 229 | " aplr_params[\"m\"] = 3000\n", |
226 | 230 | "\n", |
|
241 | 245 | " #rf_sk_params[\"n_estimators\"] = 1\n", |
242 | 246 | " #ert_params[\"n_estimators\"] = 1\n", |
243 | 247 | " #elastic_params[\"max_iter\"] = 1\n", |
244 | | - " #lm_svm_params[\"max_iter\"] = 1\n", |
| 248 | + " #sgd_params[\"max_iter\"] = 1\n", |
| 249 | + " #lsvm_params[\"max_iter\"] = 1\n", |
245 | 250 | " #nn_params[\"max_iter\"] = 1\n", |
246 | 251 | " #knn_params[\"n_neighbors\"] = 1\n", |
247 | 252 | " #knn_params[\"leaf_size\"] = 1\n", |
|
270 | 275 | " elif trial.method.name == \"elastic\":\n", |
271 | 276 | " elastic_params[\"n_jobs\"] = -1\n", |
272 | 277 | " est = Pipeline([(\"ct\", ct), (\"est\", LogisticRegression(penalty='elasticnet', solver='saga', l1_ratio=0.5, **elastic_params))])\n", |
| 278 | + " elif trial.method.name == \"sgd\":\n", |
| 279 | + " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(SGDClassifier(**sgd_params), n_jobs=-1, cv=n_calibration_folds))])\n", |
273 | 280 | " elif trial.method.name == \"lm\":\n", |
274 | | - " lm_params[\"random_state\"] = seed # TODO: is this needed for reproducibility?\n", |
| 281 | + " lm_params[\"random_state\"] = seed\n", |
275 | 282 | " est = Pipeline([(\"ct\", ct), (\"est\", LogisticRegression(**lm_params))])\n", |
276 | | - " elif trial.method.name == \"lm_svm\":\n", |
| 283 | + " elif trial.method.name == \"lsvm\":\n", |
277 | 284 | " if trial.task.name in {\"CIFAR_10\", \"Devnagari-Script\"}:\n", |
278 | 285 | " max_samples = 10000 # crashes or fit time too long without subsampling\n", |
279 | | - " if trial.task.problem == \"multiclass\":\n", |
280 | | - " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(OneVsRestClassifier(LinearSVC(**lm_svm_params), n_jobs=-1)))])\n", |
281 | | - " else:\n", |
282 | | - " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(LinearSVC(**lm_svm_params), n_jobs=-1))])\n", |
| 286 | + " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(LinearSVC(**lsvm_params), n_jobs=-1, cv=n_calibration_folds))])\n", |
283 | 287 | " elif trial.method.name == \"svm\":\n", |
284 | 288 | " if trial.task.name in {\"CIFAR_10\", \"Devnagari-Script\"}:\n", |
285 | 289 | " max_samples = 10000 # crashes or fit time too long without subsampling\n", |
286 | 290 | " svm_params[\"random_state\"] = seed\n", |
287 | | - " if trial.task.problem == \"multiclass\":\n", |
288 | | - " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(OneVsRestClassifier(SVC(**svm_params), n_jobs=-1)))])\n", |
289 | | - " else:\n", |
290 | | - " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(SVC(**svm_params), n_jobs=-1))])\n", |
| 291 | + " est = Pipeline([(\"ct\", ct), (\"est\", CalibratedClassifierCV(SVC(**svm_params), n_jobs=-1, cv=n_calibration_folds))])\n", |
291 | 292 | " elif trial.method.name == \"nn\":\n", |
292 | 293 | " est = Pipeline([(\"ct\", ct), (\"est\", MLPClassifier(**nn_params))])\n", |
293 | 294 | " elif trial.method.name == \"knn\":\n", |
294 | 295 | " est = Pipeline([(\"ct\", ct), (\"est\", KNeighborsClassifier(**knn_params))])\n", |
295 | 296 | " elif trial.method.name == \"aplr\":\n", |
296 | 297 | " ct.sparse_threshold = 0 # APLR only handles dense\n", |
| 298 | + " if trial.task.name in {\"CIFAR_10\", \"Fashion-MNIST\", \"Devnagari-Script\", \"mnist_784\"}:\n", |
| 299 | + " max_samples = 10000 # crashes or fit time too long without subsampling\n", |
297 | 300 | " est = Pipeline([(\"ct\", ct), (\"est\", APLRClassifier(**aplr_params))])\n", |
298 | 301 | " fit_params[\"y\"] = fit_params[\"y\"].astype(str).to_numpy()\n", |
299 | 302 | " y_test = y_test.astype(str).to_numpy()\n", |
|
326 | 329 | " est = Pipeline([(\"ct\", ct), (\"est\", ExtraTreesRegressor(**ert_params))])\n", |
327 | 330 | " elif trial.method.name == \"elastic\":\n", |
328 | 331 | " est = Pipeline([(\"ct\", ct), (\"est\", ElasticNet(**elastic_params))])\n", |
| 332 | + " elif trial.method.name == \"sgd\":\n", |
| 333 | + " est = Pipeline([(\"ct\", ct), (\"est\", SGDRegressor(**sgd_params))])\n", |
329 | 334 | " elif trial.method.name == \"lm\":\n", |
330 | 335 | " est = Pipeline([(\"ct\", ct), (\"est\", LinearRegression(**lm_params))])\n", |
331 | | - " elif trial.method.name == \"lm_svm\":\n", |
332 | | - " est = Pipeline([(\"ct\", ct), (\"est\", LinearSVR(**lm_svm_params))])\n", |
| 336 | + " elif trial.method.name == \"lsvm\":\n", |
| 337 | + " est = Pipeline([(\"ct\", ct), (\"est\", LinearSVR(**lsvm_params))])\n", |
333 | 338 | " elif trial.method.name == \"svm\":\n", |
334 | | - " if trial.task.name in {\"Buzzinsocialmedia_Twitter\", \"nyc-taxi-green-dec-2016\", \"Airlines_DepDelay_10M\"}:\n", |
| 339 | + " if trial.task.name in {\"Buzzinsocialmedia_Twitter\", \"nyc-taxi-green-dec-2016\", \"Airlines_DepDelay_10M\", \"Yolanda\"}:\n", |
335 | 340 | " max_samples = 100000 # crashes or fit time too long without subsampling\n", |
336 | 341 | " est = Pipeline([(\"ct\", ct), (\"est\", SVR(**svm_params))])\n", |
337 | 342 | " elif trial.method.name == \"nn\":\n", |
|
341 | 346 | " max_samples = 100000 # crashes or fit time too long without subsampling\n", |
342 | 347 | " est = Pipeline([(\"ct\", ct), (\"est\", KNeighborsRegressor(**knn_params))])\n", |
343 | 348 | " elif trial.method.name == \"aplr\":\n", |
| 349 | + " ct.sparse_threshold = 0 # APLR only handles dense\n", |
344 | 350 | " if trial.task.name in {\"Airlines_DepDelay_10M\"}:\n", |
345 | 351 | " max_samples = 100000 # crashes or fit time too long without subsampling\n", |
346 | | - " ct.sparse_threshold = 0 # APLR only handles dense\n", |
347 | 352 | " est = Pipeline([(\"ct\", ct), (\"est\", APLRRegressor(**aplr_params))])\n", |
348 | 353 | " fit_params[\"y\"] = fit_params[\"y\"].astype(str).to_numpy()\n", |
349 | 354 | " y_test = y_test.astype(str).to_numpy()\n", |
|
441 | 446 | "if is_local:\n", |
442 | 447 | " conn_str = f\"sqlite:///{os.getcwd()}/powerlift.db\"\n", |
443 | 448 | "else:\n", |
| 449 | + " import requests\n", |
| 450 | + " import json\n", |
| 451 | + " import subprocess\n", |
444 | 452 | " from azure.identity import AzureCliCredential\n", |
445 | 453 | " credential = AzureCliCredential()\n", |
| 454 | + " access_token = credential.get_token(\"https://graph.microsoft.com/.default\").token\n", |
| 455 | + " headers = {'Authorization': f'Bearer {access_token}', 'Content-Type': 'application/json'}\n", |
| 456 | + " azure_client_id = requests.get('https://graph.microsoft.com/v1.0/me', headers=headers).json().get('id')\n", |
| 457 | + " azure_tenant_id = requests.get('https://graph.microsoft.com/v1.0/organization', headers=headers).json()['value'][0].get('id')\n", |
| 458 | + " subscription_id = json.loads(subprocess.run(\"az account show\", capture_output=True, text=True, shell=True).stdout).get(\"id\")\n", |
446 | 459 | " \n", |
447 | 460 | " from dotenv import load_dotenv\n", |
448 | 461 | " load_dotenv()\n", |
449 | 462 | " TIMEOUT_SEC = 60 * 60 * 24 * 180 # 180 days\n", |
450 | 463 | " wheel_filepaths = [\"interpret_core-0.6.3-py3-none-any.whl\", \"powerlift-0.1.11-py3-none-any.whl\"]\n", |
451 | 464 | " n_containers=198\n", |
452 | 465 | " conn_str = os.getenv(\"DOCKER_DB_URL\")\n", |
453 | | - " azure_tenant_id = os.getenv(\"AZURE_TENANT_ID\")\n", |
454 | | - " azure_client_id = os.getenv(\"AZURE_CLIENT_ID\")\n", |
455 | | - " azure_client_secret = os.getenv(\"AZURE_CLIENT_SECRET\")\n", |
456 | | - " subscription_id = os.getenv(\"AZURE_SUBSCRIPTION_ID\")\n", |
| 466 | + " azure_client_secret = None # use default credentials instead\n", |
457 | 467 | " resource_group = os.getenv(\"AZURE_RESOURCE_GROUP\")\n", |
458 | 468 | "\n", |
459 | 469 | "from powerlift.bench import retrieve_openml_automl_regression, retrieve_openml_automl_classification, retrieve_openml_cc18, retrieve_catboost_50k, retrieve_pmlb\n", |
|
0 commit comments