Skip to content

Commit 52b43d6

Browse files
authored
Merge pull request #11339 from NVIDIA/merge-branch-24.08-to-main
Merge branch-24.08 into main
2 parents fd331a5 + d60008b commit 52b43d6

File tree

4 files changed

+111
-34
lines changed

4 files changed

+111
-34
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Change log
2-
Generated on 2024-08-12
2+
Generated on 2024-08-16
33

44
## Release 24.08
55

@@ -88,6 +88,8 @@ Generated on 2024-08-12
8888
### PRs
8989
|||
9090
|:---|:---|
91+
|[#11335](https://github.com/NVIDIA/spark-rapids/pull/11335)|Fix Delta Lake truncation of min/max string values|
92+
|[#11304](https://github.com/NVIDIA/spark-rapids/pull/11304)|Update changelog for v24.08.0 release [skip ci]|
9193
|[#11303](https://github.com/NVIDIA/spark-rapids/pull/11303)|Update rapids JNI and private dependency to 24.08.0|
9294
|[#11296](https://github.com/NVIDIA/spark-rapids/pull/11296)|[DOC] update doc for 2408 release [skip CI]|
9395
|[#11309](https://github.com/NVIDIA/spark-rapids/pull/11309)|[Doc ]Update lore doc about the range [skip ci]|

delta-lake/common/src/main/scala/com/nvidia/spark/rapids/delta/GpuStatisticsCollection.scala

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
33
*
44
* This file was derived from StatisticsCollection.scala
55
* in the Delta Lake project at https://github.com/delta-io/delta.
@@ -31,7 +31,7 @@ import com.nvidia.spark.rapids.delta.shims.{ShimDeltaColumnMapping, ShimDeltaUDF
3131
import org.apache.spark.sql.{Column, SparkSession}
3232
import org.apache.spark.sql.catalyst.InternalRow
3333
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
34-
import org.apache.spark.sql.functions.{count, lit, max, min, struct, substring, sum, when}
34+
import org.apache.spark.sql.functions.{count, lit, max, min, struct, sum, when}
3535
import org.apache.spark.sql.types._
3636
import org.apache.spark.sql.vectorized.ColumnarBatch
3737

@@ -87,7 +87,9 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields {
8787
collectStats(MIN, statCollectionSchema) {
8888
// Truncate string min values as necessary
8989
case (c, GpuSkippingEligibleDataType(StringType), true) =>
90-
substring(min(c), 0, stringPrefixLength)
90+
val udfTruncateMin = ShimDeltaUDF.stringStringUdf(
91+
GpuStatisticsCollection.truncateMinStringAgg(prefixLength)_)
92+
udfTruncateMin(min(c))
9193

9294
// Collect all numeric min values
9395
case (c, GpuSkippingEligibleDataType(_), true) =>
@@ -203,25 +205,76 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields {
203205
}
204206

205207
object GpuStatisticsCollection {
208+
val ASCII_MAX_CHARACTER = '\u007F'
209+
210+
val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT))
211+
212+
def truncateMinStringAgg(prefixLen: Int)(input: String): String = {
213+
if (input == null || input.length <= prefixLen) {
214+
return input
215+
}
216+
if (prefixLen <= 0) {
217+
return null
218+
}
219+
if (Character.isHighSurrogate(input.charAt(prefixLen - 1)) &&
220+
Character.isLowSurrogate(input.charAt(prefixLen))) {
221+
// If the character at prefixLen - 1 is a high surrogate and the next character is a low
222+
// surrogate, we need to include the next character in the prefix to ensure that we don't
223+
// truncate the string in the middle of a surrogate pair.
224+
input.take(prefixLen + 1)
225+
} else {
226+
input.take(prefixLen)
227+
}
228+
}
229+
206230
/**
207-
* Helper method to truncate the input string `x` to the given `prefixLen` length, while also
208-
* appending the unicode max character to the end of the truncated string. This ensures that any
209-
* value in this column is less than or equal to the max.
231+
* Helper method to truncate the input string `input` to the given `prefixLen` length, while also
232+
* ensuring the any value in this column is less than or equal to the truncated max in UTF-8
233+
* encoding.
210234
*/
211-
def truncateMaxStringAgg(prefixLen: Int)(x: String): String = {
212-
if (x == null || x.length <= prefixLen) {
213-
x
214-
} else {
215-
// Grab the prefix. We want to append `\ufffd` as a tie-breaker, but that is only safe
216-
// if the character we truncated was smaller. Keep extending the prefix until that
217-
// condition holds, or we run off the end of the string.
218-
// scalastyle:off nonascii
219-
val tieBreaker = '\ufffd'
220-
x.take(prefixLen) + x.substring(prefixLen).takeWhile(_ >= tieBreaker) + tieBreaker
221-
// scalastyle:off nonascii
235+
def truncateMaxStringAgg(prefixLen: Int)(originalMax: String): String = {
236+
// scalastyle:off nonascii
237+
if (originalMax == null || originalMax.length <= prefixLen) {
238+
return originalMax
222239
}
240+
if (prefixLen <= 0) {
241+
return null
242+
}
243+
244+
// Grab the prefix. We want to append max Unicode code point `\uDBFF\uDFFF` as a tie-breaker,
245+
// but that is only safe if the character we truncated was smaller in UTF-8 encoded binary
246+
// comparison. Keep extending the prefix until that condition holds, or we run off the end of
247+
// the string.
248+
// We also try to use the ASCII max character `\u007F` as a tie-breaker if possible.
249+
val maxLen = getExpansionLimit(prefixLen)
250+
// Start with a valid prefix
251+
var currLen = truncateMinStringAgg(prefixLen)(originalMax).length
252+
while (currLen <= maxLen) {
253+
if (currLen >= originalMax.length) {
254+
// Return originalMax if we have reached the end of the string
255+
return originalMax
256+
} else if (currLen + 1 < originalMax.length &&
257+
originalMax.substring(currLen, currLen + 2) == UTF8_MAX_CHARACTER) {
258+
// Skip the UTF-8 max character. It occupies two characters in a Scala string.
259+
currLen += 2
260+
} else if (originalMax.charAt(currLen) < ASCII_MAX_CHARACTER) {
261+
return originalMax.take(currLen) + ASCII_MAX_CHARACTER
262+
} else {
263+
return originalMax.take(currLen) + UTF8_MAX_CHARACTER
264+
}
265+
}
266+
267+
// Return null when the input string is too long to truncate.
268+
null
269+
// scalastyle:on nonascii
223270
}
224271

272+
/**
273+
* Calculates the upper character limit when constructing a maximum is not possible with only
274+
* prefixLen chars.
275+
*/
276+
private def getExpansionLimit(prefixLen: Int): Int = 2 * prefixLen
277+
225278
def batchStatsToRow(
226279
schema: StructType,
227280
explodedDataSchema: Map[Seq[String], Int],

integration_tests/src/main/python/delta_lake_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ def json_to_sort_key(j):
123123
jsons.sort(key=json_to_sort_key)
124124
return jsons
125125

126+
def read_delta_logs(spark, path):
127+
log_data = spark.sparkContext.wholeTextFiles(path).collect()
128+
return dict([(os.path.basename(x), _decode_jsons(y)) for x, y in log_data])
129+
126130
def assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path):
127131
cpu_log_data = spark.sparkContext.wholeTextFiles(data_path + "/CPU/_delta_log/*").collect()
128132
gpu_log_data = spark.sparkContext.wholeTextFiles(data_path + "/GPU/_delta_log/*").collect()

integration_tests/src/main/python/delta_lake_write_test.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import pyspark.sql.functions as f
1617
import pytest
17-
import sys
1818

1919
from asserts import *
2020
from data_gen import *
@@ -628,27 +628,45 @@ def gen_bad_data(spark):
628628
@allow_non_gpu(*delta_meta_allow)
629629
@delta_lake
630630
@ignore_order
631-
@pytest.mark.parametrize("num_cols", [-1, 0, 1, 2, 3 ], ids=idfn)
632631
@pytest.mark.skipif(is_before_spark_320(), reason="Delta Lake writes are not supported before Spark 3.2.x")
633-
def test_delta_write_stat_column_limits(num_cols, spark_tmp_path):
632+
def test_delta_write_stat_column_limits(spark_tmp_path):
634633
data_path = spark_tmp_path + "/DELTA_DATA"
635634
confs = copy_and_update(delta_writes_enabled_conf, {"spark.databricks.io.skipping.stringPrefixLength": 8})
636-
strgen = StringGen() \
637-
.with_special_case((chr(sys.maxunicode) * 7) + "abc") \
638-
.with_special_case((chr(sys.maxunicode) * 8) + "abc") \
639-
.with_special_case((chr(sys.maxunicode) * 16) + "abc") \
640-
.with_special_case(('\U0000FFFD' * 7) + "abc") \
641-
.with_special_case(('\U0000FFFD' * 8) + "abc") \
642-
.with_special_case(('\U0000FFFD' * 16) + "abc")
643-
gens = [("a", StructGen([("x", strgen), ("y", StructGen([("z", strgen)]))])),
644-
("b", binary_gen),
645-
("c", strgen)]
635+
# maximum unicode codepoint and maximum ascii character
636+
umax, amax = chr(1114111), chr(0x7f)
637+
expected_min = {"a": "abcdefgh", "b": "abcdefg�", "c": "abcdefgh",
638+
"d": "abcdefgh", "e": umax * 4, "f": umax * 4}
639+
# no max expected for column f since it cannot be truncated to 8 characters and remain
640+
# larger than the original value
641+
expected_max = {"a": "bcdefghi", "b": "bcdefgh�", "c": "bcdefghi" + amax,
642+
"d": "bcdefghi" + umax, "e": umax * 8}
643+
def write_table(spark, path):
644+
df = spark.createDataFrame([
645+
("bcdefghi", "abcdefg�", "bcdefghijk", "abcdefgh�", umax * 4, umax * 9),
646+
("abcdefgh", "bcdefgh�", "abcdefghij", "bcdefghi�", umax * 8, umax * 9)],
647+
"a string, b string, c string, d string, e string, f string")
648+
df.repartition(1).write.format("delta").save(path)
649+
def verify_stat_limits(spark):
650+
log_data = read_delta_logs(spark, data_path + "/GPU/_delta_log/*.json")
651+
assert len(log_data) == 1, "GPU should generate exactly one Delta log"
652+
json_objs = list(log_data.values())[0]
653+
json_adds = [x["add"] for x in json_objs if "add" in x]
654+
assert len(json_adds) == 1, "GPU should only generate a single add in Delta log"
655+
stats = json.loads(json_adds[0]["stats"])
656+
actual_min = stats["minValues"]
657+
assert expected_min == actual_min, \
658+
f"minValues mismatch, expected: {expected_min} actual: {actual_min}"
659+
actual_max = stats["maxValues"]
660+
assert expected_max == actual_max, \
661+
f"maxValues stats mismatch, expected: {expected_max} actual: {actual_max}"
646662
assert_gpu_and_cpu_writes_are_equal_collect(
647-
lambda spark, path: gen_df(spark, gens).coalesce(1).write.format("delta").save(path),
648-
lambda spark, path: spark.read.format("delta").load(path),
663+
write_table,
664+
read_delta_path,
649665
data_path,
650666
conf=confs)
651-
with_cpu_session(lambda spark: assert_gpu_and_cpu_delta_logs_equivalent(spark, data_path))
667+
# Many Delta Lake versions are missing the fix from https://github.com/delta-io/delta/pull/3430
668+
# so instead of a full delta log compare with the CPU, focus on the reported statistics on GPU.
669+
with_cpu_session(verify_stat_limits)
652670

653671
@allow_non_gpu("CreateTableExec", *delta_meta_allow)
654672
@delta_lake

0 commit comments

Comments
 (0)