diff --git a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/utils/DeltaUtils.scala b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/utils/DeltaUtils.scala index 339c131b..96e6b500 100644 --- a/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/utils/DeltaUtils.scala +++ b/datalake-spark3/src/main/scala/bio/ferlab/datalake/spark3/utils/DeltaUtils.scala @@ -171,10 +171,18 @@ object DeltaUtils { * +---------------+-----------------------------------+ * }}} */ - def getPartitionValues(path: String)(implicit spark: SparkSession): DataFrame = { + def getPartitionValues(path: String)(implicit spark: SparkSession): DataFrame = + getPartitionValues(getTableStats(path)) + + /** + * @see Overload of [[getPartitionValues]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getPartitionValues(statsDf: DataFrame)(implicit spark: SparkSession): DataFrame = { import spark.implicits._ - getTableStats(path) + statsDf .select(explode($"partitionValues") as Seq("partitionColumn", "value")) .groupBy("partitionColumn") .agg(array_sort(collect_set("value")) as "values") @@ -187,10 +195,18 @@ object DeltaUtils { * @param spark Spark session * @return The total number of records in the table. */ - def getNumRecords(path: String)(implicit spark: SparkSession): Long = { + def getNumRecords(path: String)(implicit spark: SparkSession): Long = + getNumRecords(getTableStats(path)) + + /** + * @see Overload of [[getNumRecords]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getNumRecords(statsDf: DataFrame)(implicit spark: SparkSession): Long = { import spark.implicits._ - getTableStats(path) + statsDf .select(sum("stats.numRecords")) .as[Long] .head() @@ -219,9 +235,17 @@ object DeltaUtils { * +---------------+--------------+----------+ * }}} */ - def getNumRecordsPerPartition(path: String)(implicit spark: SparkSession): DataFrame = { + def getNumRecordsPerPartition(path: String)(implicit spark: SparkSession): DataFrame = + getNumRecordsPerPartition(getTableStats(path)) + + /** + * @see Overload of [[getNumRecordsPerPartition]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getNumRecordsPerPartition(statsDf: DataFrame)(implicit spark: SparkSession): DataFrame = { val numRecords = sum("stats.numRecords") as "numRecords" - getStatPerPartition(getTableStats(path), numRecords) + getStatPerPartition(statsDf, numRecords) } /** @@ -240,12 +264,19 @@ object DeltaUtils { * ) * }}} */ - def getMinValues(path: String)(implicit spark: SparkSession): Map[String, Any] = { - val statsDf = getTableStats(path).select("stats.minValues.*") - val columnNames = statsDf.columns - val minValues = statsDf.columns.map(c => min(c) as c) + def getMinValues(path: String)(implicit spark: SparkSession): Map[String, Any] = + getMinValues(getTableStats(path)) + + /** + * @see Overload of [[getMinValues]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getMinValues(statsDf: DataFrame)(implicit spark: SparkSession): Map[String, Any] = { + val minValuesDf = statsDf.select("stats.minValues.*") + val minValues = minValuesDf.columns.map(c => min(c) as c) - getMap[String](statsDf, minValues) + getMap[String](minValuesDf, minValues) } /** @@ -278,8 +309,15 @@ object DeltaUtils { * +---------------+--------------+------------+------------+------------+ * }}} */ - def getMinValuesPerPartition(path: String)(implicit spark: SparkSession): DataFrame = { - val statsDf = getTableStats(path) + def getMinValuesPerPartition(path: String)(implicit spark: SparkSession): DataFrame = + getMinValuesPerPartition(getTableStats(path)) + + /** + * @see Overload of [[getMinValuesPerPartition]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getMinValuesPerPartition(statsDf: DataFrame)(implicit spark: SparkSession): DataFrame = { val columnNames = statsDf.select("stats.minValues.*").columns val minValues = columnNames.map(c => min(s"stats.minValues.$c") as c) @@ -302,12 +340,19 @@ object DeltaUtils { * ) * }}} */ - def getMaxValues(path: String)(implicit spark: SparkSession): Map[String, Any] = { - val statsDf = getTableStats(path).select("stats.maxValues.*") - val columnNames = statsDf.columns - val maxValues = columnNames.map(c => max(c) as c) + def getMaxValues(path: String)(implicit spark: SparkSession): Map[String, Any] = + getMaxValues(getTableStats(path)) + + /** + * @see Overload of [[getMaxValues]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getMaxValues(statsDf: DataFrame)(implicit spark: SparkSession): Map[String, Any] = { + val maxValuesDf = statsDf.select("stats.maxValues.*") + val maxValues = maxValuesDf.columns.map(c => max(c) as c) - getMap[Any](statsDf, maxValues) + getMap[Any](maxValuesDf, maxValues) } /** @@ -340,9 +385,15 @@ object DeltaUtils { * +---------------+--------------+------------+------------+------------+ * }}} */ - def getMaxValuesPerPartition(path: String)(implicit spark: SparkSession): DataFrame = { + def getMaxValuesPerPartition(path: String)(implicit spark: SparkSession): DataFrame = + getMaxValuesPerPartition(getTableStats(path)) - val statsDf = getTableStats(path) + /** + * @see Overload of [[getMaxValuesPerPartition]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getMaxValuesPerPartition(statsDf: DataFrame)(implicit spark: SparkSession): DataFrame = { val columnNames = statsDf.select("stats.maxValues.*").columns val maxValues = columnNames.map(c => max(s"stats.maxValues.$c") as c) @@ -365,12 +416,19 @@ object DeltaUtils { * ) * }}} */ - def getNullCounts(path: String)(implicit spark: SparkSession): Map[String, Long] = { - val statsDf = getTableStats(path).select("stats.nullCount.*") - val columnNames = statsDf.columns - val nullCounts = columnNames.map(c => sum(c) as c) + def getNullCounts(path: String)(implicit spark: SparkSession): Map[String, Long] = + getNullCounts(getTableStats(path)) + + /** + * @see Overload of [[getNullCounts]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getNullCounts(statsDf: DataFrame)(implicit spark: SparkSession): Map[String, Long] = { + val nullCountsDf = statsDf.select("stats.nullCount.*") + val nullCounts = nullCountsDf.columns.map(c => sum(c) as c) - getMap[Long](statsDf, nullCounts) + getMap[Long](nullCountsDf, nullCounts) } /** @@ -403,9 +461,15 @@ object DeltaUtils { * +---------------+--------------+-----+---+------------+ * }}} */ - def getNullCountsPerPartition(path: String)(implicit spark: SparkSession): DataFrame = { + def getNullCountsPerPartition(path: String)(implicit spark: SparkSession): DataFrame = + getNullCountsPerPartition(getTableStats(path)) - val statsDf = getTableStats(path) + /** + * @see Overload of [[getNullCountsPerPartition]]. + * @param statsDf Pre-computed stats DataFrame from [[getTableStats]]. Pass this to reuse + * an already-loaded snapshot across multiple calls. + */ + def getNullCountsPerPartition(statsDf: DataFrame)(implicit spark: SparkSession): DataFrame = { val columnNames = statsDf.select("stats.nullCount.*").columns val nullCounts = columnNames.map(c => sum(s"stats.nullCount.$c") as c)