Skip to content

Commit 06065c6

Browse files
committed
Add DownsampleVcf tool
1 parent 3a74fd2 commit 06065c6

File tree

2 files changed

+768
-0
lines changed

2 files changed

+768
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
package com.fulcrumgenomics.vcf
2+
3+
import com.fulcrumgenomics.FgBioDef._
4+
import com.fulcrumgenomics.commons.io.Io
5+
import com.fulcrumgenomics.commons.util.LazyLogging
6+
import com.fulcrumgenomics.fasta.SequenceDictionary
7+
import com.fulcrumgenomics.sopt.{arg, clp}
8+
import com.fulcrumgenomics.util.{Metric, ProgressLogger}
9+
import com.fulcrumgenomics.vcf.api.Allele.NoCallAllele
10+
import com.fulcrumgenomics.vcf.api.{Allele, Genotype, Variant, VcfCount, VcfFieldType, VcfFormatHeader, VcfHeader, VcfSource, VcfWriter}
11+
import com.fulcrumgenomics.cmdline.{ClpGroups, FgBioTool}
12+
import com.fulcrumgenomics.vcf.DownsampleVcf.{downsampleAndRegenotype, winnowVariants}
13+
14+
import scala.math.log10
15+
import scala.util.Random
16+
17+
object DownsampleVcf extends LazyLogging {
18+
/** Removes variants that are within a specified distance from a previous variant
19+
* The end position of the current variant is compared with the start position of the following variant
20+
*
21+
* @param variants an iterator of the variants to process
22+
* @param windowSize the interval (exclusive) in which to check for additional variants.
23+
* windowSize considers the distance between the end position of a variant
24+
* with the start position of the following variant
25+
* @param dict a sequencing dictionary to get contig ordering
26+
* @return a new iterator of variants with just the variant entries we want to keep
27+
*/
28+
def winnowVariants(variants: Iterator[Variant], windowSize: Int, dict: SequenceDictionary): Iterator[Variant] = {
29+
require(windowSize >= 0, f"the windowSize ($windowSize) is negative")
30+
new Iterator[Variant] {
31+
private val iter = variants.bufferBetter
32+
33+
def hasNext: Boolean = iter.hasNext
34+
35+
def isInOrder(current: Variant, next: Variant, currentIndex: Int, nextIndex: Int): Boolean = {
36+
(currentIndex < nextIndex) || (currentIndex == nextIndex && current.end <= next.pos)
37+
}
38+
39+
def next(): Variant = {
40+
val current = iter.next()
41+
val currentIndex = dict(current.chrom).index
42+
iter.dropWhile { next: Variant =>
43+
val nextIndex = dict(next.chrom).index
44+
require(
45+
isInOrder(current, next, currentIndex, nextIndex),
46+
f"variants out of order; ${current.chrom}:${current.pos} > ${next.chrom}:${next.pos}")
47+
48+
currentIndex == nextIndex && next.pos - current.end < windowSize
49+
}
50+
current
51+
}
52+
}
53+
}
54+
55+
/** Downsamples variants using Allele Depths
56+
*
57+
* @param oldAds an indexed seq of the original allele depths
58+
* @param proportion the proportion to use for downsampling,
59+
* calculated using total base count from the index and a target base count
60+
* @return a new IndexedSeq of allele depths of the same length as `oldAds`
61+
*/
62+
def downsampleADs(oldAds: IndexedSeq[Int], proportion: Double, random: Random): IndexedSeq[Int] = {
63+
require(proportion <= 1, f"proportion must be less than 1: proportion = ${proportion}")
64+
oldAds.map(s => Range(0, s).iterator.map(_ => random.nextDouble()).count(_ < proportion))
65+
}
66+
67+
/**
68+
* Does the downsampling on a Variant
69+
* @param variant the variant with the genotype to downsample
70+
* @param proportions a map of downsampling target proportions for each sample
71+
* @param random random number generator for downsampling
72+
* @param epsilon the error rate for genotyping
73+
* @return a new variant with updated genotypes
74+
*/
75+
// Returns a new variant that has downsampled ADs, recomputed PLs and updated genotypes
76+
def downsampleAndRegenotype(variant: Variant, proportions: Map[String, Double], random: Random, epsilon: Double = 0.01): Variant = {
77+
try {
78+
variant.copy(genotypes = variant.genotypes.map { case (sample, gt) =>
79+
val proportion = proportions(sample)
80+
sample -> downsampleAndRegenotype(gt = gt, proportion = proportion, random = random, epsilon = epsilon)
81+
})
82+
} catch {
83+
case e: MatchError => throw new Exception(
84+
"processing " + variant.id + " at " + variant.chrom + ":" + variant.pos + "-" + variant.end, e
85+
)
86+
}
87+
}
88+
89+
/**
90+
* Does the downsampling on a Genotype
91+
* @param gt the genotype to downsample
92+
* @param proportion the proportion to use for downsampling allele depths
93+
* @param random random number generator for downsampling
94+
* @param epsilon the error rate for genotyping
95+
* @return a new Genotype with updated allele depths, PLs and genotype
96+
*/
97+
def downsampleAndRegenotype(gt: Genotype, proportion: Double, random: Random, epsilon: Double): Genotype = {
98+
val oldAds = gt[IndexedSeq[Int]]("AD")
99+
val newAds = downsampleADs(oldAds, proportion, random)
100+
val Seq(aa, ab, bb) = computePls(newAds)
101+
val Seq(alleleA, alleleB) = gt.alleles.toSeq
102+
103+
val calls = {
104+
if (aa == 0 && ab == 0 && bb == 0) IndexedSeq(NoCallAllele, NoCallAllele)
105+
else if (aa < ab && aa < bb) IndexedSeq(alleleA, alleleA)
106+
else if (bb < ab && bb < aa) IndexedSeq(alleleB, alleleB)
107+
else IndexedSeq(alleleA, alleleB)
108+
}
109+
gt.copy(attrs = Map("PL" -> IndexedSeq(aa, ab, bb), "AD" -> newAds, "DP" -> newAds.sum), calls = calls)
110+
}
111+
112+
/**
113+
* Compute the genotype likelihoods given the allele depths.
114+
* @param ads The allele depths to generate likelihoods from
115+
* @return a list of three likelihoods
116+
*/
117+
def computePls(ads: IndexedSeq[Int]): IndexedSeq[Int] = {
118+
val likelihoods = Likelihoods(ads(0), ads(1))
119+
IndexedSeq(likelihoods.aa.round.toInt, likelihoods.ab.round.toInt, likelihoods.bb.round.toInt)
120+
}
121+
122+
123+
object Likelihoods {
124+
/** Computes the likelihoods for each possible genotype.
125+
*
126+
* @param alleleDepthA the reference allele depth
127+
* @param alleleDepthB the alternate allele depth
128+
* @param epsilon the error rate for genotyping
129+
* @return a new `Likelihood` that has the likelihoods of AA, AB, and BB
130+
*/
131+
def apply(alleleDepthA: Int, alleleDepthB: Int, epsilon: Double = 0.01): Likelihoods = {
132+
val aGivenAA = log10(1 - epsilon)
133+
val aGivenBB = log10(epsilon)
134+
val aGivenAB = log10((1 - epsilon) / 2)
135+
136+
val rawGlAA = ((alleleDepthA * aGivenAA) + (alleleDepthB * aGivenBB)) * -10
137+
val rawGlBB = ((alleleDepthA * aGivenBB) + (alleleDepthB * aGivenAA)) * -10
138+
val rawGlAB = ((alleleDepthA + alleleDepthB) * aGivenAB) * -10
139+
140+
val minGL = math.min(math.min(rawGlAA, rawGlAB), rawGlBB)
141+
142+
Likelihoods(
143+
aa = rawGlAA - minGL,
144+
ab = rawGlAB - minGL,
145+
bb = rawGlBB - minGL
146+
)
147+
}
148+
}
149+
150+
/** Stores the log10(likelihoods) for all possible bi-allelic genotypes.
151+
*
152+
* @param aa likelihood of AA
153+
* @param ab likelihood of AB
154+
* @param bb likelihood of BB
155+
*/
156+
case class Likelihoods(aa: Double, ab: Double, bb: Double) {
157+
def pls = IndexedSeq(aa.round.toInt, ab.round.toInt, bb.round.toInt)
158+
}
159+
}
160+
161+
@clp(group = ClpGroups.VcfOrBcf, description =
162+
"""
163+
|DownsampleVcf takes a vcf file and metadata with sequencing info and
164+
|1. winnows the vcf to remove variants within a specified distance to each other,
165+
|2. downsamples the variants using the provided allele depths and target base count by
166+
| re-computing/downsampling the allele depths for the new target base count
167+
| and re-computing the genotypes based on the new allele depths
168+
|and writes a new downsampled vcf file.
169+
|For single-sample VCFs, the metadata file can be omitted, and instead you can specify originalBases.
170+
""")
171+
class DownsampleVcf
172+
(@arg(flag = 'i', doc = "The vcf to downsample.") val input: PathToVcf,
173+
@arg(flag = 'm', doc = "Index file with bases per sample.") val metadata: Option[FilePath] = None,
174+
@arg(flag = 'b', doc = "Original number of bases (for single-sample VCF)") val originalBases: Option[Double] = None,
175+
@arg(flag = 'n', doc = "Target number of bases to downsample to.") val downsampleToBases: Double,
176+
@arg(flag = 'o', doc = "Output file name.") val output: PathToVcf,
177+
@arg(flag = 'w', doc = "Winnowing window size.") val windowSize: Int = 150,
178+
@arg(flag = 'e', doc = "Error rate for genotyping.") val epsilon: Double = 0.01,
179+
@arg(flag = 'c', doc = "True to write out no-calls.") val writeNoCall: Boolean = false)
180+
extends FgBioTool {
181+
Io.assertReadable(input)
182+
Io.assertReadable(metadata)
183+
Io.assertCanWriteFile(output)
184+
require(downsampleToBases > 0, "target base count must be greater than zero")
185+
require(windowSize >= 0, "window size must be greater than or equal to zero")
186+
require(0 <= epsilon && epsilon <= 1, "epsilon/error rate must be between 0 and 1")
187+
originalBases match {
188+
case Some(x) =>
189+
require(x > 0, "originalBases must be greater than zero")
190+
require(metadata.isEmpty, "Must pass either originalBases (for single-sample VCF) or metadata, not both")
191+
case None =>
192+
require(metadata.isDefined, "Must pass either originalBases (for single-sample VCF) or metadata, not both")
193+
}
194+
195+
override def execute(): Unit = {
196+
val vcf = VcfSource(input)
197+
val progress = ProgressLogger(logger, noun = "variants")
198+
val proportions = (
199+
originalBases match {
200+
case Some(x) =>
201+
require(vcf.header.samples.length == 1, "--original-bases requires a single-sample VCF")
202+
LazyList(vcf.header.samples.head -> math.min(downsampleToBases / x, 1.0))
203+
case _ =>
204+
Sample.read(metadata.getOrElse(throw new RuntimeException))
205+
.filter(s => vcf.header.samples.contains(s.SAMPLE_NAME))
206+
.map(sample => sample.SAMPLE_NAME -> math.min(downsampleToBases / sample.BASE_COUNT.toDouble, 1.0))
207+
}
208+
).toMap
209+
proportions.foreach { case (s, p) => logger.info(f"Downsampling $s with proportion ${p}%.4f") }
210+
211+
val winnowed = if (windowSize > 0) winnowVariants(vcf.iterator, windowSize = windowSize, dict = vcf.header.dict) else vcf.iterator
212+
val outputVcf = VcfWriter(path = output, header = buildOutputHeader(vcf.header))
213+
214+
val random = new Random(42)
215+
winnowed.foreach { v =>
216+
val ds = downsampleAndRegenotype(v, proportions = proportions, random = random, epsilon = epsilon)
217+
if (writeNoCall) {
218+
outputVcf += ds
219+
progress.record(ds)
220+
}
221+
else if (!ds.gts.forall(g => g.isNoCall)) {
222+
outputVcf += ds
223+
progress.record(ds)
224+
}
225+
}
226+
227+
progress.logLast()
228+
vcf.safelyClose()
229+
outputVcf.close()
230+
}
231+
232+
def buildOutputHeader(in: VcfHeader): VcfHeader = {
233+
val fmts = Seq.newBuilder[VcfFormatHeader]
234+
fmts ++= in.formats
235+
236+
if (!in.format.contains("AD")) {
237+
fmts += VcfFormatHeader(id="AD", count=VcfCount.OnePerAllele, kind=VcfFieldType.Integer, description="Per allele depths.")
238+
}
239+
240+
if (!in.format.contains("DP")) {
241+
fmts += VcfFormatHeader(id="DP", count=VcfCount.Fixed(1), kind=VcfFieldType.Integer, description="Total depth across alleles.")
242+
}
243+
244+
if (!in.format.contains("PL")) {
245+
fmts += VcfFormatHeader(id="PL", count=VcfCount.OnePerGenotype, kind=VcfFieldType.Integer, description="Per genotype phred scaled likelihoods.")
246+
}
247+
248+
in.copy(formats = fmts.result())
249+
}
250+
}
251+
252+
object Sample {
253+
/** Load a set of samples from the 1KG metadata file. */
254+
def read(path: FilePath): Seq[Sample] = {
255+
val lines = Io.readLines(path).dropWhile(_.startsWith("##")).map(line => line.dropWhile(_ == '#'))
256+
Metric.read[Sample](lines=lines)
257+
}
258+
}
259+
260+
case class Sample(ENA_FILE_PATH: String = ".",
261+
MD5SUM: String = ".",
262+
RUN_ID: String = ".",
263+
STUDY_ID: String = ".",
264+
STUDY_NAME: String = ".",
265+
CENTER_NAME: String = ".",
266+
SUBMISSION_ID: String = ".",
267+
SUBMISSION_DATE: String = ".",
268+
SAMPLE_ID: String = ".",
269+
SAMPLE_NAME: String,
270+
POPULATION: String = ".",
271+
EXPERIMENT_ID: String = ".",
272+
INSTRUMENT_PLATFORM: String = ".",
273+
INSTRUMENT_MODEL: String = ".",
274+
LIBRARY_NAME: String = ".",
275+
RUN_NAME: String = ".",
276+
INSERT_SIZE: String = ".",
277+
LIBRARY_LAYOUT: String = ".",
278+
PAIRED_FASTQ: String = ".",
279+
READ_COUNT: String = ".",
280+
BASE_COUNT: Long,
281+
ANALYSIS_GROUP: String = ".") extends Metric

0 commit comments

Comments
 (0)