Skip to content

Commit b7a2605

Browse files
committed
Adapt remaining SQL plugin callers to Java helpers
1 parent f302107 commit b7a2605

32 files changed

Lines changed: 292 additions & 239 deletions

sql-plugin/src/main/scala/com/nvidia/spark/rapids/AllocationRetryCoverageTracker.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import java.util.regex.Pattern
2323

2424
import com.nvidia.spark.rapids.Arm.withResource
2525

26-
import org.apache.spark.internal.Logging
2726

2827
/**
2928
* Memory allocation kind for retry coverage tracking.
@@ -62,7 +61,17 @@ object AllocationKind extends Enumeration {
6261
*
6362
* See: https://github.com/NVIDIA/spark-rapids/issues/13672
6463
*/
65-
object AllocationRetryCoverageTracker extends Logging {
64+
object AllocationRetryCoverageTracker {
65+
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))
66+
67+
private def logWarning(msg: => String): Unit = {
68+
log.warn(msg)
69+
}
70+
71+
private def logError(msg: => String, throwable: Throwable): Unit = {
72+
log.error(msg, throwable)
73+
}
74+
6675
import AllocationKind._
6776

6877
// Environment variable to enable retry coverage tracking (debug-only).

sql-plugin/src/main/scala/com/nvidia/spark/rapids/ArrayUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
1919
import org.apache.spark.sql.catalyst.expressions.{ArrayDistinct, Expression}
2020
import org.apache.spark.sql.rapids.GpuArrayDistinct
2121

22-
case class GpuArrayDistinctMeta(
22+
class GpuArrayDistinctMeta(
2323
expr: ArrayDistinct,
2424
override val conf: RapidsConf,
2525
parentMetaOpt: Option[RapidsMeta[_, _, _]],

sql-plugin/src/main/scala/com/nvidia/spark/rapids/DateUtils.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,12 @@
1616

1717
package com.nvidia.spark.rapids
1818

19-
import java.time.LocalDate
19+
import java.time.{Instant, LocalDate}
2020

2121
import scala.collection.mutable.ListBuffer
2222

2323
import ai.rapids.cudf.{DType, Scalar}
2424
import com.nvidia.spark.rapids.VersionUtils.isSpark320OrLater
25-
import com.nvidia.spark.rapids.shims.DateTimeUtilsShims
2625

2726
import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateToDays
2827
import org.apache.spark.sql.internal.SQLConf
@@ -53,6 +52,11 @@ object DateUtils {
5352

5453
val ONE_SECOND_MICROSECONDS = 1000000
5554

55+
private def currentTimestampMicros: Long = {
56+
val instant = Instant.now()
57+
instant.getEpochSecond * ONE_SECOND_MICROSECONDS + instant.getNano / 1000
58+
}
59+
5660
val ONE_DAY_SECONDS = 86400L
5761

5862
val ONE_DAY_MICROSECONDS = 86400000000L
@@ -80,7 +84,7 @@ object DateUtils {
8084
Map.empty
8185
} else {
8286
val today = currentDate()
83-
val now = DateTimeUtilsShims.currentTimestamp
87+
val now = currentTimestampMicros
8488
Map(
8589
EPOCH -> 0,
8690
NOW -> now / 1000000L,
@@ -94,7 +98,7 @@ object DateUtils {
9498
Map.empty
9599
} else {
96100
val today = currentDate()
97-
val now = DateTimeUtilsShims.currentTimestamp
101+
val now = currentTimestampMicros
98102
Map(
99103
EPOCH -> 0,
100104
NOW -> now,

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuGetJsonObject.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres
280280
override def addExpression(e: Expression): Unit = {
281281
val localOutputLocation = outputLocation
282282
outputLocation += 1
283-
val key = GpuExpressionEquals(e)
283+
val key = new GpuExpressionEquals(e)
284284
if (!toCombine.contains(key)) {
285285
toCombine.put(key, localOutputLocation)
286286
}
@@ -329,7 +329,7 @@ class GetJsonObjectCombiner(private val exp: GpuGetJsonObject) extends GpuExpres
329329
}
330330

331331
override def getReplacementExpression(e: Expression): Option[Expression] = {
332-
toCombine.get(GpuExpressionEquals(e)).map { localId =>
332+
toCombine.get(new GpuExpressionEquals(e)).map { localId =>
333333
GpuGetStructField(multiGet, localId, Some(fieldName(localId)))
334334
}
335335
}

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuInSet.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -18,14 +18,15 @@ package com.nvidia.spark.rapids
1818

1919
import ai.rapids.cudf.{ColumnVector, DType, Scalar}
2020
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
21+
import com.nvidia.spark.rapids.shims.ShimPredicate
2122

22-
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, Predicate}
23+
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
2324
import org.apache.spark.sql.internal.SQLConf
2425
import org.apache.spark.sql.types.{DoubleType, FloatType}
2526

2627
case class GpuInSet(
2728
child: Expression,
28-
list: Seq[Any]) extends GpuUnaryExpression with Predicate {
29+
list: Seq[Any]) extends GpuUnaryExpression with ShimPredicate {
2930
require(list != null, "list should not be null")
3031

3132
@transient private[this] lazy val hasNull: Boolean = list.contains(null)

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMapUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -177,7 +177,7 @@ object GpuMapUtils {
177177

178178
}
179179

180-
case class GpuMapFromArraysMeta(expr: MapFromArrays,
180+
class GpuMapFromArraysMeta(expr: MapFromArrays,
181181
override val conf: RapidsConf,
182182
override val parent: Option[RapidsMeta[_, _, _]],
183183
rule: DataFromReplacementRule)

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuMetrics.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import scala.collection.immutable.TreeMap
2121
import com.nvidia.spark.rapids.metrics.GpuBubbleTimerManager
2222

2323
import org.apache.spark.{SparkContext, TaskContext}
24-
import org.apache.spark.internal.Logging
2524
import org.apache.spark.sql.SparkSession
2625
import org.apache.spark.sql.catalyst.expressions.Expression
2726
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -86,7 +85,7 @@ class GpuMetricFactory(metricsConf: MetricsLevel, context: SparkContext) {
8685
createInternal(level, SQLMetrics.createTimingMetric(context, name))
8786
}
8887

89-
object GpuMetric extends Logging {
88+
object GpuMetric {
9089
// Metric names.
9190
val BUFFER_TIME = "bufferTime"
9291
val BUFFER_TIME_BUBBLE = "bufferTimeBubble"

sql-plugin/src/main/scala/com/nvidia/spark/rapids/MemoryChecker.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -25,7 +25,6 @@ import scala.util.{Failure, Success, Try}
2525
import com.nvidia.spark.rapids.Arm.withResource
2626

2727
import org.apache.spark.SparkConf
28-
import org.apache.spark.internal.Logging
2928

3029
trait MemoryChecker {
3130
def getAvailableMemoryBytes(rapidsConf: RapidsConf): Option[Long]
@@ -38,7 +37,19 @@ trait MemoryChecker {
3837
* on which it checks corresponding files, env variables, etc. for memory usage
3938
* and limits.
4039
*/
41-
object MemoryCheckerImpl extends MemoryChecker with Logging {
40+
object MemoryCheckerImpl extends MemoryChecker {
41+
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))
42+
43+
private def logInfo(msg: => String): Unit = {
44+
if (log.isInfoEnabled) {
45+
log.info(msg)
46+
}
47+
}
48+
49+
private def logWarning(msg: => String): Unit = {
50+
log.warn(msg)
51+
}
52+
4253
def main(args: Array[String]): Unit = {
4354
val conf = new RapidsConf(new SparkConf())
4455
println(s"Available memory: ${getAvailableMemoryBytes(conf)} bytes")

sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxRangeWithDoc.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -22,18 +22,17 @@ import scala.collection.mutable
2222

2323
import ai.rapids.cudf.{NvtxColor, NvtxRange}
2424

25-
import org.apache.spark.internal.Logging
26-
27-
object RangeDebugger extends Logging {
25+
object RangeDebugger {
26+
private val log = org.slf4j.LoggerFactory.getLogger(getClass.getName.stripSuffix("$"))
2827
val threadLocalStack = new ThreadLocal[mutable.ArrayStack[NvtxId]] {
2928
override def initialValue(): mutable.ArrayStack[NvtxId] = mutable.ArrayStack[NvtxId]()
3029
}
3130

3231
private def dumpOrderErrorMessage(popped: Option[NvtxId], elem: NvtxId): Unit = {
33-
logError(s"OUT OF ORDER POP of $elem")
34-
logError(s"TOP OF STACK IS ${popped.getOrElse("<nil>")}")
32+
log.error(s"OUT OF ORDER POP of $elem")
33+
log.error(s"TOP OF STACK IS ${popped.getOrElse("<nil>")}")
3534
val stackTrace = Thread.currentThread.getStackTrace
36-
stackTrace.foreach(elem => logError(elem.toString))
35+
stackTrace.foreach(elem => log.error(elem.toString))
3736
}
3837

3938
def push(elem: NvtxId): Unit = {

sql-plugin/src/main/scala/com/nvidia/spark/rapids/NvtxWithMetrics.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2019-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2019-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -76,11 +76,11 @@ object NvtxIdWithMetrics {
7676
}
7777
}
7878

79-
class MetricRange(val metrics: Seq[GpuMetric], val excludeMetric: Seq[GpuMetric] = Seq.empty)
79+
class MetricRange(val metrics: Seq[GpuMetric], val excludeMetric: Seq[GpuMetric])
8080
extends AutoCloseable {
8181

8282
// add a convenient constructor
83-
def this(metrics: GpuMetric*) = this(metrics.toSeq)
83+
def this(metrics: GpuMetric*) = this(metrics.toSeq, Seq.empty)
8484

8585
val needTracks = metrics.map(_.tryActivateTimer(excludeMetric))
8686
private val start = System.nanoTime()

0 commit comments

Comments
 (0)