|
5 | 5 | "colab": { |
6 | 6 | "name": "DeepNull_e2e.ipynb", |
7 | 7 | "provenance": [], |
8 | | - "collapsed_sections": [], |
9 | | - "toc_visible": true |
| 8 | + "collapsed_sections": [] |
10 | 9 | }, |
11 | 10 | "kernelspec": { |
12 | 11 | "display_name": "Python 3", |
|
125 | 124 | "import numpy as np\n", |
126 | 125 | "import pandas as pd\n", |
127 | 126 | "import seaborn as sns\n", |
| 127 | + "from sklearn import metrics as skmetrics\n", |
128 | 128 | "import tensorflow as tf\n", |
129 | 129 | "from typing import Dict, List\n", |
130 | 130 | "\n", |
| 131 | + "from deepnull import config\n", |
131 | 132 | "from deepnull import data\n", |
132 | 133 | "from deepnull import metrics as metrics_lib\n", |
133 | 134 | "from deepnull import model as model_lib\n", |
|
214 | 215 | }, |
215 | 216 | "source": [ |
216 | 217 | "# These are the parameters used in Hormozdiari et al 2021. The definition of\n", |
217 | | - "# each parameter is given in the ModelParameters class in deepnull/model.py.\n", |
218 | | - "# Edit directly to change.\n", |
219 | | - "model_params = model_lib.ModelParameters(\n", |
220 | | - " mlp_units=[64, 64, 32, 16],\n", |
221 | | - " mlp_activation='relu',\n", |
222 | | - " learning_rate_batch_1024=1e-4,\n", |
223 | | - " beta_1=0.9,\n", |
224 | | - " beta_2=0.99,\n", |
225 | | - " num_epochs=1000,\n", |
226 | | - " batch_size=1024\n", |
227 | | - ")" |
| 218 | + "# each parameter is given in the config class in deepnull/config.py. Note that\n", |
| 219 | + "# XGBoost models are also available by specifying config.XGBOOST.\n", |
| 220 | + "full_config = config.get_config(config.DEEPNULL)\n", |
| 221 | + "\n", |
| 222 | + "# These parameters can be edited directly like in the following statement. Here\n", |
| 223 | + "# we train for many fewer epochs than a typical run so that the colab finishes\n", |
| 224 | + "# quickly. Note that this will likely cause the following cell to complain that\n", |
| 225 | + "# there is poor performance across data folds, since the model folds do not\n", |
| 226 | + "# converge.\n", |
| 227 | + "full_config.training_config.num_epochs = 2" |
228 | 228 | ], |
229 | 229 | "execution_count": null, |
230 | 230 | "outputs": [] |
|
263 | 263 | " target=target_phenotype,\n", |
264 | 264 | " target_is_binary=target_is_binary,\n", |
265 | 265 | " covariates=covariates,\n", |
| 266 | + " full_config=full_config,\n", |
266 | 267 | " prediction_column=output_column_name,\n", |
267 | 268 | " num_folds=num_folds,\n", |
268 | | - " model_params=model_params,\n", |
269 | 269 | " seed=random_seed,\n", |
270 | 270 | " # Where temporary outputs will be written.\n", |
271 | 271 | " logdir='/content/deepnull',\n", |
272 | 272 | " verbosity=1)\n", |
273 | | - "output_df, histories, validation_performance, test_perf_df = outputs\n", |
| 273 | + "output_df, validation_performance, test_perf_df = outputs\n", |
274 | 274 | "\n", |
275 | 275 | "if not metrics_lib.acceptable_model_performance(validation_performance):\n", |
276 | 276 | " print('\\n\\n##### Warning!! #####')\n", |
|
315 | 315 | "id": "ke0YDoKy7QE3" |
316 | 316 | }, |
317 | 317 | "source": [ |
318 | | - "def plot_model_performance(validation_summary_stats: List[Dict[str, float]],\n", |
319 | | - " test_performance_df: pd.DataFrame,\n", |
320 | | - " x: str,\n", |
321 | | - " y: str):\n", |
| 318 | + "def plot_binary_model_performance(\n", |
| 319 | + " validation_summary_stats: List[Dict[str, float]],\n", |
| 320 | + " test_performance_df: pd.DataFrame,\n", |
| 321 | + " label_col: str,\n", |
| 322 | + " prediction_col: str):\n", |
| 323 | + " \"\"\"Plots performance for binary traits.\"\"\"\n", |
| 324 | + " num_folds = len(validation_summary_stats)\n", |
| 325 | + " fig, axs = plt.subplots(1, num_folds, figsize=(num_folds * 4, 5),\n", |
| 326 | + " sharex=True, sharey=True)\n", |
| 327 | + " fold_column = f'{label_col}_deepnull_eval_fold'\n", |
| 328 | + " for fold, val_performance in enumerate(validation_summary_stats):\n", |
| 329 | + " fold_mask = test_performance_df[fold_column] == fold\n", |
| 330 | + " test_fold_df = test_performance_df[fold_mask]\n", |
| 331 | + " ax = axs[fold]\n", |
| 332 | + " sns.regplot(data=test_fold_df, x=prediction_col, y=label_col, ax=ax,\n", |
| 333 | + " logistic=True, scatter_kws={'alpha': 0.5})\n", |
| 334 | + " # DeepNull and XGBoost name their equivalent metrics slightly differently.\n", |
| 335 | + " val_auroc = val_performance.get('auroc') or val_performance.get('auc')\n", |
| 336 | + " val_auprc = val_performance.get('auprc') or val_performance.get('aucpr')\n", |
| 337 | + " test_auroc = skmetrics.roc_auc_score(test_fold_df[label_col],\n", |
| 338 | + " test_fold_df[prediction_col])\n", |
| 339 | + " test_auprc = skmetrics.average_precision_score(test_fold_df[label_col],\n", |
| 340 | + " test_fold_df[prediction_col])\n", |
| 341 | + " ax.set_title(f'Fold {fold}\\n'\n", |
| 342 | + " f'Validation AUROC: {val_auroc:.2f}\\n'\n", |
| 343 | + " f'Validation AUPRC: {val_auprc:.2f}\\n'\n", |
| 344 | + " f'Test AUROC: {test_auroc:.2f}\\n'\n", |
| 345 | + " f'Test AUPRC: {test_auprc:.2f}')\n", |
| 346 | + " plt.tight_layout()\n", |
| 347 | + "\n", |
| 348 | + "\n", |
| 349 | + "def plot_quantitative_model_performance(\n", |
| 350 | + " validation_summary_stats: List[Dict[str, float]],\n", |
| 351 | + " test_performance_df: pd.DataFrame,\n", |
| 352 | + " label_col: str,\n", |
| 353 | + " prediction_col: str):\n", |
| 354 | + " \"\"\"Plots performance for quantitative traits.\"\"\"\n", |
322 | 355 | " num_folds = len(validation_summary_stats)\n", |
323 | 356 | " fig, axs = plt.subplots(1, num_folds, figsize=(num_folds * 4, 5),\n", |
324 | 357 | " sharex=True, sharey=True)\n", |
325 | | - " fold_column = f'{x}_deepnull_eval_fold'\n", |
| 358 | + " fold_column = f'{label_col}_deepnull_eval_fold'\n", |
326 | 359 | " for fold, val_performance in enumerate(validation_summary_stats):\n", |
327 | 360 | " fold_mask = test_performance_df[fold_column] == fold\n", |
328 | 361 | " test_fold_df = test_performance_df[fold_mask]\n", |
329 | 362 | " ax = axs[fold]\n", |
330 | | - " sns.regplot(data=test_fold_df, x=x, y=y, ax=ax, scatter_kws={'alpha': 0.5})\n", |
331 | | - " val_mse = val_performance['mse']\n", |
332 | | - " val_corr = val_performance['tf_pearson']\n", |
333 | | - " test_mse = np.square(test_fold_df[x] - test_fold_df[y]).mean()\n", |
334 | | - " test_corr = np.corrcoef(test_fold_df[x], test_fold_df[y])[0, 1]\n", |
| 363 | + " sns.regplot(data=test_fold_df, x=prediction_col, y=label_col, ax=ax,\n", |
| 364 | + " scatter_kws={'alpha': 0.5})\n", |
| 365 | + " # DeepNull and XGBoost name their equivalent metrics slightly differently.\n", |
| 366 | + " val_mse = val_performance.get('mse') or val_performance.get('rmse')**2\n", |
| 367 | + " val_corr = val_performance.get('tf_pearson') or val_performance.get('pearson')\n", |
| 368 | + " test_mse = np.square(test_fold_df[label_col] - test_fold_df[prediction_col]).mean()\n", |
| 369 | + " test_corr = np.corrcoef(test_fold_df[label_col], test_fold_df[prediction_col])[0, 1]\n", |
335 | 370 | " ax.set_title(f'Fold {fold}\\n'\n", |
336 | 371 | " f'Validation MSE: {val_mse:.2f}\\n'\n", |
337 | 372 | " f'Validation Pearson R: {val_corr:.2f}\\n'\n", |
|
348 | 383 | "id": "3w2h1z5E-snM" |
349 | 384 | }, |
350 | 385 | "source": [ |
351 | | - "if not target_is_binary:\n", |
352 | | - " plot_model_performance(validation_summary_stats=validation_performance,\n", |
353 | | - " test_performance_df=test_perf_df,\n", |
354 | | - " x=target_phenotype,\n", |
355 | | - " y=output_column_name)" |
| 386 | + "if target_is_binary:\n", |
| 387 | + " plot_binary_model_performance(\n", |
| 388 | + " validation_summary_stats=validation_performance,\n", |
| 389 | + " test_performance_df=test_perf_df,\n", |
| 390 | + " label_col=target_phenotype,\n", |
| 391 | + " prediction_col=output_column_name)\n", |
| 392 | + "else:\n", |
| 393 | + " plot_quantitative_model_performance(\n", |
| 394 | + " validation_summary_stats=validation_performance,\n", |
| 395 | + " test_performance_df=test_perf_df,\n", |
| 396 | + " label_col=target_phenotype,\n", |
| 397 | + " prediction_col=output_column_name)" |
356 | 398 | ], |
357 | 399 | "execution_count": null, |
358 | 400 | "outputs": [] |
|
0 commit comments