Skip to content

Commit eb4d673

Browse files
authored
feat: remove timeout metrics and rework timer usages (#122)
This PR reworks go-libddwaf context and run API to better work with timers. It also removes timeout summing that is now done on dd-trace-go side. - [x] Change `NewContext()` so it takes a list of timer options instead of simply the time budget for this context. - [x] Add a `TimerKey` to the parameters passed to `Context.Run()` to select the component where this run duration will be summed in (ex: rasp or waf) - [x] Create 3 public constants `EncodeTimeKey`, `DecodeTimeKey` and `DurationTimeKey` that all represent a part of a single WAF run that are supposed to be used as key to the map `Result.TimerStats` that was added as return value to `Context.Run()`. - [x] Rework tests to they work with the changes - [x] Adapt the README.md file to surface the main changes to the API - [x] Remove metrics.go --------- Signed-off-by: Eliott Bouhana <[email protected]>
1 parent 21ca39e commit eb4d673

File tree

9 files changed

+199
-349
lines changed

9 files changed

+199
-349
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,14 @@ func main() {
3333
}
3434
defer wafHandle.Close()
3535

36-
wafCtx := wafHandle.NewContext()
36+
wafCtx := wafHandle.NewContext(timer.WithUnlimitedBudget(), timer.WithComponent("waf", "rasp"))
3737
defer wafCtx.Close()
3838

3939
matches, actions := wafCtx.Run(RunAddressData{
4040
Persistent: map[string]any{
4141
"server.request.path_params": "/rfiinc.txt",
4242
},
43+
TimerKey: "waf",
4344
})
4445
}
4546
```

builder_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestBuilder(t *testing.T) {
7171
require.NotNil(t, handle)
7272
defer handle.Close()
7373

74-
ctx, err := handle.NewContext(timer.UnlimitedBudget)
74+
ctx, err := handle.NewContext(timer.WithBudget(timer.UnlimitedBudget))
7575
require.NoError(t, err)
7676
require.NotNil(t, ctx)
7777
defer ctx.Close()
@@ -195,7 +195,7 @@ func TestBuilder(t *testing.T) {
195195
waf := builder.Build()
196196
require.NotNil(t, waf)
197197
defer waf.Close()
198-
ctx, err := waf.NewContext(timer.UnlimitedBudget)
198+
ctx, err := waf.NewContext(timer.WithBudget(timer.UnlimitedBudget))
199199
require.NoError(t, err)
200200
require.NotNil(t, ctx)
201201
defer ctx.Close()
@@ -236,7 +236,7 @@ func TestBuilder(t *testing.T) {
236236
waf = builder.Build()
237237
require.NotNil(t, waf)
238238
defer waf.Close()
239-
ctx, err = waf.NewContext(timer.UnlimitedBudget)
239+
ctx, err = waf.NewContext(timer.WithBudget(timer.UnlimitedBudget))
240240
require.NoError(t, err)
241241
require.NotNil(t, ctx)
242242
defer ctx.Close()
@@ -369,7 +369,7 @@ func TestBuilder(t *testing.T) {
369369
require.NotNil(t, waf)
370370
defer waf.Close()
371371

372-
ctx, err := waf.NewContext(time.Hour)
372+
ctx, err := waf.NewContext(timer.WithBudget(time.Hour))
373373
require.NoError(t, err)
374374
defer ctx.Close()
375375

context.go

Lines changed: 79 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
package libddwaf
77

88
import (
9+
"maps"
910
"runtime"
1011
"sync"
11-
"sync/atomic"
1212
"time"
1313

1414
"github.com/DataDog/go-libddwaf/v4/internal/bindings"
@@ -22,25 +22,19 @@ import (
2222
// its own [Context]. New [Context] instances can be created by calling
2323
// [Handle.NewContext].
2424
type Context struct {
25+
// Timer registers the time spent in the WAF and go-libddwaf. It is created alongside the Context using the options
26+
// passed in to NewContext. Once its time budget is exhausted, each new call to Context.Run will return a timeout error.
27+
Timer timer.NodeTimer
28+
2529
handle *Handle // Instance of the WAF
2630

2731
cContext bindings.WAFContext // The C ddwaf_context pointer
2832

29-
// timeoutCount count all calls which have timeout'ed by scope. Keys are fixed at creation time.
30-
timeoutCount map[Scope]*atomic.Uint64
31-
32-
// mutex protecting the use of cContext which is not thread-safe and cgoRefs.
33+
// mutex protecting the use of cContext which is not thread-safe and truncations
3334
mutex sync.Mutex
3435

35-
// timer registers the time spent in the WAF and go-libddwaf
36-
timer timer.NodeTimer
37-
38-
// metrics stores the cumulative time spent in various parts of the WAF
39-
metrics metricsStore
40-
41-
// truncations provides details about truncations that occurred while encoding address data for
42-
// WAF execution.
43-
truncations map[Scope]map[TruncationReason][]int
36+
// truncations provides details about truncations that occurred while encoding address data for the WAF execution.
37+
truncations map[TruncationReason][]int
4438

4539
// pinner is used to retain Go data that is being passed to the WAF as part of
4640
// [RunAddressData.Persistent] until the [Context.Close] method results in the context being
@@ -49,109 +43,111 @@ type Context struct {
4943
}
5044

5145
// RunAddressData provides address data to the [Context.Run] method. If a given key is present in
52-
// both [RunAddressData.Persistent] and [RunAddressData.Ephemeral], the value from
53-
// [RunAddressData.Persistent] will take precedence.
46+
// both `Persistent` and `Ephemeral`, the value from `Persistent` will take precedence.
5447
// When encoding Go structs to the WAF-compatible format, fields with the `ddwaf:"ignore"` tag are
5548
// ignored and will not be visible to the WAF.
5649
type RunAddressData struct {
5750
// Persistent address data is scoped to the lifetime of a given Context, and subsquent calls to
58-
// [Context.Run] with the same address name will be silently ignored.
51+
// Context.Run with the same address name will be silently ignored.
5952
Persistent map[string]any
60-
// Ephemeral address data is scoped to a given [Context.Run] call and is not persisted across
53+
// Ephemeral address data is scoped to a given Context.Run call and is not persisted across
6154
// calls. This is used for protocols such as gRPC client/server streaming or GraphQL, where a
6255
// single request can incur multiple subrequests.
6356
Ephemeral map[string]any
64-
// Scope is the way to classify the different runs in the same context in order to have different
65-
// metrics.
66-
Scope Scope
57+
58+
// TimerKey is the key used to track the time spent in the WAF for this run.
59+
// If left empty, a new timer with unlimited budget is started.
60+
TimerKey timer.Key
6761
}
6862

6963
func (d RunAddressData) isEmpty() bool {
7064
return len(d.Persistent) == 0 && len(d.Ephemeral) == 0
7165
}
7266

67+
// newTimer creates a new timer for this run. If the TimerKey is empty, a new timer without taking the parent into account is created.
68+
func (d RunAddressData) newTimer(parent timer.NodeTimer) (timer.NodeTimer, error) {
69+
if d.TimerKey == "" {
70+
return timer.NewTreeTimer(
71+
timer.WithComponents(
72+
EncodeTimeKey,
73+
DurationTimeKey,
74+
DecodeTimeKey,
75+
),
76+
timer.WithBudget(parent.SumRemaining()),
77+
)
78+
}
79+
80+
return parent.NewNode(d.TimerKey,
81+
timer.WithComponents(
82+
EncodeTimeKey,
83+
DurationTimeKey,
84+
DecodeTimeKey,
85+
),
86+
timer.WithInheritedSumBudget(),
87+
)
88+
}
89+
7390
// Run encodes the given [RunAddressData] values and runs them against the WAF rules.
7491
// Callers must check the returned [Result] object even when an error is returned, as the WAF might
7592
// have been able to match some rules and generate events or actions before the error was reached;
7693
// especially when the error is [waferrors.ErrTimeout].
7794
func (context *Context) Run(addressData RunAddressData) (res Result, err error) {
7895
if addressData.isEmpty() {
79-
return
80-
}
81-
82-
if addressData.Scope == "" {
83-
addressData.Scope = DefaultScope
96+
return Result{}, nil
8497
}
8598

86-
defer func() {
87-
if err == waferrors.ErrTimeout {
88-
context.timeoutCount[addressData.Scope].Add(1)
89-
}
90-
}()
91-
9299
// If the context has already timed out, we don't need to run the WAF again
93-
if context.timer.SumExhausted() {
100+
if context.Timer.SumExhausted() {
94101
return Result{}, waferrors.ErrTimeout
95102
}
96103

97-
runTimer, err := context.timer.NewNode(wafRunTag,
98-
timer.WithComponents(
99-
wafEncodeTag,
100-
wafDecodeTag,
101-
wafDurationTag,
102-
),
103-
)
104+
runTimer, err := addressData.newTimer(context.Timer)
104105
if err != nil {
105106
return Result{}, err
106107
}
107108

108-
runTimer.Start()
109109
defer func() {
110-
context.metrics.add(addressData.Scope, wafRunTag, runTimer.Stop())
111-
context.metrics.merge(addressData.Scope, runTimer.Stats())
110+
res.TimerStats = runTimer.Stats()
112111
}()
113112

114-
wafEncodeTimer := runTimer.MustLeaf(wafEncodeTag)
113+
runTimer.Start()
114+
defer runTimer.Stop()
115+
116+
wafEncodeTimer := runTimer.MustLeaf(EncodeTimeKey)
115117
wafEncodeTimer.Start()
116-
persistentData, err := context.encodeOneAddressType(&context.pinner, addressData.Scope, addressData.Persistent, wafEncodeTimer)
118+
defer wafEncodeTimer.Stop()
119+
120+
persistentData, err := context.encodeOneAddressType(&context.pinner, addressData.Persistent, wafEncodeTimer)
117121
if err != nil {
118-
wafEncodeTimer.Stop()
119-
return res, err
122+
return Result{}, err
120123
}
121124

122125
// The WAF releases ephemeral address data at the max of each run call, so we need not keep the Go
123126
// values live beyond that in the same way we need for persistent data. We hence use a separate
124127
// encoder.
125128
var ephemeralPinner runtime.Pinner
126129
defer ephemeralPinner.Unpin()
127-
ephemeralData, err := context.encodeOneAddressType(&ephemeralPinner, addressData.Scope, addressData.Ephemeral, wafEncodeTimer)
130+
ephemeralData, err := context.encodeOneAddressType(&ephemeralPinner, addressData.Ephemeral, wafEncodeTimer)
128131
if err != nil {
129-
wafEncodeTimer.Stop()
130-
return res, err
132+
return Result{}, err
131133
}
132134

133135
wafEncodeTimer.Stop()
134136

135-
// ddwaf_run cannot run concurrently and we are going to mutate the context.cgoRefs, so we need to
136-
// lock the context
137+
// ddwaf_run cannot run concurrently, so we need to lock the context
137138
context.mutex.Lock()
138139
defer context.mutex.Unlock()
139140

140141
if context.cContext == 0 {
141142
// Context has been closed, returning an empty result...
142-
return res, waferrors.ErrContextClosed
143+
return Result{}, waferrors.ErrContextClosed
143144
}
144145

145146
if runTimer.SumExhausted() {
146-
return res, waferrors.ErrTimeout
147+
return Result{}, waferrors.ErrTimeout
147148
}
148149

149-
wafDecodeTimer := runTimer.MustLeaf(wafDecodeTag)
150-
res, err = context.run(persistentData, ephemeralData, wafDecodeTimer, runTimer.SumRemaining())
151-
152-
runTimer.AddTime(wafDurationTag, res.TimeSpent)
153-
154-
return
150+
return context.run(persistentData, ephemeralData, runTimer)
155151
}
156152

157153
// merge merges two maps of slices into a single map of slices. The resulting map will contain all
@@ -195,7 +191,7 @@ func merge[K comparable, V any](a, b map[K][]V) (merged map[K][]V) {
195191
// top level object is a nil map, but this behaviour is expected since either persistent or
196192
// ephemeral addresses are allowed to be null one at a time. In this case, Encode will return nil,
197193
// which is what we need to send to ddwaf_run to signal that the address data is empty.
198-
func (context *Context) encodeOneAddressType(pinner pin.Pinner, scope Scope, addressData map[string]any, timer timer.Timer) (*bindings.WAFObject, error) {
194+
func (context *Context) encodeOneAddressType(pinner pin.Pinner, addressData map[string]any, timer timer.Timer) (*bindings.WAFObject, error) {
199195
encoder := newLimitedEncoder(pinner, timer)
200196
if addressData == nil {
201197
return nil, nil
@@ -206,7 +202,7 @@ func (context *Context) encodeOneAddressType(pinner pin.Pinner, scope Scope, add
206202
context.mutex.Lock()
207203
defer context.mutex.Unlock()
208204

209-
context.truncations[scope] = merge(context.truncations[scope], encoder.truncations)
205+
context.truncations = merge(context.truncations, encoder.truncations)
210206
}
211207

212208
if timer.Exhausted() {
@@ -218,22 +214,25 @@ func (context *Context) encodeOneAddressType(pinner pin.Pinner, scope Scope, add
218214

219215
// run executes the ddwaf_run call with the provided data on this context. The caller is responsible for locking the
220216
// context appropriately around this call.
221-
func (context *Context) run(persistentData, ephemeralData *bindings.WAFObject, wafDecodeTimer timer.Timer, timeBudget time.Duration) (Result, error) {
217+
func (context *Context) run(persistentData, ephemeralData *bindings.WAFObject, runTimer timer.NodeTimer) (Result, error) {
222218
result := new(bindings.WAFResult)
223219
defer wafLib.ResultFree(result)
224220

225221
// The value of the timeout cannot exceed 2^55
226222
// cf. https://en.cppreference.com/w/cpp/chrono/duration
227-
timeout := uint64(timeBudget.Microseconds()) & 0x008FFFFFFFFFFFFF
223+
timeout := uint64(runTimer.SumRemaining().Microseconds()) & 0x008FFFFFFFFFFFFF
228224
ret := wafLib.Run(context.cContext, persistentData, ephemeralData, result, timeout)
229225

230-
wafDecodeTimer.Start()
231-
defer wafDecodeTimer.Stop()
226+
decodeTimer := runTimer.MustLeaf(DecodeTimeKey)
227+
decodeTimer.Start()
228+
defer decodeTimer.Stop()
232229

233-
return unwrapWafResult(ret, result)
230+
res, duration, err := unwrapWafResult(ret, result)
231+
runTimer.AddTime(DurationTimeKey, duration)
232+
return res, err
234233
}
235234

236-
func unwrapWafResult(ret bindings.WAFReturnCode, result *bindings.WAFResult) (res Result, err error) {
235+
func unwrapWafResult(ret bindings.WAFReturnCode, result *bindings.WAFResult) (res Result, duration time.Duration, err error) {
237236
if result.Timeout > 0 {
238237
err = waferrors.ErrTimeout
239238
} else {
@@ -242,28 +241,28 @@ func unwrapWafResult(ret bindings.WAFReturnCode, result *bindings.WAFResult) (re
242241
res.Derivatives, err = decodeMap(&result.Derivatives)
243242
}
244243

245-
res.TimeSpent = time.Duration(result.TotalRuntime) * time.Nanosecond
244+
duration = time.Duration(result.TotalRuntime) * time.Nanosecond
246245

247246
if ret == bindings.WAFOK {
248-
return res, err
247+
return res, duration, err
249248
}
250249

251250
if ret != bindings.WAFMatch {
252-
return res, goRunError(ret)
251+
return res, duration, goRunError(ret)
253252
}
254253

255254
res.Events, err = decodeArray(&result.Events)
256255
if err != nil {
257-
return res, err
256+
return res, duration, err
258257
}
259258
if size := result.Actions.NbEntries; size > 0 {
260259
res.Actions, err = decodeMap(&result.Actions)
261260
if err != nil {
262-
return res, err
261+
return res, duration, err
263262
}
264263
}
265264

266-
return res, err
265+
return res, duration, err
267266
}
268267

269268
// Close disposes of the underlying `ddwaf_context` and releases the associated
@@ -281,43 +280,13 @@ func (context *Context) Close() {
281280
context.pinner.Unpin() // The pinned data is no longer needed, explicitly release
282281
}
283282

284-
// Stats returns the cumulative time spent in various parts of the WAF, at
285-
// nanosecond resolution, as well as the timeout value used, and other
286-
// information about this [Context]'s usage.
287-
func (context *Context) Stats() Stats {
283+
// Truncations returns the truncations that occurred while encoding address data for WAF execution.
284+
// The key is the truncation reason: either because the object was too deep, the arrays where to large or the strings were too long.
285+
// The value is a slice of integers, each integer being the original size of the object that was truncated.
286+
// In case of the [ObjectTooDeep] reason, the original size can only be approximated because of recursive objects.
287+
func (context *Context) Truncations() map[TruncationReason][]int {
288288
context.mutex.Lock()
289289
defer context.mutex.Unlock()
290290

291-
truncations := make(map[TruncationReason][]int, len(context.truncations[DefaultScope]))
292-
for reason, counts := range context.truncations[DefaultScope] {
293-
truncations[reason] = make([]int, len(counts))
294-
copy(truncations[reason], counts)
295-
}
296-
297-
raspTruncations := make(map[TruncationReason][]int, len(context.truncations[RASPScope]))
298-
for reason, counts := range context.truncations[RASPScope] {
299-
raspTruncations[reason] = make([]int, len(counts))
300-
copy(raspTruncations[reason], counts)
301-
}
302-
303-
var (
304-
timeoutDefault uint64
305-
timeoutRASP uint64
306-
)
307-
308-
if atomic, ok := context.timeoutCount[DefaultScope]; ok {
309-
timeoutDefault = atomic.Load()
310-
}
311-
312-
if atomic, ok := context.timeoutCount[RASPScope]; ok {
313-
timeoutRASP = atomic.Load()
314-
}
315-
316-
return Stats{
317-
Timers: context.metrics.timers(),
318-
TimeoutCount: timeoutDefault,
319-
TimeoutRASPCount: timeoutRASP,
320-
Truncations: truncations,
321-
TruncationsRASP: raspTruncations,
322-
}
291+
return maps.Clone(context.truncations)
323292
}

0 commit comments

Comments
 (0)