Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions integration_tests/src/main/python/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,7 +1502,6 @@ def test_bloom_filter_join_cpu_probe(is_multi_column, kudo_enabled):
@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148")
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled):
conf = {"spark.rapids.sql.expression.BloomFilterAggregate": "false",
Expand All @@ -1517,7 +1516,6 @@ def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled):
@pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn)
@pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921")
@pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0")
@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148")
@pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn)
def test_bloom_filter_join_split_cpu_build(agg_replace_mode, is_multi_column, kudo_enabled):
conf = {"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids

import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType}
import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType,
HostMemoryBuffer}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.BloomFilter

Expand Down Expand Up @@ -78,14 +79,33 @@ class GpuBloomFilter(buffer: DeviceMemoryBuffer) extends AutoCloseable {
}

object GpuBloomFilter {
// Spark serializes their bloom filters in a specific format, see BloomFilterImpl.readFrom.
// Data is written via DataOutputStream, so everything is big-endian.
// Byte Offset Size Description
// 0 4 Version ID (see Spark's BloomFilter.Version)
// 4 4 Number of hash functions
// 8 4 Number of longs, N
// 12 N*8 Bloom filter data buffer as longs
private val HEADER_SIZE = 12
// Spark serializes bloom filters in one of two formats. All values are big-endian.
//
// V1 (12-byte header):
// Byte Offset Size Description
// 0 4 Version ID (1)
// 4 4 Number of hash functions
// 8 4 Number of longs, N
// 12 N*8 Bloom filter data buffer as longs
//
// V2 (16-byte header):
// Byte Offset Size Description
// 0 4 Version ID (2)
// 4 4 Number of hash functions
// 8 4 Hash seed
// 12 4 Number of longs, N
// 16 N*8 Bloom filter data buffer as longs
private val HEADER_SIZE_V1 = 12
private val HEADER_SIZE_V2 = 16

private def readVersionFromDevice(data: BaseDeviceMemoryBuffer): Int = {
withResource(data.sliceWithCopy(0, 4)) { versionSlice =>
withResource(HostMemoryBuffer.allocate(4)) { hostBuf =>
hostBuf.copyFromDeviceBuffer(versionSlice)
Integer.reverseBytes(hostBuf.getInt(0))
}
}
}

def apply(s: GpuScalar): GpuBloomFilter = {
s.dataType match {
Expand All @@ -100,11 +120,22 @@ object GpuBloomFilter {
}

def deserialize(data: BaseDeviceMemoryBuffer): GpuBloomFilter = {
// Sanity check bloom filter header
val totalLen = data.getLength
val bitBufferLen = totalLen - HEADER_SIZE
require(totalLen >= HEADER_SIZE, s"header size is $totalLen")
require(bitBufferLen % 8 == 0, "buffer length not a multiple of 8")
require(totalLen >= HEADER_SIZE_V1, s"buffer too small: $totalLen")

val version = readVersionFromDevice(data)
val headerSize = version match {
case 1 => HEADER_SIZE_V1
case 2 => HEADER_SIZE_V2
case _ => throw new IllegalArgumentException(
s"Unknown bloom filter version: $version")
}
require(totalLen >= headerSize,
s"buffer too small for bloom filter V$version: $totalLen")
val bitBufferLen = totalLen - headerSize
require(bitBufferLen % 8 == 0,
s"bit buffer length ($bitBufferLen) not a multiple of 8")

val filterBuffer = DeviceMemoryBuffer.allocate(totalLen)
closeOnExcept(filterBuffer) { buf =>
buf.copyFromDeviceBufferAsync(0, data, 0, buf.getLength, Cuda.DEFAULT_STREAM)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "334"}
{"spark": "340"}
{"spark": "341"}
{"spark": "341db"}
{"spark": "342"}
{"spark": "343"}
{"spark": "344"}
{"spark": "350"}
{"spark": "350db143"}
{"spark": "351"}
{"spark": "352"}
{"spark": "353"}
{"spark": "354"}
{"spark": "355"}
{"spark": "356"}
{"spark": "357"}
{"spark": "358"}
{"spark": "400"}
{"spark": "401"}
{"spark": "402"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object BloomFilterConstantsShims {
val BLOOM_FILTER_FORMAT_VERSION: Int = 1
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.jni.BloomFilter

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
Expand Down Expand Up @@ -80,7 +81,9 @@ object BloomFilterShims {
GpuBloomFilterAggregate(
childExprs.head.convertToGpu(),
a.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue,
a.numBitsExpression.eval().asInstanceOf[Number].longValue)
a.numBitsExpression.eval().asInstanceOf[Number].longValue,
BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION,
BloomFilter.DEFAULT_SEED)
}
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import ai.rapids.cudf.{ColumnVector, DType, GroupByAggregation, HostColumnVector
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.GpuLiteral
import com.nvidia.spark.rapids.jni.BloomFilter
import com.nvidia.spark.rapids.shims.BloomFilterConstantsShims

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.internal.SQLConf.{RUNTIME_BLOOM_FILTER_MAX_NUM_BITS, RUNTIME_BLOOM_FILTER_MAX_NUM_ITEMS}
Expand All @@ -59,7 +60,9 @@ import org.apache.spark.sql.types.{BinaryType, DataType}
case class GpuBloomFilterAggregate(
child: Expression,
estimatedNumItemsRequested: Long,
numBitsRequested: Long) extends GpuAggregateFunction {
numBitsRequested: Long,
version: Int = BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION,
seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction {

override def nullable: Boolean = true

Expand All @@ -81,7 +84,8 @@ case class GpuBloomFilterAggregate(

override val inputProjection: Seq[Expression] = Seq(child)

override val updateAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterUpdate(numHashes, numBits))
override val updateAggregates: Seq[CudfAggregate] =
Seq(GpuBloomFilterUpdate(numHashes, numBits, version, seed))

override val mergeAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterMerge())

Expand Down Expand Up @@ -110,9 +114,13 @@ object GpuBloomFilterAggregate {
}
}

case class GpuBloomFilterUpdate(numHashes: Int, numBits: Long) extends CudfAggregate {
case class GpuBloomFilterUpdate(
numHashes: Int,
numBits: Long,
version: Int,
seed: Int) extends CudfAggregate {
override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => {
closeOnExcept(BloomFilter.create(numHashes, numBits)) { bloomFilter =>
closeOnExcept(BloomFilter.create(version, numHashes, numBits, seed)) { bloomFilter =>
BloomFilter.put(bloomFilter, col)
bloomFilter
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "411"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

object BloomFilterConstantsShims {
val BLOOM_FILTER_FORMAT_VERSION: Int = 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite {
}
}

// V1 literal: version=1, numHashes=5, numLongs=3, followed by 3 longs of bit data
testSparkResultsAreEqual(
"might_contain with literal bloom filter buffer",
"might_contain with V1 literal bloom filter buffer",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
conf=bloomFilterEnabledConf.clone()) {
df =>
Expand All @@ -190,6 +191,20 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite {
}
}

// V2 literal: version=2, numHashes=5, seed=0, numLongs=3, followed by 3 longs of bit data
testSparkResultsAreEqual(
"might_contain with V2 literal bloom filter buffer",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
conf=bloomFilterEnabledConf.clone()) {
df =>
withExposedSqlFuncs(df.sparkSession) { spark =>
spark.sql(
"""SELECT might_contain(
|X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267',
|cast(201 as long))""".stripMargin)
}
}

testSparkResultsAreEqual(
"might_contain with all NULL inputs",
spark => spark.range(1, 1).asInstanceOf[DataFrame],
Expand Down
Loading