Skip to content

Commit 81400af

Browse files
committed
filter: lua: bind fields, createRecord and validateRecord
1 parent e95ae6a commit 81400af

File tree

4 files changed

+234
-3
lines changed

4 files changed

+234
-3
lines changed

filter/all.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ var All = []baker.FilterDesc{
1111
ClearFieldsDesc,
1212
ConcatenateDesc,
1313
LUADesc,
14-
MatchRegexDesc,
1514
NotNullDesc,
1615
PartialCloneDesc,
1716
RegexMatchDesc,

filter/lua.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@ type LUA struct {
3131

3232
func NewLUA(cfg baker.FilterParams) (baker.Filter, error) {
3333
dcfg := cfg.DecodedConfig.(*LUAConfig)
34-
3534
l := lua.NewState()
3635
if err := l.DoFile(dcfg.Script); err != nil {
3736
return nil, fmt.Errorf("can't compile lua script %q: %v", dcfg.Script, err)
3837
}
39-
registerLUARecordType(l)
38+
39+
registerLUATypes(l, cfg.ComponentParams)
40+
4041
// TODO: check function exists
4142
luaFunc := l.GetGlobal(dcfg.FilterName)
4243

@@ -61,6 +62,32 @@ func NewLUA(cfg baker.FilterParams) (baker.Filter, error) {
6162
return f, nil
6263
}
6364

65+
func registerLUATypes(l *lua.LState, comp baker.ComponentParams) {
66+
registerLUARecordType(l)
67+
68+
l.SetGlobal("createRecord", l.NewFunction(func(L *lua.LState) int {
69+
rec := comp.CreateRecord()
70+
ud := recordToLua(l, rec)
71+
L.Push(ud)
72+
return 1
73+
}))
74+
75+
l.SetGlobal("validateRecord", l.NewFunction(func(L *lua.LState) int {
76+
luar := fastcheckLuaRecord(l, 1)
77+
ok, fidx := comp.ValidateRecord(luar.r)
78+
l.Push(lua.LBool(ok))
79+
l.Push(lua.LNumber(fidx))
80+
return 2
81+
}))
82+
83+
// Create the fields table.
84+
fields := l.NewTable()
85+
for i, n := range comp.FieldNames {
86+
fields.RawSetString(n, lua.LNumber(i))
87+
}
88+
l.SetGlobal("fields", fields)
89+
}
90+
6491
func (t *LUA) Stats() baker.FilterStats { return baker.FilterStats{} }
6592

6693
func (t *LUA) Process(rec baker.Record, next func(baker.Record)) {

filter/lua_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package filter
2+
3+
import (
4+
"bytes"
5+
"io/ioutil"
6+
"os"
7+
"path/filepath"
8+
"testing"
9+
10+
"github.com/AdRoll/baker"
11+
)
12+
13+
func BenchmarkLUAProcess(b *testing.B) {
14+
b.ReportAllocs()
15+
const script = `
16+
-- rec is a record object
17+
-- next is function next(record)
18+
function dummy(rec, next)
19+
rec:set(0, "hey")
20+
next(rec)
21+
end
22+
`
23+
24+
dir, err := ioutil.TempDir("", b.Name())
25+
if err != nil {
26+
b.Fatal(err)
27+
}
28+
// fname := filepath.Join(b.TempDir(), "filters.lua")
29+
fname := filepath.Join(dir, "filters.lua")
30+
if err := ioutil.WriteFile(fname, []byte(script), os.ModePerm); err != nil {
31+
b.Fatalf("can't write lua script: %v", err)
32+
}
33+
34+
record := &baker.LogLine{}
35+
36+
fieldByName := func(name string) (baker.FieldIndex, bool) {
37+
switch name {
38+
case "foo":
39+
return 0, true
40+
case "bar":
41+
return 1, true
42+
case "baz":
43+
return 2, true
44+
}
45+
return 0, false
46+
}
47+
48+
f, err := NewLUA(baker.FilterParams{
49+
ComponentParams: baker.ComponentParams{
50+
FieldByName: fieldByName,
51+
DecodedConfig: &LUAConfig{
52+
Script: fname,
53+
FilterName: "dummy",
54+
},
55+
},
56+
})
57+
58+
if err != nil {
59+
b.Fatalf("NewLUA error = %v", err)
60+
}
61+
62+
b.ResetTimer()
63+
for n := 0; n < b.N; n++ {
64+
f.Process(record, func(baker.Record) {})
65+
}
66+
}
67+
68+
func TestLUAFilter(t *testing.T) {
69+
// This is the lua script containing the lua functions used in the test cases.
70+
fname := filepath.Join("testdata", "lua_test.lua")
71+
72+
fieldNames := []string{"foo", "bar", "baz"}
73+
fieldByName := func(name string) (baker.FieldIndex, bool) {
74+
for i, n := range fieldNames {
75+
if n == name {
76+
return baker.FieldIndex(i), true
77+
}
78+
}
79+
80+
return 0, false
81+
}
82+
83+
tests := []struct {
84+
name string // both test case name and lua filter name
85+
record string
86+
wantErr bool // configuration-time error
87+
want [][3]string // contains non discarded records with, for each of them, the 3 fields we want
88+
}{
89+
{
90+
name: "swapFieldsWithIndex",
91+
record: "abc,def,ghi",
92+
want: [][3]string{
93+
{"abc", "ghi", "def"},
94+
},
95+
},
96+
{
97+
name: "swapFieldsWithNames",
98+
record: "abc,def,ghi",
99+
want: [][3]string{
100+
{"abc", "ghi", "def"},
101+
},
102+
},
103+
{
104+
name: "_createRecord",
105+
record: "abc,def,ghi",
106+
want: [][3]string{
107+
{"hey", "ho", "let's go!"},
108+
{"abc", "def", "ghi"},
109+
},
110+
},
111+
{
112+
name: "_validateRecord",
113+
record: "ciao,,",
114+
want: [][3]string{
115+
{"good", "", ""},
116+
}},
117+
}
118+
for _, tt := range tests {
119+
t.Run(tt.name, func(t *testing.T) {
120+
f, err := NewLUA(baker.FilterParams{
121+
ComponentParams: baker.ComponentParams{
122+
FieldByName: fieldByName,
123+
FieldNames: fieldNames,
124+
CreateRecord: func() baker.Record {
125+
return &baker.LogLine{FieldSeparator: ','}
126+
},
127+
ValidateRecord: func(r baker.Record) (bool, baker.FieldIndex) {
128+
if string(r.Get(0)) != "hello" {
129+
return false, 0
130+
}
131+
return true, -1
132+
},
133+
DecodedConfig: &LUAConfig{
134+
Script: fname,
135+
FilterName: tt.name,
136+
},
137+
},
138+
})
139+
140+
if (err != nil) != (tt.wantErr) {
141+
t.Fatalf("got error = %v, want error = %v", err, tt.wantErr)
142+
}
143+
144+
if tt.wantErr {
145+
return
146+
}
147+
148+
l := &baker.LogLine{FieldSeparator: ','}
149+
if err := l.Parse([]byte(tt.record), nil); err != nil {
150+
t.Fatalf("parse error: %q", err)
151+
}
152+
153+
var got []baker.Record
154+
f.Process(l, func(r baker.Record) { got = append(got, r) })
155+
156+
// Check the number of non discarded records match
157+
if len(got) != len(tt.want) {
158+
t.Fatalf("got %d non-discarded records, want %d", len(got), len(tt.want))
159+
}
160+
161+
for recidx, rec := range tt.want {
162+
for fidx, fval := range rec {
163+
f := got[recidx].Get(baker.FieldIndex(fidx))
164+
if !bytes.Equal(f, []byte(fval)) {
165+
t.Errorf("got record[%d].Get(%d) = %q, want %q", recidx, fidx, string(f), fval)
166+
}
167+
}
168+
}
169+
})
170+
}
171+
}

filter/testdata/lua_test.lua

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
function swapFieldsWithIndex(rec, next)
2+
local f1, f2
3+
f1 = rec:get(1)
4+
rec:set(1, rec:get(2))
5+
rec:set(2, f1)
6+
next(rec)
7+
end
8+
9+
function swapFieldsWithNames(rec, next)
10+
local f1, f2
11+
f1 = rec:get(fields["bar"])
12+
rec:set(1, rec:get(fields["baz"]))
13+
rec:set(2, f1)
14+
next(rec)
15+
end
16+
17+
function _createRecord(rec, next)
18+
newrec = createRecord()
19+
newrec:set(0, "hey")
20+
newrec:set(1, "ho")
21+
newrec:set(2, "let's go!")
22+
next(newrec)
23+
next(rec)
24+
end
25+
26+
function _validateRecord(rec, next)
27+
ok, idx = validateRecord(rec)
28+
if ok == false and idx == 0 then
29+
rec:set(0, "good")
30+
else
31+
rec:set(0, "bad")
32+
end
33+
next(rec)
34+
end

0 commit comments

Comments
 (0)