Skip to content

Commit fc451d6

Browse files
committed
Inference: Fix EPP Endpoint Sync and Eliminates Races
- Stores endpoints via atomic.Value and adds setEndpoints/getEndpoints to snapshot safely without locks. - Updates Equals to compare endpoint snapshots without locks, fixing race condition in krt.Equal/DeepEqual. - Switches error handling to hasErrors/snapshotErrors/setErrors. The backend path now returns empty ClusterLoadAssinment when errors exist. - Updates tests to seed errors via setErrors and avoid direct field access. - Keeps DP collection returning Backend IR on empty endpoints and relaxes route pass to allow empty endpoint sets. Signed-off-by: Daneyon Hansen <[email protected]>
1 parent 7f40d14 commit fc451d6

File tree

6 files changed

+73
-91
lines changed

6 files changed

+73
-91
lines changed

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends.go

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import (
1111
"google.golang.org/protobuf/types/known/anypb"
1212
"google.golang.org/protobuf/types/known/structpb"
1313
"google.golang.org/protobuf/types/known/wrapperspb"
14-
"istio.io/istio/pkg/kube/krt"
1514

16-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1715
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils"
1816
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
1917
)
@@ -22,11 +20,10 @@ func processPoolBackendObjIR(
2220
ctx context.Context,
2321
in ir.BackendObjectIR,
2422
out *envoyclusterv3.Cluster,
25-
podIdx krt.Index[string, krtcollections.LocalityPod],
2623
) *ir.EndpointsForBackend {
2724
// Build an endpoint list
2825
irPool := in.ObjIr.(*inferencePool)
29-
poolEps := irPool.resolvePoolEndpoints(podIdx)
26+
poolEps := irPool.getEndpoints()
3027
if len(poolEps) == 0 {
3128
logger.Warn("no endpoints resolved for InferencePool",
3229
"namespace", irPool.obj.GetNamespace(),
@@ -35,9 +32,10 @@ func processPoolBackendObjIR(
3532

3633
// If the pool has errors, create an empty LoadAssignment to return a 503
3734
if irPool.hasErrors() {
35+
errs := irPool.snapshotErrors()
3836
logger.Debug("skipping endpoints due to InferencePool errors",
3937
"pool", in.ResourceName(),
40-
"errors", irPool.errors,
38+
"errors", errs,
4139
)
4240
out.LoadAssignment = &envoyendpointv3.ClusterLoadAssignment{
4341
ClusterName: out.Name,

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends_test.go

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@ import (
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
structpb "google.golang.org/protobuf/types/known/structpb"
12-
"istio.io/istio/pkg/kube/krt"
13-
"istio.io/istio/pkg/kube/krt/krttest"
14-
corev1 "k8s.io/api/core/v1"
1512
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1613
inf "sigs.k8s.io/gateway-api-inference-extension/api/v1"
1714

18-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1915
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown"
2016
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
21-
krtpkg "github.com/kgateway-dev/kgateway/v2/pkg/utils/krtutil"
2217
)
2318

2419
func makeBackendIR(pool *inf.InferencePool) *ir.BackendObjectIR {
@@ -53,34 +48,14 @@ func TestProcessPoolBackendObjIR_BuildsLoadAssignment(t *testing.T) {
5348
},
5449
}
5550

56-
// Build a fake Pod and wrap it into a LocalityPod
57-
corePod := &corev1.Pod{
58-
ObjectMeta: metav1.ObjectMeta{
59-
Name: "pod1",
60-
Namespace: "ns",
61-
Labels: map[string]string{"app": "test"},
62-
},
63-
Status: corev1.PodStatus{PodIP: "10.0.0.1"},
64-
}
65-
fakeLP := krtcollections.LocalityPod{
66-
Named: krt.NewNamed(corePod),
67-
AugmentedLabels: corePod.Labels,
68-
Addresses: []string{corePod.Status.PodIP},
69-
}
70-
71-
// Create a mock and with the LocalityPod collection
72-
mock := krttest.NewMock(t, []any{fakeLP})
73-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
74-
75-
// Index the pods
76-
poolKey := fmt.Sprintf("%s/%s", pool.Namespace, pool.Name)
77-
podIdx := krtpkg.UnnamedIndex(podCol, func(p krtcollections.LocalityPod) []string {
78-
return []string{poolKey}
79-
})
51+
// Build the Backend IR and seed endpoints
52+
beIR := makeBackendIR(pool)
53+
irp := beIR.ObjIr.(*inferencePool)
54+
irp.setEndpoints([]endpoint{{address: "10.0.0.1", port: 9000}})
8055

8156
// Call the code under test
8257
cluster := &envoyclusterv3.Cluster{}
83-
ret := processPoolBackendObjIR(context.Background(), *makeBackendIR(pool), cluster, podIdx)
58+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
8459
assert.Nil(t, ret, "Should return nil for a static cluster")
8560

8661
// Validate the generated LoadAssignment
@@ -119,13 +94,8 @@ func TestProcessPoolBackendObjIR_SkipsOnErrors(t *testing.T) {
11994
irp := beIR.ObjIr.(*inferencePool)
12095
irp.setErrors([]error{fmt.Errorf("failure injected")})
12196

122-
// Empty pod index
123-
mock := krttest.NewMock(t, []any{})
124-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
125-
podIdx := krtpkg.UnnamedIndex(podCol, func(krtcollections.LocalityPod) []string { return nil })
126-
12797
cluster := &envoyclusterv3.Cluster{}
128-
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster, podIdx)
98+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
12999
assert.Nil(t, ret)
130100

131101
cla := cluster.LoadAssignment

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/collections.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ func initInferencePoolCollections(
121121
eps = append(eps, endpoint{address: ip, port: irPool.targetPorts[0].number})
122122
}
123123
}
124-
if len(eps) == 0 {
125-
return nil
126-
}
124+
// Always return a backend IR so the static cluster exists.
125+
// Endpoints may be empty on first pass, they'll populate in subsequent passes.
127126
irPool.setEndpoints(eps)
128127
return buildBackendObjIrFromPool(irPool)
129128
},
@@ -135,7 +134,7 @@ func initInferencePoolCollections(
135134
backendsDP,
136135
func(_ krt.HandlerContext, be ir.BackendObjectIR) *ir.EndpointsForBackend {
137136
stub := &envoyclusterv3.Cluster{Name: be.ClusterName()}
138-
return processPoolBackendObjIR(ctx, be, stub, podIdx)
137+
return processPoolBackendObjIR(ctx, be, stub)
139138
},
140139
)
141140

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/ir.go

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"encoding/json"
55
"fmt"
66
"maps"
7-
"sync"
7+
"sync/atomic"
88
"time"
99

1010
"istio.io/istio/pkg/kube/krt"
@@ -27,20 +27,17 @@ type inferencePool struct {
2727
// configRef is a reference to the extension configuration. A configRef is typically implemented
2828
// as a Kubernetes Service resource.
2929
configRef *service
30-
// mu is a mutex to protect access to the errors list.
31-
mu sync.Mutex
32-
// errors is a list of errors that occurred while processing the InferencePool.
33-
errors []error
30+
// errors that occurred while processing the InferencePool.
31+
errorsV atomic.Value
32+
errorCount atomic.Int64
3433
// endpoints define the list of endpoints resolved by the podSelector.
35-
endpoints []endpoint
34+
endpoints atomic.Value
3635
// failOpen configures how the proxy handles traffic when the EPP extension is
3736
// non-responsive. When set to `false` and the gRPC stream cannot be established, or if
3837
// it is closed prematurely with an error, the request will fail. When set to `true` and
3938
// the gRPC stream cannot be established, the request is forwarded based on the cluster
4039
// load balancing configuration.
4140
//
42-
// Defaults to `false`.
43-
//
4441
failOpen bool
4542
}
4643

@@ -67,27 +64,35 @@ func newInferencePool(pool *inf.InferencePool) *inferencePool {
6764
ports: []servicePort{port},
6865
}
6966

70-
return &inferencePool{
67+
ir := &inferencePool{
7168
obj: pool,
7269
podSelector: convertSelector(pool.Spec.Selector.MatchLabels),
7370
// InferencePool v1 only supports single port
7471
targetPorts: []targetPort{{number: int32(pool.Spec.TargetPorts[0].Number)}},
7572
configRef: svcIR,
76-
endpoints: []endpoint{},
7773
failOpen: isFailOpen(pool),
7874
}
75+
ir.endpoints.Store([]endpoint(nil))
76+
ir.errorsV.Store([]error(nil))
77+
ir.errorCount.Store(0)
78+
79+
return ir
7980
}
8081

8182
func (ir *inferencePool) setEndpoints(eps []endpoint) {
82-
ir.mu.Lock()
83-
defer ir.mu.Unlock()
84-
ir.endpoints = eps
83+
cp := append([]endpoint(nil), eps...)
84+
ir.endpoints.Store(cp)
8585
}
8686

8787
func (ir *inferencePool) getEndpoints() []endpoint {
88-
ir.mu.Lock()
89-
defer ir.mu.Unlock()
90-
return ir.endpoints
88+
v := ir.endpoints.Load()
89+
if v == nil {
90+
return nil
91+
}
92+
src := v.([]endpoint)
93+
out := make([]endpoint, len(src))
94+
copy(out, src)
95+
return out
9196
}
9297

9398
// resolvePoolEndpoints returns the slice of <IP:Port> for the given pool
@@ -125,31 +130,34 @@ func (ir *inferencePool) Equals(other any) bool {
125130
if !ok {
126131
return false
127132
}
133+
128134
// Compare pod selector
129135
if !maps.Equal(ir.Selector(), otherPool.Selector()) {
130136
return false
131137
}
138+
132139
// Compare error presence (we only need the boolean)
133140
if ir.hasErrors() != otherPool.hasErrors() {
134141
return false
135142
}
136-
// Compare endpoint set (order‑insensitive)
137-
ir.mu.Lock()
138-
otherPool.mu.Lock()
139-
defer ir.mu.Unlock()
140-
defer otherPool.mu.Unlock()
141-
if len(ir.endpoints) != len(otherPool.endpoints) {
143+
144+
// Snapshot endpoints (avoid holding locks during compare)
145+
epsA := ir.getEndpoints()
146+
epsB := otherPool.getEndpoints()
147+
148+
if len(epsA) != len(epsB) {
142149
return false
143150
}
144-
seen := make(map[string]struct{}, len(ir.endpoints))
145-
for _, ep := range ir.endpoints {
151+
seen := make(map[string]struct{}, len(epsA))
152+
for _, ep := range epsA {
146153
seen[ep.string()] = struct{}{}
147154
}
148-
for _, ep := range otherPool.endpoints {
155+
for _, ep := range epsB {
149156
if _, ok := seen[ep.string()]; !ok {
150157
return false
151158
}
152159
}
160+
153161
// Compare target port
154162
// InferencePool v1 only supports single port
155163
if len(ir.targetPorts) != 1 || len(otherPool.targetPorts) != 1 {
@@ -158,6 +166,7 @@ func (ir *inferencePool) Equals(other any) bool {
158166
if ir.targetPorts[0].number != otherPool.targetPorts[0].number {
159167
return false
160168
}
169+
161170
// Compare object metadata
162171
if ir.obj.GetName() != otherPool.obj.GetName() ||
163172
ir.obj.GetNamespace() != otherPool.obj.GetNamespace() ||
@@ -166,14 +175,17 @@ func (ir *inferencePool) Equals(other any) bool {
166175
ir.obj.GetGeneration() != otherPool.obj.GetGeneration() {
167176
return false
168177
}
178+
169179
// Compare configRef
170180
if !ir.configRefEquals(otherPool) {
171181
return false
172182
}
183+
173184
// Compare failure mode
174185
if !ir.failOpenEqual(otherPool) {
175186
return false
176187
}
188+
177189
return true
178190
}
179191

@@ -190,25 +202,26 @@ func (ir *inferencePool) configRefEquals(other *inferencePool) bool {
190202

191203
// setErrors atomically replaces p.errors under lock.
192204
func (ir *inferencePool) setErrors(errs []error) {
193-
ir.mu.Lock()
194-
defer ir.mu.Unlock()
195-
ir.errors = errs
205+
cp := append([]error(nil), errs...)
206+
ir.errorsV.Store(cp)
207+
ir.errorCount.Store(int64(len(cp)))
196208
}
197209

198210
// snapshotErrors returns a copy of p.errors under lock.
199211
func (ir *inferencePool) snapshotErrors() []error {
200-
ir.mu.Lock()
201-
defer ir.mu.Unlock()
202-
out := make([]error, len(ir.errors))
203-
copy(out, ir.errors)
212+
v := ir.errorsV.Load()
213+
if v == nil {
214+
return nil
215+
}
216+
src := v.([]error)
217+
out := make([]error, len(src))
218+
copy(out, src)
204219
return out
205220
}
206221

207222
// hasErrors checks if the inferencePool has any errors.
208223
func (ir *inferencePool) hasErrors() bool {
209-
ir.mu.Lock()
210-
defer ir.mu.Unlock()
211-
return len(ir.errors) > 0
224+
return ir.errorCount.Load() > 0
212225
}
213226

214227
func (ir *inferencePool) failOpenEqual(other *inferencePool) bool {

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/plugin.go

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewPlugin(ctx context.Context, commonCols *collections.CommonCollections) s
6161

6262
// Wrap the init function so it can capture commonCols.Pods
6363
initBackend := func(ctx context.Context, in ir.BackendObjectIR, out *envoyclusterv3.Cluster) *ir.EndpointsForBackend {
64-
return processPoolBackendObjIR(ctx, in, out, p.podIndex)
64+
return processPoolBackendObjIR(ctx, in, out)
6565
}
6666

6767
return sdk.Plugin{
@@ -215,11 +215,9 @@ func (p *endpointPickerPass) ApplyForBackend(
215215

216216
// Ensure we are working with the latest set of endpoints for the pool.
217217
eps := irPool.resolvePoolEndpoints(p.podIdx)
218-
if len(eps) == 0 {
219-
return fmt.Errorf("no endpoints found for InferencePool %s/%s",
220-
irPool.obj.GetNamespace(),
221-
irPool.obj.GetName())
222-
}
218+
// If the pool has no endpoints yet, do not fail translation.
219+
// Keep the route valid and provide an empty subset hint so the EPP
220+
// will return 503 (or honor fail-open) rather than causing a 500.
223221
irPool.setEndpoints(eps)
224222

225223
// Tell the EPP the subset of endpoints to choose from.

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/status_test.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func TestUpdatePoolStatus_NoReferences_NoErrors(t *testing.T) {
8989
Namespace: poolNN.Namespace,
9090
Name: poolNN.Name,
9191
},
92-
ObjIr: &inferencePool{errors: nil},
92+
ObjIr: &inferencePool{},
9393
}
9494

9595
// Call the function to update the pool status
@@ -181,7 +181,7 @@ func TestUpdatePoolStatus_WithReference_NoErrors(t *testing.T) {
181181
Namespace: poolNN.Namespace,
182182
Name: poolNN.Name,
183183
},
184-
ObjIr: &inferencePool{errors: nil},
184+
ObjIr: &inferencePool{},
185185
}
186186

187187
// Call the function to update the pool status
@@ -289,14 +289,18 @@ func TestUpdatePoolStatus_WithReference_WithErrors(t *testing.T) {
289289
ControllerName: controllerName,
290290
Routes: fakeRoutesIndex(col),
291291
}
292+
293+
poolIR := &inferencePool{}
294+
poolIR.setErrors([]error{fmt.Errorf("test error")})
295+
292296
beIR := ir.BackendObjectIR{
293297
ObjectSource: ir.ObjectSource{
294298
Group: inf.GroupVersion.Group,
295299
Kind: wellknown.InferencePoolKind,
296300
Namespace: poolNN.Namespace,
297301
Name: poolNN.Name,
298302
},
299-
ObjIr: &inferencePool{errors: []error{fmt.Errorf("test error")}},
303+
ObjIr: poolIR,
300304
}
301305

302306
// Call the function to update the pool status with errors
@@ -428,7 +432,7 @@ func TestUpdatePoolStatus_DeleteRoute(t *testing.T) {
428432
Namespace: poolNN.Namespace,
429433
Name: poolNN.Name,
430434
},
431-
ObjIr: &inferencePool{errors: nil},
435+
ObjIr: &inferencePool{},
432436
}
433437

434438
// Call the function to update the pool status with the route
@@ -476,7 +480,7 @@ func TestUpdatePoolStatus_WithExtraGws(t *testing.T) {
476480
Namespace: ns,
477481
Name: poolName,
478482
},
479-
ObjIr: &inferencePool{errors: nil},
483+
ObjIr: &inferencePool{},
480484
}
481485

482486
// Simulate controller knowing about a parent Gateway even if no HTTPRoute is present

0 commit comments

Comments
 (0)