diff --git a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicits.scala b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicits.scala index 9def4b43..daf8d29b 100644 --- a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicits.scala +++ b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicits.scala @@ -6,6 +6,18 @@ import org.apache.spark.sql.{Column, DataFrame} object ACMGImplicits { + val variantColumns = Array("chromosome", "start", "end", "reference", "alternate") + + private def validateRequiredColumns(map: Map[DataFrame, (String, Array[String])], criteriaName: String = "criteria"): Unit = { + map.foreach { + case (df, (dfName, columns)) => columns.foreach( + col => require( + df.columns.contains(col), + s"Column `$col` is required in DataFrame $dfName for $criteriaName.") + ) + } + } + implicit class ACMGOperations(df: DataFrame) { /** @@ -46,6 +58,77 @@ object ACMGImplicits { } } - } + /** + * PM2 - ACMG criteria + * Moderate (supporting) evidence of pathogenic impact. + * Absent from controls (or at extremely low frequency if recessive) in Exome Sequencing Project, 1000 Genomes + * Project, or Exome Aggregration Consortium. + * + * WIP + * + */ + def getPM2(omim: DataFrame, frequencies: DataFrame): DataFrame = { + + val map = Map( + df -> ("df", Array("symbol") ++ variantColumns), + omim -> ("omim", Array("symbols", "phenotype")), + frequencies -> ("frequencies", Array("external_frequencies", "genes_symbol") ++ variantColumns) + ) + validateRequiredColumns(map, "PM2") + + // Extracting inheritance, identifying recessive genes + val inheritanceModes = List( + "Pseudoautosomal recessive", + "Autosomal recessive", + "Digenic recessive", + "X-linked recessive") + + val omimRecessive = omim.select("symbols", "phenotype.inheritance") + .withColumn("is_recessive", inheritanceModes.map(m => array_contains(col("inheritance"), m)).reduce(_ || _)) + .select(col("is_recessive"), explode(col("symbols")).as("symbol")) + .filter(col("is_recessive") === true) + .distinct() + + // Extracting frequency (and lack of) + val maxAf = frequencies.schema("external_frequencies").dataType match { + case s: StructType => { + val afCols = s.fields.map(_.name).map { field => + struct(col(s"external_frequencies.$field.af") as "v", lit(field) as "k") + } + greatest(afCols: _*).getItem("v").as("max_af") + } + } + + val freqPerSymbol = frequencies.select( + col("chromosome"), + col("start"), + col("end"), + col("reference"), + col("alternate"), + explode(col("genes_symbol")).as("symbol"), + maxAf, + maxAf.isNull.as("max_af_is_null") + ) + + df + .join(omimRecessive, Seq("symbol"), "leftouter") + .na.fill(false, Seq("is_recessive")) + .join(freqPerSymbol, Seq("chromosome", "start", "end", "reference", "alternate", "symbol"), "leftouter") + .na.fill(false, Seq("max_af_is_null")) + .na.fill(0, Seq("max_af")) + .withColumn("PM2", + struct( + col("is_recessive").as("is_recessive"), + col("max_af").as("max_af"), + col("max_af_is_null").as("max_af_is_null"), + (col("max_af_is_null") || + col("max_af") === 0 || + (col("is_recessive") && col("max_af") < 0.0001)).as("score") + ) + ) + .drop("is_recessive", "max_af", "max_af_is_null") + + } + } } diff --git a/datalake-spark3/src/test/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicitsSpec.scala b/datalake-spark3/src/test/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicitsSpec.scala index a937cf99..c60685c9 100644 --- a/datalake-spark3/src/test/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicitsSpec.scala +++ b/datalake-spark3/src/test/scala/bio/ferlab/datalake/spark3/implicits/ACMGImplicitsSpec.scala @@ -12,6 +12,13 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers spark.sparkContext.setLogLevel("ERROR") + val variantSchema = new StructType() + .add("chromosome", StringType, true) + .add("start", IntegerType, true) + .add("end", IntegerType, true) + .add("reference", StringType, true) + .add("alternate", StringType, true) + def ba1Fixture = { new { val querySchema = new StructType() @@ -46,7 +53,7 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers } } - "get_BA1" should "throw IllegalArgumentException if `external_frequencies` column is absent" in { + "getBA1" should "throw IllegalArgumentException if `external_frequencies` column is absent" in { val structureData = Seq(Row(1), Row(2)) val structureSchema = new StructType().add("start", IntegerType, true) @@ -65,4 +72,120 @@ class ACMGImplicitsSpec extends AnyFlatSpec with WithSparkSession with Matchers f.result.collect() should contain theSameElementsAs f.resultData } + + def pm2Fixture = { + new { + val omimSchema = new StructType() + .add("symbols", new ArrayType(StringType, true), true) + .add("phenotype", new StructType() + .add("inheritance", new ArrayType(StringType, true), true) + ) + + val omimData = Seq( + Row(Array("gene1", "gene2"), Row(Array("Digenic recessive"))), + Row(Array("gene3"), Row(Array("Autosomal Recessive"))), + Row(Array("gene4"), Row(Array("Autosomal Dominant"))) + ) + + val omimDF = spark.createDataFrame(spark.sparkContext.parallelize(omimData), omimSchema) + + val freqSchema = variantSchema + .add("genes_symbol", new ArrayType(StringType, true), true) + .add("external_frequencies", new StructType() + .add("thousand_genomes", new StructType() + .add("af", DoubleType, true) + .add("an", IntegerType, true)) + .add("topmed_bravo", new StructType() + .add("af", DoubleType, true) + .add("an", IntegerType, true))) + + val freqData = Seq(Row("1", 1, 2, "A", "C", Array("gene1"), null)) + + val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), freqSchema) + + val querySchema = variantSchema + .add("symbol", StringType, true) + + val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), querySchema) + + val queryData = Seq(Row("1", 1, 2, "A", "C", "gene1")) + + val resultData = Seq(Row("1", 1, 2, "A", "C", "gene1", Row(true, 0.01, true, false))) + + val result = queryDF.getPM2(omimDF, freqDF) + } + } + + "getPM2" should "throw IllegalArgumentException if `phenotype` column is absent from the OMIM DataFrame" in { + val f = pm2Fixture + + an[IllegalArgumentException] should be thrownBy f.queryDF.getPM2(f.omimDF.drop("phenotype"), f.freqDF) + } + + it should "return the correct PM2 schema" in { + val f = pm2Fixture + + f.result.schema shouldBe f.querySchema + .add("PM2", new StructType() + .add("is_recessive", BooleanType, false) + .add("max_af", DoubleType, false) + .add("max_af_is_null", BooleanType, false) + .add("score", BooleanType, false), false + ) + } + + it should "return missing frequencies as PM2 true" in { + val f = pm2Fixture + + val freqData = Seq( + Row("1", 1, 2, "A", "C", Array("gene1"), null), + Row("2", 1, 2, "A", "C", Array("gene1"), Row(null, Row(0.0, 0))), + Row("3", 1, 2, "A", "C", Array("gene1"), Row(Row(0.0, 1200), Row(0.0, 1000))), + ) + val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema) + + val queryData = Seq( + Row("1", 1, 2, "A", "C", "gene1"), + Row("2", 1, 2, "A", "C", "gene1"), + Row("3", 1, 2, "A", "C", "gene1"), + ) + + val resultData = Seq( + Row("1", 1, 2, "A", "C", "gene1", Row(true, 0.00, true, true)), + Row("2", 1, 2, "A", "C", "gene1", Row(true, 0.00, false, true)), + Row("3", 1, 2, "A", "C", "gene1", Row(true, 0.00, false, true)), + ) + val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema) + val result = queryDF.getPM2(f.omimDF, freqDF) + + result.collect() should contain theSameElementsAs resultData + } + + it should "return low AF in genes with recessive disease as PM2 true " in { + val f = pm2Fixture + + val freqData = Seq( + Row("1", 1, 2, "A", "C", Array("gene4"), Row(Row(0.00001, 1), Row(0.0, 1))), + Row("1", 1, 2, "A", "C", Array("gene2"), Row(Row(0.00001, 1), Row(0.0, 1))), + Row("2", 1, 2, "A", "C", Array("gene2"), Row(Row(0.001, 1), Row(0.0, 1))), + ) + val freqDF = spark.createDataFrame(spark.sparkContext.parallelize(freqData), f.freqSchema) + + val queryData = Seq( + Row("1", 1, 2, "A", "C", "gene4"), + Row("1", 1, 2, "A", "C", "gene2"), + Row("2", 1, 2, "A", "C", "gene2"), + ) + + val resultData = Seq( + Row("1", 1, 2, "A", "C", "gene4", Row(false, 0.00001, false, false)), + Row("1", 1, 2, "A", "C", "gene2", Row(true, 0.00001, false, true)), + Row("2", 1, 2, "A", "C", "gene2", Row(true, 0.001, false, false)), + ) + val queryDF = spark.createDataFrame(spark.sparkContext.parallelize(queryData), f.querySchema) + val result = queryDF.getPM2(f.omimDF, freqDF) + + result.collect() should contain theSameElementsAs resultData + } + }