Skip to content

Commit 48cc1ab

Browse files
committed
Inference: Fix EPP Endpoint Sync and Eliminates Races
- Stores endpoints via atomic.Value and adds setEndpoints/getEndpoints to snapshot safely without locks. - Updates Equals to compare endpoint snapshots without locks, fixing race condition in krt.Equal/DeepEqual. - Switches error handling to hasErrors/snapshotErrors/setErrors. The backend path now returns empty ClusterLoadAssinment when errors exist. - Updates tests to seed errors via setErrors and avoid direct field access. - Keeps DP collection returning Backend IR on empty endpoints and relaxes route pass to allow empty endpoint sets. Signed-off-by: Daneyon Hansen <[email protected]>
1 parent 7f40d14 commit 48cc1ab

File tree

6 files changed

+79
-91
lines changed

6 files changed

+79
-91
lines changed

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import (
1111
"google.golang.org/protobuf/types/known/anypb"
1212
"google.golang.org/protobuf/types/known/structpb"
1313
"google.golang.org/protobuf/types/known/wrapperspb"
14-
"istio.io/istio/pkg/kube/krt"
1514

16-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1715
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils"
1816
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
1917
)
@@ -22,11 +20,10 @@ func processPoolBackendObjIR(
2220
ctx context.Context,
2321
in ir.BackendObjectIR,
2422
out *envoyclusterv3.Cluster,
25-
podIdx krt.Index[string, krtcollections.LocalityPod],
2623
) *ir.EndpointsForBackend {
2724
// Build an endpoint list
2825
irPool := in.ObjIr.(*inferencePool)
29-
poolEps := irPool.resolvePoolEndpoints(podIdx)
26+
poolEps := irPool.getEndpoints()
3027
if len(poolEps) == 0 {
3128
logger.Warn("no endpoints resolved for InferencePool",
3229
"namespace", irPool.obj.GetNamespace(),
@@ -35,9 +32,10 @@ func processPoolBackendObjIR(
3532

3633
// If the pool has errors, create an empty LoadAssignment to return a 503
3734
if irPool.hasErrors() {
35+
errs := irPool.snapshotErrors()
3836
logger.Debug("skipping endpoints due to InferencePool errors",
3937
"pool", in.ResourceName(),
40-
"errors", irPool.errors,
38+
"errors", errs,
4139
)
4240
out.LoadAssignment = &envoyendpointv3.ClusterLoadAssignment{
4341
ClusterName: out.Name,

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends_test.go

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@ import (
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
structpb "google.golang.org/protobuf/types/known/structpb"
12-
"istio.io/istio/pkg/kube/krt"
13-
"istio.io/istio/pkg/kube/krt/krttest"
14-
corev1 "k8s.io/api/core/v1"
1512
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1613
inf "sigs.k8s.io/gateway-api-inference-extension/api/v1"
1714

18-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1915
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown"
2016
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
21-
krtpkg "github.com/kgateway-dev/kgateway/v2/pkg/utils/krtutil"
2217
)
2318

2419
func makeBackendIR(pool *inf.InferencePool) *ir.BackendObjectIR {
@@ -53,34 +48,14 @@ func TestProcessPoolBackendObjIR_BuildsLoadAssignment(t *testing.T) {
5348
},
5449
}
5550

56-
// Build a fake Pod and wrap it into a LocalityPod
57-
corePod := &corev1.Pod{
58-
ObjectMeta: metav1.ObjectMeta{
59-
Name: "pod1",
60-
Namespace: "ns",
61-
Labels: map[string]string{"app": "test"},
62-
},
63-
Status: corev1.PodStatus{PodIP: "10.0.0.1"},
64-
}
65-
fakeLP := krtcollections.LocalityPod{
66-
Named: krt.NewNamed(corePod),
67-
AugmentedLabels: corePod.Labels,
68-
Addresses: []string{corePod.Status.PodIP},
69-
}
70-
71-
// Create a mock and with the LocalityPod collection
72-
mock := krttest.NewMock(t, []any{fakeLP})
73-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
74-
75-
// Index the pods
76-
poolKey := fmt.Sprintf("%s/%s", pool.Namespace, pool.Name)
77-
podIdx := krtpkg.UnnamedIndex(podCol, func(p krtcollections.LocalityPod) []string {
78-
return []string{poolKey}
79-
})
51+
// Build the Backend IR and seed endpoints
52+
beIR := makeBackendIR(pool)
53+
irp := beIR.ObjIr.(*inferencePool)
54+
irp.setEndpoints([]endpoint{{address: "10.0.0.1", port: 9000}})
8055

8156
// Call the code under test
8257
cluster := &envoyclusterv3.Cluster{}
83-
ret := processPoolBackendObjIR(context.Background(), *makeBackendIR(pool), cluster, podIdx)
58+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
8459
assert.Nil(t, ret, "Should return nil for a static cluster")
8560

8661
// Validate the generated LoadAssignment
@@ -119,13 +94,8 @@ func TestProcessPoolBackendObjIR_SkipsOnErrors(t *testing.T) {
11994
irp := beIR.ObjIr.(*inferencePool)
12095
irp.setErrors([]error{fmt.Errorf("failure injected")})
12196

122-
// Empty pod index
123-
mock := krttest.NewMock(t, []any{})
124-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
125-
podIdx := krtpkg.UnnamedIndex(podCol, func(krtcollections.LocalityPod) []string { return nil })
126-
12797
cluster := &envoyclusterv3.Cluster{}
128-
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster, podIdx)
98+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
12999
assert.Nil(t, ret)
130100

131101
cla := cluster.LoadAssignment

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/collections.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ func initInferencePoolCollections(
121121
eps = append(eps, endpoint{address: ip, port: irPool.targetPorts[0].number})
122122
}
123123
}
124-
if len(eps) == 0 {
125-
return nil
126-
}
124+
// Always return a backend IR so the static cluster exists.
125+
// Endpoints may be empty on first pass, they'll populate in subsequent passes.
127126
irPool.setEndpoints(eps)
128127
return buildBackendObjIrFromPool(irPool)
129128
},
@@ -135,7 +134,7 @@ func initInferencePoolCollections(
135134
backendsDP,
136135
func(_ krt.HandlerContext, be ir.BackendObjectIR) *ir.EndpointsForBackend {
137136
stub := &envoyclusterv3.Cluster{Name: be.ClusterName()}
138-
return processPoolBackendObjIR(ctx, be, stub, podIdx)
137+
return processPoolBackendObjIR(ctx, be, stub)
139138
},
140139
)
141140

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/ir.go

Lines changed: 54 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import (
44
"encoding/json"
55
"fmt"
66
"maps"
7-
"sync"
7+
"math"
8+
"sync/atomic"
89
"time"
910

1011
"istio.io/istio/pkg/kube/krt"
@@ -27,20 +28,17 @@ type inferencePool struct {
2728
// configRef is a reference to the extension configuration. A configRef is typically implemented
2829
// as a Kubernetes Service resource.
2930
configRef *service
30-
// mu is a mutex to protect access to the errors list.
31-
mu sync.Mutex
32-
// errors is a list of errors that occurred while processing the InferencePool.
33-
errors []error
31+
// errors that occurred while processing the InferencePool.
32+
errorsV atomic.Value
33+
errorCount atomic.Int32
3434
// endpoints define the list of endpoints resolved by the podSelector.
35-
endpoints []endpoint
35+
endpoints atomic.Value
3636
// failOpen configures how the proxy handles traffic when the EPP extension is
3737
// non-responsive. When set to `false` and the gRPC stream cannot be established, or if
3838
// it is closed prematurely with an error, the request will fail. When set to `true` and
3939
// the gRPC stream cannot be established, the request is forwarded based on the cluster
4040
// load balancing configuration.
4141
//
42-
// Defaults to `false`.
43-
//
4442
failOpen bool
4543
}
4644

@@ -67,27 +65,35 @@ func newInferencePool(pool *inf.InferencePool) *inferencePool {
6765
ports: []servicePort{port},
6866
}
6967

70-
return &inferencePool{
68+
ir := &inferencePool{
7169
obj: pool,
7270
podSelector: convertSelector(pool.Spec.Selector.MatchLabels),
7371
// InferencePool v1 only supports single port
7472
targetPorts: []targetPort{{number: int32(pool.Spec.TargetPorts[0].Number)}},
7573
configRef: svcIR,
76-
endpoints: []endpoint{},
7774
failOpen: isFailOpen(pool),
7875
}
76+
ir.endpoints.Store([]endpoint(nil))
77+
ir.errorsV.Store([]error(nil))
78+
ir.errorCount.Store(0)
79+
80+
return ir
7981
}
8082

8183
func (ir *inferencePool) setEndpoints(eps []endpoint) {
82-
ir.mu.Lock()
83-
defer ir.mu.Unlock()
84-
ir.endpoints = eps
84+
cp := append([]endpoint(nil), eps...)
85+
ir.endpoints.Store(cp)
8586
}
8687

8788
func (ir *inferencePool) getEndpoints() []endpoint {
88-
ir.mu.Lock()
89-
defer ir.mu.Unlock()
90-
return ir.endpoints
89+
v := ir.endpoints.Load()
90+
if v == nil {
91+
return nil
92+
}
93+
src := v.([]endpoint)
94+
out := make([]endpoint, len(src))
95+
copy(out, src)
96+
return out
9197
}
9298

9399
// resolvePoolEndpoints returns the slice of <IP:Port> for the given pool
@@ -125,31 +131,34 @@ func (ir *inferencePool) Equals(other any) bool {
125131
if !ok {
126132
return false
127133
}
134+
128135
// Compare pod selector
129136
if !maps.Equal(ir.Selector(), otherPool.Selector()) {
130137
return false
131138
}
139+
132140
// Compare error presence (we only need the boolean)
133141
if ir.hasErrors() != otherPool.hasErrors() {
134142
return false
135143
}
136-
// Compare endpoint set (order‑insensitive)
137-
ir.mu.Lock()
138-
otherPool.mu.Lock()
139-
defer ir.mu.Unlock()
140-
defer otherPool.mu.Unlock()
141-
if len(ir.endpoints) != len(otherPool.endpoints) {
144+
145+
// Snapshot endpoints (avoid holding locks during compare)
146+
epsA := ir.getEndpoints()
147+
epsB := otherPool.getEndpoints()
148+
149+
if len(epsA) != len(epsB) {
142150
return false
143151
}
144-
seen := make(map[string]struct{}, len(ir.endpoints))
145-
for _, ep := range ir.endpoints {
152+
seen := make(map[string]struct{}, len(epsA))
153+
for _, ep := range epsA {
146154
seen[ep.string()] = struct{}{}
147155
}
148-
for _, ep := range otherPool.endpoints {
156+
for _, ep := range epsB {
149157
if _, ok := seen[ep.string()]; !ok {
150158
return false
151159
}
152160
}
161+
153162
// Compare target port
154163
// InferencePool v1 only supports single port
155164
if len(ir.targetPorts) != 1 || len(otherPool.targetPorts) != 1 {
@@ -158,6 +167,7 @@ func (ir *inferencePool) Equals(other any) bool {
158167
if ir.targetPorts[0].number != otherPool.targetPorts[0].number {
159168
return false
160169
}
170+
161171
// Compare object metadata
162172
if ir.obj.GetName() != otherPool.obj.GetName() ||
163173
ir.obj.GetNamespace() != otherPool.obj.GetNamespace() ||
@@ -166,14 +176,17 @@ func (ir *inferencePool) Equals(other any) bool {
166176
ir.obj.GetGeneration() != otherPool.obj.GetGeneration() {
167177
return false
168178
}
179+
169180
// Compare configRef
170181
if !ir.configRefEquals(otherPool) {
171182
return false
172183
}
184+
173185
// Compare failure mode
174186
if !ir.failOpenEqual(otherPool) {
175187
return false
176188
}
189+
177190
return true
178191
}
179192

@@ -190,25 +203,31 @@ func (ir *inferencePool) configRefEquals(other *inferencePool) bool {
190203

191204
// setErrors atomically replaces p.errors under lock.
192205
func (ir *inferencePool) setErrors(errs []error) {
193-
ir.mu.Lock()
194-
defer ir.mu.Unlock()
195-
ir.errors = errs
206+
cp := append([]error(nil), errs...)
207+
ir.errorsV.Store(cp)
208+
// Bound before casting to avoid gosec G115 (int -> int32 overflow).
209+
n := len(cp)
210+
if n > math.MaxInt32 {
211+
n = math.MaxInt32
212+
}
213+
ir.errorCount.Store(int32(n))
196214
}
197215

198216
// snapshotErrors returns a copy of p.errors under lock.
199217
func (ir *inferencePool) snapshotErrors() []error {
200-
ir.mu.Lock()
201-
defer ir.mu.Unlock()
202-
out := make([]error, len(ir.errors))
203-
copy(out, ir.errors)
218+
v := ir.errorsV.Load()
219+
if v == nil {
220+
return nil
221+
}
222+
src := v.([]error)
223+
out := make([]error, len(src))
224+
copy(out, src)
204225
return out
205226
}
206227

207228
// hasErrors checks if the inferencePool has any errors.
208229
func (ir *inferencePool) hasErrors() bool {
209-
ir.mu.Lock()
210-
defer ir.mu.Unlock()
211-
return len(ir.errors) > 0
230+
return ir.errorCount.Load() > 0
212231
}
213232

214233
func (ir *inferencePool) failOpenEqual(other *inferencePool) bool {

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/plugin.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewPlugin(ctx context.Context, commonCols *collections.CommonCollections) s
6161

6262
// Wrap the init function so it can capture commonCols.Pods
6363
initBackend := func(ctx context.Context, in ir.BackendObjectIR, out *envoyclusterv3.Cluster) *ir.EndpointsForBackend {
64-
return processPoolBackendObjIR(ctx, in, out, p.podIndex)
64+
return processPoolBackendObjIR(ctx, in, out)
6565
}
6666

6767
return sdk.Plugin{
@@ -215,11 +215,9 @@ func (p *endpointPickerPass) ApplyForBackend(
215215

216216
// Ensure we are working with the latest set of endpoints for the pool.
217217
eps := irPool.resolvePoolEndpoints(p.podIdx)
218-
if len(eps) == 0 {
219-
return fmt.Errorf("no endpoints found for InferencePool %s/%s",
220-
irPool.obj.GetNamespace(),
221-
irPool.obj.GetName())
222-
}
218+
// If the pool has no endpoints yet, do not fail translation.
219+
// Keep the route valid and provide an empty subset hint so the EPP
220+
// will return 503 (or honor fail-open) rather than causing a 500.
223221
irPool.setEndpoints(eps)
224222

225223
// Tell the EPP the subset of endpoints to choose from.

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/status_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestUpdatePoolStatus_NoReferences_NoErrors(t *testing.T) {
8989
Namespace: poolNN.Namespace,
9090
Name: poolNN.Name,
9191
},
92-
ObjIr: &inferencePool{errors: nil},
92+
ObjIr: &inferencePool{},
9393
}
9494

9595
// Call the function to update the pool status
@@ -181,7 +181,7 @@ func TestUpdatePoolStatus_WithReference_NoErrors(t *testing.T) {
181181
Namespace: poolNN.Namespace,
182182
Name: poolNN.Name,
183183
},
184-
ObjIr: &inferencePool{errors: nil},
184+
ObjIr: &inferencePool{},
185185
}
186186

187187
// Call the function to update the pool status
@@ -289,14 +289,18 @@ func TestUpdatePoolStatus_WithReference_WithErrors(t *testing.T) {
289289
ControllerName: controllerName,
290290
Routes: fakeRoutesIndex(col),
291291
}
292+
293+
poolIR := &inferencePool{}
294+
poolIR.setErrors([]error{fmt.Errorf("test error")})
295+
292296
beIR := ir.BackendObjectIR{
293297
ObjectSource: ir.ObjectSource{
294298
Group: inf.GroupVersion.Group,
295299
Kind: wellknown.InferencePoolKind,
296300
Namespace: poolNN.Namespace,
297301
Name: poolNN.Name,
298302
},
299-
ObjIr: &inferencePool{errors: []error{fmt.Errorf("test error")}},
303+
ObjIr: poolIR,
300304
}
301305

302306
// Call the function to update the pool status with errors
@@ -428,7 +432,7 @@ func TestUpdatePoolStatus_DeleteRoute(t *testing.T) {
428432
Namespace: poolNN.Namespace,
429433
Name: poolNN.Name,
430434
},
431-
ObjIr: &inferencePool{errors: nil},
435+
ObjIr: &inferencePool{},
432436
}
433437

434438
// Call the function to update the pool status with the route
@@ -476,7 +480,7 @@ func TestUpdatePoolStatus_WithExtraGws(t *testing.T) {
476480
Namespace: ns,
477481
Name: poolName,
478482
},
479-
ObjIr: &inferencePool{errors: nil},
483+
ObjIr: &inferencePool{},
480484
}
481485

482486
// Simulate controller knowing about a parent Gateway even if no HTTPRoute is present

0 commit comments

Comments
 (0)