diff --git a/integration_tests/src/main/python/join_test.py b/integration_tests/src/main/python/join_test.py index 5225a2082b5..bd2e357eb4d 100644 --- a/integration_tests/src/main/python/join_test.py +++ b/integration_tests/src/main/python/join_test.py @@ -1502,7 +1502,6 @@ def test_bloom_filter_join_cpu_probe(is_multi_column, kudo_enabled): @pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") -@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148") @pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled): conf = {"spark.rapids.sql.expression.BloomFilterAggregate": "false", @@ -1517,7 +1516,6 @@ def test_bloom_filter_join_cpu_build(is_multi_column, kudo_enabled): @pytest.mark.parametrize("is_multi_column", [False, True], ids=idfn) @pytest.mark.skipif(is_databricks_runtime(), reason="https://github.com/NVIDIA/spark-rapids/issues/8921") @pytest.mark.skipif(is_before_spark_330(), reason="Bloom filter joins added in Spark 3.3.0") -@pytest.mark.xfail(condition=is_spark_411_or_later(), reason="https://github.com/NVIDIA/spark-rapids/issues/14148") @pytest.mark.parametrize("kudo_enabled", ["true", "false"], ids=idfn) def test_bloom_filter_join_split_cpu_build(agg_replace_mode, is_multi_column, kudo_enabled): conf = {"spark.rapids.sql.hashAgg.replaceMode": agg_replace_mode, diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala index 74a37d875be..0dac6a4f79c 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/GpuBloomFilter.scala @@ -45,7 +45,8 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids -import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType} +import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType, + HostMemoryBuffer} import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} import com.nvidia.spark.rapids.jni.BloomFilter @@ -78,14 +79,33 @@ class GpuBloomFilter(buffer: DeviceMemoryBuffer) extends AutoCloseable { } object GpuBloomFilter { - // Spark serializes their bloom filters in a specific format, see BloomFilterImpl.readFrom. - // Data is written via DataOutputStream, so everything is big-endian. - // Byte Offset Size Description - // 0 4 Version ID (see Spark's BloomFilter.Version) - // 4 4 Number of hash functions - // 8 4 Number of longs, N - // 12 N*8 Bloom filter data buffer as longs - private val HEADER_SIZE = 12 + // Spark serializes bloom filters in one of two formats. All values are big-endian. + // + // V1 (12-byte header): + // Byte Offset Size Description + // 0 4 Version ID (1) + // 4 4 Number of hash functions + // 8 4 Number of longs, N + // 12 N*8 Bloom filter data buffer as longs + // + // V2 (16-byte header): + // Byte Offset Size Description + // 0 4 Version ID (2) + // 4 4 Number of hash functions + // 8 4 Hash seed + // 12 4 Number of longs, N + // 16 N*8 Bloom filter data buffer as longs + private val HEADER_SIZE_V1 = 12 + private val HEADER_SIZE_V2 = 16 + + private def readVersionFromDevice(data: BaseDeviceMemoryBuffer): Int = { + withResource(data.sliceWithCopy(0, 4)) { versionSlice => + withResource(HostMemoryBuffer.allocate(4)) { hostBuf => + hostBuf.copyFromDeviceBuffer(versionSlice) + Integer.reverseBytes(hostBuf.getInt(0)) + } + } + } def apply(s: GpuScalar): GpuBloomFilter = { s.dataType match { @@ -100,11 +120,22 @@ object GpuBloomFilter { } def deserialize(data: BaseDeviceMemoryBuffer): GpuBloomFilter = { - // Sanity check bloom filter header val totalLen = data.getLength - val bitBufferLen = totalLen - HEADER_SIZE - require(totalLen >= HEADER_SIZE, s"header size is $totalLen") - require(bitBufferLen % 8 == 0, "buffer length not a multiple of 8") + require(totalLen >= HEADER_SIZE_V1, s"buffer too small: $totalLen") + + val version = readVersionFromDevice(data) + val headerSize = version match { + case 1 => HEADER_SIZE_V1 + case 2 => HEADER_SIZE_V2 + case _ => throw new IllegalArgumentException( + s"Unknown bloom filter version: $version") + } + require(totalLen >= headerSize, + s"buffer too small for bloom filter V$version: $totalLen") + val bitBufferLen = totalLen - headerSize + require(bitBufferLen % 8 == 0, + s"bit buffer length ($bitBufferLen) not a multiple of 8") + val filterBuffer = DeviceMemoryBuffer.allocate(totalLen) closeOnExcept(filterBuffer) { buf => buf.copyFromDeviceBufferAsync(0, data, 0, buf.getLength, Cuda.DEFAULT_STREAM) diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala new file mode 100644 index 00000000000..3332ba7d8f4 --- /dev/null +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "330"} +{"spark": "330db"} +{"spark": "331"} +{"spark": "332"} +{"spark": "332db"} +{"spark": "333"} +{"spark": "334"} +{"spark": "340"} +{"spark": "341"} +{"spark": "341db"} +{"spark": "342"} +{"spark": "343"} +{"spark": "344"} +{"spark": "350"} +{"spark": "350db143"} +{"spark": "351"} +{"spark": "352"} +{"spark": "353"} +{"spark": "354"} +{"spark": "355"} +{"spark": "356"} +{"spark": "357"} +{"spark": "358"} +{"spark": "400"} +{"spark": "401"} +{"spark": "402"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +object BloomFilterConstantsShims { + val BLOOM_FILTER_FORMAT_VERSION: Int = 1 +} \ No newline at end of file diff --git a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala index 7590a075e89..707d83fc608 100644 --- a/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala +++ b/sql-plugin/src/main/spark330/scala/com/nvidia/spark/rapids/shims/BloomFilterShims.scala @@ -46,6 +46,7 @@ spark-rapids-shim-json-lines ***/ package com.nvidia.spark.rapids.shims import com.nvidia.spark.rapids._ +import com.nvidia.spark.rapids.jni.BloomFilter import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate @@ -80,7 +81,9 @@ object BloomFilterShims { GpuBloomFilterAggregate( childExprs.head.convertToGpu(), a.estimatedNumItemsExpression.eval().asInstanceOf[Number].longValue, - a.numBitsExpression.eval().asInstanceOf[Number].longValue) + a.numBitsExpression.eval().asInstanceOf[Number].longValue, + BloomFilterConstantsShims.BLOOM_FILTER_FORMAT_VERSION, + BloomFilter.DEFAULT_SEED) } }) ).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap diff --git a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala index 2e0cab83747..05ea339a6b1 100644 --- a/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala +++ b/sql-plugin/src/main/spark330/scala/org/apache/spark/sql/rapids/aggregate/GpuBloomFilterAggregate.scala @@ -59,7 +59,9 @@ import org.apache.spark.sql.types.{BinaryType, DataType} case class GpuBloomFilterAggregate( child: Expression, estimatedNumItemsRequested: Long, - numBitsRequested: Long) extends GpuAggregateFunction { + numBitsRequested: Long, + version: Int = BloomFilter.VERSION_2, + seed: Int = BloomFilter.DEFAULT_SEED) extends GpuAggregateFunction { override def nullable: Boolean = true @@ -81,7 +83,8 @@ case class GpuBloomFilterAggregate( override val inputProjection: Seq[Expression] = Seq(child) - override val updateAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterUpdate(numHashes, numBits)) + override val updateAggregates: Seq[CudfAggregate] = + Seq(GpuBloomFilterUpdate(numHashes, numBits, version, seed)) override val mergeAggregates: Seq[CudfAggregate] = Seq(GpuBloomFilterMerge()) @@ -110,9 +113,13 @@ object GpuBloomFilterAggregate { } } -case class GpuBloomFilterUpdate(numHashes: Int, numBits: Long) extends CudfAggregate { +case class GpuBloomFilterUpdate( + numHashes: Int, + numBits: Long, + version: Int, + seed: Int) extends CudfAggregate { override val reductionAggregate: ColumnVector => Scalar = (col: ColumnVector) => { - closeOnExcept(BloomFilter.create(numHashes, numBits)) { bloomFilter => + closeOnExcept(BloomFilter.create(version, numHashes, numBits, seed)) { bloomFilter => BloomFilter.put(bloomFilter, col) bloomFilter } diff --git a/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala new file mode 100644 index 00000000000..66b12a915f6 --- /dev/null +++ b/sql-plugin/src/main/spark411/scala/com/nvidia/spark/rapids/shims/BloomFilterConstantsShims.scala @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*** spark-rapids-shim-json-lines +{"spark": "411"} +spark-rapids-shim-json-lines ***/ +package com.nvidia.spark.rapids.shims + +object BloomFilterConstantsShims { + val BLOOM_FILTER_FORMAT_VERSION: Int = 2 +} \ No newline at end of file diff --git a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala index b13d8266333..3936407e406 100644 --- a/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala +++ b/tests/src/test/spark330/scala/com/nvidia/spark/rapids/BloomFilterAggregateQuerySuite.scala @@ -177,8 +177,9 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { } } + // V1 literal: version=1, numHashes=5, numLongs=3, followed by 3 longs of bit data testSparkResultsAreEqual( - "might_contain with literal bloom filter buffer", + "might_contain with V1 literal bloom filter buffer", spark => spark.range(1, 1).asInstanceOf[DataFrame], conf=bloomFilterEnabledConf.clone()) { df => @@ -190,6 +191,20 @@ class BloomFilterAggregateQuerySuite extends SparkQueryCompareTestSuite { } } + // V2 literal: version=2, numHashes=5, seed=0, numLongs=3, followed by 3 longs of bit data + testSparkResultsAreEqual( + "might_contain with V2 literal bloom filter buffer", + spark => spark.range(1, 1).asInstanceOf[DataFrame], + conf=bloomFilterEnabledConf.clone()) { + df => + withExposedSqlFuncs(df.sparkSession) { spark => + spark.sql( + """SELECT might_contain( + |X'0000000200000005000000000000000343A2EC6EA8C117E2D3CDB767296B144FC5BFBCED9737F267', + |cast(201 as long))""".stripMargin) + } + } + testSparkResultsAreEqual( "might_contain with all NULL inputs", spark => spark.range(1, 1).asInstanceOf[DataFrame],