Skip to content

Commit 9d1aa06

Browse files
committed
make changes suggested in PR
1 parent 06065c6 commit 9d1aa06

File tree

1 file changed

+151
-109
lines changed

1 file changed

+151
-109
lines changed

src/main/scala/com/fulcrumgenomics/vcf/DownsampleVcf.scala

+151-109
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ import scala.math.log10
1515
import scala.util.Random
1616

1717
object DownsampleVcf extends LazyLogging {
18-
/** Removes variants that are within a specified distance from a previous variant
19-
* The end position of the current variant is compared with the start position of the following variant
20-
*
18+
/** Removes variants that are within a specified distance from a previous variant.
19+
* The end position of the current variant is compared with the start position of the following variant.
2120
* @param variants an iterator of the variants to process
2221
* @param windowSize the interval (exclusive) in which to check for additional variants.
2322
* windowSize considers the distance between the end position of a variant
@@ -32,7 +31,7 @@ object DownsampleVcf extends LazyLogging {
3231

3332
def hasNext: Boolean = iter.hasNext
3433

35-
def isInOrder(current: Variant, next: Variant, currentIndex: Int, nextIndex: Int): Boolean = {
34+
private def isInOrder(current: Variant, next: Variant, currentIndex: Int, nextIndex: Int): Boolean = {
3635
(currentIndex < nextIndex) || (currentIndex == nextIndex && current.end <= next.pos)
3736
}
3837

@@ -52,32 +51,33 @@ object DownsampleVcf extends LazyLogging {
5251
}
5352
}
5453

55-
/** Downsamples variants using Allele Depths
56-
*
54+
/** Downsamples variants by randomly sampling the total allele depths at the given proportion.
5755
* @param oldAds an indexed seq of the original allele depths
5856
* @param proportion the proportion to use for downsampling,
5957
* calculated using total base count from the index and a target base count
6058
* @return a new IndexedSeq of allele depths of the same length as `oldAds`
6159
*/
62-
def downsampleADs(oldAds: IndexedSeq[Int], proportion: Double, random: Random): IndexedSeq[Int] = {
60+
def downsampleADs(oldAds: IterableOnce[Int], proportion: Double, random: Random): IndexedSeq[Int] = {
6361
require(proportion <= 1, f"proportion must be less than 1: proportion = ${proportion}")
64-
oldAds.map(s => Range(0, s).iterator.map(_ => random.nextDouble()).count(_ < proportion))
62+
oldAds.iterator.toIndexedSeq.map(s => Range(0, s).iterator.map(_ => random.nextDouble()).count(_ < proportion))
6563
}
6664

6765
/**
68-
* Does the downsampling on a Variant
69-
* @param variant the variant with the genotype to downsample
70-
* @param proportions a map of downsampling target proportions for each sample
66+
* Re-genotypes a variant for each sample after downsampling the allele counts based on the given
67+
* per-sample proportions.
68+
* @param variant the variant to downsample and re-genotype
69+
* @param proportions proportion to downsample the allele counts for each sample prior to re-genotyping
7170
* @param random random number generator for downsampling
72-
* @param epsilon the error rate for genotyping
73-
* @return a new variant with updated genotypes
71+
* @param epsilon the sequencing error rate for genotyping
72+
* @return a new variant with updated genotypes, downsampled ADs, and recomputed PLs
7473
*/
75-
// Returns a new variant that has downsampled ADs, recomputed PLs and updated genotypes
76-
def downsampleAndRegenotype(variant: Variant, proportions: Map[String, Double], random: Random, epsilon: Double = 0.01): Variant = {
74+
def downsampleAndRegenotype(variant: Variant,
75+
proportions: Map[String, Double],
76+
random: Random, epsilon: Double=0.01): Variant = {
7777
try {
78-
variant.copy(genotypes = variant.genotypes.map { case (sample, gt) =>
78+
variant.copy(genotypes=variant.genotypes.map { case (sample, gt) =>
7979
val proportion = proportions(sample)
80-
sample -> downsampleAndRegenotype(gt = gt, proportion = proportion, random = random, epsilon = epsilon)
80+
sample -> downsampleAndRegenotype(gt=gt, proportion=proportion, random=random, epsilon=epsilon)
8181
})
8282
} catch {
8383
case e: MatchError => throw new Exception(
@@ -87,15 +87,15 @@ object DownsampleVcf extends LazyLogging {
8787
}
8888

8989
/**
90-
* Does the downsampling on a Genotype
90+
* Re-genotypes a sample after downsampling the allele counts based on the given proportion.
9191
* @param gt the genotype to downsample
92-
* @param proportion the proportion to use for downsampling allele depths
92+
* @param proportion proportion to downsample the allele count prior to re-genotyping
9393
* @param random random number generator for downsampling
94-
* @param epsilon the error rate for genotyping
95-
* @return a new Genotype with updated allele depths, PLs and genotype
94+
* @param epsilon the sequencing error rate for genotyping
95+
* @return a new Genotype with updated allele depths, PLs, and genotype
9696
*/
9797
def downsampleAndRegenotype(gt: Genotype, proportion: Double, random: Random, epsilon: Double): Genotype = {
98-
val oldAds = gt[IndexedSeq[Int]]("AD")
98+
val oldAds = gt.getOrElse[IndexedSeq[Int]]("AD", throw new Exception(s"AD tag not found for sample ${gt.sample}"))
9999
val newAds = downsampleADs(oldAds, proportion, random)
100100
val Seq(aa, ab, bb) = computePls(newAds)
101101
val Seq(alleleA, alleleB) = gt.alleles.toSeq
@@ -106,20 +106,21 @@ object DownsampleVcf extends LazyLogging {
106106
else if (bb < ab && bb < aa) IndexedSeq(alleleB, alleleB)
107107
else IndexedSeq(alleleA, alleleB)
108108
}
109-
gt.copy(attrs = Map("PL" -> IndexedSeq(aa, ab, bb), "AD" -> newAds, "DP" -> newAds.sum), calls = calls)
109+
gt.copy(attrs=Map("PL" -> IndexedSeq(aa, ab, bb), "AD" -> newAds, "DP" -> newAds.sum), calls=calls)
110110
}
111111

112112
/**
113-
* Compute the genotype likelihoods given the allele depths.
114-
* @param ads The allele depths to generate likelihoods from
115-
* @return a list of three likelihoods
113+
* Compute the genotype likelihoods given the allele depths, assuming a diploid genotype (i.e.
114+
* two allele depths).
115+
* @param ads The input depths for the two alleles A and B.
116+
* @return a list of three likelihoods for the alleles AA, AB, and BB.
116117
*/
117118
def computePls(ads: IndexedSeq[Int]): IndexedSeq[Int] = {
119+
require(ads.length == 2, "there must be exactly two allele depths")
118120
val likelihoods = Likelihoods(ads(0), ads(1))
119121
IndexedSeq(likelihoods.aa.round.toInt, likelihoods.ab.round.toInt, likelihoods.bb.round.toInt)
120122
}
121123

122-
123124
object Likelihoods {
124125
/** Computes the likelihoods for each possible genotype.
125126
*
@@ -128,7 +129,7 @@ object DownsampleVcf extends LazyLogging {
128129
* @param epsilon the error rate for genotyping
129130
* @return a new `Likelihood` that has the likelihoods of AA, AB, and BB
130131
*/
131-
def apply(alleleDepthA: Int, alleleDepthB: Int, epsilon: Double = 0.01): Likelihoods = {
132+
def apply(alleleDepthA: Int, alleleDepthB: Int, epsilon: Double=0.01): Likelihoods = {
132133
val aGivenAA = log10(1 - epsilon)
133134
val aGivenBB = log10(epsilon)
134135
val aGivenAB = log10((1 - epsilon) / 2)
@@ -154,102 +155,143 @@ object DownsampleVcf extends LazyLogging {
154155
* @param bb likelihood of BB
155156
*/
156157
case class Likelihoods(aa: Double, ab: Double, bb: Double) {
158+
/**
159+
* Returns the likelihoods as a list of phred-scaled integers (i.e, the value of the PL tag).
160+
* @return a list of phred-scaled likelihooodS for AA, AB, BB.
161+
*/
157162
def pls = IndexedSeq(aa.round.toInt, ab.round.toInt, bb.round.toInt)
158163
}
159164
}
160165

161-
@clp(group = ClpGroups.VcfOrBcf, description =
162-
"""
163-
|DownsampleVcf takes a vcf file and metadata with sequencing info and
164-
|1. winnows the vcf to remove variants within a specified distance to each other,
165-
|2. downsamples the variants using the provided allele depths and target base count by
166-
| re-computing/downsampling the allele depths for the new target base count
167-
| and re-computing the genotypes based on the new allele depths
168-
|and writes a new downsampled vcf file.
169-
|For single-sample VCFs, the metadata file can be omitted, and instead you can specify originalBases.
170-
""")
171-
class DownsampleVcf
172-
(@arg(flag = 'i', doc = "The vcf to downsample.") val input: PathToVcf,
173-
@arg(flag = 'm', doc = "Index file with bases per sample.") val metadata: Option[FilePath] = None,
174-
@arg(flag = 'b', doc = "Original number of bases (for single-sample VCF)") val originalBases: Option[Double] = None,
175-
@arg(flag = 'n', doc = "Target number of bases to downsample to.") val downsampleToBases: Double,
176-
@arg(flag = 'o', doc = "Output file name.") val output: PathToVcf,
177-
@arg(flag = 'w', doc = "Winnowing window size.") val windowSize: Int = 150,
178-
@arg(flag = 'e', doc = "Error rate for genotyping.") val epsilon: Double = 0.01,
179-
@arg(flag = 'c', doc = "True to write out no-calls.") val writeNoCall: Boolean = false)
180-
extends FgBioTool {
181-
Io.assertReadable(input)
182-
Io.assertReadable(metadata)
183-
Io.assertCanWriteFile(output)
184-
require(downsampleToBases > 0, "target base count must be greater than zero")
185-
require(windowSize >= 0, "window size must be greater than or equal to zero")
186-
require(0 <= epsilon && epsilon <= 1, "epsilon/error rate must be between 0 and 1")
187-
originalBases match {
188-
case Some(x) =>
189-
require(x > 0, "originalBases must be greater than zero")
190-
require(metadata.isEmpty, "Must pass either originalBases (for single-sample VCF) or metadata, not both")
191-
case None =>
192-
require(metadata.isDefined, "Must pass either originalBases (for single-sample VCF) or metadata, not both")
166+
@clp(group=ClpGroups.VcfOrBcf, description =
167+
"""
168+
|Re-genotypes a VCF after downsampling the allele counts.
169+
|
170+
|The input VCF must have at least one sample.
171+
|
172+
|If the input VCF contains a single sample, the downsampling target may be specified as a
173+
|proportion of the original read depth using `--proportion=(0..1)`, or as the combination of
174+
|the original and target _number of sequenced bases_ (`--originalBases` and
175+
|`--downsampleToBases`). For multi-sample VCFs, the downsampling target must be specified using
176+
|`--downsampleToBases`, and a metadata file with the total number of sequenced bases per sample
177+
|is required as well. The metadata file must follow the
178+
|[[https://www.internationalgenome.org/category/meta-data/] 1000 Genomes index format], but the
179+
|only required columns are `SAMPLE_NAME` and `BASE_COUNT`. A propportion for each sample is
180+
|calculated by dividing the _target number of sequenced bases_ by the _original number of
181+
|sequenced bases_.
182+
|
183+
|The tool first (optionally) winnows the VCF file to remove variants within a distance to each
184+
|other specified by `--window-size` (the default value of `0` disables winnowing). Next, each
185+
|sample at each variant is examined independently. The allele depths per-genotype are randoml
186+
|downsampled given the proportion. The downsampled allele depths are then used to re-compute
187+
|allele likelhoods and produce a new genotype.
188+
|
189+
|The tool outputs a downsampled VCF file with the winnowed variants removed, and with the
190+
|genotype calls and `DP`, `AD`, and `PL` tags updated for each sample at each retained variant.
191+
""")
192+
class DownsampleVcf
193+
(@arg(flag='i', doc="The vcf to downsample.") val input: PathToVcf,
194+
@arg(flag='p', doc="Proportion of bases to retain (for single-sample VCF).") val proportion: Option[Double] = None,
195+
@arg(flag='b', doc="Original number of bases (for single-sample VCF).") val originalBases: Option[Double] = None,
196+
@arg(flag='m', doc="Index file with bases per sample.") val metadata: Option[FilePath] = None,
197+
@arg(flag='n', doc="Target number of bases to downsample to.") val downsampleToBases: Option[Double],
198+
@arg(flag='o', doc="Output file name.") val output: PathToVcf,
199+
@arg(flag='w', doc="Winnowing window size.") val windowSize: Int = 0,
200+
@arg(flag='e', doc="Sequencing Error rate for genotyping.") val epsilon: Double = 0.01,
201+
@arg(flag='c', doc="True to write out no-calls.") val writeNoCall: Boolean = false,
202+
@arg(flag='s', doc="Random seed value.") val seed: Int = 42,
203+
) extends FgBioTool {
204+
Io.assertReadable(input)
205+
Io.assertCanWriteFile(output)
206+
require(windowSize >= 0, "window size must be greater than or equal to zero")
207+
require(0 <= epsilon && epsilon <= 1, "epsilon/error rate must be between 0 and 1")
208+
(proportion, originalBases, metadata, downsampleToBases) match {
209+
case (Some(x), None, None, None) =>
210+
require(x > 0, "proportion must be greater than 0")
211+
require(x < 1, "proportion must be less than 1")
212+
case (None, Some(original), None, Some(target)) =>
213+
require(original > 0, "originalBases must be greater than zero")
214+
require(target > 0, "target base count must be greater than zero")
215+
case (None, None, Some(metadata), Some(target)) =>
216+
Io.assertReadable(metadata)
217+
require(target > 0, "target base count must be greater than zero")
218+
case (None, _, _, None) =>
219+
throw new IllegalArgumentException(
220+
"exactly one of proportion or downsampleToBases must be specified"
221+
)
222+
case _ =>
223+
throw new IllegalArgumentException(
224+
"exactly one of proportion, originalBases, or metadata must be specified"
225+
)
226+
}
227+
228+
override def execute(): Unit = {
229+
val vcf = VcfSource(input)
230+
val proportions = (
231+
(proportion, originalBases, metadata, downsampleToBases) match {
232+
case (Some(x), None, None, None) =>
233+
require(vcf.header.samples.length == 1, "--original-bases requires a single-sample VCF")
234+
LazyList(vcf.header.samples.head -> x)
235+
case (None, Some(original), None, Some(target)) =>
236+
require(vcf.header.samples.length == 1, "--original-bases requires a single-sample VCF")
237+
LazyList(vcf.header.samples.head -> math.min(target / original, 1.0))
238+
case (None, None, Some(metadata), Some(target)) =>
239+
Sample.read(metadata)
240+
.filter(s => vcf.header.samples.contains(s.SAMPLE_NAME))
241+
.map(sample => sample.SAMPLE_NAME -> math.min(target / sample.BASE_COUNT.toDouble, 1.0))
242+
case _ =>
243+
throw new RuntimeException("unexpected parameter combination")
244+
}
245+
).toMap
246+
proportions.foreach { case (s, p) => logger.info(f"Downsampling $s with proportion ${p}%.4f") }
247+
248+
val inputProgress = ProgressLogger(logger, noun="variants read")
249+
val inputVariants = ProgressLogger.ProgressLoggingIterator(vcf.iterator).progress(inputProgress)
250+
val winnowed = if (windowSize > 0) {
251+
val winnowed = winnowVariants(inputVariants, windowSize=windowSize, dict=vcf.header.dict)
252+
val winnowedProgress = ProgressLogger(logger, noun="variants retained")
253+
ProgressLogger.ProgressLoggingIterator(winnowed).progress(winnowedProgress)
254+
} else {
255+
inputVariants
193256
}
257+
val outputVcf = VcfWriter(path=output, header=buildOutputHeader(vcf.header))
194258

195-
override def execute(): Unit = {
196-
val vcf = VcfSource(input)
197-
val progress = ProgressLogger(logger, noun = "variants")
198-
val proportions = (
199-
originalBases match {
200-
case Some(x) =>
201-
require(vcf.header.samples.length == 1, "--original-bases requires a single-sample VCF")
202-
LazyList(vcf.header.samples.head -> math.min(downsampleToBases / x, 1.0))
203-
case _ =>
204-
Sample.read(metadata.getOrElse(throw new RuntimeException))
205-
.filter(s => vcf.header.samples.contains(s.SAMPLE_NAME))
206-
.map(sample => sample.SAMPLE_NAME -> math.min(downsampleToBases / sample.BASE_COUNT.toDouble, 1.0))
207-
}
208-
).toMap
209-
proportions.foreach { case (s, p) => logger.info(f"Downsampling $s with proportion ${p}%.4f") }
210-
211-
val winnowed = if (windowSize > 0) winnowVariants(vcf.iterator, windowSize = windowSize, dict = vcf.header.dict) else vcf.iterator
212-
val outputVcf = VcfWriter(path = output, header = buildOutputHeader(vcf.header))
213-
214-
val random = new Random(42)
215-
winnowed.foreach { v =>
216-
val ds = downsampleAndRegenotype(v, proportions = proportions, random = random, epsilon = epsilon)
217-
if (writeNoCall) {
218-
outputVcf += ds
219-
progress.record(ds)
220-
}
221-
else if (!ds.gts.forall(g => g.isNoCall)) {
222-
outputVcf += ds
223-
progress.record(ds)
224-
}
259+
val progress = ProgressLogger(logger, noun="variants written")
260+
val random = new Random(seed)
261+
winnowed.foreach { v =>
262+
val ds = downsampleAndRegenotype(v, proportions=proportions, random=random, epsilon=epsilon)
263+
if (writeNoCall || !ds.gts.forall(g => g.isNoCall)) {
264+
outputVcf += ds
265+
progress.record(ds)
225266
}
226-
227-
progress.logLast()
228-
vcf.safelyClose()
229-
outputVcf.close()
230267
}
268+
269+
progress.logLast()
270+
vcf.safelyClose()
271+
outputVcf.close()
272+
}
231273

232-
def buildOutputHeader(in: VcfHeader): VcfHeader = {
233-
val fmts = Seq.newBuilder[VcfFormatHeader]
234-
fmts ++= in.formats
274+
private def buildOutputHeader(in: VcfHeader): VcfHeader = {
275+
val fmts = Seq.newBuilder[VcfFormatHeader]
276+
fmts ++= in.formats
235277

236-
if (!in.format.contains("AD")) {
237-
fmts += VcfFormatHeader(id="AD", count=VcfCount.OnePerAllele, kind=VcfFieldType.Integer, description="Per allele depths.")
238-
}
239-
240-
if (!in.format.contains("DP")) {
241-
fmts += VcfFormatHeader(id="DP", count=VcfCount.Fixed(1), kind=VcfFieldType.Integer, description="Total depth across alleles.")
242-
}
278+
if (!in.format.contains("AD")) {
279+
fmts += VcfFormatHeader(id="AD", count=VcfCount.OnePerAllele, kind=VcfFieldType.Integer, description="Per allele depths.")
280+
}
243281

244-
if (!in.format.contains("PL")) {
245-
fmts += VcfFormatHeader(id="PL", count=VcfCount.OnePerGenotype, kind=VcfFieldType.Integer, description="Per genotype phred scaled likelihoods.")
246-
}
282+
if (!in.format.contains("DP")) {
283+
fmts += VcfFormatHeader(id="DP", count=VcfCount.Fixed(1), kind=VcfFieldType.Integer, description="Total depth across alleles.")
284+
}
247285

248-
in.copy(formats = fmts.result())
286+
if (!in.format.contains("PL")) {
287+
fmts += VcfFormatHeader(id="PL", count=VcfCount.OnePerGenotype, kind=VcfFieldType.Integer, description="Per genotype phred scaled likelihoods.")
249288
}
289+
290+
in.copy(formats=fmts.result())
250291
}
292+
}
251293

252-
object Sample {
294+
private object Sample {
253295
/** Load a set of samples from the 1KG metadata file. */
254296
def read(path: FilePath): Seq[Sample] = {
255297
val lines = Io.readLines(path).dropWhile(_.startsWith("##")).map(line => line.dropWhile(_ == '#'))

0 commit comments

Comments
 (0)