Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ import org.apache.spark.sql.{Column, DataFrame}

object ACMGImplicits {

val variantColumns = Array("chromosome", "start", "end", "reference", "alternate")

private def validateRequiredColumns(map: Map[DataFrame, (String, Array[String])], criteriaName: String = "criteria"): Unit = {
map.foreach {
case (df, (dfName, columns)) => columns.foreach(
col => require(
df.columns.contains(col),
s"Column `$col` is required in DataFrame $dfName for $criteriaName.")
)
}
}

implicit class ACMGOperations(df: DataFrame) {

/**
Expand Down Expand Up @@ -46,6 +58,77 @@ object ACMGImplicits {
}
}

}
/**
* PM2 - ACMG criteria
* Moderate (supporting) evidence of pathogenic impact.
* Absent from controls (or at extremely low frequency if recessive) in Exome Sequencing Project, 1000 Genomes
* Project, or Exome Aggregration Consortium.
*
* WIP
*
*/
def getPM2(omim: DataFrame, frequencies: DataFrame): DataFrame = {

val map = Map(
df -> ("df", Array("symbol") ++ variantColumns),
omim -> ("omim", Array("symbols", "phenotype")),
frequencies -> ("frequencies", Array("external_frequencies", "genes_symbol") ++ variantColumns)
)
validateRequiredColumns(map, "PM2")

// Extracting inheritance, identifying recessive genes
val inheritanceModes = List(
"Pseudoautosomal recessive",
"Autosomal recessive",
"Digenic recessive",
"X-linked recessive")

val omimRecessive = omim.select("symbols", "phenotype.inheritance")
.withColumn("is_recessive", inheritanceModes.map(m => array_contains(col("inheritance"), m)).reduce(_ || _))
.select(col("is_recessive"), explode(col("symbols")).as("symbol"))
.filter(col("is_recessive") === true)
.distinct()


// Extracting frequency (and lack of)
val maxAf = frequencies.schema("external_frequencies").dataType match {
case s: StructType => {
val afCols = s.fields.map(_.name).map { field =>
struct(col(s"external_frequencies.$field.af") as "v", lit(field) as "k")
}
greatest(afCols: _*).getItem("v").as("max_af")
}
}

val freqPerSymbol = frequencies.select(
col("chromosome"),
col("start"),
col("end"),
col("reference"),
col("alternate"),
explode(col("genes_symbol")).as("symbol"),
maxAf,
maxAf.isNull.as("max_af_is_null")
)

df
.join(omimRecessive, Seq("symbol"), "leftouter")
.na.fill(false, Seq("is_recessive"))
.join(freqPerSymbol, Seq("chromosome", "start", "end", "reference", "alternate", "symbol"), "leftouter")
.na.fill(false, Seq("max_af_is_null"))
.na.fill(0, Seq("max_af"))
.withColumn("PM2",
struct(
col("is_recessive").as("is_recessive"),
col("max_af").as("max_af"),
col("max_af_is_null").as("max_af_is_null"),
(col("max_af_is_null") ||
col("max_af") === 0 ||
(col("is_recessive") && col("max_af") < 0.0001)).as("score")
)
)
.drop("is_recessive", "max_af", "max_af_is_null")

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers

spark.sparkContext.setLogLevel("ERROR")

val variantSchema = new StructType()
.add("chromosome", StringType, true)
.add("start", IntegerType, true)
.add("end", IntegerType, true)
.add("reference", StringType, true)
.add("alternate", StringType, true)

def ba1Fixture = {
new {
val querySchema = new StructType()
Expand Down Expand Up @@ -46,7 +53,7 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers
}
}

"get_BA1" should "throw IllegalArgumentException if `external_frequencies` column is absent" in {
"getBA1" should "throw IllegalArgumentException if `external_frequencies` column is absent" in {
val structureData = Seq(Row(1), Row(2))
val structureSchema = new StructType().add("start", IntegerType, true)

Expand All @@ -65,4 +72,120 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers
f.result.collect() should contain theSameElementsAs f.resultData
}


def pm2Fixture = {
new {
val omimSchema = new StructType()
.add("symbols", new ArrayType(StringType, true), true)
.add("phenotype", new StructType()
.add("inheritance", new ArrayType(StringType, true), true)
)

val omimData = Seq(
Row(Array("gene1", "gene2"), Row(Array("Digenic recessive"))),
Row(Array("gene3"), Row(Array("Autosomal Recessive"))),
Row(Array("gene4"), Row(Array("Autosomal Dominant")))
)

val omimDF = spark.createDataFrame(spark.sparkContext.parallelize(omimData), omimSchema)

val freqSchema = variantSchema
.add("genes_symbol", new ArrayType(StringType, true), true)
.add("external_frequencies", new StructType()
.add("thousand_genomes", new StructType()
.add("af", DoubleType, true)
.add("an", IntegerType, true))
.add("topmed_bravo", new StructType()
.add("af", DoubleType, true)
.add("an", IntegerType, true)))

val freqData = Seq(Row("1", 1, 2, "A", "C", Array("gene1"), null))

val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), freqSchema)

val querySchema = variantSchema
.add("symbol", StringType, true)

val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), querySchema)

val queryData = Seq(Row("1", 1, 2, "A", "C", "gene1"))

val resultData = Seq(Row("1", 1, 2, "A", "C", "gene1", Row(true, 0.01, true, false)))

val result = queryDF.getPM2(omimDF, freqDF)
}
}

"getPM2" should "throw IllegalArgumentException if `phenotype` column is absent from the OMIM DataFrame" in {
val f = pm2Fixture

an[IllegalArgumentException] should be thrownBy f.queryDF.getPM2(f.omimDF.drop("phenotype"), f.freqDF)
}

it should "return the correct PM2 schema" in {
val f = pm2Fixture

f.result.schema shouldBe f.querySchema
.add("PM2", new StructType()
.add("is_recessive", BooleanType, false)
.add("max_af", DoubleType, false)
.add("max_af_is_null", BooleanType, false)
.add("score", BooleanType, false), false
)
}

it should "return missing frequencies as PM2 true" in {
val f = pm2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene1"), null),
Row("2", 1, 2, "A", "C", Array("gene1"), Row(null, Row(0.0, 0))),
Row("3", 1, 2, "A", "C", Array("gene1"), Row(Row(0.0, 1200), Row(0.0, 1000))),
)
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene1"),
Row("2", 1, 2, "A", "C", "gene1"),
Row("3", 1, 2, "A", "C", "gene1"),
)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene1", Row(true, 0.00, true, true)),
Row("2", 1, 2, "A", "C", "gene1", Row(true, 0.00, false, true)),
Row("3", 1, 2, "A", "C", "gene1", Row(true, 0.00, false, true)),
)
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)
val result = queryDF.getPM2(f.omimDF, freqDF)

result.collect() should contain theSameElementsAs resultData
}

it should "return low AF in genes with recessive disease as PM2 true " in {
val f = pm2Fixture

val freqData = Seq(
Row("1", 1, 2, "A", "C", Array("gene4"), Row(Row(0.00001, 1), Row(0.0, 1))),
Row("1", 1, 2, "A", "C", Array("gene2"), Row(Row(0.00001, 1), Row(0.0, 1))),
Row("2", 1, 2, "A", "C", Array("gene2"), Row(Row(0.001, 1), Row(0.0, 1))),
)
val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema)

val queryData = Seq(
Row("1", 1, 2, "A", "C", "gene4"),
Row("1", 1, 2, "A", "C", "gene2"),
Row("2", 1, 2, "A", "C", "gene2"),
)

val resultData = Seq(
Row("1", 1, 2, "A", "C", "gene4", Row(false, 0.00001, false, false)),
Row("1", 1, 2, "A", "C", "gene2", Row(true, 0.00001, false, true)),
Row("2", 1, 2, "A", "C", "gene2", Row(true, 0.001, false, false)),
)
val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema)
val result = queryDF.getPM2(f.omimDF, freqDF)

result.collect() should contain theSameElementsAs resultData
}

}