Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/main/cpp/benchmarks/bloom_filter.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,17 +22,18 @@
#include <hash/hash.hpp>
#include <nvbench/nvbench.cuh>

static void bloom_filter_put(nvbench::state& state)
namespace {

void bloom_filter_put_impl(nvbench::state& state, int version)
{
constexpr int num_rows = 150'000'000;
constexpr int num_hashes = 3;

// create the bloom filter
cudf::size_type const bloom_filter_bytes = state.get_int64("bloom_filter_bytes");
cudf::size_type const bloom_filter_longs = bloom_filter_bytes / sizeof(int64_t);
auto bloom_filter = spark_rapids_jni::bloom_filter_create(num_hashes, bloom_filter_longs);
auto bloom_filter =
spark_rapids_jni::bloom_filter_create(version, num_hashes, bloom_filter_longs);

// create a column of hashed values
data_profile_builder builder;
builder.no_validity();
auto const src = create_random_table({{cudf::type_id::INT64}}, row_count{num_rows}, builder);
Expand All @@ -41,7 +42,7 @@ static void bloom_filter_put(nvbench::state& state)
auto const stream = cudf::get_default_stream();
state.set_cuda_stream(nvbench::make_cuda_stream_view(stream.value()));
state.exec(nvbench::exec_tag::timer | nvbench::exec_tag::sync,
[&](nvbench::launch& launch, auto& timer) {
[&](nvbench::launch&, auto& timer) {
timer.start();
spark_rapids_jni::bloom_filter_put(*bloom_filter, *input);
stream.synchronize();
Expand All @@ -57,7 +58,24 @@ static void bloom_filter_put(nvbench::state& state)
state.add_element_count(static_cast<double>(bytes_written) / time, "Write bytes/sec");
}

NVBENCH_BENCH(bloom_filter_put)
.set_name("Bloom Filter Put")
void bloom_filter_put_v1(nvbench::state& state)
{
bloom_filter_put_impl(state, spark_rapids_jni::bloom_filter_version_1);
}

void bloom_filter_put_v2(nvbench::state& state)
{
bloom_filter_put_impl(state, spark_rapids_jni::bloom_filter_version_2);
}

} // namespace

NVBENCH_BENCH(bloom_filter_put_v1)
.set_name("Bloom Filter Put V1")
.add_int64_axis("bloom_filter_bytes",
{512 * 1024, 1024 * 1024, 2 * 1024 * 1024, 4 * 1024 * 1024, 8 * 1024 * 1024});

NVBENCH_BENCH(bloom_filter_put_v2)
.set_name("Bloom Filter Put V2")
.add_int64_axis("bloom_filter_bytes",
{512 * 1024, 1024 * 1024, 2 * 1024 * 1024, 4 * 1024 * 1024, 8 * 1024 * 1024});
24 changes: 20 additions & 4 deletions src/main/cpp/src/BloomFilterJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
* Copyright (c) 2023-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,17 +20,33 @@
#include "jni_utils.hpp"
#include "utilities.hpp"

#include <limits>

extern "C" {

JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_BloomFilter_creategpu(
JNIEnv* env, jclass, jint numHashes, jlong bloomFilterBits)
JNIEnv* env, jclass, jint version, jint numHashes, jlong bloomFilterBits, jint seed)
{
JNI_TRY
{
cudf::jni::auto_set_device(env);

int bloom_filter_longs = static_cast<int>((bloomFilterBits + 63) / 64);
auto bloom_filter = spark_rapids_jni::bloom_filter_create(numHashes, bloom_filter_longs);
// Per the Spark implementation, according to the BitArray class,
// https://github.com/apache/spark/blob/5075ea6a85f3f1689766cf08a7d5b2ce500be1fb/common/sketch/src/main/java/org/apache/spark/util/sketch/BitArray.java#L34
// the number of longs representing the bit array can only be Integer.MAX_VALUE, at the most.
// (This is presumably because the BitArray is indexed with an int32_t.)
// This implies that the maximum supported bloom filter bit count is Integer.MAX_VALUE * 64.

JNI_ARG_CHECK(
env,
bloomFilterBits > 0 &&
bloomFilterBits <= static_cast<int64_t>(std::numeric_limits<int32_t>::max()) * 64,
"bloom filter bit count must be positive and less than or equal to the maximum supported "
"size",
0);
auto const bloom_filter_longs = static_cast<int32_t>((bloomFilterBits + 63) / 64);
auto bloom_filter =
spark_rapids_jni::bloom_filter_create(version, numHashes, bloom_filter_longs, seed);
return reinterpret_cast<jlong>(bloom_filter.release());
}
JNI_CATCH(env, 0);
Expand Down
Loading
Loading