|
1 | 1 | /*
|
2 |
| - * Copyright (c) 2020-2023, NVIDIA CORPORATION. |
| 2 | + * Copyright (c) 2020-2024, NVIDIA CORPORATION. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
@@ -31,8 +31,8 @@ import org.apache.spark.rdd.RDD
|
31 | 31 | import org.apache.spark.sql.catalyst.expressions._
|
32 | 32 | import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
|
33 | 33 | import org.apache.spark.sql.execution.SparkPlan
|
34 |
| -import org.apache.spark.sql.rapids.execution.python.shims.GpuArrowPythonRunner |
35 |
| -import org.apache.spark.sql.rapids.shims.{ArrowUtilsShim, DataTypeUtilsShim} |
| 34 | +import org.apache.spark.sql.rapids.execution.python.shims.GpuGroupedPythonRunnerFactory |
| 35 | +import org.apache.spark.sql.rapids.shims.DataTypeUtilsShim |
36 | 36 | import org.apache.spark.sql.types.{DataType, StructField, StructType}
|
37 | 37 | import org.apache.spark.sql.vectorized.ColumnarBatch
|
38 | 38 |
|
@@ -109,8 +109,6 @@ case class GpuAggregateInPandasExec(
|
109 | 109 | val (mNumInputRows, mNumInputBatches, mNumOutputRows, mNumOutputBatches) = commonGpuMetrics()
|
110 | 110 |
|
111 | 111 | lazy val isPythonOnGpuEnabled = GpuPythonHelper.isPythonOnGpuEnabled(conf)
|
112 |
| - val sessionLocalTimeZone = conf.sessionLocalTimeZone |
113 |
| - val pythonRunnerConf = ArrowUtilsShim.getPythonRunnerConfMap(conf) |
114 | 112 | val childOutput = child.output
|
115 | 113 | val resultExprs = resultExpressions
|
116 | 114 |
|
@@ -204,27 +202,22 @@ case class GpuAggregateInPandasExec(
|
204 | 202 | }
|
205 | 203 | }
|
206 | 204 |
|
| 205 | + val runnerFactory = GpuGroupedPythonRunnerFactory(conf, pyFuncs, argOffsets, |
| 206 | + aggInputSchema, DataTypeUtilsShim.fromAttributes(pyOutAttributes), |
| 207 | + PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF) |
| 208 | + |
207 | 209 | // Third, sends to Python to execute the aggregate and returns the result.
|
208 | 210 | if (pyInputIter.hasNext) {
|
209 | 211 | // Launch Python workers only when the data is not empty.
|
210 |
| - val pyRunner = new GpuArrowPythonRunner( |
211 |
| - pyFuncs, |
212 |
| - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, |
213 |
| - argOffsets, |
214 |
| - aggInputSchema, |
215 |
| - sessionLocalTimeZone, |
216 |
| - pythonRunnerConf, |
217 |
| - // The whole group data should be written in a single call, so here is unlimited |
218 |
| - Int.MaxValue, |
219 |
| - DataTypeUtilsShim.fromAttributes(pyOutAttributes)) |
220 |
| - |
| 212 | + val pyRunner = runnerFactory.getRunner() |
221 | 213 | val pyOutputIterator = pyRunner.compute(pyInputIter, context.partitionId(), context)
|
222 | 214 |
|
223 | 215 | val combinedAttrs = gpuGroupingExpressions.map(_.toAttribute) ++ pyOutAttributes
|
224 | 216 | val resultRefs = GpuBindReferences.bindGpuReferences(resultExprs, combinedAttrs)
|
225 | 217 | // Gets the combined batch for each group and projects for the output.
|
226 |
| - new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, pyRunner, |
227 |
| - mNumOutputRows, mNumOutputBatches).map { combinedBatch => |
| 218 | + new CombiningIterator(batchProducer.getBatchQueue, pyOutputIterator, |
| 219 | + pyRunner.asInstanceOf[GpuArrowOutput], mNumOutputRows, |
| 220 | + mNumOutputBatches).map { combinedBatch => |
228 | 221 | withResource(combinedBatch) { batch =>
|
229 | 222 | GpuProjectExec.project(batch, resultRefs)
|
230 | 223 | }
|
|
0 commit comments