Skip to content

Add predicate to sum() builtin #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions builtin/builtin.go
Original file line number Diff line number Diff line change
@@ -83,6 +83,11 @@ var Builtins = []*Function{
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "sum",
Predicate: true,
Types: types(new(func([]any, func(any) bool) int)),
},
{
Name: "groupBy",
Predicate: true,
@@ -387,13 +392,6 @@ var Builtins = []*Function{
return validateAggregateFunc("min", args)
},
},
{
Name: "sum",
Func: sum,
Validate: func(args []reflect.Type) (reflect.Type, error) {
return validateAggregateFunc("sum", args)
},
},
{
Name: "mean",
Func: func(args ...any) (any, error) {
4 changes: 0 additions & 4 deletions builtin/builtin_test.go
Original file line number Diff line number Diff line change
@@ -90,9 +90,6 @@ func TestBuiltin(t *testing.T) {
{`sum([.5, 1.5, 2.5])`, 4.5},
{`sum([])`, 0},
{`sum([1, 2, 3.0, 4])`, 10.0},
{`sum(10, [1, 2, 3], 1..9)`, 61},
{`sum(-10, [1, 2, 3, 4])`, 0},
{`sum(-10.9, [1, 2, 3, 4, 9])`, 8.1},
{`mean(1..9)`, 5.0},
{`mean([.5, 1.5, 2.5])`, 1.5},
{`mean([])`, 0.0},
@@ -219,7 +216,6 @@ func TestBuiltin_errors(t *testing.T) {
{`min([1, "2"])`, `invalid argument for min (type string)`},
{`median(1..9, "t")`, "invalid argument for median (type string)"},
{`mean("s", 1..9)`, "invalid argument for mean (type string)"},
{`sum("s", "h")`, "invalid argument for sum (type string)"},
{`duration("error")`, `invalid duration`},
{`date("error")`, `invalid date`},
{`get()`, `invalid number of arguments (expected 2, got 0)`},
39 changes: 0 additions & 39 deletions builtin/lib.go
Original file line number Diff line number Diff line change
@@ -258,45 +258,6 @@ func String(arg any) any {
return fmt.Sprintf("%v", arg)
}

func sum(args ...any) (any, error) {
var total int
var fTotal float64

for _, arg := range args {
rv := reflect.ValueOf(deref.Deref(arg))

switch rv.Kind() {
case reflect.Array, reflect.Slice:
size := rv.Len()
for i := 0; i < size; i++ {
elemSum, err := sum(rv.Index(i).Interface())
if err != nil {
return nil, err
}
switch elemSum := elemSum.(type) {
case int:
total += elemSum
case float64:
fTotal += elemSum
}
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
total += int(rv.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
total += int(rv.Uint())
case reflect.Float32, reflect.Float64:
fTotal += rv.Float()
default:
return nil, fmt.Errorf("invalid argument for sum (type %T)", arg)
}
}

if fTotal != 0.0 {
return fTotal + float64(total), nil
}
return total, nil
}

func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
var val any
for _, arg := range args {
23 changes: 23 additions & 0 deletions checker/checker.go
Original file line number Diff line number Diff line change
@@ -668,6 +668,29 @@ func (v *checker) BuiltinNode(node *ast.BuiltinNode) (reflect.Type, info) {
}
return v.error(node.Arguments[1], "predicate should has one input and one output param")

case "sum":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
return v.error(node.Arguments[0], "builtin %v takes only array (got %v)", node.Name, collection)
}

if len(node.Arguments) == 2 {
v.begin(collection)
closure, _ := v.visit(node.Arguments[1])
v.end()

if isFunc(closure) &&
closure.NumOut() == 1 &&
closure.NumIn() == 1 && isAny(closure.In(0)) {
return closure.Out(0), info{}
}
} else {
if isAny(collection) {
return anyType, info{}
}
return collection.Elem(), info{}
}

case "find", "findLast":
collection, _ := v.visit(node.Arguments[0])
if !isArray(collection) && !isAny(collection) {
19 changes: 19 additions & 0 deletions compiler/compiler.go
Original file line number Diff line number Diff line change
@@ -809,6 +809,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
c.emit(OpEnd)
return

case "sum":
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emit(OpInt, 0)
c.emit(OpSetAcc)
c.emitLoop(func() {
if len(node.Arguments) == 2 {
c.compile(node.Arguments[1])
} else {
c.emit(OpPointer)
}
c.emit(OpGetAcc)
c.emit(OpAdd)
c.emit(OpSetAcc)
})
c.emit(OpGetAcc)
c.emit(OpEnd)
return

case "find":
c.compile(node.Arguments[0])
c.emit(OpBegin)
1 change: 1 addition & 0 deletions parser/parser.go
Original file line number Diff line number Diff line change
@@ -34,6 +34,7 @@ var predicates = map[string]struct {
"filter": {[]arg{expr, closure}},
"map": {[]arg{expr, closure}},
"count": {[]arg{expr, closure}},
"sum": {[]arg{expr, closure | optional}},
"find": {[]arg{expr, closure}},
"findIndex": {[]arg{expr, closure}},
"findLast": {[]arg{expr, closure}},
1 change: 0 additions & 1 deletion test/fuzz/fuzz_corpus.txt
Original file line number Diff line number Diff line change
@@ -10455,7 +10455,6 @@ max(f64, i64)
max(false ? 1 : 0.5)
max(false ? 1 : nil)
max(false ? add : ok)
max(false ? half : list)
max(false ? i : nil)
max(false ? i32 : score)
max(false ? true : 1)
14 changes: 0 additions & 14 deletions testdata/examples.txt
Original file line number Diff line number Diff line change
@@ -7419,12 +7419,6 @@ get(ok ? score : foo, String?.foo())
get(ok ? score : i64, foo)
get(reduce(list, array), i32)
get(sort(array), i32)
get(sum(array), Qux)
get(sum(array), String)
get(sum(array), f32)
get(sum(array), f64 == list)
get(sum(array), greet)
get(sum(array), i)
get(take(list, i), i64)
get(true ? "bar" : ok, score(i))
get(true ? "foo" : half, list)
@@ -7460,7 +7454,6 @@ greet != nil ? list : false
greet != score
greet != score != false
greet != score or ok
greet != sum(array)
greet == add
greet == add ? i : list
greet == add or ok
@@ -12200,7 +12193,6 @@ last(ok ? ok : 0.5)
last(reduce(array, list))
last(reduce(list, array))
last(sort(array))
last(sum(array))
last(true ? "bar" : half)
last(true ? add : list)
last(true ? foo : 1)
@@ -14818,7 +14810,6 @@ ok != nil ? nil : array
ok != not ok
ok != ok
ok != ok ? false : "bar"
ok != sum(array)
ok && !false
ok && !ok
ok && "foo" matches "bar"
@@ -16970,7 +16961,6 @@ string(groupBy(list, i))
string(half != nil)
string(half != score)
string(half == nil)
string(half == sum(array))
string(half(0.5))
string(half(1))
string(half(f64))
@@ -17297,18 +17287,14 @@ sum([0.5])
sum([f32])
sum(array)
sum(array) != f32
sum(array) != half
sum(array) != ok
sum(array) % i
sum(array) % i64
sum(array) - f32
sum(array) / -f64
sum(array) < i
sum(array) == div
sum(array) == i64 - i
sum(array) ^ f64
sum(array) not in array
sum(array) not in list
sum(filter(array, ok))
sum(groupBy(array, i32).String)
sum(groupBy(list, #)?.greet)