-
Notifications
You must be signed in to change notification settings - Fork 77
Expand file tree
/
Copy pathlua.go
More file actions
125 lines (101 loc) · 2.55 KB
/
lua.go
File metadata and controls
125 lines (101 loc) · 2.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package rdns
import (
"errors"
"fmt"
"strings"
"github.com/miekg/dns"
)
type Lua struct {
id string
resolvers []Resolver
scripts chan *LuaScript
bytecode ByteCode
opt LuaOptions
}
var _ Resolver = &Lua{}
type LuaOptions struct {
Script string
Concurrency uint
NoSandbox bool // Disables the sandbox. When false (default), scripts cannot access os/io/debug/etc.
}
func NewLua(id string, opt LuaOptions, resolvers ...Resolver) (*Lua, error) {
if opt.Concurrency == 0 {
opt.Concurrency = 4
}
// Compile the script
bytecode, err := LuaCompile(strings.NewReader(opt.Script), id)
if err != nil {
return nil, err
}
r := &Lua{
id: id,
resolvers: resolvers,
opt: opt,
scripts: make(chan *LuaScript, opt.Concurrency),
bytecode: bytecode,
}
// Initialize scripts
for range opt.Concurrency {
s, err := r.newScript()
if err != nil {
return nil, err
}
r.scripts <- s
}
return r, nil
}
func (r *Lua) Resolve(q *dns.Msg, ci ClientInfo) (*dns.Msg, error) {
s := <-r.scripts
defer func() { r.scripts <- s }()
log := logger(r.id, q, ci)
// Call the "resolve" function in the script. It should return 2 values.
ret, err := s.Call("Resolve", 2, q, ci)
if err != nil {
log.Error("failed to run lua script", "error", err)
return nil, err
}
// Extract the answer and error from the returned values
if len(ret) != 2 {
return nil, fmt.Errorf("invalid return value, expected 2, got %d", len(ret))
}
answer, ok := ret[0].(*dns.Msg)
if ret[0] != nil && !ok {
return nil, fmt.Errorf("invalid return value, expected Message, got %T", ret[0])
}
err, ok = ret[1].(error)
if ret[1] != nil && !ok {
return nil, fmt.Errorf("invalid return value, expected Error, got %T", ret[1])
}
return answer, err
}
func (r *Lua) String() string {
return r.id
}
func (r *Lua) Close() {
close(r.scripts)
for s := range r.scripts {
s.L.Close()
}
}
func (r *Lua) newScript() (*LuaScript, error) {
s, err := NewScriptFromByteCode(r.bytecode, !r.opt.NoSandbox)
if err != nil {
return nil, err
}
// Register types and methods
s.RegisterConstants()
s.RegisterMessageType()
s.RegisterQuestionType()
s.RegisterRRTypes()
s.RegisterOPTType()
s.RegisterEDNS0Types()
s.RegisterErrorType()
s.RegisterClientInfoType()
// Inject the resolvers into the state (so they can be used in the script)
s.InjectResolvers(r.resolvers)
// The script must contain a Resolve() function which is the entry point
if !s.HasFunction("Resolve") {
return nil, errors.New("no Resolve() function found in lua script")
}
return s, nil
}