Skip to content

Commit f323b66

Browse files
lioraronclaude
andcommitted
merge: resolve conflict with main in gate_factory.go
Keep both the DefaultCacheTTL constant from this branch and the GateFactory interface compliance check added on main. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2 parents 6fcad43 + 405738c commit f323b66

15 files changed

+64
-1
lines changed

pkg/async/api/api.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ type GateFactory interface {
4141
CreateGate(gateType string, params map[string]string) (DispatchGate, error)
4242
}
4343

44+
var _ DispatchGate = DispatchGateFunc(nil)
45+
4446
// DispatchGateFunc is a function type that implements DispatchGate.
4547
// This allows any function with the signature func(context.Context) float64
4648
// to be used as a DispatchGate.

pkg/async/api/http_client.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"net/http"
99
)
1010

11+
var _ InferenceClient = (*HTTPInferenceClient)(nil)
12+
1113
// HTTPInferenceClient is the default HTTP implementation of InferenceClient.
1214
type HTTPInferenceClient struct {
1315
client *http.Client

pkg/async/api/inference_error.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ type InferenceError interface {
3131
Category() ErrorCategory
3232
}
3333

34+
var _ InferenceError = (*ClientError)(nil)
35+
3436
// ClientError represents an inference client error with category and context.
3537
type ClientError struct {
3638
ErrorCategory ErrorCategory

pkg/async/inference/flowcontrol/binary_metric_dispatch_gate.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"context"
2121
"flag"
2222

23+
asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
2324
"github.com/prometheus/client_golang/api"
2425
"sigs.k8s.io/controller-runtime/pkg/log"
2526
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -30,6 +31,8 @@ var prometheusURL = flag.String("gate.prometheus.url", "", "Prometheus URL for n
3031
var gmpProjectID = flag.String("gate.pmetric.gmp.project-id", "", "Project ID for Google Managed Prometheus")
3132
var prometheusQueryModelName = flag.String("gate.prometheus.model-name", "", "metrics name to use for avg_queue_size")
3233

34+
var _ asyncapi.DispatchGate = (*BinaryMetricDispatchGate)(nil)
35+
3336
// BinaryMetricDispatchGate implements DispatchGate using a MetricSource.
3437
// It returns 0.0 (no capacity) if the metric value is non-zero,
3538
// and 1.0 (full capacity) if the metric value is zero.

pkg/async/inference/flowcontrol/dispatch_gate.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import (
2222
"github.com/llm-d-incubation/llm-d-async/pkg/async/api"
2323
)
2424

25+
var _ api.DispatchGate = DispatchGateFunc(nil)
26+
2527
// DispatchGateFunc is a function type that implements DispatchGate.
2628
// This allows any function with the signature func(context.Context) float64
2729
// to be used as a DispatchGate.

pkg/async/inference/flowcontrol/gate_factory.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ import (
3030
// DefaultCacheTTL is the default TTL for cached Prometheus metric sources.
3131
const DefaultCacheTTL = 5 * time.Second
3232

33+
var _ asyncapi.GateFactory = (*GateFactory)(nil)
34+
3335
// GateFactory creates DispatchGate instances based on configuration.
3436
type GateFactory struct {
3537
prometheusURL string

pkg/async/inference/flowcontrol/metric_source.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ type MetricSource interface {
4646
Query(ctx context.Context) ([]Sample, error)
4747
}
4848

49+
var _ MetricSource = (*PromQLMetricSource)(nil)
50+
4951
// PromQLMetricSource implements MetricSource by executing a PromQL expression
5052
// against a Prometheus-compatible API.
5153
type PromQLMetricSource struct {

pkg/async/inference/flowcontrol/saturation_metric_dispatch_gate.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"flag"
2222
"math"
2323

24+
asyncapi "github.com/llm-d-incubation/llm-d-async/pkg/async/api"
2425
"github.com/prometheus/client_golang/api"
2526
"sigs.k8s.io/controller-runtime/pkg/log"
2627
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
@@ -31,6 +32,8 @@ var saturationThreshold = flag.Float64("gate.saturation.threshold", 0.8, "satura
3132
var saturationFallback = flag.Float64("gate.saturation.fallback", 0.0, "fallback saturation value on error/missing metrics; default 0.0")
3233
var saturationQueryExpr = flag.String("gate.saturation.query-expr", "", "custom PromQL expression for saturation metric; overrides inference-pool label selector")
3334

35+
var _ asyncapi.DispatchGate = (*SaturationMetricDispatchGate)(nil)
36+
3437
// SaturationMetricDispatchGate implements DispatchGate based on pool saturation.
3538
// It queries a MetricSource for saturation samples and returns 0.0 if saturation
3639
// is at or above the configured threshold, otherwise returns max(0, 1 - saturation),

pkg/async/random_robin_policy.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ func NewRandomRobinPolicy() api.RequestMergePolicy {
1010
return &RandomRobinPolicy{}
1111
}
1212

13+
var _ api.RequestMergePolicy = (*RandomRobinPolicy)(nil)
14+
1315
type RandomRobinPolicy struct {
1416
}
1517

1618
func (r *RandomRobinPolicy) MergeRequestChannels(channels []api.RequestChannel) api.EmbelishedRequestChannel {
17-
mergedChannel := make(chan api.EmbelishedRequestMessage)
19+
mergedChannel := make(chan api.EmbelishedRequestMessage, len(channels))
1820

1921
cases := make([]reflect.SelectCase, len(channels)) //nolint:staticcheck
2022
for i, ch := range channels {

pkg/async/random_robin_policy_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package async
22

33
import (
44
"testing"
5+
"time"
56

67
"github.com/llm-d-incubation/llm-d-async/pkg/async/api"
78
)
@@ -41,3 +42,33 @@ func TestProcessAllChannels(t *testing.T) {
4142
}
4243
}
4344
}
45+
46+
func TestMergedChannelIsBuffered(t *testing.T) {
47+
numChannels := 3
48+
channels := make([]api.RequestChannel, numChannels)
49+
for i := range numChannels {
50+
channels[i] = api.RequestChannel{Channel: make(chan api.RequestMessage, 1)}
51+
}
52+
policy := NewRandomRobinPolicy()
53+
merged := policy.MergeRequestChannels(channels)
54+
55+
// Send one message per input channel.
56+
for i, ch := range channels {
57+
ch.Channel <- api.RequestMessage{Id: string(rune('A' + i))}
58+
}
59+
60+
// The merge goroutine should be able to forward all messages into the
61+
// buffered merged channel without a consumer draining it. With an
62+
// unbuffered channel this would deadlock because the goroutine blocks
63+
// on the first send.
64+
deadline := time.After(2 * time.Second)
65+
received := 0
66+
for received < numChannels {
67+
select {
68+
case <-merged.Channel:
69+
received++
70+
case <-deadline:
71+
t.Fatalf("timed out: only received %d/%d messages — merged channel may be unbuffered", received, numChannels)
72+
}
73+
}
74+
}

0 commit comments

Comments
 (0)