Skip to content

Commit 6979f37

Browse files
lucidtronixcmnbroad
authored andcommitted
Add CNNScoreVariants with 2D model, CNNVariantTrain, CNNVariantWriteTensors and FilterVariantTranches tools. (#4245)
1 parent 4a44f16 commit 6979f37

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+4140
-526
lines changed

scripts/gatkcondaenv.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ dependencies:
1919
- xz=5.2.3=0
2020
- zlib=1.2.11=0
2121
- pip:
22+
- biopython==1.70
2223
- bleach==1.5.0
2324
- cycler==0.10.0
2425
- enum34==1.1.6
2526
- h5py==2.7.1
2627
- html5lib==0.9999999
2728
- joblib==0.11
28-
- keras==2.1.1
29+
- keras==2.1.4
2930
- markdown==2.6.9
3031
- matplotlib==2.1.0
3132
- numpy==1.13.3
@@ -34,9 +35,12 @@ dependencies:
3435
- protobuf==3.5.0.post1
3536
- pymc3==3.1
3637
- pyparsing==2.2.0
38+
- pysam==0.13
3739
- python-dateutil==2.6.1
3840
- pytz==2017.3
41+
- pyvcf==0.6.8
3942
- pyyaml==3.12
43+
- scikit-learn==0.19.1
4044
- scipy==1.0.0
4145
- six==1.11.0
4246
- tensorflow==1.4.0

src/main/java/org/broadinstitute/hellbender/engine/filters/VariantFilterLibrary.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
*/
66
public final class VariantFilterLibrary {
77
public static VariantFilter ALLOW_ALL_VARIANTS = variant -> true;
8+
public static VariantFilter NOT_SV_OR_SYMBOLIC = variant -> !variant.isSymbolicOrSV();
89
}

src/main/java/org/broadinstitute/hellbender/tools/walkers/vqsr/CNNScoreVariants.java

Lines changed: 492 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package org.broadinstitute.hellbender.tools.walkers.vqsr;
2+
3+
import org.broadinstitute.barclay.argparser.*;
4+
import org.broadinstitute.barclay.help.DocumentedFeature;
5+
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
6+
import org.broadinstitute.hellbender.exceptions.GATKException;
7+
import org.broadinstitute.hellbender.utils.io.Resource;
8+
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
9+
import picard.cmdline.programgroups.VariantEvaluationProgramGroup;
10+
11+
12+
import java.util.ArrayList;
13+
import java.util.Arrays;
14+
import java.util.List;
15+
16+
/**
17+
* Train a Convolutional Neural Network (CNN) for filtering variants.
18+
* This tool expects requires training data generated by {@link CNNVariantWriteTensors}.
19+
*
20+
*
21+
* <h3>Inputs</h3>
22+
* <ul>
23+
* <li>data-dir The training data created by {@link CNNVariantWriteTensors}.</li>
24+
* <li>The tensor-name argument determines what types of tensors the model will expect.
25+
* Set it to "reference" for 1D tensors or "read_tensor" for 2D tensors.</li>
26+
* </ul>
27+
*
28+
* <h3>Outputs</h3>
29+
* <ul>
30+
* <li>output-dir The model weights file and semantic configuration json are saved here.
31+
* This default to the current working directory.</li>
32+
* <li>model-name The name for your model.</li>
33+
* </ul>
34+
*
35+
* <h3>Usage example</h3>
36+
*
37+
* <h4>Train a 1D CNN on Reference Tensors</h4>
38+
* <pre>
39+
* gatk CNNVariantTrain \
40+
* -tensor-type reference \
41+
* -input-tensors-dir my_tensor_folder \
42+
* -model-name my_1d_model
43+
* </pre>
44+
*
45+
* <h4>Train a 2D CNN on Read Tensors</h4>
46+
* <pre>
47+
* gatk CNNVariantTrain \
48+
* -input-tensors-dir my_tensor_folder \
49+
* -tensor-type read-tensor \
50+
* -model-name my_2d_model
51+
* </pre>
52+
*
53+
*/
54+
@CommandLineProgramProperties(
55+
summary = "Train a CNN model for filtering variants",
56+
oneLineSummary = "Train a CNN model for filtering variants",
57+
programGroup = VariantEvaluationProgramGroup.class
58+
)
59+
@DocumentedFeature
60+
@ExperimentalFeature
61+
public class CNNVariantTrain extends CommandLineProgram {
62+
63+
@Argument(fullName = "input-tensor-dir", shortName = "input-tensor-dir", doc = "Directory of training tensors to create.")
64+
private String inputTensorDir;
65+
66+
@Argument(fullName = "output-dir", shortName = "output-dir", doc = "Directory where models will be saved, defaults to current working directory.", optional = true)
67+
private String outputDir = "./";
68+
69+
@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate, reference for 1D reference tensors and read_tensor for 2D tensors.", optional = true)
70+
private TensorType tensorType = TensorType.reference;
71+
72+
@Argument(fullName = "model-name", shortName = "model-name", doc = "Name of the model to be trained.", optional = true)
73+
private String modelName = "variant_filter_model";
74+
75+
@Argument(fullName = "epochs", shortName = "epochs", doc = "Maximum number of training epochs.", optional = true, minValue = 0)
76+
private int epochs = 10;
77+
78+
@Argument(fullName = "training-steps", shortName = "training-steps", doc = "Number of training steps per epoch.", optional = true, minValue = 0)
79+
private int trainingSteps = 10;
80+
81+
@Argument(fullName = "validation-steps", shortName = "validation-steps", doc = "Number of validation steps per epoch.", optional = true, minValue = 0)
82+
private int validationSteps = 2;
83+
84+
@Argument(fullName = "image-dir", shortName = "image-dir", doc = "Path where plots and figures are saved.", optional = true)
85+
private String imageDir;
86+
87+
@Advanced
88+
@Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true)
89+
private boolean channelsLast = true;
90+
91+
@Advanced
92+
@Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Which set of annotations to use.", optional = true)
93+
private String annotationSet = "best_practices";
94+
95+
// Start the Python executor. This does not actually start the Python process, but fails if python can't be located
96+
final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true);
97+
98+
99+
@Override
100+
protected void onStartup() {
101+
PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn");
102+
}
103+
104+
@Override
105+
protected Object doWork() {
106+
final Resource pythonScriptResource = new Resource("training.py", FilterVariantTranches.class);
107+
List<String> arguments = new ArrayList<>(Arrays.asList(
108+
"--data_dir", inputTensorDir,
109+
"--output_dir", outputDir,
110+
"--tensor_name", tensorType.name(),
111+
"--annotation_set", annotationSet,
112+
"--epochs", Integer.toString(epochs),
113+
"--training_steps", Integer.toString(trainingSteps),
114+
"--validation_steps", Integer.toString(validationSteps),
115+
"--id", modelName));
116+
117+
if(channelsLast){
118+
arguments.add("--channels_last");
119+
} else {
120+
arguments.add("--channels_first");
121+
}
122+
123+
if(imageDir != null){
124+
arguments.addAll(Arrays.asList("--image_dir", imageDir));
125+
}
126+
127+
if (tensorType == TensorType.reference) {
128+
arguments.addAll(Arrays.asList("--mode", "train_on_reference_tensors_and_annotations"));
129+
} else if (tensorType == TensorType.read_tensor) {
130+
arguments.addAll(Arrays.asList("--mode", "train_small_model_on_read_tensors_and_annotations"));
131+
} else {
132+
throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
133+
}
134+
135+
logger.info("Args are:"+ Arrays.toString(arguments.toArray()));
136+
final boolean pythonReturnCode = pythonExecutor.executeScript(
137+
pythonScriptResource,
138+
null,
139+
arguments
140+
);
141+
return pythonReturnCode;
142+
}
143+
144+
}
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package org.broadinstitute.hellbender.tools.walkers.vqsr;
2+
3+
import org.broadinstitute.barclay.argparser.*;
4+
import org.broadinstitute.barclay.help.DocumentedFeature;
5+
import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
6+
import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
7+
import org.broadinstitute.hellbender.exceptions.GATKException;
8+
import org.broadinstitute.hellbender.utils.io.Resource;
9+
import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor;
10+
import picard.cmdline.programgroups.VariantEvaluationProgramGroup;
11+
12+
import java.util.ArrayList;
13+
import java.util.Arrays;
14+
import java.util.List;
15+
16+
/**
17+
* Write variant tensors for training a Convolutional Neural Network (CNN) for filtering variants.
18+
* After running this tool, a model can be trained with the {@link CNNVariantTrain} tool.
19+
*
20+
*
21+
* <h3>Inputs</h3>
22+
* <ul>
23+
* <li>The input variants to make into tensors.
24+
* These variant calls must be annotated with the standard best practices annotations.</li>
25+
* <li>The truth VCF has validated variant calls, like those in the genomes in a bottle,
26+
* platinum genomes, or CHM VCFs. Variants in both the input VCF and the truth VCF
27+
* will be used as positive training data.</li>
28+
* <li>The truth BED is a bed file define the confident region for the validated calls.
29+
* Variants from the input VCF inside this region, but not included in the truth VCF
30+
* will be used as negative training data.</li>
31+
* <li>The tensor-name argument determines what types of tensors will be written.
32+
* Set it to "reference" to write 1D tensors or "read_tensor" to write 2D tensors.</li>
33+
* <li>The bam-file argument is necessary to write 2D tensors which incorporate read data.</li>
34+
* </ul>
35+
*
36+
* <h3>Outputs</h3>
37+
* <ul>
38+
* <li>data-dir This directory is created and populated with variant tensors.
39+
* it will be divided into training, validation and test sets and each set will be further divided into
40+
* positive and negative SNPs and INDELs.</li>
41+
* </ul>
42+
*
43+
* <h3>Usage example</h3>
44+
*
45+
* <h4>Write Reference Tensors</h4>
46+
* <pre>
47+
* gatk CNNVariantWriteTensors \
48+
* -R reference.fasta \
49+
* -V input.vcf.gz \
50+
* -truth-vcf platinum-genomes.vcf \
51+
* -truth-bed platinum-confident-region.bed \
52+
* -tensor-name reference \
53+
* -output-tensor-dir my-tensor-folder
54+
* </pre>
55+
*
56+
* <h4>Write Read Tensors</h4>
57+
* <pre>
58+
* gatk CNNVariantWriteTensors \
59+
* -R reference.fasta \
60+
* -V input.vcf.gz \
61+
* -truth-vcf platinum-genomes.vcf \
62+
* -truth-bed platinum-confident-region.bed \
63+
* -tensor-name read_tensor \
64+
* -bam-file input.bam \
65+
* -output-tensor-dir my-tensor-folder
66+
* </pre>
67+
*
68+
*/
69+
@CommandLineProgramProperties(
70+
summary = "Write variant tensors for training a CNN to filter variants",
71+
oneLineSummary = "Write variant tensors for training a CNN to filter variants",
72+
programGroup = VariantEvaluationProgramGroup.class
73+
)
74+
@DocumentedFeature
75+
@ExperimentalFeature
76+
public class CNNVariantWriteTensors extends CommandLineProgram {
77+
78+
@Argument(fullName = StandardArgumentDefinitions.REFERENCE_LONG_NAME,
79+
shortName = StandardArgumentDefinitions.REFERENCE_SHORT_NAME,
80+
doc = "Reference fasta file.")
81+
private String reference;
82+
83+
@Argument(fullName = StandardArgumentDefinitions.VARIANT_LONG_NAME,
84+
shortName = StandardArgumentDefinitions.VARIANT_SHORT_NAME,
85+
doc = "Input VCF file")
86+
private String inputVcf;
87+
88+
@Argument(fullName = "output-tensor-dir", shortName = "output-tensor-dir", doc = "Directory of training tensors. Subdivided into train, valid and test sets.")
89+
private String outputTensorsDir;
90+
91+
@Argument(fullName = "truth-vcf", shortName = "truth-vcf", doc = "Validated VCF file.")
92+
private String truthVcf;
93+
94+
@Argument(fullName = "truth-bed", shortName = "truth-bed", doc = "Confident region of the validated VCF file.")
95+
private String truthBed;
96+
97+
@Argument(fullName = "bam-file", shortName = "bam-file", doc = "BAM or BAMout file to use for read data when generating 2D tensors.", optional = true)
98+
private String bamFile = "";
99+
100+
@Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate.")
101+
private TensorType tensorType = TensorType.reference;
102+
103+
@Advanced
104+
@Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true)
105+
private boolean channelsLast = true;
106+
107+
@Advanced
108+
@Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Which set of annotations to use.", optional = true)
109+
private String annotationSet = "best_practices";
110+
111+
@Argument(fullName = "max-tensors", shortName = "max-tensors", doc = "Maximum number of tensors to write.", optional = true, minValue = 0)
112+
private int maxTensors = 1000000;
113+
114+
// Start the Python executor. This does not actually start the Python process, but fails if python can't be located
115+
final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true);
116+
117+
@Override
118+
protected void onStartup() {
119+
PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn");
120+
}
121+
122+
@Override
123+
protected Object doWork() {
124+
final Resource pythonScriptResource = new Resource("training.py", FilterVariantTranches.class);
125+
List<String> arguments = new ArrayList<>(Arrays.asList(
126+
"--reference_fasta", reference,
127+
"--input_vcf", inputVcf,
128+
"--bam_file", bamFile,
129+
"--train_vcf", truthVcf,
130+
"--bed_file", truthBed,
131+
"--tensor_name", tensorType.name(),
132+
"--annotation_set", annotationSet,
133+
"--samples", Integer.toString(maxTensors),
134+
"--data_dir", outputTensorsDir));
135+
136+
if(channelsLast){
137+
arguments.add("--channels_last");
138+
} else{
139+
arguments.add("--channels_first");
140+
}
141+
142+
if (tensorType == TensorType.reference) {
143+
arguments.addAll(Arrays.asList("--mode", "write_reference_and_annotation_tensors"));
144+
} else if (tensorType == TensorType.read_tensor) {
145+
arguments.addAll(Arrays.asList("--mode", "write_read_and_annotation_tensors"));
146+
} else {
147+
throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name());
148+
}
149+
150+
logger.info("Args are:"+ Arrays.toString(arguments.toArray()));
151+
final boolean pythonReturnCode = pythonExecutor.executeScript(
152+
pythonScriptResource,
153+
null,
154+
arguments
155+
);
156+
return pythonReturnCode;
157+
}
158+
159+
}

0 commit comments

Comments
 (0)