Skip to content

Commit d001d54

Browse files
authored
Merge 'Window functions plumbing 1: add stubs for builtins, refactor row_number() to use VDBE aggregate machinery' from Jussi Saurio
1. Introduce stubs for all builtin window functions 2. rename `WindowFunctionKind` to `AccumulatorFunc`, because both window functions and aggregate functions will use the same `op_agg_step` mechanism 3. remove specialcased `row_number()` handling and route it through `op_agg_step`, which is what other window functions will use as well Reviewed-by: Mikaël Francoeur (@LeMikaelF) Closes #7358
2 parents 466a5f5 + e6cdadd commit d001d54

16 files changed

Lines changed: 614 additions & 417 deletions

core/function.rs

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -460,29 +460,146 @@ pub enum AggFunc {
460460
External(Arc<ExtFunc>),
461461
}
462462

463-
#[derive(Debug, Clone, Copy, PartialEq, Eq, strum::EnumIter)]
463+
#[derive(Debug, Clone, strum::EnumIter)]
464464
pub enum WindowFunc {
465465
RowNumber,
466+
Rank,
467+
DenseRank,
468+
PercentRank,
469+
CumeDist,
470+
Ntile,
471+
Lag,
472+
Lead,
473+
FirstValue,
474+
LastValue,
475+
NthValue,
476+
#[strum(disabled)]
477+
External(Arc<ExtFunc>),
466478
}
467479

468480
impl WindowFunc {
481+
/// SQL name of this window function. Matches the strings used by
482+
/// `Display` so EXPLAIN output and error messages agree.
483+
pub fn as_str(&self) -> &'static str {
484+
match self {
485+
Self::RowNumber => "row_number",
486+
Self::Rank => "rank",
487+
Self::DenseRank => "dense_rank",
488+
Self::PercentRank => "percent_rank",
489+
Self::CumeDist => "cume_dist",
490+
Self::Ntile => "ntile",
491+
Self::Lag => "lag",
492+
Self::Lead => "lead",
493+
Self::FirstValue => "first_value",
494+
Self::LastValue => "last_value",
495+
Self::NthValue => "nth_value",
496+
Self::External(_) => unreachable!(
497+
"WindowFunc::External is not constructible: ExtFunc has no Window variant"
498+
),
499+
}
500+
}
501+
469502
pub fn arities(&self) -> &'static [i32] {
470503
match self {
471-
Self::RowNumber => &[0],
504+
Self::RowNumber | Self::Rank | Self::DenseRank | Self::PercentRank | Self::CumeDist => {
505+
&[0]
506+
}
507+
Self::Ntile | Self::FirstValue | Self::LastValue => &[1],
508+
Self::NthValue => &[2],
509+
Self::Lag | Self::Lead => &[1, 2, 3],
510+
Self::External(_) => unreachable!(
511+
"WindowFunc::External is not constructible: ExtFunc has no Window variant"
512+
),
513+
}
514+
}
515+
516+
/// Whether name resolution + runtime dispatch are wired up. Stub variants
517+
/// must not be advertised via `pragma_function_list`, or introspection
518+
/// drifts ahead of the resolver and users get "no such function" when
519+
/// they try to call them.
520+
pub fn is_implemented(&self) -> bool {
521+
matches!(self, Self::RowNumber)
522+
}
523+
}
524+
525+
impl PartialEq for WindowFunc {
526+
fn eq(&self, other: &Self) -> bool {
527+
match (self, other) {
528+
(Self::RowNumber, Self::RowNumber)
529+
| (Self::Rank, Self::Rank)
530+
| (Self::DenseRank, Self::DenseRank)
531+
| (Self::PercentRank, Self::PercentRank)
532+
| (Self::CumeDist, Self::CumeDist)
533+
| (Self::Ntile, Self::Ntile)
534+
| (Self::Lag, Self::Lag)
535+
| (Self::Lead, Self::Lead)
536+
| (Self::FirstValue, Self::FirstValue)
537+
| (Self::LastValue, Self::LastValue)
538+
| (Self::NthValue, Self::NthValue) => true,
539+
(Self::External(a), Self::External(b)) => Arc::ptr_eq(a, b),
540+
_ => false,
472541
}
473542
}
474543
}
475544

545+
impl Eq for WindowFunc {}
546+
476547
impl Deterministic for WindowFunc {
477548
fn is_deterministic(&self) -> bool {
478-
true
549+
match self {
550+
Self::RowNumber
551+
| Self::Rank
552+
| Self::DenseRank
553+
| Self::PercentRank
554+
| Self::CumeDist
555+
| Self::Ntile
556+
| Self::Lag
557+
| Self::Lead
558+
| Self::FirstValue
559+
| Self::LastValue
560+
| Self::NthValue => true,
561+
Self::External(_) => unreachable!(
562+
"WindowFunc::External is not constructible: ExtFunc has no Window variant"
563+
),
564+
}
479565
}
480566
}
481567

482568
impl std::fmt::Display for WindowFunc {
483569
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
570+
f.write_str(self.as_str())
571+
}
572+
}
573+
574+
/// Function reference used by AggStep / AggValue / AggFinal opcodes.
575+
/// Aggregates used in window context and pure window functions share the same
576+
/// step/value dispatch path; this enum carries which side of that split a
577+
/// particular call belongs to.
578+
#[derive(Debug, Clone)]
579+
pub enum AccumulatorFunc {
580+
Agg(AggFunc),
581+
Window(WindowFunc),
582+
}
583+
584+
impl AccumulatorFunc {
585+
/// Extract the inner `AggFunc` when this kind is known to be an
586+
/// aggregate. `unreachable!`s on `Window(...)` — the only opcodes
587+
/// that carry an `AccumulatorFunc` are the AggStep / AggValue /
588+
/// AggFinal trio, and the call sites that emit those wrap aggregates
589+
/// only. A `Window` value reaching here is a planner bug.
590+
pub fn expect_agg(&self) -> &AggFunc {
484591
match self {
485-
Self::RowNumber => write!(f, "row_number"),
592+
Self::Agg(f) => f,
593+
Self::Window(f) => {
594+
unreachable!("window function {f} reached an aggregate-only dispatch path")
595+
}
596+
}
597+
}
598+
599+
pub fn as_str(&self) -> &'static str {
600+
match self {
601+
Self::Agg(f) => f.as_str(),
602+
Self::Window(f) => f.as_str(),
486603
}
487604
}
488605
}
@@ -1721,8 +1838,11 @@ impl Func {
17211838
push(f.to_string(), "w", f.arities(), f.is_deterministic());
17221839
}
17231840

1724-
// Window functions.
1841+
// Window functions (skip stub variants until they're wired up).
17251842
for f in WindowFunc::iter() {
1843+
if !f.is_implemented() {
1844+
continue;
1845+
}
17261846
push(f.to_string(), "w", f.arities(), f.is_deterministic());
17271847
}
17281848

core/translate/aggregation.rs

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use turso_parser::ast;
22

33
use crate::{
4-
function::AggFunc,
4+
function::{AccumulatorFunc, AggFunc},
55
schema::Table,
66
sync::Arc,
77
translate::collate::CollationSeq,
@@ -36,7 +36,7 @@ pub fn emit_ungrouped_aggregation<'a>(
3636
let agg_result_reg = agg_start_reg + i;
3737
program.emit_insn(Insn::AggFinal {
3838
register: agg_result_reg,
39-
func: agg.func.clone(),
39+
func: AccumulatorFunc::Agg(agg.func.clone()),
4040
});
4141
}
4242
// we now have the agg results in (agg_start_reg..agg_start_reg + aggregates.len() - 1)
@@ -364,7 +364,7 @@ pub fn translate_aggregation_step(
364364
acc_reg: target_register,
365365
col: expr_reg,
366366
delimiter: 0,
367-
func: AggFunc::Avg,
367+
func: AccumulatorFunc::Agg(AggFunc::Avg),
368368
comparator: None,
369369
});
370370
target_register
@@ -377,7 +377,7 @@ pub fn translate_aggregation_step(
377377
acc_reg: target_register,
378378
col: expr_reg,
379379
delimiter: 0,
380-
func: AggFunc::Count0,
380+
func: AccumulatorFunc::Agg(AggFunc::Count0),
381381
comparator: None,
382382
});
383383
target_register
@@ -392,7 +392,7 @@ pub fn translate_aggregation_step(
392392
acc_reg: target_register,
393393
col: expr_reg,
394394
delimiter: 0,
395-
func: AggFunc::Count,
395+
func: AccumulatorFunc::Agg(AggFunc::Count),
396396
comparator: None,
397397
});
398398
target_register
@@ -417,7 +417,7 @@ pub fn translate_aggregation_step(
417417
acc_reg: target_register,
418418
col: expr_reg,
419419
delimiter: delimiter_reg,
420-
func: AggFunc::GroupConcat,
420+
func: AccumulatorFunc::Agg(AggFunc::GroupConcat),
421421
comparator: None,
422422
});
423423

@@ -437,7 +437,7 @@ pub fn translate_aggregation_step(
437437
acc_reg: target_register,
438438
col: expr_reg,
439439
delimiter: 0,
440-
func: AggFunc::Max,
440+
func: AccumulatorFunc::Agg(AggFunc::Max),
441441
comparator,
442442
});
443443
target_register
@@ -456,7 +456,7 @@ pub fn translate_aggregation_step(
456456
acc_reg: target_register,
457457
col: expr_reg,
458458
delimiter: 0,
459-
func: AggFunc::Min,
459+
func: AccumulatorFunc::Agg(AggFunc::Min),
460460
comparator,
461461
});
462462
target_register
@@ -474,7 +474,7 @@ pub fn translate_aggregation_step(
474474
acc_reg: target_register,
475475
col: expr_reg,
476476
delimiter: value_reg,
477-
func: AggFunc::JsonGroupObject,
477+
func: AccumulatorFunc::Agg(AggFunc::JsonGroupObject),
478478
comparator: None,
479479
});
480480
target_register
@@ -490,7 +490,7 @@ pub fn translate_aggregation_step(
490490
acc_reg: target_register,
491491
col: expr_reg,
492492
delimiter: 0,
493-
func: AggFunc::JsonGroupArray,
493+
func: AccumulatorFunc::Agg(AggFunc::JsonGroupArray),
494494
comparator: None,
495495
});
496496
target_register
@@ -508,7 +508,7 @@ pub fn translate_aggregation_step(
508508
acc_reg: target_register,
509509
col: expr_reg,
510510
delimiter: delimiter_reg,
511-
func: AggFunc::StringAgg,
511+
func: AccumulatorFunc::Agg(AggFunc::StringAgg),
512512
comparator: None,
513513
});
514514

@@ -524,7 +524,7 @@ pub fn translate_aggregation_step(
524524
acc_reg: target_register,
525525
col: expr_reg,
526526
delimiter: 0,
527-
func: AggFunc::Sum,
527+
func: AccumulatorFunc::Agg(AggFunc::Sum),
528528
comparator: None,
529529
});
530530
target_register
@@ -539,7 +539,7 @@ pub fn translate_aggregation_step(
539539
acc_reg: target_register,
540540
col: expr_reg,
541541
delimiter: 0,
542-
func: AggFunc::Total,
542+
func: AccumulatorFunc::Agg(AggFunc::Total),
543543
comparator: None,
544544
});
545545
target_register
@@ -555,7 +555,7 @@ pub fn translate_aggregation_step(
555555
acc_reg: target_register,
556556
col: expr_reg,
557557
delimiter: 0,
558-
func: AggFunc::ArrayAgg,
558+
func: AccumulatorFunc::Agg(AggFunc::ArrayAgg),
559559
comparator: None,
560560
});
561561
target_register
@@ -573,7 +573,7 @@ pub fn translate_aggregation_step(
573573
acc_reg: target_register,
574574
col: value_reg,
575575
delimiter: 0,
576-
func: AggFunc::Mode,
576+
func: AccumulatorFunc::Agg(AggFunc::Mode),
577577
comparator: None,
578578
});
579579
target_register
@@ -595,7 +595,7 @@ pub fn translate_aggregation_step(
595595
acc_reg: target_register,
596596
col: value_reg,
597597
delimiter: fraction_reg,
598-
func: func.clone(),
598+
func: AccumulatorFunc::Agg(func.clone()),
599599
comparator: None,
600600
});
601601
target_register
@@ -630,11 +630,11 @@ pub fn translate_aggregation_step(
630630
acc_reg: target_register,
631631
col: expr_reg,
632632
delimiter: 0,
633-
func: AggFunc::External(if registered_argc < 0 {
633+
func: AccumulatorFunc::Agg(AggFunc::External(if registered_argc < 0 {
634634
Arc::new(func.with_aggregate_arg_count(num_args))
635635
} else {
636636
func.clone()
637-
}),
637+
})),
638638
comparator: None,
639639
});
640640
target_register

core/translate/group_by.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use super::{
66
plan::{Distinctness, GroupBy, SelectPlan, SubqueryEvalPhase, SubqueryOrigin},
77
result_row::emit_select_result,
88
};
9+
use crate::function::AccumulatorFunc;
910
use crate::translate::{
1011
aggregation::{translate_aggregation_step, AggArgumentSource},
1112
order_by::{custom_type_comparator, EmitOrderBy},
@@ -1014,7 +1015,7 @@ pub fn group_by_emit_row_phase<'a>(
10141015
let agg_result_reg = agg_start_reg + i;
10151016
program.emit_insn(Insn::AggFinal {
10161017
register: agg_result_reg,
1017-
func: agg.func.clone(),
1018+
func: AccumulatorFunc::Agg(agg.func.clone()),
10181019
});
10191020
t_ctx.resolver.cache_expr_reg(
10201021
std::borrow::Cow::Owned(agg.original_expr.clone()),

core/translate/main_loop/body.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ fn emit_loop_source<'a>(
266266
acc_reg: start_reg,
267267
col: expr_reg,
268268
delimiter: 0,
269-
func: min_max.func.clone(),
269+
func: crate::function::AccumulatorFunc::Agg(min_max.func.clone()),
270270
comparator,
271271
});
272272
program.emit_insn(Insn::Goto {

core/translate/plan.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::{
22
alloc::{self, TursoIteratorExt},
3-
function::{AggFunc, WindowFunc},
3+
function::{AccumulatorFunc, AggFunc},
44
schema::{
55
BTreeTable, ColDef, Column, FromClauseSubquery, Index, Schema, Table, Type, ROWID_SENTINEL,
66
},
@@ -3090,12 +3090,6 @@ impl Window {
30903090
}
30913091
}
30923092

3093-
#[derive(Debug, Clone)]
3094-
pub enum WindowFunctionKind {
3095-
Agg(AggFunc),
3096-
Window(WindowFunc),
3097-
}
3098-
30993093
/// One window function call belonging to a `Window`.
31003094
///
31013095
/// Window queries are planned by wrapping the original FROM/WHERE in a
@@ -3108,7 +3102,7 @@ pub enum WindowFunctionKind {
31083102
pub struct WindowFunction {
31093103
/// The resolved function. Aggregate window functions and specialized window
31103104
/// functions such as ROW_NUMBER() are supported.
3111-
pub func: WindowFunctionKind,
3105+
pub func: AccumulatorFunc,
31123106
/// The expression from which the function was resolved. Used as the lookup
31133107
/// key when matching SQL occurrences back to this entry during rewriting.
31143108
pub original_expr: Expr,

0 commit comments

Comments
 (0)