Skip to content

Commit 512a1da

Browse files
authored
fix: make BloomFilter intermediate buffer Spark-compatible (#4390)
1 parent 6f61894 commit 512a1da

7 files changed

Lines changed: 392 additions & 47 deletions

File tree

native/spark-expr/src/bloom_filter/bloom_filter_agg.rs

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilter
2525
use arrow::array::ArrayRef;
2626
use arrow::array::BinaryArray;
2727
use datafusion::common::{downcast_value, ScalarValue};
28-
use datafusion::error::Result;
28+
use datafusion::error::{DataFusionError, Result};
2929
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
3030
use datafusion::logical_expr::{AggregateUDFImpl, Signature};
3131
use datafusion::physical_expr::expressions::Literal;
@@ -141,15 +141,30 @@ impl Accumulator for SparkBloomFilter {
141141
ScalarValue::Utf8(Some(value)) => {
142142
self.put_binary(value.as_bytes());
143143
}
144-
_ => {
145-
unreachable!()
144+
// Spark's BloomFilterAggregate.update ignores null inputs.
145+
ScalarValue::Int8(None)
146+
| ScalarValue::Int16(None)
147+
| ScalarValue::Int32(None)
148+
| ScalarValue::Int64(None)
149+
| ScalarValue::Utf8(None) => {}
150+
other => {
151+
return Err(DataFusionError::Internal(format!(
152+
"bloom_filter_agg received an unsupported input type: {other:?}"
153+
)));
146154
}
147155
}
148156
Ok(())
149157
})
150158
}
151159

152160
fn evaluate(&mut self) -> Result<ScalarValue> {
161+
// Spark's BloomFilterAggregate.eval returns NULL when no bit is set,
162+
// i.e. the aggregate saw no non-null input. Mirror that here so an
163+
// empty-input bloom_filter_agg yields NULL rather than a serialized
164+
// empty filter.
165+
if self.cardinality() == 0 {
166+
return Ok(ScalarValue::Binary(None));
167+
}
153168
Ok(ScalarValue::Binary(Some(self.spark_serialization())))
154169
}
155170

@@ -173,7 +188,34 @@ impl Accumulator for SparkBloomFilter {
173188
);
174189
assert_eq!(states[0].len(), 1);
175190
let state_sv = downcast_value!(states[0], BinaryArray);
176-
self.merge_filter(state_sv.value_data());
177-
Ok(())
191+
self.merge_filter(state_sv.value_data())
192+
}
193+
}
194+
195+
#[cfg(test)]
196+
mod tests {
197+
use super::*;
198+
199+
/// Spark's BloomFilterAggregate.eval returns NULL when the filter saw no
200+
/// non-null input (cardinality 0); an untouched accumulator must match.
201+
#[test]
202+
fn evaluate_empty_filter_yields_null() {
203+
let num_bits = 1024;
204+
let num_hash = spark_bloom_filter::optimal_num_hash_functions(100, num_bits);
205+
let mut acc = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
206+
assert_eq!(acc.evaluate().unwrap(), ScalarValue::Binary(None));
207+
}
208+
209+
/// A filter with at least one set bit serializes to a non-null binary.
210+
#[test]
211+
fn evaluate_non_empty_filter_yields_binary() {
212+
let num_bits = 1024;
213+
let num_hash = spark_bloom_filter::optimal_num_hash_functions(100, num_bits);
214+
let mut acc = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
215+
acc.put_long(42);
216+
assert!(matches!(
217+
acc.evaluate().unwrap(),
218+
ScalarValue::Binary(Some(_))
219+
));
178220
}
179221
}

native/spark-expr/src/bloom_filter/spark_bit_array.rs

Lines changed: 44 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::ToByteSlice;
19-
use std::iter::zip;
20-
2118
/// A simple bit array implementation that simulates the behavior of Spark's BitArray which is
2219
/// used in the BloomFilter implementation. Some methods are not implemented as they are not
2320
/// required for the current use case.
@@ -61,41 +58,28 @@ impl SparkBitArray {
6158
self.word_size() as u64 * 64
6259
}
6360

64-
pub fn byte_size(&self) -> usize {
65-
self.word_size() * 8
66-
}
67-
6861
pub fn word_size(&self) -> usize {
6962
self.data.len()
7063
}
7164

72-
#[allow(dead_code)] // this is only called from tests
73-
pub fn cardinality(&self) -> usize {
74-
self.bit_count
75-
}
76-
77-
pub fn to_bytes(&self) -> Vec<u8> {
78-
Vec::from(self.data.to_byte_slice())
79-
}
80-
8165
pub fn data(&self) -> Vec<u64> {
8266
self.data.clone()
8367
}
8468

85-
// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from an
86-
// Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word vector.
87-
pub fn merge_bits(&mut self, other: &[u8]) {
88-
assert_eq!(self.byte_size(), other.len());
69+
/// Number of set bits in the array. Mirrors Spark's `BitArray.cardinality()`.
70+
pub fn cardinality(&self) -> usize {
71+
self.bit_count
72+
}
73+
74+
/// OR-merge `incoming` (big-endian `u64` words, one per word in `self`) into
75+
/// `self.data` in place and refresh `bit_count` in the same pass. The caller
76+
/// is responsible for ensuring `incoming.len() == self.word_size() * 8`.
77+
pub fn merge_be_words(&mut self, incoming: &[u8]) {
78+
debug_assert_eq!(self.data.len() * 8, incoming.len());
8979
let mut bit_count: usize = 0;
90-
// For each word, merge the bits into self, and accumulate a new bit_count.
91-
for i in zip(
92-
self.data.iter_mut(),
93-
other
94-
.chunks(8)
95-
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
96-
) {
97-
*i.0 |= i.1;
98-
bit_count += i.0.count_ones() as usize;
80+
for (word, chunk) in self.data.iter_mut().zip(incoming.chunks_exact(8)) {
81+
*word |= u64::from_be_bytes(chunk.try_into().unwrap());
82+
bit_count += word.count_ones() as usize;
9983
}
10084
self.bit_count = bit_count;
10185
}
@@ -108,6 +92,37 @@ pub fn num_words(num_bits: usize) -> usize {
10892
#[cfg(test)]
10993
mod test {
11094
use super::*;
95+
use arrow::datatypes::ToByteSlice;
96+
use std::iter::zip;
97+
98+
impl SparkBitArray {
99+
fn byte_size(&self) -> usize {
100+
self.word_size() * 8
101+
}
102+
103+
fn to_bytes(&self) -> Vec<u8> {
104+
Vec::from(self.data.to_byte_slice())
105+
}
106+
107+
/// Combines SparkBitArrays, however other is a &[u8] because we anticipate to come from
108+
/// an Arrow ScalarValue::Binary which is a byte vector underneath, rather than a word
109+
/// vector.
110+
fn merge_bits(&mut self, other: &[u8]) {
111+
assert_eq!(self.byte_size(), other.len());
112+
let mut bit_count: usize = 0;
113+
// For each word, merge the bits into self, and accumulate a new bit_count.
114+
for i in zip(
115+
self.data.iter_mut(),
116+
other
117+
.chunks(8)
118+
.map(|chunk| u64::from_ne_bytes(chunk.try_into().unwrap())),
119+
) {
120+
*i.0 |= i.1;
121+
bit_count += i.0.count_ones() as usize;
122+
}
123+
self.bit_count = bit_count;
124+
}
125+
}
111126

112127
#[test]
113128
fn test_spark_bit_array() {

native/spark-expr/src/bloom_filter/spark_bloom_filter.rs

Lines changed: 157 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use arrow::array::{ArrowNativeTypeOp, BooleanArray, Int64Array};
1919
use arrow::datatypes::ToByteSlice;
20+
use datafusion::common::{DataFusionError, Result as DFResult};
2021
use std::cmp;
2122

2223
use crate::bloom_filter::spark_bit_array;
@@ -271,17 +272,72 @@ impl SparkBloomFilter {
271272
.collect()
272273
}
273274

275+
/// Number of set bits in the underlying bit array. Mirrors Spark's
276+
/// `BloomFilter.cardinality()`: a filter that has seen no non-null input
277+
/// has cardinality 0.
278+
pub fn cardinality(&self) -> usize {
279+
self.bits.cardinality()
280+
}
281+
274282
pub fn state_as_bytes(&self) -> Vec<u8> {
275-
self.bits.to_bytes()
283+
self.spark_serialization()
276284
}
277285

278-
pub fn merge_filter(&mut self, other: &[u8]) {
279-
assert_eq!(
280-
other.len(),
281-
self.bits.byte_size(),
282-
"Cannot merge SparkBloomFilters with different lengths."
283-
);
284-
self.bits.merge_bits(other);
286+
pub fn merge_filter(&mut self, other: &[u8]) -> DFResult<()> {
287+
let mut offset = 0;
288+
289+
let version_int = read_num_be_bytes!(i32, 4, other[offset..]);
290+
offset += 4;
291+
if version_int != self.version.to_int() {
292+
return Err(DataFusionError::Internal(format!(
293+
"BloomFilter merge: version mismatch (got {}, expected {})",
294+
version_int,
295+
self.version.to_int(),
296+
)));
297+
}
298+
299+
let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32;
300+
offset += 4;
301+
if num_hash != self.num_hash_functions {
302+
return Err(DataFusionError::Internal(format!(
303+
"BloomFilter merge: num_hash_functions mismatch (got {}, expected {})",
304+
num_hash, self.num_hash_functions,
305+
)));
306+
}
307+
308+
if let SparkBloomFilterVersion::V2 = self.version {
309+
let seed = read_num_be_bytes!(i32, 4, other[offset..]);
310+
offset += 4;
311+
if seed != self.seed {
312+
return Err(DataFusionError::Internal(format!(
313+
"BloomFilter merge: seed mismatch (got {}, expected {})",
314+
seed, self.seed,
315+
)));
316+
}
317+
}
318+
319+
let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize;
320+
offset += 4;
321+
if num_words != self.bits.word_size() {
322+
return Err(DataFusionError::Internal(format!(
323+
"BloomFilter merge: num_words mismatch (got {}, expected {})",
324+
num_words,
325+
self.bits.word_size(),
326+
)));
327+
}
328+
329+
let expected_bytes = num_words * 8;
330+
if other.len() - offset < expected_bytes {
331+
return Err(DataFusionError::Internal(format!(
332+
"BloomFilter merge: truncated bit array (got {} bytes, expected {})",
333+
other.len() - offset,
334+
expected_bytes,
335+
)));
336+
}
337+
338+
self.bits
339+
.merge_be_words(&other[offset..offset + expected_bytes]);
340+
Ok(())
285341
}
286342
}
287343

@@ -396,4 +452,97 @@ mod tests {
396452
buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes
397453
let _ = SparkBloomFilter::from(buf.as_slice());
398454
}
455+
456+
/// Two V1 filters with identical parameters. Populate the first, serialize via
457+
/// state_as_bytes, merge into the empty second, and verify the second contains
458+
/// everything the first did. Exercises the aggregator state → merge_batch path.
459+
#[test]
460+
fn state_round_trip_v1_merge() {
461+
let num_bits = 1024;
462+
let num_hash = optimal_num_hash_functions(100, num_bits);
463+
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
464+
for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
465+
a.put_long(v);
466+
}
467+
468+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
469+
b.merge_filter(&a.state_as_bytes()).unwrap();
470+
471+
for v in [1_i64, 7, 42, 99, -3, i64::MAX] {
472+
assert!(b.might_contain_long(v), "missing {v} after merge");
473+
}
474+
}
475+
476+
/// V2 default seed (0) round-trip through state_as_bytes → merge_filter.
477+
#[test]
478+
fn state_round_trip_v2_default_seed() {
479+
let num_bits = 1024;
480+
let num_hash = optimal_num_hash_functions(100, num_bits);
481+
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
482+
for v in [11_i64, 222, 3333] {
483+
a.put_long(v);
484+
}
485+
486+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
487+
b.merge_filter(&a.state_as_bytes()).unwrap();
488+
489+
for v in [11_i64, 222, 3333] {
490+
assert!(b.might_contain_long(v));
491+
}
492+
}
493+
494+
/// V2 non-zero seed round-trip; verifies the seed field is parsed and that
495+
/// both filters use the same seed-dependent hash scattering.
496+
#[test]
497+
fn state_round_trip_v2_nonzero_seed() {
498+
let num_bits = 1024;
499+
let num_hash = optimal_num_hash_functions(100, num_bits);
500+
let seed = 0x5eed_5eed_u32 as i32;
501+
let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
502+
a.put_long(123);
503+
504+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed);
505+
b.merge_filter(&a.state_as_bytes()).unwrap();
506+
507+
assert!(b.might_contain_long(123));
508+
}
509+
510+
fn assert_merge_err_contains(filter: &mut SparkBloomFilter, buf: &[u8], needle: &str) {
511+
let err = filter.merge_filter(buf).unwrap_err().to_string();
512+
assert!(err.contains(needle), "expected `{needle}` in error: {err}");
513+
}
514+
515+
#[test]
516+
fn merge_rejects_version_mismatch() {
517+
let num_bits = 1024;
518+
let num_hash = optimal_num_hash_functions(100, num_bits);
519+
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0);
520+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0);
521+
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "version mismatch");
522+
}
523+
524+
#[test]
525+
fn merge_rejects_num_hash_mismatch() {
526+
let num_bits = 1024;
527+
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0);
528+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0);
529+
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_hash_functions mismatch");
530+
}
531+
532+
#[test]
533+
fn merge_rejects_seed_mismatch_v2() {
534+
let num_bits = 1024;
535+
let num_hash = optimal_num_hash_functions(100, num_bits);
536+
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1);
537+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2);
538+
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "seed mismatch");
539+
}
540+
541+
#[test]
542+
fn merge_rejects_num_words_mismatch() {
543+
let num_hash = 5;
544+
let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0);
545+
let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0);
546+
assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_words mismatch");
547+
}
399548
}

0 commit comments

Comments
 (0)