Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,6 @@ def test_bloom_filter_join_cpu_probe(is_multi_column, kudo_enabled):
@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148")
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled):
conf = {"spark.rapids.sql.expression.BloomFilterAggregate": "false",
Expand All @@ -1517,7 +1516,6 @@ def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled):
@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148")
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_bloom_filter_join_split_cpu_build(agg_replace_mode, is_multi_column, kudo_enabled):
conf = {"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids

import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType}
import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType,
HostMemoryBuffer}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.BloomFilter

Expand Down Expand Up @@ -78,14 +79,33 @@ class GpuBloomFilter(buffer: DeviceMemoryBuffer) extends AutoCloseable {
}

object GpuBloomFilter {
// Spark serializes their bloom filters in a specific format, see BloomFilterImpl.readFrom.
// Data is written via DataOutputStream, so everything is big-endian.
// Byte Offset Size Description
// 0 4 Version ID (see Spark's BloomFilter.Version)
// 4 4 Number of hash functions
// 8 4 Number of longs, N
// 12 N*8 Bloom filter data buffer as longs
private val HEADER_SIZE = 12
// Spark serializes bloom filters in one of two formats. All values are big-endian.
//
// V1 (12-byte header):
// Byte Offset Size Description
// 0 4 Version ID (1)
// 4 4 Number of hash functions
// 8 4 Number of longs, N
// 12 N*8 Bloom filter data buffer as longs
//
// V2 (16-byte header):
// Byte Offset Size Description
// 0 4 Version ID (2)
// 4 4 Number of hash functions
// 8 4 Hash seed
// 12 4 Number of longs, N
// 16 N*8 Bloom filter data buffer as longs
private val HEADER_SIZE_V1 = 12
private val HEADER_SIZE_V2 = 16

private def readVersionFromDevice(data: BaseDeviceMemoryBuffer): Int = {
withResource(data.sliceWithCopy(0, 4)) { versionSlice =>
withResource(HostMemoryBuffer.allocate(4)) { hostBuf =>
hostBuf.copyFromDeviceBuffer(versionSlice)
Integer.reverseBytes(hostBuf.getInt(0))
}
}
}

def apply(s: GpuScalar): GpuBloomFilter = {
s.dataType match {
Expand All @@ -100,11 +120,22 @@ object GpuBloomFilter {
}

def deserialize(data: BaseDeviceMemoryBuffer): GpuBloomFilter = {
// Sanity check bloom filter header
val totalLen = data.getLength
val bitBufferLen = totalLen - HEADER_SIZE
require(totalLen >= HEADER_SIZE, s"header size is $totalLen")
require(bitBufferLen % 8 == 0, "buffer length not a multiple of 8")
require(totalLen >= HEADER_SIZE_V1, s"buffer too small: $totalLen")

val version = readVersionFromDevice(data)
val headerSize = version match {
case 1 => HEADER_SIZE_V1
case 2 => HEADER_SIZE_V2
case _ => throw new IllegalArgumentException(
s"Unknown bloom filter version: $version")
}
require(totalLen >= headerSize,
s"buffer too small for bloom filter V$version: $totalLen")
val bitBufferLen = totalLen - headerSize
require(bitBufferLen % 8 == 0,
s"bit buffer length ($bitBufferLen) not a multiple of 8")

val filterBuffer = DeviceMemoryBuffer.allocate(totalLen)
closeOnExcept(filterBuffer) { buf =>
buf.copyFromDeviceBufferAsync(0, data, 0, buf.getLength, Cuda.DEFAULT_STREAM)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "350db143"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "354"}
{"spark": "355"}
{"spark": "356"}
{"spark": "357"}
{"spark": "358"}
{"spark": "400"}
{"spark": "401"}
{"spark": "402"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object BloomFilterConstantsShims {
val BLOOM_FILTER_FORMAT_VERSION: Int = 1
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing newline at end of file

Both new BloomFilterConstantsShims.scala files (spark330 and spark411) are missing a trailing newline. This can cause issues with certain tools and doesn't follow standard POSIX text file convention. Please add a newline after the closing brace.

Suggested change
}
}

The same applies to sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala at line 24.

Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.jni.BloomFilter

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
Expand Down Expand Up @@ -80,7 +81,9 @@ object BloomFilterShims {
GpuBloomFilterAggregate(
childExprs.head.convertToGpu(),
a.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
a.numBitsExpression.eval().asInstanceOf[Number].longValue)
a.numBitsExpression.eval().asInstanceOf[Number].longValue,
BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION,
BloomFilter.DEFAULT_SEED)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ import org.apache.spark.sql.types.{BinaryType, DataType}
case class GpuBloomFilterAggregate(
child: Expression,
estimatedNumItemsRequested: Long,
numBitsRequested: Long) extends GpuAggregateFunction {
numBitsRequested: Long,
version: Int = BloomFilter.VERSION_2,
seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction {
Comment on lines +63 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default version parameter targets wrong format for pre-4.11 shims

The default value for version is BloomFilter.VERSION_2, but this class (spark330/...) is compiled for all Spark versions from 3.3.0 through 4.1.1. For any code path that constructs GpuBloomFilterAggregate without explicitly passing version (e.g. direct instantiation in tests or future callers), the aggregate would produce a V2 filter even when running under a pre-4.1.1 Spark version.

While all current production paths go through BloomFilterShims.convertToGpuImpl() which explicitly passes BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION, a safer default would be BloomFilter.VERSION_1 (or 1) to match the behaviour expected by Spark < 4.1.1, with the spark411 shim overriding this at construction time.

Suggested change
version: Int = BloomFilter.VERSION_2,
seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction {
version: Int = BloomFilter.VERSION_1,
seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction {


override def nullable: Boolean = true

Expand All @@ -81,7 +83,8 @@ case class GpuBloomFilterAggregate(

override val inputProjection: Seq[Expression] = Seq(child)

override val updateAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterUpdate(numHashes, numBits))
override val updateAggregates: Seq[CudfAggregate] =
Seq(GpuBloomFilterUpdate(numHashes, numBits, version, seed))

override val mergeAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterMerge())

Expand Down Expand Up @@ -110,9 +113,13 @@ object GpuBloomFilterAggregate {
}
}

case class GpuBloomFilterUpdate(numHashes: Int, numBits: Long) extends CudfAggregate {
case class GpuBloomFilterUpdate(
numHashes: Int,
numBits: Long,
version: Int,
seed: Int) extends CudfAggregate {
override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => {
closeOnExcept(BloomFilter.create(numHashes, numBits)) { bloomFilter =>
closeOnExcept(BloomFilter.create(version, numHashes, numBits, seed)) { bloomFilter =>
BloomFilter.put(bloomFilter, col)
bloomFilter
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "411"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object BloomFilterConstantsShims {
val BLOOM_FILTER_FORMAT_VERSION: Int = 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite {
}
}

// V1 literal: version=1, numHashes=5, numLongs=3, followed by 3 longs of bit data
testSparkResultsAreEqual(
"might_contain with literal bloom filter buffer",
"might_contain with V1 literal bloom filter buffer",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
conf=bloomFilterEnabledConf.clone()) {
df =>
Expand All @@ -190,6 +191,20 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite {
}
}

// V2 literal: version=2, numHashes=5, seed=0, numLongs=3, followed by 3 longs of bit data
testSparkResultsAreEqual(
"might_contain with V2 literal bloom filter buffer",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
conf=bloomFilterEnabledConf.clone()) {
df =>
withExposedSqlFuncs(df.sparkSession) { spark =>
spark.sql(
"""SELECT might_contain(
|X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
|cast(201 as long))""".stripMargin)
}
}

testSparkResultsAreEqual(
"might_contain with all NULL inputs",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
Expand Down
Loading