Skip to content

Commit a39dfad

Browse files
authored
feat(func): support median function (#3755)
Signed-off-by: Song Gao <disxiaofei@163.com>
1 parent 232aa83 commit a39dfad

File tree

4 files changed

+166
-0
lines changed

4 files changed

+166
-0
lines changed

docs/en_US/sqls/functions/aggregate_functions.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ Examples:
177177
SELECT deduplicate(a, false)->a as r1 FROM demo GROUP BY SlidingWindow(hh, 1)
178178
```
179179

180+
## MEDIAN
181+
182+
```text
183+
median(col)
184+
```
185+
186+
Returns the median value of expression in the group.
187+
180188
## STDDEV
181189

182190
```text

docs/zh_CN/sqls/functions/aggregate_functions.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,14 @@ deduplicate(col, false)
148148
SELECT deduplicate(a, false)->a as r1 FROM demo GROUP BY SlidingWindow(hh, 1)
149149
```
150150

151+
## MEDIAN
152+
153+
```text
154+
median(col)
155+
```
156+
157+
返回组中所有值的中位数。
158+
151159
## STDDEV
152160

153161
```text

internal/binder/function/funcs_agg.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package function
1616

1717
import (
1818
"fmt"
19+
"sort"
1920

2021
"github.com/lf-edge/ekuiper/contract/v2/api"
2122
"github.com/montanaflynn/stats"
@@ -25,6 +26,33 @@ import (
2526
)
2627

2728
func registerAggFunc() {
29+
builtins["median"] = builtinFunc{
30+
fType: ast.FuncTypeAgg,
31+
exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
32+
arg0 := args[0].([]interface{})
33+
if len(arg0) < 1 {
34+
return int64(0), true
35+
}
36+
switch arg0[0].(type) {
37+
case float64:
38+
f64s, err := cast.ToFloat64Slice(arg0, cast.CONVERT_SAMEKIND, cast.IGNORE_NIL)
39+
if err != nil {
40+
return err, false
41+
}
42+
return median(f64s), true
43+
case int64, int:
44+
i64s, err := cast.ToInt64Slice(arg0, cast.CONVERT_SAMEKIND)
45+
if err != nil {
46+
return err, false
47+
}
48+
return median(i64s), true
49+
default:
50+
return fmt.Errorf("%v should be number", arg0[0]), false
51+
}
52+
},
53+
val: ValidateOneNumberArg,
54+
check: returnNilIfHasAnyNil,
55+
}
2856
builtins["avg"] = builtinFunc{
2957
fType: ast.FuncTypeAgg,
3058
exec: func(ctx api.FunctionContext, args []interface{}) (interface{}, bool) {
@@ -379,3 +407,22 @@ func registerAggFunc() {
379407
check: returnNilIfHasAnyNil,
380408
}
381409
}
410+
411+
type Number interface {
412+
int64 | float64
413+
}
414+
415+
func median[T Number](nums []T) interface{} {
416+
sort.Slice(nums, func(i, j int) bool {
417+
return nums[i] < nums[j]
418+
})
419+
n := len(nums)
420+
if n == 0 {
421+
return 0
422+
}
423+
if n%2 == 1 {
424+
return nums[n/2]
425+
} else {
426+
return float64((nums[n/2-1])+(nums[n/2])) / 2
427+
}
428+
}

internal/binder/function/funcs_agg_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,106 @@ func TestLastValueValidation(t *testing.T) {
667667
}
668668
}
669669
}
670+
671+
func TestMedian(t *testing.T) {
672+
tests := []struct {
673+
name string
674+
input interface{}
675+
expected interface{}
676+
}{
677+
{
678+
name: "single int64 element",
679+
input: []int64{5},
680+
expected: int64(5),
681+
},
682+
{
683+
name: "odd length int64 slice",
684+
input: []int64{1, 3, 5, 7, 9},
685+
expected: int64(5),
686+
},
687+
{
688+
name: "even length int64 slice",
689+
input: []int64{1, 3, 5, 7, 9, 11},
690+
expected: float64(6),
691+
},
692+
{
693+
name: "unsorted int64 slice",
694+
input: []int64{9, 1, 5, 3, 7},
695+
expected: int64(5),
696+
},
697+
{
698+
name: "empty float64 slice",
699+
input: []float64{},
700+
expected: 0,
701+
},
702+
{
703+
name: "single float64 element",
704+
input: []float64{5.5},
705+
expected: 5.5,
706+
},
707+
{
708+
name: "odd length float64 slice",
709+
input: []float64{1.1, 3.3, 5.5, 7.7, 9.9},
710+
expected: 5.5,
711+
},
712+
{
713+
name: "even length float64 slice",
714+
input: []float64{1.1, 3.3, 5.5, 7.7, 9.9, 11.1},
715+
expected: 6.6,
716+
},
717+
{
718+
name: "unsorted float64 slice",
719+
input: []float64{9.9, 1.1, 5.5, 3.3, 7.7},
720+
expected: 5.5,
721+
},
722+
}
723+
for _, tt := range tests {
724+
t.Run(tt.name, func(t *testing.T) {
725+
var result interface{}
726+
switch v := tt.input.(type) {
727+
case []int64:
728+
result = median(v)
729+
case []float64:
730+
result = median(v)
731+
default:
732+
t.Fatalf("unsupported input type: %T", v)
733+
}
734+
require.Equal(t, tt.expected, result)
735+
})
736+
}
737+
}
738+
739+
func TestMedianFunc(t *testing.T) {
740+
fm, ok := builtins["median"]
741+
require.True(t, ok)
742+
contextLogger := conf.Log.WithField("rule", "testExec")
743+
ctx := kctx.WithValue(kctx.Background(), kctx.LoggerKey, contextLogger)
744+
tempStore, _ := state.CreateStore("mockRule0", def.AtMostOnce)
745+
fctx := kctx.NewDefaultFuncContext(ctx.WithMeta("mockRule0", "test", tempStore), 2)
746+
tests := []struct {
747+
args []interface{}
748+
expect interface{}
749+
}{
750+
{
751+
args: []interface{}{},
752+
expect: int64(0),
753+
},
754+
{
755+
args: []interface{}{
756+
int64(5), int64(1), int64(2), int64(3), int64(4),
757+
},
758+
expect: int64(3),
759+
},
760+
{
761+
args: []interface{}{
762+
float64(5), float64(1), float64(2), float64(3), float64(4),
763+
},
764+
expect: float64(3),
765+
},
766+
}
767+
for _, tt := range tests {
768+
got, ok := fm.exec(fctx, []interface{}{tt.args})
769+
require.True(t, ok)
770+
require.Equal(t, tt.expect, got)
771+
}
772+
}

0 commit comments

Comments
 (0)