Skip to content

Commit 16acdba

Browse files
committed
Update to DeepNull v0.2.0 with XGBoost support
1 parent 682b454 commit 16acdba

15 files changed

Lines changed: 988 additions & 491 deletions

nonlinear-covariate-gwas/DeepNull_e2e.ipynb

Lines changed: 72 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
"colab": {
66
"name": "DeepNull_e2e.ipynb",
77
"provenance": [],
8-
"collapsed_sections": [],
9-
"toc_visible": true
8+
"collapsed_sections": []
109
},
1110
"kernelspec": {
1211
"display_name": "Python 3",
@@ -125,9 +124,11 @@
125124
"import numpy as np\n",
126125
"import pandas as pd\n",
127126
"import seaborn as sns\n",
127+
"from sklearn import metrics as skmetrics\n",
128128
"import tensorflow as tf\n",
129129
"from typing import Dict, List\n",
130130
"\n",
131+
"from deepnull import config\n",
131132
"from deepnull import data\n",
132133
"from deepnull import metrics as metrics_lib\n",
133134
"from deepnull import model as model_lib\n",
@@ -214,17 +215,16 @@
214215
},
215216
"source": [
216217
"# These are the parameters used in Hormozdiari et al 2021. The definition of\n",
217-
"# each parameter is given in the ModelParameters class in deepnull/model.py.\n",
218-
"# Edit directly to change.\n",
219-
"model_params = model_lib.ModelParameters(\n",
220-
" mlp_units=[64, 64, 32, 16],\n",
221-
" mlp_activation='relu',\n",
222-
" learning_rate_batch_1024=1e-4,\n",
223-
" beta_1=0.9,\n",
224-
" beta_2=0.99,\n",
225-
" num_epochs=1000,\n",
226-
" batch_size=1024\n",
227-
")"
218+
"# each parameter is given in the config class in deepnull/config.py. Note that\n",
219+
"# XGBoost models are also available by specifying config.XGBOOST.\n",
220+
"full_config = config.get_config(config.DEEPNULL)\n",
221+
"\n",
222+
"# These parameters can be edited directly like in the following statement. Here\n",
223+
"# we train for many fewer epochs than a typical run so that the colab finishes\n",
224+
"# quickly. Note that this will likely cause the following cell to complain that\n",
225+
"# there is poor performance across data folds, since the model folds do not\n",
226+
"# converge.\n",
227+
"full_config.training_config.num_epochs = 2"
228228
],
229229
"execution_count": null,
230230
"outputs": []
@@ -263,14 +263,14 @@
263263
" target=target_phenotype,\n",
264264
" target_is_binary=target_is_binary,\n",
265265
" covariates=covariates,\n",
266+
" full_config=full_config,\n",
266267
" prediction_column=output_column_name,\n",
267268
" num_folds=num_folds,\n",
268-
" model_params=model_params,\n",
269269
" seed=random_seed,\n",
270270
" # Where temporary outputs will be written.\n",
271271
" logdir='/content/deepnull',\n",
272272
" verbosity=1)\n",
273-
"output_df, histories, validation_performance, test_perf_df = outputs\n",
273+
"output_df, validation_performance, test_perf_df = outputs\n",
274274
"\n",
275275
"if not metrics_lib.acceptable_model_performance(validation_performance):\n",
276276
" print('\\n\\n##### Warning!! #####')\n",
@@ -315,23 +315,58 @@
315315
"id": "ke0YDoKy7QE3"
316316
},
317317
"source": [
318-
"def plot_model_performance(validation_summary_stats: List[Dict[str, float]],\n",
319-
" test_performance_df: pd.DataFrame,\n",
320-
" x: str,\n",
321-
" y: str):\n",
318+
"def plot_binary_model_performance(\n",
319+
" validation_summary_stats: List[Dict[str, float]],\n",
320+
" test_performance_df: pd.DataFrame,\n",
321+
" label_col: str,\n",
322+
" prediction_col: str):\n",
323+
" \"\"\"Plots performance for binary traits.\"\"\"\n",
324+
" num_folds = len(validation_summary_stats)\n",
325+
" fig, axs = plt.subplots(1, num_folds, figsize=(num_folds * 4, 5),\n",
326+
" sharex=True, sharey=True)\n",
327+
" fold_column = f'{label_col}_deepnull_eval_fold'\n",
328+
" for fold, val_performance in enumerate(validation_summary_stats):\n",
329+
" fold_mask = test_performance_df[fold_column] == fold\n",
330+
" test_fold_df = test_performance_df[fold_mask]\n",
331+
" ax = axs[fold]\n",
332+
" sns.regplot(data=test_fold_df, x=prediction_col, y=label_col, ax=ax,\n",
333+
" logistic=True, scatter_kws={'alpha': 0.5})\n",
334+
" # DeepNull and XGBoost name their equivalent metrics slightly differently.\n",
335+
" val_auroc = val_performance.get('auroc') or val_performance.get('auc')\n",
336+
" val_auprc = val_performance.get('auprc') or val_performance.get('aucpr')\n",
337+
" test_auroc = skmetrics.roc_auc_score(test_fold_df[label_col],\n",
338+
" test_fold_df[prediction_col])\n",
339+
" test_auprc = skmetrics.average_precision_score(test_fold_df[label_col],\n",
340+
" test_fold_df[prediction_col])\n",
341+
" ax.set_title(f'Fold {fold}\\n'\n",
342+
" f'Validation AUROC: {val_auroc:.2f}\\n'\n",
343+
" f'Validation AUPRC: {val_auprc:.2f}\\n'\n",
344+
" f'Test AUROC: {test_auroc:.2f}\\n'\n",
345+
" f'Test AUPRC: {test_auprc:.2f}')\n",
346+
" plt.tight_layout()\n",
347+
"\n",
348+
"\n",
349+
"def plot_quantitative_model_performance(\n",
350+
" validation_summary_stats: List[Dict[str, float]],\n",
351+
" test_performance_df: pd.DataFrame,\n",
352+
" label_col: str,\n",
353+
" prediction_col: str):\n",
354+
" \"\"\"Plots performance for quantitative traits.\"\"\"\n",
322355
" num_folds = len(validation_summary_stats)\n",
323356
" fig, axs = plt.subplots(1, num_folds, figsize=(num_folds * 4, 5),\n",
324357
" sharex=True, sharey=True)\n",
325-
" fold_column = f'{x}_deepnull_eval_fold'\n",
358+
" fold_column = f'{label_col}_deepnull_eval_fold'\n",
326359
" for fold, val_performance in enumerate(validation_summary_stats):\n",
327360
" fold_mask = test_performance_df[fold_column] == fold\n",
328361
" test_fold_df = test_performance_df[fold_mask]\n",
329362
" ax = axs[fold]\n",
330-
" sns.regplot(data=test_fold_df, x=x, y=y, ax=ax, scatter_kws={'alpha': 0.5})\n",
331-
" val_mse = val_performance['mse']\n",
332-
" val_corr = val_performance['tf_pearson']\n",
333-
" test_mse = np.square(test_fold_df[x] - test_fold_df[y]).mean()\n",
334-
" test_corr = np.corrcoef(test_fold_df[x], test_fold_df[y])[0, 1]\n",
363+
" sns.regplot(data=test_fold_df, x=prediction_col, y=label_col, ax=ax,\n",
364+
" scatter_kws={'alpha': 0.5})\n",
365+
" # DeepNull and XGBoost name their equivalent metrics slightly differently.\n",
366+
" val_mse = val_performance.get('mse') or val_performance.get('rmse')**2\n",
367+
" val_corr = val_performance.get('tf_pearson') or val_performance.get('pearson')\n",
368+
" test_mse = np.square(test_fold_df[label_col] - test_fold_df[prediction_col]).mean()\n",
369+
" test_corr = np.corrcoef(test_fold_df[label_col], test_fold_df[prediction_col])[0, 1]\n",
335370
" ax.set_title(f'Fold {fold}\\n'\n",
336371
" f'Validation MSE: {val_mse:.2f}\\n'\n",
337372
" f'Validation Pearson R: {val_corr:.2f}\\n'\n",
@@ -348,11 +383,18 @@
348383
"id": "3w2h1z5E-snM"
349384
},
350385
"source": [
351-
"if not target_is_binary:\n",
352-
" plot_model_performance(validation_summary_stats=validation_performance,\n",
353-
" test_performance_df=test_perf_df,\n",
354-
" x=target_phenotype,\n",
355-
" y=output_column_name)"
386+
"if target_is_binary:\n",
387+
" plot_binary_model_performance(\n",
388+
" validation_summary_stats=validation_performance,\n",
389+
" test_performance_df=test_perf_df,\n",
390+
" label_col=target_phenotype,\n",
391+
" prediction_col=output_column_name)\n",
392+
"else:\n",
393+
" plot_quantitative_model_performance(\n",
394+
" validation_summary_stats=validation_performance,\n",
395+
" test_performance_df=test_perf_df,\n",
396+
" label_col=target_phenotype,\n",
397+
" prediction_col=output_column_name)"
356398
],
357399
"execution_count": null,
358400
"outputs": []

nonlinear-covariate-gwas/README.md

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,72 @@ To see all available flags, run
5252
python -m deepnull.main --help 2> /dev/null
5353
```
5454

55+
Of particular note is the `--model_config` flag. DeepNull uses the
56+
[ml_collections](https://github.com/google/ml_collections) library to specify
57+
all parameters related to the model and training regimen. The supported
58+
configuration code is located in [`config.py`](config.py), and parameters can
59+
be modified as described in detail in the
60+
[`ml_collections README`](https://github.com/google/ml_collections#parameterising-the-get_config-function).
61+
As a brief example, to use the DeepNull architecture with the `elu` activation
62+
and train with batch size 4096, the above example command would be modified as
63+
follows:
64+
65+
```bash
66+
python -m deepnull.main \
67+
--input_tsv=/input/ORIGINAL_PHENOCOVAR_TSV \
68+
--output_tsv=/output/PHENOCOVAR_WITH_DEEPNULL_PREDICTION_TSV \
69+
--target=pheno \
70+
--covariates="age,sex,genotyping_array" \
71+
--model_config=/path/to/config.py:deepnull \
72+
--model_config.model_config.mlp_activation=elu \
73+
--model_config.training_config.batch_size=4096
74+
```
75+
76+
where `/path/to/config.py` provides the path to [`config.py`](config.py) on your
77+
machine.
78+
79+
## Incorporating DeepNull into a GWAS analysis
80+
81+
The above section, "How to run DeepNull", shows that the DeepNull software adds
82+
a single column to a phenotype+covariate file of interest that represents a
83+
nonlinear prediction of the target phenotype of interest. To incorporate this
84+
into a GWAS analysis, the single additional covariate should be **added** as an
85+
additional covariate. A concrete example with `BOLT-LMM`, using the same file,
86+
phenotype `pheno`, and covariates `age`, `sex`, `genotyping_array` as above, is
87+
shown below:
88+
89+
### Original example GWAS command
90+
```bash
91+
# N.B. Data loading flags are omitted for brevity.
92+
93+
bolt \
94+
--phenoFile /input/ORIGINAL_PHENOCOVAR_TSV \
95+
--covarFile /input/ORIGINAL_PHENOCOVAR_TSV \
96+
--qCovarCol age \
97+
--qCovarCol sex \
98+
--qCovarCol genotyping_array \
99+
--phenoCol pheno
100+
```
101+
102+
After running DeepNull on the `/input/ORIGINAL_PHENOCOVAR_TSV` to create the new
103+
TSV `/output/PHENOCOVAR_WITH_DEEPNULL_PREDICTION_TSV` that includes the column
104+
`pheno_deepnull`, the updated command is given below:
105+
106+
### Updated GWAS command to incorporate DeepNull
107+
```bash
108+
# N.B. Data loading flags are omitted for brevity.
109+
# Note the addition of the single `--qCovarCol pheno_deepnull` line.
110+
111+
bolt \
112+
--phenoFile /output/PHENOCOVAR_WITH_DEEPNULL_PREDICTION_TSV \
113+
--covarFile /output/PHENOCOVAR_WITH_DEEPNULL_PREDICTION_TSV \
114+
--qCovarCol age \
115+
--qCovarCol sex \
116+
--qCovarCol genotyping_array \
117+
--qCovarCol pheno_deepnull \
118+
--phenoCol pheno
119+
```
120+
55121
## Data
56122

57123
Datasets used to reproduce the results from the above publication are available

nonlinear-covariate-gwas/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,4 @@
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727
"""DeepNull."""
2828

29-
__version__ = '0.1.3'
29+
__version__ = '0.2.0'

nonlinear-covariate-gwas/config.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright 2021 Google LLC.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions are met:
5+
#
6+
# 1. Redistributions of source code must retain the above copyright notice, this
7+
# list of conditions and the following disclaimer.
8+
#
9+
# 2. Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
#
13+
# 3. Neither the name of the copyright holder nor the names of its contributors
14+
# may be used to endorse or promote products derived from this software
15+
# without specific prior written permission.
16+
#
17+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
"""Configuration for all model types.
28+
29+
This configuration file is used to specify all different supported types of
30+
models for training DeepNull. The configuration is parsed by model.py for the
31+
proper instantiation of the selected model.
32+
33+
See https://github.com/google/ml_collections for details on ConfigDict.
34+
"""
35+
import ml_collections
36+
37+
# Valid model types.
38+
# The model used for the main figures in the paper.
39+
DEEPNULL = 'deepnull'
40+
# XGBoost-based models.
41+
XGBOOST = 'xgboost'
42+
43+
44+
def get_config(config_name: str) -> ml_collections.ConfigDict:
45+
"""Returns the config specified by `config_name`."""
46+
supported_models = {
47+
DEEPNULL:
48+
ml_collections.ConfigDict({
49+
'model_type':
50+
DEEPNULL,
51+
'model_config':
52+
ml_collections.ConfigDict({
53+
# The MLP units for the nonlinear path of DeepNull.
54+
'mlp_units': (64, 64, 32, 16),
55+
# The activation function to use. See
56+
# https://keras.io/api/layers/activations.
57+
'mlp_activation': 'relu',
58+
}),
59+
'optimizer_config':
60+
ml_collections.ConfigDict({
61+
# Learning rate for a batch size of 1024. The actual
62+
# learning rate used is scaled linearly as
63+
# `learning_rate * batch_size / 1024`.
64+
'learning_rate_batch_1024': 1e-4,
65+
# Betas for the Adam optimizer.
66+
'beta_1': 0.9,
67+
'beta_2': 0.99,
68+
# The optimization metric to use to select the best model
69+
# checkpoint. This must be a metric generated during
70+
# training (which depends on whether the target is a
71+
# binary or continuous variable). If unspecified, the
72+
# default metric for the associated target type is used.
73+
'optimization_metric': '',
74+
}),
75+
'training_config':
76+
ml_collections.ConfigDict({
77+
# Number of full passes through the training data.
78+
'num_epochs': 1000,
79+
# Number of training examples per batch.
80+
'batch_size': 1024,
81+
}),
82+
}),
83+
XGBOOST:
84+
ml_collections.ConfigDict({
85+
'model_type':
86+
XGBOOST,
87+
'model_config':
88+
ml_collections.ConfigDict({
89+
# See
90+
# https://xgboost.readthedocs.io/en/latest/parameter.html
91+
# for full details on all parameters.
92+
# The target objective. If unspecified, will be the
93+
# default objective for the type of model prediction (i.e.
94+
# regression vs classification).
95+
'objective': '',
96+
'max_depth': 3,
97+
'eta': 0.32,
98+
'alpha': 0.658,
99+
'lambda': 2.0,
100+
# If unspecified, will be the default metric for the type
101+
# of model prediction.
102+
'eval_metric': '',
103+
}),
104+
'training_config':
105+
ml_collections.ConfigDict({
106+
'num_boost_round': 25,
107+
}),
108+
}),
109+
}
110+
111+
if config_name not in supported_models:
112+
raise ValueError(f'Config "{config_name}" is not a supported model: '
113+
f'{sorted(supported_models)}')
114+
115+
return supported_models[config_name]

0 commit comments

Comments
 (0)