Skip to content

Commit 97e5b48

Browse files
CopilotpxLi
andcommitted
Fix method call for checking null count in GpuSlice implementation
Co-authored-by: pxLi <8086184+pxLi@users.noreply.github.com>
1 parent 71251d2 commit 97e5b48

1 file changed

Lines changed: 11 additions & 6 deletions

File tree

sql-plugin/src/main/scala/org/apache/spark/sql/rapids/collectionOperations.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,16 +162,21 @@ case class GpuSlice(x: Expression, start: Expression, length: Expression)
162162
override def doColumnar(listS: GpuScalar, startCol: GpuColumnVector,
163163
lengthCol: GpuColumnVector): ColumnVector = {
164164
val numRows = startCol.getRowCount.toInt
165-
withResource(GpuColumnVector.from(listS, numRows, dataType)) { listCol =>
166-
doColumnar(listCol, startCol, lengthCol)
165+
// When list is null, return all nulls like the CPU does.
166+
if (!listS.isValid) {
167+
GpuColumnVector.columnVectorFromNull(numRows, dataType)
168+
} else {
169+
withResource(GpuColumnVector.from(listS, numRows, dataType)) { listCol =>
170+
doColumnar(listCol, startCol, lengthCol)
171+
}
167172
}
168173
}
169174

170175
override def doColumnar(listCol: GpuColumnVector, startS: GpuScalar,
171176
lengthS: GpuScalar): ColumnVector = {
172177
// When input column is all nulls or either start or length is null, return all nulls like the CPU does.
173178
// This matches CPU behavior for slice function with null inputs.
174-
if (listCol.getRowCount == listCol.numNulls() || !startS.isValid || !lengthS.isValid) {
179+
if (listCol.getRowCount == listCol.getBase.getNullCount || !startS.isValid || !lengthS.isValid) {
175180
GpuColumnVector.columnVectorFromNull(listCol.getRowCount.toInt, dataType)
176181
} else {
177182
val list = listCol.getBase
@@ -190,7 +195,7 @@ case class GpuSlice(x: Expression, start: Expression, length: Expression)
190195
override def doColumnar(listCol: GpuColumnVector, startS: GpuScalar,
191196
lengthCol: GpuColumnVector): ColumnVector = {
192197
// When start is null, return all nulls like the CPU does.
193-
if (listCol.getRowCount == listCol.numNulls() || !startS.isValid) {
198+
if (listCol.getRowCount == listCol.getBase.getNullCount || !startS.isValid) {
194199
GpuColumnVector.columnVectorFromNull(listCol.getRowCount.toInt, dataType)
195200
} else {
196201
val list = listCol.getBase
@@ -212,7 +217,7 @@ case class GpuSlice(x: Expression, start: Expression, length: Expression)
212217
override def doColumnar(listCol: GpuColumnVector, startCol: GpuColumnVector,
213218
lengthS: GpuScalar): ColumnVector = {
214219
// When length is null, return all nulls like the CPU does.
215-
if (listCol.getRowCount == listCol.numNulls() || !lengthS.isValid) {
220+
if (listCol.getRowCount == listCol.getBase.getNullCount || !lengthS.isValid) {
216221
GpuColumnVector.columnVectorFromNull(listCol.getRowCount.toInt, dataType)
217222
} else {
218223
val list = listCol.getBase
@@ -237,7 +242,7 @@ case class GpuSlice(x: Expression, start: Expression, length: Expression)
237242

238243
override def doColumnar(listCol: GpuColumnVector, startCol: GpuColumnVector,
239244
lengthCol: GpuColumnVector): ColumnVector = {
240-
if (listCol.getRowCount == listCol.numNulls()) {
245+
if (listCol.getRowCount == listCol.getBase.getNullCount) {
241246
GpuColumnVector.columnVectorFromNull(listCol.getRowCount.toInt, dataType)
242247
} else {
243248
val list = listCol.getBase

0 commit comments

Comments
 (0)