Skip to content

Commit 591a07f

Browse files
committed
Inference: Fix EPP Endpoint Sync and Eliminates Races
- Stores endpoints via atomic.Value and adds setEndpoints/getEndpoints to snapshot safely without locks. - Updates Equals to compare endpoint snapshots without locks, fixing race condition in krt.Equal/DeepEqual. - Switches error handling to hasErrors/snapshotErrors/setErrors. The backend path now returns empty ClusterLoadAssinment when errors exist. - Updates tests to seed errors via setErrors and avoid direct field access. - Keeps DP collection returning Backend IR on empty endpoints and relaxes route pass to allow empty endpoint sets. - Passes errors to DP to ensure cluster management. Signed-off-by: Daneyon Hansen <[email protected]>
1 parent 2703002 commit 591a07f

File tree

6 files changed

+123
-122
lines changed

6 files changed

+123
-122
lines changed

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends.go

Lines changed: 47 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import (
1111
"google.golang.org/protobuf/types/known/anypb"
1212
"google.golang.org/protobuf/types/known/structpb"
1313
"google.golang.org/protobuf/types/known/wrapperspb"
14-
"istio.io/istio/pkg/kube/krt"
1514

16-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1715
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils"
1816
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
1917
)
@@ -22,34 +20,57 @@ func processPoolBackendObjIR(
2220
ctx context.Context,
2321
in ir.BackendObjectIR,
2422
out *envoyclusterv3.Cluster,
25-
podIdx krt.Index[string, krtcollections.LocalityPod],
2623
) *ir.EndpointsForBackend {
27-
// Build an endpoint list
2824
irPool := in.ObjIr.(*inferencePool)
29-
poolEps := irPool.resolvePoolEndpoints(podIdx)
30-
if len(poolEps) == 0 {
31-
logger.Warn("no endpoints resolved for InferencePool",
32-
"namespace", irPool.obj.GetNamespace(),
33-
"name", irPool.obj.GetName())
34-
}
3525

36-
// If the pool has errors, create an empty LoadAssignment to return a 503
26+
// Always set the cluster name up front so error paths program the right cluster.
27+
out.Name = in.ClusterName()
28+
out.ClusterDiscoveryType = &envoyclusterv3.Cluster_Type{Type: envoyclusterv3.Cluster_STATIC}
29+
out.LbPolicy = envoyclusterv3.Cluster_ROUND_ROBIN
30+
31+
// If the pool has errors, create an empty LoadAssignment to return a 503.
3732
if irPool.hasErrors() {
33+
errs := irPool.snapshotErrors()
3834
logger.Debug("skipping endpoints due to InferencePool errors",
3935
"pool", in.ResourceName(),
40-
"errors", irPool.errors,
36+
"errors", errs,
4137
)
38+
4239
out.LoadAssignment = &envoyendpointv3.ClusterLoadAssignment{
4340
ClusterName: out.Name,
4441
Endpoints: []*envoyendpointv3.LocalityLbEndpoints{{}},
4542
}
43+
44+
// Still set subset config so Envoy’s view is consistent, but it won’t matter with no endpoints.
45+
out.LbSubsetConfig = &envoyclusterv3.Cluster_LbSubsetConfig{
46+
SubsetSelectors: []*envoyclusterv3.Cluster_LbSubsetConfig_LbSubsetSelector{{
47+
Keys: []string{dstEndpointKey},
48+
}},
49+
FallbackPolicy: envoyclusterv3.Cluster_LbSubsetConfig_ANY_ENDPOINT,
50+
}
51+
52+
// TODO [danehans]: Set H1/H2 app protocol programmatically:
53+
// https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/1273
54+
addHTTP1(out)
55+
out.CircuitBreakers = &envoyclusterv3.CircuitBreakers{
56+
Thresholds: []*envoyclusterv3.CircuitBreakers_Thresholds{{
57+
MaxConnections: wrapperspb.UInt32(defaultExtProcMaxRequests),
58+
MaxPendingRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
59+
MaxRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
60+
}},
61+
}
62+
4663
return nil
4764
}
4865

49-
// Static cluster with subset lb config
50-
out.Name = in.ClusterName()
51-
out.ClusterDiscoveryType = &envoyclusterv3.Cluster_Type{Type: envoyclusterv3.Cluster_STATIC}
52-
out.LbPolicy = envoyclusterv3.Cluster_ROUND_ROBIN
66+
// Build the static cluster with subset lb from IR endpoints.
67+
poolEps := irPool.getEndpoints()
68+
if len(poolEps) == 0 {
69+
logger.Warn("no endpoints resolved for InferencePool",
70+
"namespace", irPool.obj.GetNamespace(),
71+
"name", irPool.obj.GetName())
72+
}
73+
5374
out.LbSubsetConfig = &envoyclusterv3.Cluster_LbSubsetConfig{
5475
SubsetSelectors: []*envoyclusterv3.Cluster_LbSubsetConfig_LbSubsetSelector{{
5576
Keys: []string{dstEndpointKey},
@@ -68,7 +89,7 @@ func processPoolBackendObjIR(
6889

6990
// Build the subset metadata struct used by the EPP for endpoint selection
7091
mdStruct, err := structpb.NewStruct(map[string]interface{}{
71-
dstEndpointKey: addr,
92+
dstEndpointKey: addr,
7293
})
7394
if err != nil {
7495
logger.Error("failed to build endpoint metadata for endpoint",
@@ -78,15 +99,14 @@ func processPoolBackendObjIR(
7899
continue
79100
}
80101

81-
// Build the LB endpoint
82-
lbEp := &envoyendpointv3.LbEndpoint{
102+
lbEndpoints = append(lbEndpoints, &envoyendpointv3.LbEndpoint{
83103
HostIdentifier: &envoyendpointv3.LbEndpoint_Endpoint{
84104
Endpoint: &envoyendpointv3.Endpoint{
85105
Address: &envoycorev3.Address{
86106
Address: &envoycorev3.Address_SocketAddress{
87107
SocketAddress: &envoycorev3.SocketAddress{
88108
Address: ep.address,
89-
PortSpecifier: &envoycorev3.SocketAddress_PortValue{PortValue: uint32(ep.port)}, //nolint:gosec // G115: ep.port is int32 representing a port number, always in valid range
109+
PortSpecifier: &envoycorev3.SocketAddress_PortValue{PortValue: uint32(ep.port)}, //nolint:gosec // ep.port is a valid int32 port
90110
},
91111
},
92112
},
@@ -98,29 +118,20 @@ func processPoolBackendObjIR(
98118
envoyLbNamespace: mdStruct,
99119
},
100120
},
101-
}
102-
lbEndpoints = append(lbEndpoints, lbEp)
121+
})
103122
}
104123

105-
// Attach the endpoints to the cluster load assignment
106124
out.LoadAssignment = &envoyendpointv3.ClusterLoadAssignment{
107125
ClusterName: out.Name,
108-
Endpoints: []*envoyendpointv3.LocalityLbEndpoints{{
109-
LbEndpoints: lbEndpoints,
110-
}},
126+
Endpoints: []*envoyendpointv3.LocalityLbEndpoints{{LbEndpoints: lbEndpoints}},
111127
}
112-
113128
out.CircuitBreakers = &envoyclusterv3.CircuitBreakers{
114-
Thresholds: []*envoyclusterv3.CircuitBreakers_Thresholds{
115-
{
116-
MaxConnections: wrapperspb.UInt32(defaultExtProcMaxRequests),
117-
MaxPendingRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
118-
MaxRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
119-
},
120-
},
129+
Thresholds: []*envoyclusterv3.CircuitBreakers_Thresholds{{
130+
MaxConnections: wrapperspb.UInt32(defaultExtProcMaxRequests),
131+
MaxPendingRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
132+
MaxRequests: wrapperspb.UInt32(defaultExtProcMaxRequests),
133+
}},
121134
}
122-
123-
// Return nil since we're building a static cluster
124135
return nil
125136
}
126137

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/backends_test.go

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,11 @@ import (
99
"github.com/stretchr/testify/assert"
1010
"github.com/stretchr/testify/require"
1111
structpb "google.golang.org/protobuf/types/known/structpb"
12-
"istio.io/istio/pkg/kube/krt"
13-
"istio.io/istio/pkg/kube/krt/krttest"
14-
corev1 "k8s.io/api/core/v1"
1512
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1613
inf "sigs.k8s.io/gateway-api-inference-extension/api/v1"
1714

18-
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
1915
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown"
2016
"github.com/kgateway-dev/kgateway/v2/pkg/pluginsdk/ir"
21-
krtpkg "github.com/kgateway-dev/kgateway/v2/pkg/utils/krtutil"
2217
)
2318

2419
func makeBackendIR(pool *inf.InferencePool) *ir.BackendObjectIR {
@@ -53,34 +48,14 @@ func TestProcessPoolBackendObjIR_BuildsLoadAssignment(t *testing.T) {
5348
},
5449
}
5550

56-
// Build a fake Pod and wrap it into a LocalityPod
57-
corePod := &corev1.Pod{
58-
ObjectMeta: metav1.ObjectMeta{
59-
Name: "pod1",
60-
Namespace: "ns",
61-
Labels: map[string]string{"app": "test"},
62-
},
63-
Status: corev1.PodStatus{PodIP: "10.0.0.1"},
64-
}
65-
fakeLP := krtcollections.LocalityPod{
66-
Named: krt.NewNamed(corePod),
67-
AugmentedLabels: corePod.Labels,
68-
Addresses: []string{corePod.Status.PodIP},
69-
}
70-
71-
// Create a mock and with the LocalityPod collection
72-
mock := krttest.NewMock(t, []any{fakeLP})
73-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
74-
75-
// Index the pods
76-
poolKey := fmt.Sprintf("%s/%s", pool.Namespace, pool.Name)
77-
podIdx := krtpkg.UnnamedIndex(podCol, func(p krtcollections.LocalityPod) []string {
78-
return []string{poolKey}
79-
})
51+
// Build the Backend IR and seed endpoints
52+
beIR := makeBackendIR(pool)
53+
irp := beIR.ObjIr.(*inferencePool)
54+
irp.setEndpoints([]endpoint{{address: "10.0.0.1", port: 9000}})
8055

8156
// Call the code under test
8257
cluster := &envoyclusterv3.Cluster{}
83-
ret := processPoolBackendObjIR(context.Background(), *makeBackendIR(pool), cluster, podIdx)
58+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
8459
assert.Nil(t, ret, "Should return nil for a static cluster")
8560

8661
// Validate the generated LoadAssignment
@@ -119,13 +94,8 @@ func TestProcessPoolBackendObjIR_SkipsOnErrors(t *testing.T) {
11994
irp := beIR.ObjIr.(*inferencePool)
12095
irp.setErrors([]error{fmt.Errorf("failure injected")})
12196

122-
// Empty pod index
123-
mock := krttest.NewMock(t, []any{})
124-
podCol := krttest.GetMockCollection[krtcollections.LocalityPod](mock)
125-
podIdx := krtpkg.UnnamedIndex(podCol, func(krtcollections.LocalityPod) []string { return nil })
126-
12797
cluster := &envoyclusterv3.Cluster{}
128-
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster, podIdx)
98+
ret := processPoolBackendObjIR(context.Background(), *beIR, cluster)
12999
assert.Nil(t, ret)
130100

131101
cla := cluster.LoadAssignment

internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/collections.go

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ func initInferencePoolCollections(
104104
poolCol,
105105
func(ctx krt.HandlerContext, ip *inf.InferencePool) *ir.BackendObjectIR {
106106
irPool := newInferencePool(ip)
107+
108+
// Propagate validation errors to the DP.
109+
if errs := validatePool(ip, commonCol.Services); len(errs) > 0 {
110+
irPool.setErrors(errs)
111+
}
112+
107113
pods := krt.Fetch(ctx, commonCol.LocalityPods, krt.FilterGeneric(func(obj any) bool {
108114
pod, ok := obj.(krtcollections.LocalityPod)
109115
if !ok {
@@ -121,9 +127,8 @@ func initInferencePoolCollections(
121127
eps = append(eps, endpoint{address: ip, port: irPool.targetPorts[0].number})
122128
}
123129
}
124-
if len(eps) == 0 {
125-
return nil
126-
}
130+
// Always return a backend IR so the static cluster exists.
131+
// Endpoints may be empty on first pass, they'll populate in subsequent passes.
127132
irPool.setEndpoints(eps)
128133
return buildBackendObjIrFromPool(irPool)
129134
},
@@ -135,7 +140,7 @@ func initInferencePoolCollections(
135140
backendsDP,
136141
func(_ krt.HandlerContext, be ir.BackendObjectIR) *ir.EndpointsForBackend {
137142
stub := &envoyclusterv3.Cluster{Name: be.ClusterName()}
138-
return processPoolBackendObjIR(ctx, be, stub, podIdx)
143+
return processPoolBackendObjIR(ctx, be, stub)
139144
},
140145
)
141146

0 commit comments

Comments
 (0)