Skip to content

Commit 98901dd

Browse files
authored
perf(inline-agg): add BoolAnd and BoolOr accumulator types (#6984)
## Summary Implements BoolAnd and BoolOr accumulators from #6585 (item 7) for the inline grouped aggregation path. Each accumulator holds a per-group `Option<bool>` state; first non-null value seeds the state, subsequent non-null values combine via `&&` (BoolAnd) or `||` (BoolOr). Output dtype is Boolean. Grouping semantics and final query results are unchanged. ## Why `AggExpr::BoolAnd` and `AggExpr::BoolOr` already exist in the DSL and are wired in the fallback path (`src/daft-recordbatch/src/lib.rs` → `Series::bool_and(groups)` / `Series::bool_or(groups)`), but currently fall back to `make_groups + eval_agg_expression` even when the rest of the query qualifies for the inline path. Adding them to the inline accumulator framework completes inline coverage of the standard reducer-style aggregates (Count / Sum / Min / Max / Product / BoolAnd / BoolOr) that all share the same `Vec<Option<T>>` per-group state shape. ## Changes Made - `src/daft-recordbatch/src/ops/inline_agg.rs`: - New `define_bool_and_accum!` and `define_bool_or_accum!` macros (kept separate per the Sum/Product precedent — these are semantically distinct ops with different identity and absorbing elements). - `define_agg_accumulator_enum!` extended with `BoolAnd` and `BoolOr` variants. - `try_create_accumulator` dispatches `AggExpr::BoolAnd(expr)` and `AggExpr::BoolOr(expr)` on `DataType::Boolean`. - `can_inline_agg` adds a separate Boolean-only dtype arm; existing numeric arm for Sum/Min/Max/Product is unchanged. - 5 new tests + 4 helpers. **Implementation note:** `BooleanArray::values()` doesn't expose a `&[bool]` slice because Arrow stores bools bit-packed. The null-free tight loop uses `self.source.to_bitmap()` + `bitmap.value(row_idx)` instead of the `.zip(values().iter())` pattern Sum/Product use over primitive slices. Functionally equivalent, just a different access pattern forced by the storage layout. ## Behavior - Queries with `BoolAnd` / `BoolOr` over Boolean columns now take the inline path instead of falling back to `make_groups + eval_agg_expression`. - Output values identical to the fallback path (verified by inline-vs-fallback tests). - All other agg types and dispatch paths are unchanged. - **Not implemented (deferred):** short-circuit optimization (stop scanning a group once BoolAnd hits `false` / BoolOr hits `true`). Adding a per-row branch to the hot loop would regress non-short-circuiting groups; Sum/Min/Max have analogous opportunities and intentionally don't take them. Revisit if benchmarks show it matters. ## Test Plan - `cargo test -p daft-recordbatch --release inline_agg` — 37 passed (32 pre-existing + 5 new). - `cargo fmt -p daft-recordbatch --check` — clean. - `cargo clippy -p daft-recordbatch --release --features python` — clean, no `#[allow]`s added. New test cases: - `test_inline_bool_and_matches_fallback` — Utf8 keys + Boolean vals (no-null tight loop). - `test_inline_bool_or_matches_fallback` — Utf8 keys + Boolean vals (no-null tight loop, OR semantics). - `test_inline_int_key_bool_and_matches_fallback` — Int64 keys + Boolean vals (FNV int-key fast path). - `test_inline_bool_and_with_nulls_matches_fallback` — Boolean vals with `None` interspersed (exercises null-value branch). - `test_inline_all_null_bool_or_matches_fallback` — all-null vals (exercises empty `Option<bool>` finalize path). ## Related Issues - Part of #6585 (item 7).
1 parent 272cd83 commit 98901dd

1 file changed

Lines changed: 327 additions & 1 deletion

File tree

src/daft-recordbatch/src/ops/inline_agg.rs

Lines changed: 327 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,135 @@ define_minmax_accum!(MaxAccumF64, Float64Type, f64, |a, b| if a.gt(&b) {
368368
b
369369
});
370370

371+
macro_rules! define_bool_and_accum {
372+
($name:ident) => {
373+
struct $name {
374+
accumulators: Vec<Option<bool>>,
375+
source: BooleanArray,
376+
}
377+
378+
impl $name {
379+
fn new(source: BooleanArray) -> Self {
380+
Self {
381+
accumulators: Vec::new(),
382+
source,
383+
}
384+
}
385+
386+
fn init_groups(&mut self, n: usize) {
387+
self.accumulators.resize(n, None);
388+
}
389+
390+
/// Vectorized batch update over pre-computed group_ids.
391+
fn update_batch(&mut self, group_ids: &[u32]) {
392+
let accs = &mut self.accumulators;
393+
if self.source.null_count() == 0 {
394+
// Tight loop: no null checks needed on source values.
395+
let bitmap = self.source.to_bitmap();
396+
for (row_idx, &gid) in group_ids.iter().enumerate() {
397+
let val = bitmap.value(row_idx);
398+
let acc = &mut accs[gid as usize];
399+
*acc = Some(match *acc {
400+
Some(a) => a && val,
401+
None => val,
402+
});
403+
}
404+
} else {
405+
// Source has nulls: check each value.
406+
for (row_idx, &gid) in group_ids.iter().enumerate() {
407+
if let Some(val) = self.source.get(row_idx) {
408+
let acc = &mut accs[gid as usize];
409+
*acc = Some(match *acc {
410+
Some(a) => a && val,
411+
None => val,
412+
});
413+
}
414+
}
415+
}
416+
}
417+
418+
fn finalize(self, name: &str) -> DaftResult<Series> {
419+
let has_nulls = self.accumulators.iter().any(|a| a.is_none());
420+
if has_nulls {
421+
Ok(BooleanArray::from_iter(name, self.accumulators.into_iter()).into_series())
422+
} else {
423+
Ok(BooleanArray::from_values(
424+
name,
425+
self.accumulators.into_iter().map(|opt| opt.unwrap()),
426+
)
427+
.into_series())
428+
}
429+
}
430+
}
431+
};
432+
}
433+
434+
macro_rules! define_bool_or_accum {
435+
($name:ident) => {
436+
struct $name {
437+
accumulators: Vec<Option<bool>>,
438+
source: BooleanArray,
439+
}
440+
441+
impl $name {
442+
fn new(source: BooleanArray) -> Self {
443+
Self {
444+
accumulators: Vec::new(),
445+
source,
446+
}
447+
}
448+
449+
fn init_groups(&mut self, n: usize) {
450+
self.accumulators.resize(n, None);
451+
}
452+
453+
/// Vectorized batch update over pre-computed group_ids.
454+
fn update_batch(&mut self, group_ids: &[u32]) {
455+
let accs = &mut self.accumulators;
456+
if self.source.null_count() == 0 {
457+
// Tight loop: no null checks needed on source values.
458+
let bitmap = self.source.to_bitmap();
459+
for (row_idx, &gid) in group_ids.iter().enumerate() {
460+
let val = bitmap.value(row_idx);
461+
let acc = &mut accs[gid as usize];
462+
*acc = Some(match *acc {
463+
Some(a) => a || val,
464+
None => val,
465+
});
466+
}
467+
} else {
468+
// Source has nulls: check each value.
469+
for (row_idx, &gid) in group_ids.iter().enumerate() {
470+
if let Some(val) = self.source.get(row_idx) {
471+
let acc = &mut accs[gid as usize];
472+
*acc = Some(match *acc {
473+
Some(a) => a || val,
474+
None => val,
475+
});
476+
}
477+
}
478+
}
479+
}
480+
481+
fn finalize(self, name: &str) -> DaftResult<Series> {
482+
let has_nulls = self.accumulators.iter().any(|a| a.is_none());
483+
if has_nulls {
484+
Ok(BooleanArray::from_iter(name, self.accumulators.into_iter()).into_series())
485+
} else {
486+
Ok(BooleanArray::from_values(
487+
name,
488+
self.accumulators.into_iter().map(|opt| opt.unwrap()),
489+
)
490+
.into_series())
491+
}
492+
}
493+
}
494+
};
495+
}
496+
497+
define_bool_and_accum!(BoolAndAccum);
498+
define_bool_or_accum!(BoolOrAccum);
499+
371500
// ---------------------------------------------------------------------------
372501
// AggAccumulator enum — eliminates vtable dispatch in the hot loop
373502
// ---------------------------------------------------------------------------
@@ -442,6 +571,8 @@ define_agg_accumulator_enum!(
442571
MaxU64(MaxAccumU64),
443572
MaxF32(MaxAccumF32),
444573
MaxF64(MaxAccumF64),
574+
BoolAnd(BoolAndAccum),
575+
BoolOr(BoolOrAccum),
445576
);
446577

447578
impl AggAccumulator {
@@ -604,6 +735,34 @@ fn try_create_accumulator(
604735
Float64 => Float64Array => MaxF64(MaxAccumF64),
605736
)
606737
}
738+
AggExpr::BoolAnd(expr) => {
739+
let evaluated = source.eval_agg_child(expr)?;
740+
let name = evaluated.name().to_string();
741+
match evaluated.data_type() {
742+
DataType::Boolean => {
743+
let arr = evaluated.downcast::<BooleanArray>()?;
744+
Ok(Some((
745+
AggAccumulator::BoolAnd(BoolAndAccum::new(arr.clone())),
746+
name,
747+
)))
748+
}
749+
_ => Ok(None),
750+
}
751+
}
752+
AggExpr::BoolOr(expr) => {
753+
let evaluated = source.eval_agg_child(expr)?;
754+
let name = evaluated.name().to_string();
755+
match evaluated.data_type() {
756+
DataType::Boolean => {
757+
let arr = evaluated.downcast::<BooleanArray>()?;
758+
Ok(Some((
759+
AggAccumulator::BoolOr(BoolOrAccum::new(arr.clone())),
760+
name,
761+
)))
762+
}
763+
_ => Ok(None),
764+
}
765+
}
607766
_ => Ok(None),
608767
}
609768
}
@@ -615,8 +774,9 @@ fn try_create_accumulator(
615774
/// Returns true if all agg expressions can be handled by the inline path.
616775
///
617776
/// Requirements:
618-
/// 1. All agg expressions are Count, Sum, Product, Min, or Max.
777+
/// 1. All agg expressions are Count, Sum, Product, Min, Max, BoolAnd, or BoolOr.
619778
/// 2. For Sum/Product/Min/Max, the value column dtype must be a supported numeric type.
779+
/// 3. For BoolAnd/BoolOr, the value column dtype must be Boolean.
620780
///
621781
/// Uses schema-level type inference (`to_field`) instead of expression evaluation
622782
/// to avoid materializing computed columns just for a dtype check.
@@ -630,6 +790,8 @@ pub(super) fn can_inline_agg(to_agg: &[BoundAggExpr], source: &RecordBatch) -> b
630790
| AggExpr::Product(..)
631791
| AggExpr::Min(..)
632792
| AggExpr::Max(..)
793+
| AggExpr::BoolAnd(..)
794+
| AggExpr::BoolOr(..)
633795
)
634796
}) {
635797
return false;
@@ -656,6 +818,13 @@ pub(super) fn can_inline_agg(to_agg: &[BoundAggExpr], source: &RecordBatch) -> b
656818
false
657819
}
658820
}
821+
AggExpr::BoolAnd(expr) | AggExpr::BoolOr(expr) => {
822+
if let Ok(field) = expr.to_field(&source.schema) {
823+
matches!(field.dtype, DataType::Boolean)
824+
} else {
825+
false
826+
}
827+
}
659828
_ => unreachable!("pre-check ensures only supported types reach here"),
660829
})
661830
}
@@ -1894,6 +2063,163 @@ mod tests {
18942063
assert_batches_equal(&inline_result, &fallback_result);
18952064
}
18962065

2066+
// --- BoolAnd / BoolOr tests ---
2067+
2068+
/// Helper for groupby tests with Boolean values (Utf8 keys).
2069+
fn make_bool_val_test_batch() -> (RecordBatch, Vec<BoundExpr>, Schema) {
2070+
let keys = Series::from_arrow(
2071+
Arc::new(Field::new("key", DataType::Utf8)),
2072+
Arc::new(arrow::array::LargeStringArray::from(vec![
2073+
Some("a"),
2074+
Some("b"),
2075+
Some("a"),
2076+
Some("b"),
2077+
Some("a"),
2078+
Some("c"),
2079+
])),
2080+
)
2081+
.unwrap();
2082+
let vals = BooleanArray::from_iter(
2083+
"val",
2084+
vec![
2085+
Some(true),
2086+
Some(true),
2087+
Some(true),
2088+
Some(false),
2089+
Some(false),
2090+
Some(true),
2091+
]
2092+
.into_iter(),
2093+
)
2094+
.into_series();
2095+
let schema = Schema::new(vec![
2096+
Field::new("key", DataType::Utf8),
2097+
Field::new("val", DataType::Boolean),
2098+
]);
2099+
let rb = RecordBatch::from_nonempty_columns(vec![keys, vals]).unwrap();
2100+
let group_by = vec![BoundExpr::try_new(resolved_col("key"), &schema).unwrap()];
2101+
(rb, group_by, schema)
2102+
}
2103+
2104+
/// Helper for groupby tests with Boolean values and Int64 keys (FNV fast path).
2105+
fn make_int_key_bool_val_test_batch() -> (RecordBatch, Vec<BoundExpr>, Schema) {
2106+
let keys = Int64Array::from_iter(
2107+
Field::new("key", DataType::Int64),
2108+
vec![Some(1), Some(2), Some(1), Some(2), Some(1)],
2109+
)
2110+
.into_series();
2111+
let vals = BooleanArray::from_iter(
2112+
"val",
2113+
vec![Some(true), Some(false), Some(true), Some(true), Some(false)].into_iter(),
2114+
)
2115+
.into_series();
2116+
let schema = Schema::new(vec![
2117+
Field::new("key", DataType::Int64),
2118+
Field::new("val", DataType::Boolean),
2119+
]);
2120+
let rb = RecordBatch::from_nonempty_columns(vec![keys, vals]).unwrap();
2121+
let group_by = vec![BoundExpr::try_new(resolved_col("key"), &schema).unwrap()];
2122+
(rb, group_by, schema)
2123+
}
2124+
2125+
/// Helper for groupby tests with nullable Boolean values.
2126+
fn make_bool_val_with_nulls_test_batch() -> (RecordBatch, Vec<BoundExpr>, Schema) {
2127+
let keys = Int64Array::from_iter(
2128+
Field::new("key", DataType::Int64),
2129+
vec![Some(1), Some(2), Some(1), Some(2), Some(1), Some(3)],
2130+
)
2131+
.into_series();
2132+
let vals = BooleanArray::from_iter(
2133+
"val",
2134+
vec![Some(true), None, Some(false), Some(true), None, None].into_iter(),
2135+
)
2136+
.into_series();
2137+
let schema = Schema::new(vec![
2138+
Field::new("key", DataType::Int64),
2139+
Field::new("val", DataType::Boolean),
2140+
]);
2141+
let rb = RecordBatch::from_nonempty_columns(vec![keys, vals]).unwrap();
2142+
let group_by = vec![BoundExpr::try_new(resolved_col("key"), &schema).unwrap()];
2143+
(rb, group_by, schema)
2144+
}
2145+
2146+
/// Helper for groupby where the Boolean value column is entirely null.
2147+
fn make_all_null_bool_val_test_batch() -> (RecordBatch, Vec<BoundExpr>, Schema) {
2148+
let keys = Int64Array::from_iter(
2149+
Field::new("key", DataType::Int64),
2150+
vec![Some(1), Some(2), Some(1), Some(2), Some(1)],
2151+
)
2152+
.into_series();
2153+
let vals = BooleanArray::from_iter("val", vec![None, None, None, None, None].into_iter())
2154+
.into_series();
2155+
let schema = Schema::new(vec![
2156+
Field::new("key", DataType::Int64),
2157+
Field::new("val", DataType::Boolean),
2158+
]);
2159+
let rb = RecordBatch::from_nonempty_columns(vec![keys, vals]).unwrap();
2160+
let group_by = vec![BoundExpr::try_new(resolved_col("key"), &schema).unwrap()];
2161+
(rb, group_by, schema)
2162+
}
2163+
2164+
#[test]
2165+
fn test_inline_bool_and_matches_fallback() {
2166+
let (rb, group_by, schema) = make_bool_val_test_batch();
2167+
let bound_agg =
2168+
vec![BoundAggExpr::try_new(AggExpr::BoolAnd(resolved_col("val")), &schema).unwrap()];
2169+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2170+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2171+
assert_batches_equal(&inline_result, &fallback_result);
2172+
}
2173+
2174+
#[test]
2175+
fn test_inline_bool_or_matches_fallback() {
2176+
let (rb, group_by, schema) = make_bool_val_test_batch();
2177+
let bound_agg =
2178+
vec![BoundAggExpr::try_new(AggExpr::BoolOr(resolved_col("val")), &schema).unwrap()];
2179+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2180+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2181+
assert_batches_equal(&inline_result, &fallback_result);
2182+
}
2183+
2184+
#[test]
2185+
fn test_inline_int_key_bool_and_matches_fallback() {
2186+
let (rb, group_by, schema) = make_int_key_bool_val_test_batch();
2187+
let bound_agg =
2188+
vec![BoundAggExpr::try_new(AggExpr::BoolAnd(resolved_col("val")), &schema).unwrap()];
2189+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2190+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2191+
assert_batches_equal(&inline_result, &fallback_result);
2192+
}
2193+
2194+
#[test]
2195+
fn test_inline_bool_and_with_nulls_matches_fallback() {
2196+
let (rb, group_by, schema) = make_bool_val_with_nulls_test_batch();
2197+
let bound_agg =
2198+
vec![BoundAggExpr::try_new(AggExpr::BoolAnd(resolved_col("val")), &schema).unwrap()];
2199+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2200+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2201+
assert_batches_equal(&inline_result, &fallback_result);
2202+
}
2203+
2204+
#[test]
2205+
fn test_inline_bool_or_with_nulls_matches_fallback() {
2206+
let (rb, group_by, schema) = make_bool_val_with_nulls_test_batch();
2207+
let bound_agg =
2208+
vec![BoundAggExpr::try_new(AggExpr::BoolOr(resolved_col("val")), &schema).unwrap()];
2209+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2210+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2211+
assert_batches_equal(&inline_result, &fallback_result);
2212+
}
2213+
2214+
#[test]
2215+
fn test_inline_all_null_bool_or_matches_fallback() {
2216+
let (rb, group_by, schema) = make_all_null_bool_val_test_batch();
2217+
let bound_agg =
2218+
vec![BoundAggExpr::try_new(AggExpr::BoolOr(resolved_col("val")), &schema).unwrap()];
2219+
let inline_result = rb.agg_groupby_inline(&bound_agg, &group_by).unwrap();
2220+
let fallback_result = rb.agg_groupby_fallback(&bound_agg, &group_by).unwrap();
2221+
assert_batches_equal(&inline_result, &fallback_result);
2222+
}
18972223
// --- Product tests ---
18982224

18992225
#[test]

0 commit comments

Comments
 (0)