Skip to content

Commit dd925fd

Browse files
authored
Add mem grow to builtin (#505)
* Add mem grow to builtin * Add OpValidateArgs opcode
1 parent 7b890a1 commit dd925fd

File tree

6 files changed

+63
-24
lines changed

6 files changed

+63
-24
lines changed

builtin/builtin.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@ import (
1313
)
1414

1515
type Function struct {
16-
Name string
17-
Func func(args ...any) (any, error)
18-
Fast func(arg any) any
19-
Types []reflect.Type
20-
Validate func(args []reflect.Type) (reflect.Type, error)
21-
Predicate bool
16+
Name string
17+
Func func(args ...any) (any, error)
18+
Fast func(arg any) any
19+
ValidateArgs func(args ...any) (any, error)
20+
Types []reflect.Type
21+
Validate func(args []reflect.Type) (reflect.Type, error)
22+
Predicate bool
2223
}
2324

2425
var (
@@ -325,12 +326,15 @@ var Builtins = []*Function{
325326
},
326327
{
327328
Name: "repeat",
328-
Func: func(args ...any) (any, error) {
329+
ValidateArgs: func(args ...any) (any, error) {
329330
n := runtime.ToInt(args[1])
330-
if n > 1e6 {
331-
panic("memory budget exceeded")
331+
if n < 0 {
332+
panic(fmt.Errorf("invalid argument for repeat (expected positive integer, got %d)", n))
332333
}
333-
return strings.Repeat(args[0].(string), n), nil
334+
return uint(n), nil
335+
},
336+
Func: func(args ...any) (any, error) {
337+
return strings.Repeat(args[0].(string), runtime.ToInt(args[1])), nil
334338
},
335339
Types: types(strings.Repeat),
336340
},

builtin/builtin_test.go

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,26 @@ func TestBuiltin_memory_limits(t *testing.T) {
260260

261261
for _, test := range tests {
262262
t.Run(test.input, func(t *testing.T) {
263-
_, err := expr.Eval(test.input, nil)
264-
assert.Error(t, err)
265-
assert.Contains(t, err.Error(), "memory budget exceeded")
263+
timeout := make(chan bool, 1)
264+
go func() {
265+
time.Sleep(time.Second)
266+
timeout <- true
267+
}()
268+
269+
done := make(chan bool, 1)
270+
go func() {
271+
_, err := expr.Eval(test.input, nil)
272+
assert.Error(t, err)
273+
assert.Contains(t, err.Error(), "memory budget exceeded")
274+
done <- true
275+
}()
276+
277+
select {
278+
case <-done:
279+
// Success.
280+
case <-timeout:
281+
t.Fatal("timeout")
282+
}
266283
})
267284
}
268285
}

compiler/compiler.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,31 +147,31 @@ func (c *compiler) addVariable(name string) int {
147147
func (c *compiler) emitFunction(fn *builtin.Function, argsLen int) {
148148
switch argsLen {
149149
case 0:
150-
c.emit(OpCall0, c.addFunction(fn))
150+
c.emit(OpCall0, c.addFunction(fn.Name, fn.Func))
151151
case 1:
152-
c.emit(OpCall1, c.addFunction(fn))
152+
c.emit(OpCall1, c.addFunction(fn.Name, fn.Func))
153153
case 2:
154-
c.emit(OpCall2, c.addFunction(fn))
154+
c.emit(OpCall2, c.addFunction(fn.Name, fn.Func))
155155
case 3:
156-
c.emit(OpCall3, c.addFunction(fn))
156+
c.emit(OpCall3, c.addFunction(fn.Name, fn.Func))
157157
default:
158-
c.emit(OpLoadFunc, c.addFunction(fn))
158+
c.emit(OpLoadFunc, c.addFunction(fn.Name, fn.Func))
159159
c.emit(OpCallN, argsLen)
160160
}
161161
}
162162

163163
// addFunction adds builtin.Function.Func to the program.functions and returns its index.
164-
func (c *compiler) addFunction(fn *builtin.Function) int {
164+
func (c *compiler) addFunction(name string, fn Function) int {
165165
if fn == nil {
166166
panic("function is nil")
167167
}
168-
if p, ok := c.functionsIndex[fn.Name]; ok {
168+
if p, ok := c.functionsIndex[name]; ok {
169169
return p
170170
}
171171
p := len(c.functions)
172-
c.functions = append(c.functions, fn.Func)
173-
c.functionsIndex[fn.Name] = p
174-
c.debugInfo[fmt.Sprintf("func_%d", p)] = fn.Name
172+
c.functions = append(c.functions, fn)
173+
c.functionsIndex[name] = p
174+
c.debugInfo[fmt.Sprintf("func_%d", p)] = name
175175
return p
176176
}
177177

@@ -904,6 +904,12 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
904904
for _, arg := range node.Arguments {
905905
c.compile(arg)
906906
}
907+
908+
if f.ValidateArgs != nil {
909+
c.emit(OpLoadFunc, c.addFunction("$_validate_args_"+f.Name, f.ValidateArgs))
910+
c.emit(OpValidateArgs, len(node.Arguments))
911+
}
912+
907913
if f.Fast != nil {
908914
c.emit(OpCallBuiltin1, id)
909915
} else if f.Func != nil {

vm/opcodes.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ const (
6060
OpCallFast
6161
OpCallTyped
6262
OpCallBuiltin1
63+
OpValidateArgs
6364
OpArray
6465
OpMap
6566
OpLen

vm/program.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ func (program *Program) DisassembleWriter(w io.Writer) {
137137
constant("OpLoadMethod")
138138

139139
case OpLoadFunc:
140-
argument("OpLoadFunc")
140+
argumentWithInfo("OpLoadFunc", "func")
141141

142142
case OpLoadEnv:
143143
code("OpLoadEnv")
@@ -278,6 +278,9 @@ func (program *Program) DisassembleWriter(w io.Writer) {
278278
case OpCallBuiltin1:
279279
builtinArg("OpCallBuiltin1")
280280

281+
case OpValidateArgs:
282+
argument("OpValidateArgs")
283+
281284
case OpArray:
282285
code("OpArray")
283286

vm/vm.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,14 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
397397
case OpCallBuiltin1:
398398
vm.push(builtin.Builtins[arg].Fast(vm.pop()))
399399

400+
case OpValidateArgs:
401+
fn := vm.pop().(Function)
402+
mem, err := fn(vm.stack[len(vm.stack)-arg:]...)
403+
if err != nil {
404+
panic(err)
405+
}
406+
vm.memGrow(mem.(uint))
407+
400408
case OpArray:
401409
size := vm.pop().(int)
402410
vm.memGrow(uint(size))

0 commit comments

Comments
 (0)