Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions pkg/async/inference/flowcontrol/cached_metric_source.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
Copyright 2026 The llm-d Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package flowcontrol

import (
"context"
"sync"
"time"
)

// CachedMetricSource wraps a MetricSource with a TTL cache so that repeated
// queries within the TTL return the cached result instead of hitting the
// backend (e.g. Prometheus) on every call.
type CachedMetricSource struct {
source MetricSource
ttl time.Duration

mu sync.Mutex
samples []Sample
err error
expiry time.Time
}

// NewCachedMetricSource wraps the given source with a cache that holds results
// for the specified TTL duration.
func NewCachedMetricSource(source MetricSource, ttl time.Duration) *CachedMetricSource {
return &CachedMetricSource{
source: source,
ttl: ttl,
}
}

// Query returns cached samples if the cache is still valid, otherwise
// delegates to the underlying source and caches the result.
func (c *CachedMetricSource) Query(ctx context.Context) ([]Sample, error) {
c.mu.Lock()
defer c.mu.Unlock()

now := time.Now()
if now.Before(c.expiry) {
return c.samples, c.err
}

c.samples, c.err = c.source.Query(ctx)
c.expiry = now.Add(c.ttl)
return c.samples, c.err
}
118 changes: 118 additions & 0 deletions pkg/async/inference/flowcontrol/cached_metric_source_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
Copyright 2026 The llm-d Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package flowcontrol

import (
"context"
"errors"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
)

// countingMetricSource counts how many times Query is called and returns
// configurable samples/errors.
type countingMetricSource struct {
calls atomic.Int32
samples []Sample
err error
}

func (c *countingMetricSource) Query(_ context.Context) ([]Sample, error) {
c.calls.Add(1)
return c.samples, c.err
}

func TestCachedMetricSource_CachesWithinTTL(t *testing.T) {
inner := &countingMetricSource{
samples: []Sample{{Value: 42.0}},
}
cached := NewCachedMetricSource(inner, 1*time.Hour)
ctx := context.Background()

// First call should hit the source.
s1, err := cached.Query(ctx)
require.NoError(t, err)
require.Len(t, s1, 1)
require.Equal(t, 42.0, s1[0].Value)
require.Equal(t, int32(1), inner.calls.Load())

// Second call within TTL should return cached result.
s2, err := cached.Query(ctx)
require.NoError(t, err)
require.Equal(t, s1, s2)
require.Equal(t, int32(1), inner.calls.Load())
}

func TestCachedMetricSource_RefreshesAfterTTL(t *testing.T) {
inner := &countingMetricSource{
samples: []Sample{{Value: 1.0}},
}
cached := NewCachedMetricSource(inner, 10*time.Millisecond)
ctx := context.Background()

_, err := cached.Query(ctx)
require.NoError(t, err)
require.Equal(t, int32(1), inner.calls.Load())

// Wait for TTL to expire.
time.Sleep(20 * time.Millisecond)

_, err = cached.Query(ctx)
require.NoError(t, err)
require.Equal(t, int32(2), inner.calls.Load())
}

func TestCachedMetricSource_CachesErrors(t *testing.T) {
expectedErr := errors.New("prometheus down")
inner := &countingMetricSource{
err: expectedErr,
}
cached := NewCachedMetricSource(inner, 1*time.Hour)
ctx := context.Background()

_, err := cached.Query(ctx)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(1), inner.calls.Load())

// Error should also be cached.
_, err = cached.Query(ctx)
require.ErrorIs(t, err, expectedErr)
require.Equal(t, int32(1), inner.calls.Load())
}

func TestCachedMetricSource_WorksWithGates(t *testing.T) {
inner := &countingMetricSource{
samples: []Sample{{Value: 0.0}},
}
cached := NewCachedMetricSource(inner, 1*time.Hour)
ctx := context.Background()

// CachedMetricSource implements MetricSource, so it should work
// transparently with any gate.
binaryGate := NewBinaryMetricDispatchGateWithSource(cached)
require.Equal(t, 1.0, binaryGate.Budget(ctx))
require.Equal(t, 1.0, binaryGate.Budget(ctx))
require.Equal(t, int32(1), inner.calls.Load())

satGate := NewSaturationMetricDispatchGateWithSource(cached, 0.8, 0.0)
require.Equal(t, 1.0, satGate.Budget(ctx))
// Still only 1 call since the cache hasn't expired.
require.Equal(t, int32(1), inner.calls.Load())
}
21 changes: 20 additions & 1 deletion pkg/async/inference/flowcontrol/gate_factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,26 +19,40 @@ package flowcontrol
import (
"fmt"
"strconv"
"time"

asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
redisgate "github.com/llm-d-incubation/llm-d-async/pkg/redis"
promapi "github.com/prometheus/client_golang/api"
goredis "github.com/redis/go-redis/v9"
)

// DefaultCacheTTL is the default TTL for cached Prometheus metric sources.
const DefaultCacheTTL = 5 * time.Second

var _ asyncapi.GateFactory = (*GateFactory)(nil)

// GateFactory creates DispatchGate instances based on configuration.
type GateFactory struct {
prometheusURL string
cacheTTL time.Duration
redisClients map[string]*goredis.Client
}

// NewGateFactory creates a new GateFactory with an optional Prometheus URL.
// If prometheusURL is empty, Prometheus gates will fail at creation time.
// Prometheus metric sources are cached with DefaultCacheTTL; use
// NewGateFactoryWithCacheTTL to override.
func NewGateFactory(prometheusURL string) *GateFactory {
return NewGateFactoryWithCacheTTL(prometheusURL, DefaultCacheTTL)
}

// NewGateFactoryWithCacheTTL creates a GateFactory with a custom cache TTL
// for Prometheus metric sources. A TTL of 0 disables caching.
func NewGateFactoryWithCacheTTL(prometheusURL string, cacheTTL time.Duration) *GateFactory {
return &GateFactory{
prometheusURL: prometheusURL,
cacheTTL: cacheTTL,
redisClients: make(map[string]*goredis.Client),
}
}
Expand Down Expand Up @@ -111,7 +125,12 @@ func (f *GateFactory) CreateGate(gateType string, params map[string]string) (asy
return nil, fmt.Errorf("failed to create Prometheus metric source: %w", err)
}

return NewSaturationMetricDispatchGateWithSource(source, threshold, fallback), nil
var ms MetricSource = source
if f.cacheTTL > 0 {
ms = NewCachedMetricSource(source, f.cacheTTL)
}

return NewSaturationMetricDispatchGateWithSource(ms, threshold, fallback), nil

default:
// Unknown gate types default to open gate
Expand Down
Loading