|
1 | 1 | /* |
2 | | - * Copyright (c) 2022-2023, NVIDIA CORPORATION. |
| 2 | + * Copyright (c) 2022-2024, NVIDIA CORPORATION. |
3 | 3 | * |
4 | 4 | * This file was derived from StatisticsCollection.scala |
5 | 5 | * in the Delta Lake project at https://github.com/delta-io/delta. |
@@ -31,7 +31,7 @@ import com.nvidia.spark.rapids.delta.shims.{ShimDeltaColumnMapping, ShimDeltaUDF |
31 | 31 | import org.apache.spark.sql.{Column, SparkSession} |
32 | 32 | import org.apache.spark.sql.catalyst.InternalRow |
33 | 33 | import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute |
34 | | -import org.apache.spark.sql.functions.{count, lit, max, min, struct, substring, sum, when} |
| 34 | +import org.apache.spark.sql.functions.{count, lit, max, min, struct, sum, when} |
35 | 35 | import org.apache.spark.sql.types._ |
36 | 36 | import org.apache.spark.sql.vectorized.ColumnarBatch |
37 | 37 |
|
@@ -87,7 +87,9 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields { |
87 | 87 | collectStats(MIN, statCollectionSchema) { |
88 | 88 | // Truncate string min values as necessary |
89 | 89 | case (c, GpuSkippingEligibleDataType(StringType), true) => |
90 | | - substring(min(c), 0, stringPrefixLength) |
| 90 | + val udfTruncateMin = ShimDeltaUDF.stringStringUdf( |
| 91 | + GpuStatisticsCollection.truncateMinStringAgg(prefixLength)_) |
| 92 | + udfTruncateMin(min(c)) |
91 | 93 |
|
92 | 94 | // Collect all numeric min values |
93 | 95 | case (c, GpuSkippingEligibleDataType(_), true) => |
@@ -203,25 +205,76 @@ trait GpuStatisticsCollection extends ShimUsesMetadataFields { |
203 | 205 | } |
204 | 206 |
|
205 | 207 | object GpuStatisticsCollection { |
| 208 | + val ASCII_MAX_CHARACTER = '\u007F' |
| 209 | + |
| 210 | + val UTF8_MAX_CHARACTER = new String(Character.toChars(Character.MAX_CODE_POINT)) |
| 211 | + |
| 212 | + def truncateMinStringAgg(prefixLen: Int)(input: String): String = { |
| 213 | + if (input == null || input.length <= prefixLen) { |
| 214 | + return input |
| 215 | + } |
| 216 | + if (prefixLen <= 0) { |
| 217 | + return null |
| 218 | + } |
| 219 | + if (Character.isHighSurrogate(input.charAt(prefixLen - 1)) && |
| 220 | + Character.isLowSurrogate(input.charAt(prefixLen))) { |
| 221 | + // If the character at prefixLen - 1 is a high surrogate and the next character is a low |
| 222 | + // surrogate, we need to include the next character in the prefix to ensure that we don't |
| 223 | + // truncate the string in the middle of a surrogate pair. |
| 224 | + input.take(prefixLen + 1) |
| 225 | + } else { |
| 226 | + input.take(prefixLen) |
| 227 | + } |
| 228 | + } |
| 229 | + |
206 | 230 | /** |
207 | | - * Helper method to truncate the input string `x` to the given `prefixLen` length, while also |
208 | | - * appending the unicode max character to the end of the truncated string. This ensures that any |
209 | | - * value in this column is less than or equal to the max. |
| 231 | + * Helper method to truncate the input string `input` to the given `prefixLen` length, while also |
| 232 | + * ensuring the any value in this column is less than or equal to the truncated max in UTF-8 |
| 233 | + * encoding. |
210 | 234 | */ |
211 | | - def truncateMaxStringAgg(prefixLen: Int)(x: String): String = { |
212 | | - if (x == null || x.length <= prefixLen) { |
213 | | - x |
214 | | - } else { |
215 | | - // Grab the prefix. We want to append `\ufffd` as a tie-breaker, but that is only safe |
216 | | - // if the character we truncated was smaller. Keep extending the prefix until that |
217 | | - // condition holds, or we run off the end of the string. |
218 | | - // scalastyle:off nonascii |
219 | | - val tieBreaker = '\ufffd' |
220 | | - x.take(prefixLen) + x.substring(prefixLen).takeWhile(_ >= tieBreaker) + tieBreaker |
221 | | - // scalastyle:off nonascii |
| 235 | + def truncateMaxStringAgg(prefixLen: Int)(originalMax: String): String = { |
| 236 | + // scalastyle:off nonascii |
| 237 | + if (originalMax == null || originalMax.length <= prefixLen) { |
| 238 | + return originalMax |
222 | 239 | } |
| 240 | + if (prefixLen <= 0) { |
| 241 | + return null |
| 242 | + } |
| 243 | + |
| 244 | + // Grab the prefix. We want to append max Unicode code point `\uDBFF\uDFFF` as a tie-breaker, |
| 245 | + // but that is only safe if the character we truncated was smaller in UTF-8 encoded binary |
| 246 | + // comparison. Keep extending the prefix until that condition holds, or we run off the end of |
| 247 | + // the string. |
| 248 | + // We also try to use the ASCII max character `\u007F` as a tie-breaker if possible. |
| 249 | + val maxLen = getExpansionLimit(prefixLen) |
| 250 | + // Start with a valid prefix |
| 251 | + var currLen = truncateMinStringAgg(prefixLen)(originalMax).length |
| 252 | + while (currLen <= maxLen) { |
| 253 | + if (currLen >= originalMax.length) { |
| 254 | + // Return originalMax if we have reached the end of the string |
| 255 | + return originalMax |
| 256 | + } else if (currLen + 1 < originalMax.length && |
| 257 | + originalMax.substring(currLen, currLen + 2) == UTF8_MAX_CHARACTER) { |
| 258 | + // Skip the UTF-8 max character. It occupies two characters in a Scala string. |
| 259 | + currLen += 2 |
| 260 | + } else if (originalMax.charAt(currLen) < ASCII_MAX_CHARACTER) { |
| 261 | + return originalMax.take(currLen) + ASCII_MAX_CHARACTER |
| 262 | + } else { |
| 263 | + return originalMax.take(currLen) + UTF8_MAX_CHARACTER |
| 264 | + } |
| 265 | + } |
| 266 | + |
| 267 | + // Return null when the input string is too long to truncate. |
| 268 | + null |
| 269 | + // scalastyle:on nonascii |
223 | 270 | } |
224 | 271 |
|
| 272 | + /** |
| 273 | + * Calculates the upper character limit when constructing a maximum is not possible with only |
| 274 | + * prefixLen chars. |
| 275 | + */ |
| 276 | + private def getExpansionLimit(prefixLen: Int): Int = 2 * prefixLen |
| 277 | + |
225 | 278 | def batchStatsToRow( |
226 | 279 | schema: StructType, |
227 | 280 | explodedDataSchema: Map[Seq[String], Int], |
|
0 commit comments