|
| 1 | +package org.broadinstitute.hellbender.tools.walkers.vqsr; |
| 2 | + |
| 3 | +import org.broadinstitute.barclay.argparser.*; |
| 4 | +import org.broadinstitute.barclay.help.DocumentedFeature; |
| 5 | +import org.broadinstitute.hellbender.cmdline.CommandLineProgram; |
| 6 | +import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; |
| 7 | +import org.broadinstitute.hellbender.exceptions.GATKException; |
| 8 | +import org.broadinstitute.hellbender.utils.io.Resource; |
| 9 | +import org.broadinstitute.hellbender.utils.python.PythonScriptExecutor; |
| 10 | +import picard.cmdline.programgroups.VariantEvaluationProgramGroup; |
| 11 | + |
| 12 | +import java.util.ArrayList; |
| 13 | +import java.util.Arrays; |
| 14 | +import java.util.List; |
| 15 | + |
| 16 | +/** |
| 17 | + * Write variant tensors for training a Convolutional Neural Network (CNN) for filtering variants. |
| 18 | + * After running this tool, a model can be trained with the {@link CNNVariantTrain} tool. |
| 19 | + * |
| 20 | + * |
| 21 | + * <h3>Inputs</h3> |
| 22 | + * <ul> |
| 23 | + * <li>The input variants to make into tensors. |
| 24 | + * These variant calls must be annotated with the standard best practices annotations.</li> |
| 25 | + * <li>The truth VCF has validated variant calls, like those in the genomes in a bottle, |
| 26 | + * platinum genomes, or CHM VCFs. Variants in both the input VCF and the truth VCF |
| 27 | + * will be used as positive training data.</li> |
| 28 | + * <li>The truth BED is a bed file define the confident region for the validated calls. |
| 29 | + * Variants from the input VCF inside this region, but not included in the truth VCF |
| 30 | + * will be used as negative training data.</li> |
| 31 | + * <li>The tensor-name argument determines what types of tensors will be written. |
| 32 | + * Set it to "reference" to write 1D tensors or "read_tensor" to write 2D tensors.</li> |
| 33 | + * <li>The bam-file argument is necessary to write 2D tensors which incorporate read data.</li> |
| 34 | + * </ul> |
| 35 | + * |
| 36 | + * <h3>Outputs</h3> |
| 37 | + * <ul> |
| 38 | + * <li>data-dir This directory is created and populated with variant tensors. |
| 39 | + * it will be divided into training, validation and test sets and each set will be further divided into |
| 40 | + * positive and negative SNPs and INDELs.</li> |
| 41 | + * </ul> |
| 42 | + * |
| 43 | + * <h3>Usage example</h3> |
| 44 | + * |
| 45 | + * <h4>Write Reference Tensors</h4> |
| 46 | + * <pre> |
| 47 | + * gatk CNNVariantWriteTensors \ |
| 48 | + * -R reference.fasta \ |
| 49 | + * -V input.vcf.gz \ |
| 50 | + * -truth-vcf platinum-genomes.vcf \ |
| 51 | + * -truth-bed platinum-confident-region.bed \ |
| 52 | + * -tensor-name reference \ |
| 53 | + * -output-tensor-dir my-tensor-folder |
| 54 | + * </pre> |
| 55 | + * |
| 56 | + * <h4>Write Read Tensors</h4> |
| 57 | + * <pre> |
| 58 | + * gatk CNNVariantWriteTensors \ |
| 59 | + * -R reference.fasta \ |
| 60 | + * -V input.vcf.gz \ |
| 61 | + * -truth-vcf platinum-genomes.vcf \ |
| 62 | + * -truth-bed platinum-confident-region.bed \ |
| 63 | + * -tensor-name read_tensor \ |
| 64 | + * -bam-file input.bam \ |
| 65 | + * -output-tensor-dir my-tensor-folder |
| 66 | + * </pre> |
| 67 | + * |
| 68 | + */ |
| 69 | +@CommandLineProgramProperties( |
| 70 | + summary = "Write variant tensors for training a CNN to filter variants", |
| 71 | + oneLineSummary = "Write variant tensors for training a CNN to filter variants", |
| 72 | + programGroup = VariantEvaluationProgramGroup.class |
| 73 | +) |
| 74 | +@DocumentedFeature |
| 75 | +@ExperimentalFeature |
| 76 | +public class CNNVariantWriteTensors extends CommandLineProgram { |
| 77 | + |
| 78 | + @Argument(fullName = StandardArgumentDefinitions.REFERENCE_LONG_NAME, |
| 79 | + shortName = StandardArgumentDefinitions.REFERENCE_SHORT_NAME, |
| 80 | + doc = "Reference fasta file.") |
| 81 | + private String reference; |
| 82 | + |
| 83 | + @Argument(fullName = StandardArgumentDefinitions.VARIANT_LONG_NAME, |
| 84 | + shortName = StandardArgumentDefinitions.VARIANT_SHORT_NAME, |
| 85 | + doc = "Input VCF file") |
| 86 | + private String inputVcf; |
| 87 | + |
| 88 | + @Argument(fullName = "output-tensor-dir", shortName = "output-tensor-dir", doc = "Directory of training tensors. Subdivided into train, valid and test sets.") |
| 89 | + private String outputTensorsDir; |
| 90 | + |
| 91 | + @Argument(fullName = "truth-vcf", shortName = "truth-vcf", doc = "Validated VCF file.") |
| 92 | + private String truthVcf; |
| 93 | + |
| 94 | + @Argument(fullName = "truth-bed", shortName = "truth-bed", doc = "Confident region of the validated VCF file.") |
| 95 | + private String truthBed; |
| 96 | + |
| 97 | + @Argument(fullName = "bam-file", shortName = "bam-file", doc = "BAM or BAMout file to use for read data when generating 2D tensors.", optional = true) |
| 98 | + private String bamFile = ""; |
| 99 | + |
| 100 | + @Argument(fullName = "tensor-type", shortName = "tensor-type", doc = "Name of the tensors to generate.") |
| 101 | + private TensorType tensorType = TensorType.reference; |
| 102 | + |
| 103 | + @Advanced |
| 104 | + @Argument(fullName = "channels-last", shortName = "channels-last", doc = "Store the channels in the last axis of tensors, tensorflow->true, theano->false", optional = true) |
| 105 | + private boolean channelsLast = true; |
| 106 | + |
| 107 | + @Advanced |
| 108 | + @Argument(fullName = "annotation-set", shortName = "annotation-set", doc = "Which set of annotations to use.", optional = true) |
| 109 | + private String annotationSet = "best_practices"; |
| 110 | + |
| 111 | + @Argument(fullName = "max-tensors", shortName = "max-tensors", doc = "Maximum number of tensors to write.", optional = true, minValue = 0) |
| 112 | + private int maxTensors = 1000000; |
| 113 | + |
| 114 | + // Start the Python executor. This does not actually start the Python process, but fails if python can't be located |
| 115 | + final PythonScriptExecutor pythonExecutor = new PythonScriptExecutor(true); |
| 116 | + |
| 117 | + @Override |
| 118 | + protected void onStartup() { |
| 119 | + PythonScriptExecutor.checkPythonEnvironmentForPackage("vqsr_cnn"); |
| 120 | + } |
| 121 | + |
| 122 | + @Override |
| 123 | + protected Object doWork() { |
| 124 | + final Resource pythonScriptResource = new Resource("training.py", FilterVariantTranches.class); |
| 125 | + List<String> arguments = new ArrayList<>(Arrays.asList( |
| 126 | + "--reference_fasta", reference, |
| 127 | + "--input_vcf", inputVcf, |
| 128 | + "--bam_file", bamFile, |
| 129 | + "--train_vcf", truthVcf, |
| 130 | + "--bed_file", truthBed, |
| 131 | + "--tensor_name", tensorType.name(), |
| 132 | + "--annotation_set", annotationSet, |
| 133 | + "--samples", Integer.toString(maxTensors), |
| 134 | + "--data_dir", outputTensorsDir)); |
| 135 | + |
| 136 | + if(channelsLast){ |
| 137 | + arguments.add("--channels_last"); |
| 138 | + } else{ |
| 139 | + arguments.add("--channels_first"); |
| 140 | + } |
| 141 | + |
| 142 | + if (tensorType == TensorType.reference) { |
| 143 | + arguments.addAll(Arrays.asList("--mode", "write_reference_and_annotation_tensors")); |
| 144 | + } else if (tensorType == TensorType.read_tensor) { |
| 145 | + arguments.addAll(Arrays.asList("--mode", "write_read_and_annotation_tensors")); |
| 146 | + } else { |
| 147 | + throw new GATKException("Unknown tensor mapping mode:"+ tensorType.name()); |
| 148 | + } |
| 149 | + |
| 150 | + logger.info("Args are:"+ Arrays.toString(arguments.toArray())); |
| 151 | + final boolean pythonReturnCode = pythonExecutor.executeScript( |
| 152 | + pythonScriptResource, |
| 153 | + null, |
| 154 | + arguments |
| 155 | + ); |
| 156 | + return pythonReturnCode; |
| 157 | + } |
| 158 | + |
| 159 | +} |
0 commit comments