Skip to content

Commit 8a3927f

Browse files
committed
Adapt plugin metadata callers to Java helpers
Signed-off-by: Gera Shegalov <gshegalov@nvidia.com>
1 parent add1afc commit 8a3927f

20 files changed

Lines changed: 328 additions & 153 deletions

File tree

delta-lake/common/src/main/delta-io/scala/org/apache/spark/sql/delta/rapids/GpuCheckDeltaInvariant.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION.
33
*
44
* This file was derived from CheckDeltaInvariant.scala in the
55
* Delta Lake project at https://github.com/delta-io/delta.
@@ -132,8 +132,8 @@ object GpuCheckDeltaInvariant extends Logging {
132132
ExprChecks.projectOnly(
133133
TypeSig.all,
134134
TypeSig.all,
135-
paramCheck = Seq(ParamCheck("input", TypeSig.all, TypeSig.all)),
136-
repeatingParamCheck = Some(RepeatingParamCheck("extra", TypeSig.all, TypeSig.all))
135+
paramCheck = Seq(new ParamCheck("input", TypeSig.all, TypeSig.all)),
136+
repeatingParamCheck = Some(new RepeatingParamCheck("extra", TypeSig.all, TypeSig.all))
137137
),
138138
(c, conf, p, r) => new GpuCheckDeltaInvariantMeta(c, conf, p, r))
139139

sql-plugin/src/main/scala/com/nvidia/spark/rapids/AstUtil.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2023, NVIDIA CORPORATION.
2+
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -110,7 +110,7 @@ object AstUtil {
110110
val gpuExpr = expr.convertToGpu()
111111

112112
// Check if we've already processed this expression (for deduplication)
113-
processed.get(GpuExpressionEquals(gpuExpr)) match {
113+
processed.get(new GpuExpressionEquals(gpuExpr)) match {
114114
case Some(replacement) =>
115115
replacement
116116
case None =>
@@ -135,7 +135,7 @@ object AstUtil {
135135
// Create an AttributeReference explicitly to avoid issues with unresolved aliases
136136
val attributeRef = AttributeReference(alias.name, gpuExpr.dataType,
137137
gpuExpr.nullable, alias.metadata)(alias.exprId, alias.qualifier)
138-
processed.put(GpuExpressionEquals(gpuExpr), attributeRef)
138+
processed.put(new GpuExpressionEquals(gpuExpr), attributeRef)
139139
attributeRef
140140
}
141141
} else {

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuPostHocResolutionOverrides.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
2626
* phase by `SparkSessionExtensions.injectPostHocResolutionRule`. As its name suggests, it will
2727
* be applied after the logical plan has been resolved.
2828
*/
29-
case class GpuPostHocResolutionOverrides(spark: SparkSession) extends Rule[LogicalPlan] {
29+
class GpuPostHocResolutionOverrides(val spark: SparkSession)
30+
extends Rule[LogicalPlan] with Serializable {
3031

3132
@transient private val rapidsConf = new RapidsConf(spark.sessionState.conf)
3233

sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuUserDefinedFunction.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,7 +23,6 @@ import com.nvidia.spark.rapids.RapidsPluginImplicits._
2323
import com.nvidia.spark.rapids.shims.ShimExpression
2424

2525
import org.apache.spark.SparkException
26-
import org.apache.spark.internal.Logging
2726
import org.apache.spark.sql.catalyst.InternalRow
2827
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UserDefinedExpression}
2928
import org.apache.spark.sql.rapids.execution.TrampolineUtil
@@ -91,7 +90,13 @@ object GpuUserDefinedFunction {
9190
* and do the processing on CPU.
9291
*/
9392
trait GpuRowBasedUserDefinedFunction extends GpuExpression
94-
with ShimExpression with UserDefinedExpression with Serializable with Logging {
93+
with ShimExpression with UserDefinedExpression with Serializable {
94+
95+
@transient private lazy val log = org.slf4j.LoggerFactory.getLogger(
96+
classOf[GpuRowBasedUserDefinedFunction])
97+
98+
private def logDebug(msg: => String): Unit = if (log.isDebugEnabled) log.debug(msg)
99+
95100
/** name of the UDF function */
96101
val name: String
97102

sql-plugin/src/main/scala/com/nvidia/spark/rapids/HashExprMetas.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION.
2+
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -83,14 +83,14 @@ object HashExprChecks {
8383

8484
val murmur3ProjectChecks: ExprChecks = ExprChecks.projectOnly(
8585
TypeSig.INT, TypeSig.INT,
86-
repeatingParamCheck = Some(RepeatingParamCheck(
86+
repeatingParamCheck = Some(new RepeatingParamCheck(
8787
"input",
8888
murmur3InputTypes,
8989
TypeSig.all)))
9090

9191
val xxhash64ProjectChecks: ExprChecks = ExprChecks.projectOnly(
9292
TypeSig.LONG, TypeSig.LONG,
93-
repeatingParamCheck = Some(RepeatingParamCheck(
93+
repeatingParamCheck = Some(new RepeatingParamCheck(
9494
"input",
9595
XxHash64Shims.supportedTypes,
9696
TypeSig.all)))

sql-plugin/src/main/scala/com/nvidia/spark/rapids/MetaUtils.scala

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2020-2026, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -16,7 +16,7 @@
1616

1717
package com.nvidia.spark.rapids
1818

19-
import java.nio.{ByteBuffer, ByteOrder}
19+
import java.nio.{Buffer, ByteBuffer, ByteOrder}
2020

2121
import scala.collection.mutable.ArrayBuffer
2222

@@ -25,7 +25,6 @@ import com.google.flatbuffers.FlatBufferBuilder
2525
import com.nvidia.spark.rapids.Arm.withResource
2626
import com.nvidia.spark.rapids.format._
2727

28-
import org.apache.spark.internal.Logging
2928
import org.apache.spark.sql.types.DataType
3029
import org.apache.spark.sql.vectorized.ColumnarBatch
3130
import org.apache.spark.storage.ShuffleBlockBatchId
@@ -117,9 +116,9 @@ object MetaUtils {
117116
packedMeta: ByteBuffer,
118117
numRows: Long): TableMeta = {
119118
val vectorBuffer = fbb.createUnintializedVector(1, packedMeta.remaining(), 1)
120-
packedMeta.mark()
119+
packedMeta.asInstanceOf[Buffer].mark()
121120
vectorBuffer.put(packedMeta)
122-
packedMeta.reset()
121+
packedMeta.asInstanceOf[Buffer].reset()
123122
val packedMetaOffset = fbb.endVector()
124123

125124
TableMeta.startTableMeta(fbb)
@@ -262,7 +261,7 @@ class DirectByteBufferFactory extends FlatBufferBuilder.ByteBufferFactory {
262261
}
263262
}
264263

265-
object ShuffleMetadata extends Logging{
264+
object ShuffleMetadata {
266265

267266
val bbFactory = new DirectByteBufferFactory
268267

0 commit comments

Comments
 (0)