Skip to content

Commit 919f806

Browse files
lioraronclaude
andcommitted
perf: cache Prometheus queries in flow-control dispatch gates
Add CachedMetricSource, a TTL-based wrapper around MetricSource that avoids hitting Prometheus on every Budget() call. The default TTL is 5 seconds, configurable via NewGateFactoryWithCacheTTL. GateFactory now automatically wraps Prometheus-backed sources with the cache. Existing callers of NewGateFactory get the default TTL with no code changes. Closes #101 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f9a548c commit 919f806

3 files changed

Lines changed: 198 additions & 1 deletion

File tree

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
Copyright 2026 The llm-d Authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package flowcontrol
18+
19+
import (
20+
"context"
21+
"sync"
22+
"time"
23+
)
24+
25+
// CachedMetricSource wraps a MetricSource with a TTL cache so that repeated
26+
// queries within the TTL return the cached result instead of hitting the
27+
// backend (e.g. Prometheus) on every call.
28+
type CachedMetricSource struct {
29+
source MetricSource
30+
ttl time.Duration
31+
32+
mu sync.Mutex
33+
samples []Sample
34+
err error
35+
expiry time.Time
36+
}
37+
38+
// NewCachedMetricSource wraps the given source with a cache that holds results
39+
// for the specified TTL duration.
40+
func NewCachedMetricSource(source MetricSource, ttl time.Duration) *CachedMetricSource {
41+
return &CachedMetricSource{
42+
source: source,
43+
ttl: ttl,
44+
}
45+
}
46+
47+
// Query returns cached samples if the cache is still valid, otherwise
48+
// delegates to the underlying source and caches the result.
49+
func (c *CachedMetricSource) Query(ctx context.Context) ([]Sample, error) {
50+
c.mu.Lock()
51+
defer c.mu.Unlock()
52+
53+
if time.Now().Before(c.expiry) {
54+
return c.samples, c.err
55+
}
56+
57+
c.samples, c.err = c.source.Query(ctx)
58+
c.expiry = time.Now().Add(c.ttl)
59+
return c.samples, c.err
60+
}
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
Copyright 2026 The llm-d Authors
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package flowcontrol
18+
19+
import (
20+
"context"
21+
"errors"
22+
"sync/atomic"
23+
"testing"
24+
"time"
25+
26+
"github.com/stretchr/testify/require"
27+
)
28+
29+
// countingMetricSource counts how many times Query is called and returns
30+
// configurable samples/errors.
31+
type countingMetricSource struct {
32+
calls atomic.Int32
33+
samples []Sample
34+
err error
35+
}
36+
37+
func (c *countingMetricSource) Query(_ context.Context) ([]Sample, error) {
38+
c.calls.Add(1)
39+
return c.samples, c.err
40+
}
41+
42+
func TestCachedMetricSource_CachesWithinTTL(t *testing.T) {
43+
inner := &countingMetricSource{
44+
samples: []Sample{{Value: 42.0}},
45+
}
46+
cached := NewCachedMetricSource(inner, 1*time.Hour)
47+
ctx := context.Background()
48+
49+
// First call should hit the source.
50+
s1, err := cached.Query(ctx)
51+
require.NoError(t, err)
52+
require.Len(t, s1, 1)
53+
require.Equal(t, 42.0, s1[0].Value)
54+
require.Equal(t, int32(1), inner.calls.Load())
55+
56+
// Second call within TTL should return cached result.
57+
s2, err := cached.Query(ctx)
58+
require.NoError(t, err)
59+
require.Equal(t, s1, s2)
60+
require.Equal(t, int32(1), inner.calls.Load())
61+
}
62+
63+
func TestCachedMetricSource_RefreshesAfterTTL(t *testing.T) {
64+
inner := &countingMetricSource{
65+
samples: []Sample{{Value: 1.0}},
66+
}
67+
cached := NewCachedMetricSource(inner, 10*time.Millisecond)
68+
ctx := context.Background()
69+
70+
_, err := cached.Query(ctx)
71+
require.NoError(t, err)
72+
require.Equal(t, int32(1), inner.calls.Load())
73+
74+
// Wait for TTL to expire.
75+
time.Sleep(20 * time.Millisecond)
76+
77+
_, err = cached.Query(ctx)
78+
require.NoError(t, err)
79+
require.Equal(t, int32(2), inner.calls.Load())
80+
}
81+
82+
func TestCachedMetricSource_CachesErrors(t *testing.T) {
83+
expectedErr := errors.New("prometheus down")
84+
inner := &countingMetricSource{
85+
err: expectedErr,
86+
}
87+
cached := NewCachedMetricSource(inner, 1*time.Hour)
88+
ctx := context.Background()
89+
90+
_, err := cached.Query(ctx)
91+
require.ErrorIs(t, err, expectedErr)
92+
require.Equal(t, int32(1), inner.calls.Load())
93+
94+
// Error should also be cached.
95+
_, err = cached.Query(ctx)
96+
require.ErrorIs(t, err, expectedErr)
97+
require.Equal(t, int32(1), inner.calls.Load())
98+
}
99+
100+
func TestCachedMetricSource_WorksWithGates(t *testing.T) {
101+
inner := &countingMetricSource{
102+
samples: []Sample{{Value: 0.0}},
103+
}
104+
cached := NewCachedMetricSource(inner, 1*time.Hour)
105+
ctx := context.Background()
106+
107+
// CachedMetricSource implements MetricSource, so it should work
108+
// transparently with any gate.
109+
binaryGate := NewBinaryMetricDispatchGateWithSource(cached)
110+
require.Equal(t, 1.0, binaryGate.Budget(ctx))
111+
require.Equal(t, 1.0, binaryGate.Budget(ctx))
112+
require.Equal(t, int32(1), inner.calls.Load())
113+
114+
satGate := NewSaturationMetricDispatchGateWithSource(cached, 0.8, 0.0)
115+
require.Equal(t, 1.0, satGate.Budget(ctx))
116+
// Still only 1 call since the cache hasn't expired.
117+
require.Equal(t, int32(1), inner.calls.Load())
118+
}

pkg/async/inference/flowcontrol/gate_factory.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,38 @@ package flowcontrol
1919
import (
2020
"fmt"
2121
"strconv"
22+
"time"
2223

2324
asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
2425
redisgate "github.com/llm-d-incubation/llm-d-async/pkg/redis"
2526
promapi "github.com/prometheus/client_golang/api"
2627
goredis "github.com/redis/go-redis/v9"
2728
)
2829

30+
// DefaultCacheTTL is the default TTL for cached Prometheus metric sources.
31+
const DefaultCacheTTL = 5 * time.Second
32+
2933
// GateFactory creates DispatchGate instances based on configuration.
3034
type GateFactory struct {
3135
prometheusURL string
36+
cacheTTL time.Duration
3237
redisClients map[string]*goredis.Client
3338
}
3439

3540
// NewGateFactory creates a new GateFactory with an optional Prometheus URL.
3641
// If prometheusURL is empty, Prometheus gates will fail at creation time.
42+
// Prometheus metric sources are cached with DefaultCacheTTL; use
43+
// NewGateFactoryWithCacheTTL to override.
3744
func NewGateFactory(prometheusURL string) *GateFactory {
45+
return NewGateFactoryWithCacheTTL(prometheusURL, DefaultCacheTTL)
46+
}
47+
48+
// NewGateFactoryWithCacheTTL creates a GateFactory with a custom cache TTL
49+
// for Prometheus metric sources. A TTL of 0 disables caching.
50+
func NewGateFactoryWithCacheTTL(prometheusURL string, cacheTTL time.Duration) *GateFactory {
3851
return &GateFactory{
3952
prometheusURL: prometheusURL,
53+
cacheTTL: cacheTTL,
4054
redisClients: make(map[string]*goredis.Client),
4155
}
4256
}
@@ -109,7 +123,12 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
109123
return nil, fmt.Errorf("failed to create Prometheus metric source: %w", err)
110124
}
111125

112-
return NewSaturationMetricDispatchGateWithSource(source, threshold, fallback), nil
126+
var ms MetricSource = source
127+
if f.cacheTTL > 0 {
128+
ms = NewCachedMetricSource(source, f.cacheTTL)
129+
}
130+
131+
return NewSaturationMetricDispatchGateWithSource(ms, threshold, fallback), nil
113132

114133
default:
115134
// Unknown gate types default to open gate

0 commit comments

Comments
 (0)