|
17 | 17 |
|
18 | 18 | use arrow::array::{ArrowNativeTypeOp, BooleanArray, Int64Array}; |
19 | 19 | use arrow::datatypes::ToByteSlice; |
| 20 | +use datafusion::common::{DataFusionError, Result as DFResult}; |
20 | 21 | use std::cmp; |
21 | 22 |
|
22 | 23 | use crate::bloom_filter::spark_bit_array; |
@@ -271,17 +272,72 @@ impl SparkBloomFilter { |
271 | 272 | .collect() |
272 | 273 | } |
273 | 274 |
|
| 275 | + /// Number of set bits in the underlying bit array. Mirrors Spark's |
| 276 | + /// `BloomFilter.cardinality()`: a filter that has seen no non-null input |
| 277 | + /// has cardinality 0. |
| 278 | + pub fn cardinality(&self) -> usize { |
| 279 | + self.bits.cardinality() |
| 280 | + } |
| 281 | + |
274 | 282 | pub fn state_as_bytes(&self) -> Vec<u8> { |
275 | | - self.bits.to_bytes() |
| 283 | + self.spark_serialization() |
276 | 284 | } |
277 | 285 |
|
278 | | - pub fn merge_filter(&mut self, other: &[u8]) { |
279 | | - assert_eq!( |
280 | | - other.len(), |
281 | | - self.bits.byte_size(), |
282 | | - "Cannot merge SparkBloomFilters with different lengths." |
283 | | - ); |
284 | | - self.bits.merge_bits(other); |
| 286 | + pub fn merge_filter(&mut self, other: &[u8]) -> DFResult<()> { |
| 287 | + let mut offset = 0; |
| 288 | + |
| 289 | + let version_int = read_num_be_bytes!(i32, 4, other[offset..]); |
| 290 | + offset += 4; |
| 291 | + if version_int != self.version.to_int() { |
| 292 | + return Err(DataFusionError::Internal(format!( |
| 293 | + "BloomFilter merge: version mismatch (got {}, expected {})", |
| 294 | + version_int, |
| 295 | + self.version.to_int(), |
| 296 | + ))); |
| 297 | + } |
| 298 | + |
| 299 | + let num_hash = read_num_be_bytes!(i32, 4, other[offset..]) as u32; |
| 300 | + offset += 4; |
| 301 | + if num_hash != self.num_hash_functions { |
| 302 | + return Err(DataFusionError::Internal(format!( |
| 303 | + "BloomFilter merge: num_hash_functions mismatch (got {}, expected {})", |
| 304 | + num_hash, self.num_hash_functions, |
| 305 | + ))); |
| 306 | + } |
| 307 | + |
| 308 | + if let SparkBloomFilterVersion::V2 = self.version { |
| 309 | + let seed = read_num_be_bytes!(i32, 4, other[offset..]); |
| 310 | + offset += 4; |
| 311 | + if seed != self.seed { |
| 312 | + return Err(DataFusionError::Internal(format!( |
| 313 | + "BloomFilter merge: seed mismatch (got {}, expected {})", |
| 314 | + seed, self.seed, |
| 315 | + ))); |
| 316 | + } |
| 317 | + } |
| 318 | + |
| 319 | + let num_words = read_num_be_bytes!(i32, 4, other[offset..]) as usize; |
| 320 | + offset += 4; |
| 321 | + if num_words != self.bits.word_size() { |
| 322 | + return Err(DataFusionError::Internal(format!( |
| 323 | + "BloomFilter merge: num_words mismatch (got {}, expected {})", |
| 324 | + num_words, |
| 325 | + self.bits.word_size(), |
| 326 | + ))); |
| 327 | + } |
| 328 | + |
| 329 | + let expected_bytes = num_words * 8; |
| 330 | + if other.len() - offset < expected_bytes { |
| 331 | + return Err(DataFusionError::Internal(format!( |
| 332 | + "BloomFilter merge: truncated bit array (got {} bytes, expected {})", |
| 333 | + other.len() - offset, |
| 334 | + expected_bytes, |
| 335 | + ))); |
| 336 | + } |
| 337 | + |
| 338 | + self.bits |
| 339 | + .merge_be_words(&other[offset..offset + expected_bytes]); |
| 340 | + Ok(()) |
285 | 341 | } |
286 | 342 | } |
287 | 343 |
|
@@ -396,4 +452,97 @@ mod tests { |
396 | 452 | buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes |
397 | 453 | let _ = SparkBloomFilter::from(buf.as_slice()); |
398 | 454 | } |
| 455 | + |
| 456 | + /// Two V1 filters with identical parameters. Populate the first, serialize via |
| 457 | + /// state_as_bytes, merge into the empty second, and verify the second contains |
| 458 | + /// everything the first did. Exercises the aggregator state → merge_batch path. |
| 459 | + #[test] |
| 460 | + fn state_round_trip_v1_merge() { |
| 461 | + let num_bits = 1024; |
| 462 | + let num_hash = optimal_num_hash_functions(100, num_bits); |
| 463 | + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); |
| 464 | + for v in [1_i64, 7, 42, 99, -3, i64::MAX] { |
| 465 | + a.put_long(v); |
| 466 | + } |
| 467 | + |
| 468 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); |
| 469 | + b.merge_filter(&a.state_as_bytes()).unwrap(); |
| 470 | + |
| 471 | + for v in [1_i64, 7, 42, 99, -3, i64::MAX] { |
| 472 | + assert!(b.might_contain_long(v), "missing {v} after merge"); |
| 473 | + } |
| 474 | + } |
| 475 | + |
| 476 | + /// V2 default seed (0) round-trip through state_as_bytes → merge_filter. |
| 477 | + #[test] |
| 478 | + fn state_round_trip_v2_default_seed() { |
| 479 | + let num_bits = 1024; |
| 480 | + let num_hash = optimal_num_hash_functions(100, num_bits); |
| 481 | + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); |
| 482 | + for v in [11_i64, 222, 3333] { |
| 483 | + a.put_long(v); |
| 484 | + } |
| 485 | + |
| 486 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); |
| 487 | + b.merge_filter(&a.state_as_bytes()).unwrap(); |
| 488 | + |
| 489 | + for v in [11_i64, 222, 3333] { |
| 490 | + assert!(b.might_contain_long(v)); |
| 491 | + } |
| 492 | + } |
| 493 | + |
| 494 | + /// V2 non-zero seed round-trip; verifies the seed field is parsed and that |
| 495 | + /// both filters use the same seed-dependent hash scattering. |
| 496 | + #[test] |
| 497 | + fn state_round_trip_v2_nonzero_seed() { |
| 498 | + let num_bits = 1024; |
| 499 | + let num_hash = optimal_num_hash_functions(100, num_bits); |
| 500 | + let seed = 0x5eed_5eed_u32 as i32; |
| 501 | + let mut a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); |
| 502 | + a.put_long(123); |
| 503 | + |
| 504 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, seed); |
| 505 | + b.merge_filter(&a.state_as_bytes()).unwrap(); |
| 506 | + |
| 507 | + assert!(b.might_contain_long(123)); |
| 508 | + } |
| 509 | + |
| 510 | + fn assert_merge_err_contains(filter: &mut SparkBloomFilter, buf: &[u8], needle: &str) { |
| 511 | + let err = filter.merge_filter(buf).unwrap_err().to_string(); |
| 512 | + assert!(err.contains(needle), "expected `{needle}` in error: {err}"); |
| 513 | + } |
| 514 | + |
| 515 | + #[test] |
| 516 | + fn merge_rejects_version_mismatch() { |
| 517 | + let num_bits = 1024; |
| 518 | + let num_hash = optimal_num_hash_functions(100, num_bits); |
| 519 | + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 0); |
| 520 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, num_bits, 0); |
| 521 | + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "version mismatch"); |
| 522 | + } |
| 523 | + |
| 524 | + #[test] |
| 525 | + fn merge_rejects_num_hash_mismatch() { |
| 526 | + let num_bits = 1024; |
| 527 | + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 5, num_bits, 0); |
| 528 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 7, num_bits, 0); |
| 529 | + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_hash_functions mismatch"); |
| 530 | + } |
| 531 | + |
| 532 | + #[test] |
| 533 | + fn merge_rejects_seed_mismatch_v2() { |
| 534 | + let num_bits = 1024; |
| 535 | + let num_hash = optimal_num_hash_functions(100, num_bits); |
| 536 | + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 1); |
| 537 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V2, num_hash, num_bits, 2); |
| 538 | + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "seed mismatch"); |
| 539 | + } |
| 540 | + |
| 541 | + #[test] |
| 542 | + fn merge_rejects_num_words_mismatch() { |
| 543 | + let num_hash = 5; |
| 544 | + let a = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 512, 0); |
| 545 | + let mut b = SparkBloomFilter::new(SparkBloomFilterVersion::V1, num_hash, 1024, 0); |
| 546 | + assert_merge_err_contains(&mut b, &a.state_as_bytes(), "num_words mismatch"); |
| 547 | + } |
399 | 548 | } |
0 commit comments