Skip to content

Commit 40cccfc

Browse files
sdk: clean up ABI aggregates funcs
1 parent 220a7a8 commit 40cccfc

2 files changed

Lines changed: 44 additions & 186 deletions

File tree

mysql-test/suite/villagesql/std_data/aggregate_vdf.cc

Lines changed: 39 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,15 @@
1616

1717
// Test aggregate VDFs for VillageSQL.
1818
//
19-
// This extension demonstrates different ways to implement aggregate VDFs,
20-
// ordered from most idiomatic to most manual:
19+
// This extension demonstrates different aggregate VDF implementations using
20+
// make_aggregate_func. All three callbacks use typed C++ signatures:
2121
//
22-
// 1. vdf_sum(INT) -> INT: Typed API with std::optional<long long> as state.
23-
// Uses .state<T>() for automatic allocation, typed clear/accumulate
24-
// callbacks taking State&, and a result function returning
25-
// std::optional<long long> (nullopt -> SQL NULL).
22+
// 1. vdf_sum(INT) -> INT: std::optional<long long> state (nullable result).
23+
// 2. vdf_count(INT) -> INT: plain long long state (non-nullable result).
24+
// 3. vdf_concat(STRING) -> STRING: std::optional<std::string> state.
25+
// 4. vdf_max(INT) -> INT: struct state with typed clear/accumulate/result.
2626
//
27-
// 2. vdf_count(INT) -> INT: Typed API with plain long long as state. Same
28-
// pattern as vdf_sum but the result is never NULL, so the result function
29-
// returns long long directly.
30-
//
31-
// 3. vdf_concat(STRING) -> STRING: Typed API with std::optional<std::string>
32-
// as state. Shows that the typed aggregate API works with string types.
33-
//
34-
// 4. vdf_max(INT) -> INT: Raw ABI with manual prerun/postrun and explicit
35-
// user_data casts. For authors who prefer direct control over state
36-
// management or need custom prerun/postrun logic beyond simple allocation.
37-
//
38-
// Also includes simple_double(INT) -> INT as a scalar VDF for testing mixed
39-
// scalar/aggregate queries.
27+
// Also includes simple_double(INT) -> INT as a scalar VDF for mixed testing.
4028

4129
#include <villagesql/vsql.h>
4230

@@ -46,7 +34,6 @@
4634
using namespace vsql;
4735

4836
// vdf_sum: aggregate that sums INT values, returns NULL for empty groups.
49-
// State is optional<long long> — nullopt means no non-NULL values seen.
5037

5138
using SumState = std::optional<long long>;
5239

@@ -58,7 +45,13 @@ void vdf_sum_accumulate(SumState &state, IntArg val) {
5845
}
5946
}
6047

61-
std::optional<long long> vdf_sum_result(const SumState &state) { return state; }
48+
void vdf_sum_result(const SumState &state, IntResult out) {
49+
if (!state.has_value()) {
50+
out.set_null();
51+
return;
52+
}
53+
out.set(state.value());
54+
}
6255

6356
// vdf_count: aggregate that counts non-NULL INT values (always returns a value)
6457

@@ -72,10 +65,11 @@ void vdf_count_accumulate(CountState &state, IntArg val) {
7265
}
7366
}
7467

75-
long long vdf_count_result(const CountState &state) { return state; }
68+
void vdf_count_result(const CountState &state, IntResult out) {
69+
out.set(state);
70+
}
7671

7772
// vdf_concat: aggregate that concatenates STRING values with commas.
78-
// Returns NULL for empty groups.
7973

8074
using ConcatState = std::optional<std::string>;
8175

@@ -92,57 +86,41 @@ void vdf_concat_accumulate(ConcatState &state, StringArg val) {
9286
}
9387
}
9488

95-
std::optional<std::string> vdf_concat_result(const ConcatState &state) {
96-
return state;
89+
void vdf_concat_result(const ConcatState &state, StringResult out) {
90+
if (!state.has_value()) {
91+
out.set_null();
92+
return;
93+
}
94+
out.set(state.value());
9795
}
9896

9997
// vdf_max: aggregate that returns the maximum INT value.
100-
// Uses raw ABI callbacks with manual prerun/postrun instead of the typed API,
101-
// for authors who prefer explicit control over state management.
10298

10399
struct MaxState {
104100
long long max_val = 0;
105101
bool has_value = false;
106102
};
107103

108-
void vdf_max_prerun(vef_context_t *, vef_prerun_args_t *,
109-
vef_prerun_result_t *result) {
110-
result->user_data = new MaxState{};
111-
result->type = VEF_RESULT_VALUE;
112-
}
113-
114-
void vdf_max_postrun(vef_context_t *, vef_postrun_args_t *args,
115-
vef_postrun_result_t *) {
116-
delete static_cast<MaxState *>(args->user_data);
104+
void vdf_max_clear(MaxState &state) {
105+
state.max_val = 0;
106+
state.has_value = false;
117107
}
118108

119-
void vdf_max_clear(vef_context_t *, vef_vdf_args_t *args) {
120-
auto *state = static_cast<MaxState *>(args->user_data);
121-
state->max_val = 0;
122-
state->has_value = false;
123-
}
124-
125-
void vdf_max_accumulate(vef_context_t *ctx, vef_vdf_args_t *args,
126-
vef_vdf_result_t *) {
127-
auto *state = static_cast<MaxState *>(args->user_data);
128-
vef_invalue_t val = vsql::func_builder::get_invalue(ctx, args, 0);
129-
if (!val.is_null) {
130-
if (!state->has_value || val.int_value > state->max_val) {
131-
state->max_val = val.int_value;
132-
state->has_value = true;
109+
void vdf_max_accumulate(MaxState &state, IntArg val) {
110+
if (!val.is_null()) {
111+
if (!state.has_value || val.value() > state.max_val) {
112+
state.max_val = val.value();
113+
state.has_value = true;
133114
}
134115
}
135116
}
136117

137-
void vdf_max_result(vef_context_t *, vef_vdf_args_t *args,
138-
vef_vdf_result_t *out) {
139-
auto *state = static_cast<MaxState *>(args->user_data);
140-
if (!state->has_value) {
141-
out->type = VEF_RESULT_NULL;
118+
void vdf_max_result(const MaxState &state, IntResult out) {
119+
if (!state.has_value) {
120+
out.set_null();
142121
return;
143122
}
144-
out->int_value = state->max_val;
145-
out->type = VEF_RESULT_VALUE;
123+
out.set(state.max_val);
146124
}
147125

148126
// Scalar function that doubles an INT value (for mixed testing)
@@ -156,32 +134,27 @@ void simple_double_impl(IntArg input, IntResult out) {
156134

157135
VEF_GENERATE_ENTRY_POINTS(
158136
make_extension()
159-
.func(make_func<&vdf_sum_result>("vdf_sum")
137+
.func(make_aggregate_func<SumState, &vdf_sum_result>("vdf_sum")
160138
.returns(INT)
161139
.param(INT)
162-
.state<SumState>()
163140
.clear<&vdf_sum_clear>()
164141
.accumulate<&vdf_sum_accumulate>()
165142
.build())
166-
.func(make_func<&vdf_count_result>("vdf_count")
143+
.func(make_aggregate_func<CountState, &vdf_count_result>("vdf_count")
167144
.returns(INT)
168145
.param(INT)
169-
.state<CountState>()
170146
.clear<&vdf_count_clear>()
171147
.accumulate<&vdf_count_accumulate>()
172148
.build())
173-
.func(make_func<&vdf_concat_result>("vdf_concat")
149+
.func(make_aggregate_func<ConcatState, &vdf_concat_result>("vdf_concat")
174150
.returns(STRING)
175151
.param(STRING)
176-
.state<ConcatState>()
177152
.clear<&vdf_concat_clear>()
178153
.accumulate<&vdf_concat_accumulate>()
179154
.build())
180-
.func(make_func<&vdf_max_result>("vdf_max")
155+
.func(make_aggregate_func<MaxState, &vdf_max_result>("vdf_max")
181156
.returns(INT)
182157
.param(INT)
183-
.prerun<&vdf_max_prerun>()
184-
.postrun<&vdf_max_postrun>()
185158
.clear<&vdf_max_clear>()
186159
.accumulate<&vdf_max_accumulate>()
187160
.build())

villagesql/sdk/include/villagesql/vsql/func_builder.h

Lines changed: 5 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -289,55 +289,6 @@ struct AggResultWithOutputWrapper {
289289
}
290290
};
291291

292-
// Wraps T(const State&) or optional<T>(const State&) -> vef_vdf_func_t
293-
template <typename State, auto Func>
294-
struct AggResultWrapper {
295-
static void invoke(vef_context_t *, vef_vdf_args_t *args,
296-
vef_vdf_result_t *result) {
297-
const auto &state = *static_cast<State *>(args->user_data);
298-
write_result(Func(state), result);
299-
}
300-
301-
private:
302-
template <typename T>
303-
static void write_result(const std::optional<T> &val,
304-
vef_vdf_result_t *result) {
305-
if (!val.has_value()) {
306-
result->type = VEF_RESULT_NULL;
307-
} else {
308-
write_scalar(*val, result);
309-
}
310-
}
311-
312-
template <typename T>
313-
static void write_result(const T &val, vef_vdf_result_t *result) {
314-
write_scalar(val, result);
315-
}
316-
317-
static void write_scalar(long long v, vef_vdf_result_t *r) {
318-
r->int_value = v;
319-
r->type = VEF_RESULT_VALUE;
320-
}
321-
322-
static void write_scalar(double v, vef_vdf_result_t *r) {
323-
r->real_value = v;
324-
r->type = VEF_RESULT_VALUE;
325-
}
326-
327-
static void write_scalar(const std::string &v, vef_vdf_result_t *r) {
328-
if (v.size() > r->max_str_len) {
329-
r->type = VEF_RESULT_ERROR;
330-
snprintf(r->error_msg, VEF_MAX_ERROR_LEN,
331-
"aggregate result (%zu bytes) exceeds buffer (%zu bytes)",
332-
v.size(), r->max_str_len);
333-
return;
334-
}
335-
memcpy(r->str_buf, v.data(), v.size());
336-
r->actual_len = v.size();
337-
r->type = VEF_RESULT_VALUE;
338-
}
339-
};
340-
341292
// =============================================================================
342293
// Wrapper Template
343294
// =============================================================================
@@ -861,8 +812,6 @@ struct FuncBuilder {
861812
buffer_size_(0),
862813
prerun_(nullptr),
863814
postrun_(nullptr),
864-
clear_(nullptr),
865-
accumulate_(nullptr),
866815
deterministic_(false) {}
867816

868817
const char *name_;
@@ -871,8 +820,6 @@ struct FuncBuilder {
871820
size_t buffer_size_;
872821
vef_prerun_func_t prerun_;
873822
vef_postrun_func_t postrun_;
874-
vef_vdf_clear_func_t clear_;
875-
vef_vdf_accumulate_func_t accumulate_;
876823
bool deterministic_;
877824

878825
constexpr FuncBuilder<Func, NumParams> &returns(const char *t) {
@@ -887,8 +834,6 @@ struct FuncBuilder {
887834
next.buffer_size_ = buffer_size_;
888835
next.prerun_ = prerun_;
889836
next.postrun_ = postrun_;
890-
next.clear_ = clear_;
891-
next.accumulate_ = accumulate_;
892837
next.deterministic_ = deterministic_;
893838
for (size_t i = 0; i < NumParams; ++i) {
894839
next.param_types_[i] = param_types_[i];
@@ -919,74 +864,17 @@ struct FuncBuilder {
919864
return *this;
920865
}
921866

922-
// TODO(villagesql-beta): remove .clear(), .accumulate(), and .state() from
923-
// FuncBuilder once all callers have migrated to make_aggregate_func, which
924-
// validates all three callback signatures against State at compile time.
925-
// Raw vef_vdf_clear_func_t / vef_vdf_accumulate_func_t are accepted here
926-
// for backward compatibility with existing extensions.
927-
template <auto Fn>
928-
constexpr FuncBuilder<Func, NumParams> &clear() {
929-
if constexpr (std::is_same_v<decltype(Fn), vef_vdf_clear_func_t>) {
930-
clear_ = Fn;
931-
} else {
932-
using Params = typename FuncParamTypes<decltype(Fn)>::type;
933-
using State = std::remove_reference_t<std::tuple_element_t<0, Params>>;
934-
clear_ = &agg_clear_wrapper<State, Fn>;
935-
}
936-
return *this;
937-
}
938-
939-
template <auto Fn>
940-
constexpr FuncBuilder<Func, NumParams> &accumulate() {
941-
if constexpr (std::is_same_v<decltype(Fn), vef_vdf_accumulate_func_t>) {
942-
accumulate_ = Fn;
943-
} else {
944-
using Params = typename FuncParamTypes<decltype(Fn)>::type;
945-
using State = std::remove_reference_t<std::tuple_element_t<0, Params>>;
946-
accumulate_ = &AggAccumulateWrapper<State, Fn, NumParams>::invoke;
947-
}
948-
return *this;
949-
}
950-
951-
template <typename State>
952-
constexpr FuncBuilder<Func, NumParams> &state() {
953-
prerun_ = &auto_prerun<State>;
954-
postrun_ = &auto_postrun<State>;
955-
return *this;
956-
}
957-
958867
constexpr StaticFuncDesc<NumParams> build() const {
959868
static_assert(NumParams <= kMaxParams,
960869
"Too many parameters (max is kMaxParams)");
961-
if ((clear_ == nullptr) != (accumulate_ == nullptr)) {
962-
config_error__aggregate_must_set_both_clear_and_accumulate();
963-
}
964870

965871
using AllParams = typename FuncParamTypes<decltype(Func)>::type;
966872
using UniquePTuple = typename unique_params_types<AllParams>::type;
967873

968874
FuncWithMetadata meta{};
969-
if constexpr (std::is_same_v<decltype(Func), vef_vdf_func_t>) {
970-
// Raw vef_vdf_func_t: passed through for backward compatibility.
971-
// TODO(villagesql-beta): remove once all callers use make_aggregate_func.
972-
meta.f = Func;
973-
} else if constexpr (std::tuple_size_v<AllParams> == 1 &&
974-
std::is_lvalue_reference_v<
975-
std::tuple_element_t<0, AllParams>> &&
976-
std::is_const_v<std::remove_reference_t<
977-
std::tuple_element_t<0, AllParams>>>) {
978-
// Typed aggregate result: T(const State&) or optional<T>(const State&).
979-
using State = std::remove_const_t<
980-
std::remove_reference_t<std::tuple_element_t<0, AllParams>>>;
981-
meta.f = &AggResultWrapper<State, Func>::invoke;
982-
} else {
983-
// Typed scalar VDF.
984-
meta.f = &Wrapper<Func, NumParams>::invoke;
985-
}
875+
meta.f = &Wrapper<Func, NumParams>::invoke;
986876
meta.prerun = prerun_;
987877
meta.postrun = postrun_;
988-
meta.clear = clear_;
989-
meta.accumulate = accumulate_;
990878
meta.return_type = to_vef_type(return_type_);
991879
meta.num_params = NumParams;
992880
meta.buffer_size = buffer_size_;
@@ -1007,9 +895,9 @@ struct FuncBuilder {
1007895
// AggFuncBuilder
1008896
// =============================================================================
1009897
//
1010-
// Builder for aggregate VDFs. Prefer make_aggregate_func<State, &result_fn>()
1011-
// over make_func() + .state<>() for aggregates: the State type is explicit,
1012-
// and clear/accumulate signatures are validated against State at compile time.
898+
// Builder for aggregate VDFs. Use make_aggregate_func<State, &result_fn>().
899+
// The State type is explicit, and clear/accumulate signatures are validated
900+
// against State at compile time.
1013901

1014902
template <typename State, auto Func, size_t NumParams>
1015903
struct AggFuncBuilder {
@@ -1155,14 +1043,11 @@ constexpr AggFuncBuilder<State, Func, 0> make_aggregate_func(const char *name) {
11551043
}
11561044

11571045
// make_func<&impl>("name")
1158-
// TODO(villagesql-beta): remove the vef_vdf_func_t exception below once all
1159-
// callers have migrated to make_aggregate_func.
11601046
template <auto Func>
11611047
constexpr FuncBuilder<Func, 0> make_func(const char *name) {
11621048
using AllParams = typename FuncParamTypes<decltype(Func)>::type;
11631049
static_assert(
1164-
std::is_same_v<decltype(Func), vef_vdf_func_t> ||
1165-
std::tuple_size_v<AllParams> == 0 ||
1050+
std::tuple_size_v<AllParams> == 0 ||
11661051
!is_context_param<std::tuple_element_t<0, AllParams>>::value,
11671052
"vsql make_func: deprecated vef_context_t* first parameter not "
11681053
"supported; use a typed function or make_aggregate_func (see "

0 commit comments

Comments
 (0)