Skip to content

Commit de6f577

Browse files
committed
support audit log plugin v2
Signed-off-by: Yang Keao <yangkeao@chunibyo.icu>
1 parent 14d5243 commit de6f577

File tree

13 files changed

+880
-44
lines changed

13 files changed

+880
-44
lines changed

cmd/replayer/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func main() {
135135
Speed: *speed,
136136
Username: *username,
137137
Password: *password,
138-
Format: *format,
138+
Format: replaycmd.TrafficFormat(*format),
139139
ReadOnly: *readonly,
140140
StartTime: *startTime,
141141
CommandStartTime: *cmdStartTime,

pkg/server/api/traffic.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (h *Server) TrafficReplay(c *gin.Context) {
100100
}
101101
cfg.Username = c.PostForm("username")
102102
cfg.Password = c.PostForm("password")
103-
cfg.Format = c.PostForm("format")
103+
cfg.Format = cmd.TrafficFormat(c.PostForm("format"))
104104
cfg.ReadOnly = strings.EqualFold(c.PostForm("readonly"), "true")
105105
cfg.IgnoreErrs = strings.EqualFold(c.PostForm("ignore-errs"), "true")
106106
cfg.KeyFile = globalCfg.Security.EncryptionKeyPath
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Copyright 2026 PingCAP, Inc.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package cmd
5+
6+
import (
7+
"fmt"
8+
"strconv"
9+
"strings"
10+
"time"
11+
12+
"github.com/pingcap/tiproxy/lib/util/errors"
13+
pnet "github.com/pingcap/tiproxy/pkg/proxy/net"
14+
"github.com/siddontang/go/hack"
15+
"go.uber.org/zap"
16+
)
17+
18+
var _ AuditLogDecoder = (*AuditLogExtensionDecoder)(nil)
19+
20+
type AuditLogExtensionDecoder struct {
21+
connInfo map[uint64]auditLogPluginConnCtx
22+
commandEndTime time.Time
23+
// pendingCmds contains the commands that has not been returned yet.
24+
pendingCmds []*Command
25+
psCloseStrategy PSCloseStrategy
26+
idAllocator *ConnIDAllocator
27+
lg *zap.Logger
28+
}
29+
30+
func NewAuditLogExtensionDecoder(lg *zap.Logger) AuditLogDecoder {
31+
return &AuditLogExtensionDecoder{
32+
connInfo: make(map[uint64]auditLogPluginConnCtx),
33+
psCloseStrategy: PSCloseStrategyDirected,
34+
lg: lg,
35+
}
36+
}
37+
38+
// EnableFilterCommandWithRetry implements [AuditLogDecoder].
39+
func (decoder *AuditLogExtensionDecoder) EnableFilterCommandWithRetry() {
40+
// do nothing for extension decoder, it's not supported yet
41+
}
42+
43+
// SetCommandEndTime implements [AuditLogDecoder].
44+
func (decoder *AuditLogExtensionDecoder) SetCommandEndTime(t time.Time) {
45+
decoder.commandEndTime = t
46+
}
47+
48+
// SetIDAllocator implements [AuditLogDecoder].
49+
func (decoder *AuditLogExtensionDecoder) SetIDAllocator(alloc *ConnIDAllocator) {
50+
decoder.idAllocator = alloc
51+
}
52+
53+
// SetPSCloseStrategy implements [AuditLogDecoder].
54+
func (decoder *AuditLogExtensionDecoder) SetPSCloseStrategy(s PSCloseStrategy) {
55+
decoder.psCloseStrategy = s
56+
}
57+
58+
// SetCommandStartTime implements [AuditLogDecoder].
59+
func (decoder *AuditLogExtensionDecoder) SetCommandStartTime(t time.Time) {
60+
// do nothing for extension decoder
61+
}
62+
63+
func (decoder *AuditLogExtensionDecoder) Decode(reader LineReader) (retCmd *Command, err error) {
64+
defer func() {
65+
if retCmd != nil {
66+
fmt.Println("Decoded command:", retCmd.ConnID, retCmd.Line, retCmd.StartTs, retCmd.EndTs, "error:", err)
67+
}
68+
}()
69+
if len(decoder.pendingCmds) > 0 {
70+
cmd := decoder.pendingCmds[0]
71+
decoder.pendingCmds = decoder.pendingCmds[1:]
72+
return cmd, nil
73+
}
74+
75+
kvs := make(map[string]string, 25)
76+
for {
77+
line, filename, lineIdx, err := reader.ReadLine()
78+
if err != nil {
79+
return nil, err
80+
}
81+
clear(kvs)
82+
err = parseLog(kvs, hack.String(line))
83+
if err != nil {
84+
return nil, errors.Errorf("%s, line %d: %s", filename, lineIdx, err.Error())
85+
}
86+
connStr := kvs[auditPluginKeyConnID]
87+
if len(connStr) == 0 {
88+
return nil, errors.Errorf("%s, line %d: no connection id in line: %s", filename, lineIdx, line)
89+
}
90+
upstreamConnID, err := strconv.ParseUint(connStr, 10, 64)
91+
if err != nil {
92+
return nil, errors.Errorf("%s, line %d: parsing connection id failed: %s", filename, lineIdx, connStr)
93+
}
94+
95+
// TODO: add both startTs and endTs in extension log. We only have the endTS is the current format.
96+
endTs, err := time.Parse(timeLayout, kvs[auditPluginKeyLogTime])
97+
if endTs.Before(decoder.commandEndTime) {
98+
// Ignore the commands before CommandEndTime.
99+
continue
100+
}
101+
102+
var connID uint64
103+
if connCtx, ok := decoder.connInfo[upstreamConnID]; ok {
104+
connID = connCtx.connID
105+
} else {
106+
// New connection, allocate a new connection ID.
107+
if decoder.idAllocator == nil {
108+
connID = upstreamConnID
109+
} else {
110+
connID = decoder.idAllocator.alloc()
111+
}
112+
connCtx.connID = connID
113+
decoder.connInfo[upstreamConnID] = connCtx
114+
}
115+
116+
eventStr := kvs[auditPluginKeyEvent]
117+
if len(eventStr) <= 4 {
118+
return nil, errors.Errorf("%s, line %d: invalid event field: %s", filename, lineIdx, eventStr)
119+
}
120+
// Remove the surrounding quotes and brackets.
121+
eventStr = eventStr[2 : len(eventStr)-2]
122+
events := strings.Split(eventStr, ",")
123+
var cmds []*Command
124+
switch events[0] {
125+
case "CONNECTION":
126+
if len(events) > 1 && events[1] == "DISCONNECT" {
127+
delete(decoder.connInfo, upstreamConnID)
128+
cmds = []*Command{{
129+
Type: pnet.ComQuit,
130+
Payload: []byte{pnet.ComQuit.Byte()},
131+
}}
132+
}
133+
case "QUERY":
134+
cmds, err = decoder.parseQueryEvent(kvs, events, upstreamConnID)
135+
}
136+
if err != nil {
137+
return nil, errors.Wrapf(err, "%s, line %d", filename, lineIdx)
138+
}
139+
// The log is ignored, skip.
140+
if len(cmds) == 0 {
141+
continue
142+
}
143+
144+
db := kvs[auditPluginKeyCurDB]
145+
for _, cmd := range cmds {
146+
cmd.Success = true
147+
cmd.UpstreamConnID = upstreamConnID
148+
cmd.ConnID = connID
149+
// We don't have an accurate startTs in extension log.
150+
cmd.StartTs = endTs
151+
cmd.CurDB = db
152+
cmd.FileName = filename
153+
cmd.Line = lineIdx
154+
cmd.EndTs = endTs
155+
cmd.kvs = kvs
156+
}
157+
if len(cmds) > 1 {
158+
decoder.pendingCmds = cmds[1:]
159+
}
160+
return cmds[0], nil
161+
}
162+
}
163+
164+
func (decoder *AuditLogExtensionDecoder) parseQueryEvent(kvs map[string]string, events []string, connID uint64) ([]*Command, error) {
165+
connInfo := decoder.connInfo[connID]
166+
if connInfo.preparedStmt == nil {
167+
connInfo.preparedStmt = make(map[uint32]struct{})
168+
connInfo.preparedStmtSql = make(map[string]uint32)
169+
}
170+
171+
var sql string
172+
sqlStr := kvs[auditPluginKeySQL]
173+
if len(sqlStr) > 0 {
174+
var err error
175+
sql, err = parseSQL(sqlStr)
176+
if err != nil {
177+
return nil, errors.Wrapf(err, "unquote sql failed: %s", sqlStr)
178+
}
179+
}
180+
cmds := make([]*Command, 0, 3)
181+
// Only handle two events:
182+
// - QUERY,EXECUTE
183+
// - QUERY
184+
if events[0] == "QUERY" && len(events) > 1 && events[1] == "EXECUTE" {
185+
params, ok := kvs[auditPluginKeyParams]
186+
if !ok {
187+
return nil, nil
188+
}
189+
args, err := parseExecuteParamsForExtension(params)
190+
if err != nil {
191+
return nil, err
192+
}
193+
194+
var stmtID uint32
195+
var shouldPrepare bool
196+
197+
switch decoder.psCloseStrategy {
198+
case PSCloseStrategyAlways:
199+
connInfo.lastPsID++
200+
decoder.connInfo[connID] = connInfo
201+
stmtID = connInfo.lastPsID
202+
shouldPrepare = true
203+
case PSCloseStrategyNever:
204+
if id, ok := connInfo.preparedStmtSql[sql]; ok {
205+
shouldPrepare = false
206+
stmtID = id
207+
} else {
208+
connInfo.lastPsID++
209+
connInfo.preparedStmtSql[sql] = connInfo.lastPsID
210+
decoder.connInfo[connID] = connInfo
211+
stmtID = connInfo.lastPsID
212+
shouldPrepare = true
213+
}
214+
}
215+
216+
// Append PREPARE command if needed.
217+
if shouldPrepare {
218+
cmds = append(cmds, &Command{
219+
CapturedPsID: stmtID,
220+
Type: pnet.ComStmtPrepare,
221+
StmtType: kvs[auditPluginKeyStmtType],
222+
PreparedStmt: sql,
223+
Payload: append([]byte{pnet.ComStmtPrepare.Byte()}, hack.Slice(sql)...),
224+
})
225+
}
226+
227+
// Append EXECUTE command
228+
executeReq, err := pnet.MakeExecuteStmtRequest(stmtID, args, true)
229+
if err != nil {
230+
return nil, errors.Wrapf(err, "make execute request failed")
231+
}
232+
cmds = append(cmds, &Command{
233+
CapturedPsID: stmtID,
234+
Type: pnet.ComStmtExecute,
235+
StmtType: kvs[auditPluginKeyStmtType],
236+
PreparedStmt: sql,
237+
Params: args,
238+
Payload: executeReq,
239+
})
240+
connInfo.lastCmd = cmds[len(cmds)-1]
241+
242+
// Append CLOSE command if needed.
243+
if decoder.psCloseStrategy == PSCloseStrategyAlways {
244+
// close the prepared statement right after it's executed.
245+
cmds = append(cmds, &Command{
246+
CapturedPsID: stmtID,
247+
Type: pnet.ComStmtClose,
248+
StmtType: kvs[auditPluginKeyStmtType],
249+
PreparedStmt: sql,
250+
Payload: pnet.MakeCloseStmtRequest(stmtID),
251+
})
252+
}
253+
} else if events[0] == "QUERY" {
254+
cmds = append(cmds, &Command{
255+
Type: pnet.ComQuery,
256+
StmtType: kvs[auditPluginKeyStmtType],
257+
Payload: append([]byte{pnet.ComQuery.Byte()}, hack.Slice(sql)...),
258+
})
259+
connInfo.lastCmd = cmds[0]
260+
}
261+
262+
decoder.connInfo[connID] = connInfo
263+
return cmds, nil
264+
}
265+
266+
// parseExecuteParamsForExtension parses the param in audit log extension field like "[1,abc,NULL,\"test bytes\""]"
267+
// This function has the following known limitations:
268+
// - All params are returned as string type. It cannot distinguish int 1 and string "1".
269+
// - It cannot distinguish single empty string and no param.
270+
func parseExecuteParamsForExtension(value string) ([]any, error) {
271+
v, err := strconv.Unquote(value)
272+
if err != nil {
273+
return nil, errors.Wrapf(err, "unquote execute params failed: %s", value)
274+
}
275+
if v[0] != '[' || v[len(v)-1] != ']' {
276+
return nil, errors.Errorf("no brackets in params: %s", value)
277+
}
278+
v = v[1 : len(v)-1]
279+
if len(v) == 0 {
280+
return nil, nil
281+
}
282+
283+
params := make([]any, 0, 10)
284+
for idx := 0; idx < len(v); idx++ {
285+
switch v[idx] {
286+
case '"':
287+
endIdx := skipQuotes(v[idx+1:], false)
288+
if endIdx == -1 {
289+
return nil, errors.Errorf("unterminated quote in params: %s", v[idx+1:])
290+
}
291+
292+
unquoted, err := strconv.Unquote(v[idx : idx+endIdx+2])
293+
if err != nil {
294+
return nil, errors.Wrapf(err, "unquote param failed: %s", v[idx:idx+endIdx+2])
295+
}
296+
params = append(params, unquoted)
297+
idx += endIdx + 1
298+
case ',', ' ':
299+
default:
300+
endIdx := strings.Index(v[idx:], ",")
301+
if endIdx == -1 {
302+
endIdx = len(v) - idx
303+
}
304+
params = append(params, v[idx:idx+endIdx])
305+
idx += endIdx - 1
306+
}
307+
}
308+
309+
return params, nil
310+
}

0 commit comments

Comments
 (0)