Skip to content

Commit 63a0477

Browse files
committed
Fix SQL caller unshim compile fallout
1 parent f714012 commit 63a0477

12 files changed

Lines changed: 39 additions & 27 deletions

File tree

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres
280280
override def addExpression(e: Expression): Unit = {
281281
val localOutputLocation = outputLocation
282282
outputLocation += 1
283-
val key = GpuExpressionEquals(e)
283+
val key = new GpuExpressionEquals(e)
284284
if (!toCombine.contains(key)) {
285285
toCombine.put(key, localOutputLocation)
286286
}
@@ -329,7 +329,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres
329329
}
330330

331331
override def getReplacementExpression(e: Expression): Option[Expression] = {
332-
toCombine.get(GpuExpressionEquals(e)).map { localId =>
332+
toCombine.get(new GpuExpressionEquals(e)).map { localId =>
333333
GpuGetStructField(multiGet, localId, Some(fieldName(localId)))
334334
}
335335
}

sql-plugin/src/main/scala/com/nvidia/spark/rapids/literals.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -31,7 +31,7 @@ import com.nvidia.spark.rapids.Arm.withResource
3131
import com.nvidia.spark.rapids.RapidsPluginImplicits.AutoCloseableProducingArray
3232
import com.nvidia.spark.rapids.shims.{GpuTypeShims, SparkShimImpl}
3333
import org.apache.commons.codec.binary.{Hex => ApacheHex}
34-
import org.json4s.JsonAST.{JField, JNull, JString}
34+
import org.json4s.JsonAST.{JField, JNull, JString, JValue}
3535

3636
import org.apache.spark.internal.Logging
3737
import org.apache.spark.sql.catalyst.InternalRow
@@ -685,7 +685,7 @@ case class GpuLiteral (value: Any, dataType: DataType) extends GpuLeafExpression
685685
case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString)
686686
case (other, _) => JString(other.toString)
687687
}
688-
("value" -> jsonValue) :: ("dataType" -> TrampolineUtil.jsonValue(dataType)) :: Nil
688+
("value" -> jsonValue) :: ("dataType" -> TrampolineUtil.jsonValue(dataType).asInstanceOf[JValue]) :: Nil
689689
}
690690

691691
override def sql: String = (value, dataType) match {

sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/BucketingUtilsShim.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ object BucketingUtilsShim {
4343
// table and a normal one.
4444
val bucketIdExpression = GpuHashPartitioning(bucketColumns, spec.numBuckets)
4545
.partitionIdExpression
46-
GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "")
46+
new GpuWriterBucketSpec(bucketIdExpression, (_: Int) => "")
4747
}
4848
}
4949
}

sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/TreeNode.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package com.nvidia.spark.rapids.shims
1818

19-
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, TernaryExpression, UnaryExpression}
19+
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Predicate, TernaryExpression, UnaryExpression}
2020
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryCommand}
2121
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan, UnaryExecNode}
2222
import org.apache.spark.sql.execution.command.DataWritingCommand
@@ -49,6 +49,10 @@ trait ShimTernaryExpression extends TernaryExpression {
4949
}
5050
}
5151

52+
trait ShimPredicate extends Predicate {
53+
def contextIndependentFoldable: Boolean = children.forall(_.foldable)
54+
}
55+
5256
trait ShimSparkPlan extends SparkPlan {
5357
override def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = {
5458
legacyWithNewChildren(newChildren)

sql-plugin/src/main/scala/com/nvidia/spark/rapids/shims/parquet/ParquetSchemaClipShims.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ object ParquetSchemaClipShims {
107107
val scale = decimalLogicalTypeAnnotation.getScale
108108

109109
if (!(maxPrecision == -1 || 1 <= precision && precision <= maxPrecision)) {
110-
throw new RapidsAnalysisException(s"Invalid decimal precision: $typeName " +
110+
throw RapidsAnalysisException(s"Invalid decimal precision: $typeName " +
111111
s"cannot store $precision digits (max $maxPrecision)")
112112
}
113113

@@ -166,14 +166,14 @@ object ParquetSchemaClipShims {
166166
ParquetTimestampAnnotationShims.timestampTypeForMillisOrMicros(timestamp)
167167
case timestamp: TimestampLogicalTypeAnnotation if timestamp.getUnit == TimeUnit.NANOS &&
168168
ParquetLegacyNanoAsLongShims.legacyParquetNanosAsLong =>
169-
throw new RapidsAnalysisException(
169+
throw RapidsAnalysisException(
170170
"GPU does not support spark.sql.legacy.parquet.nanosAsLong")
171171
case _ => illegalType()
172172
}
173173

174174
case INT96 =>
175175
if (!SQLConf.get.isParquetINT96AsTimestamp) {
176-
throw new RapidsAnalysisException(
176+
throw RapidsAnalysisException(
177177
"INT96 is not supported unless it's interpreted as timestamp. " +
178178
s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.")
179179
}

sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ object PythonArgumentUtils {
6464
}
6565
}.toArray
6666
}.toArray
67-
GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, argOffsets, None)
67+
new GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq, argOffsets, None)
6868
}
6969
}

sql-plugin/src/main/spark332db/scala/org/apache/spark/sql/rapids/GpuFileFormatWriter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ trait GpuFileFormatWriterBase extends Serializable with Logging {
533533
private def verifySchema(format: ColumnarFileFormat, schema: StructType): Unit = {
534534
schema.foreach { field =>
535535
if (!format.supportDataType(field.dataType)) {
536-
throw new RapidsAnalysisException(
536+
throw RapidsAnalysisException(
537537
s"$format data source does not support ${field.dataType.catalogString} data type.")
538538
}
539539
}

sql-plugin/src/main/spark350db143/scala/org/apache/spark/sql/rapids/execution/python/shims/PythonArgumentsUtils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,15 @@ object PythonArgumentUtils {
4848
(None, e)
4949
}
5050
if (allInputs.exists(_.semanticEquals(value))) {
51-
GpuArgumentMeta(allInputs.indexWhere(_.semanticEquals(value)), key)
51+
new GpuArgumentMeta(allInputs.indexWhere(_.semanticEquals(value)), key)
5252
} else {
5353
allInputs += value
5454
dataTypes += value.dataType
55-
GpuArgumentMeta(allInputs.length - 1, key)
55+
new GpuArgumentMeta(allInputs.length - 1, key)
5656
}
5757
}.toArray
5858
}.toArray
59-
GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq,
59+
new GpuPythonArguments(allInputs.toSeq, dataTypes.toSeq,
6060
argMetas.map(_.map(_.offset)), Some(argMetas.map(_.map(_.name))))
6161
}
6262
}

tests/src/test/scala/com/nvidia/spark/rapids/GpuSemaphoreSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -86,15 +86,15 @@ class GpuSemaphoreSuite extends AnyFunSuite
8686

8787
def assertAcquired(result: TryAcquireResult): Unit = result match {
8888
case SemaphoreAcquired => // NOOP
89-
case AcquireFailed(_) =>
89+
case _: AcquireFailed =>
9090
fail("The Semaphore was not acquired")
9191
}
9292

9393
def assertNotAcquired(numExpectedWaiting: Int, result: TryAcquireResult): Unit = result match {
9494
case SemaphoreAcquired =>
9595
fail("The Semaphore was acquired when we didn't expect it")
96-
case AcquireFailed(numWaiting) =>
97-
assert(numWaiting == numExpectedWaiting, "The number of waiting tasks didn't match")
96+
case failed: AcquireFailed =>
97+
assert(failed.numWaitingTasks == numExpectedWaiting, "The number of waiting tasks didn't match")
9898
}
9999

100100
test("multi tryAcquire") {

tests/src/test/scala/com/nvidia/spark/rapids/GpuSortRetrySuite.scala

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -176,10 +176,14 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
176176
}
177177

178178
test("GPU each batch sort with GpuRetryOOM") {
179-
val eachBatchIter = GpuSortEachBatchIterator(
179+
val eachBatchIter = new GpuSortEachBatchIterator(
180180
batchIter(2),
181181
gpuSorter,
182-
singleBatch = false)
182+
false,
183+
NoopMetric,
184+
NoopMetric,
185+
NoopMetric,
186+
NoopMetric)
183187
RmmSpark.forceRetryOOM(RmmSpark.getCurrentThreadId, 2,
184188
RmmSpark.OomInjectionType.GPU.ordinal, 0)
185189
while (eachBatchIter.hasNext) {
@@ -201,10 +205,14 @@ class GpuSortRetrySuite extends RmmSparkRetrySuiteBase with MockitoSugar {
201205
test("GPU each batch sort throws GpuSplitAndRetryOOM") {
202206
val inputIter = batchIter(2)
203207
try {
204-
val eachBatchIter = GpuSortEachBatchIterator(
208+
val eachBatchIter = new GpuSortEachBatchIterator(
205209
inputIter,
206210
gpuSorter,
207-
singleBatch = false)
211+
false,
212+
NoopMetric,
213+
NoopMetric,
214+
NoopMetric,
215+
NoopMetric)
208216
RmmSpark.forceSplitAndRetryOOM(RmmSpark.getCurrentThreadId, 1,
209217
RmmSpark.OomInjectionType.GPU.ordinal, 0)
210218
assertThrows[GpuSplitAndRetryOOM] {

0 commit comments

Comments
 (0)