Skip to content

Commit b981485

Browse files
committed
Adds initial infer ext epp plugin
Signed-off-by: Daneyon Hansen <[email protected]>
1 parent 8947ce1 commit b981485

File tree

4 files changed

+632
-6
lines changed

4 files changed

+632
-6
lines changed

internal/kgateway/controller/start.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2"
3131
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common"
3232
extensionsplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin"
33+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker"
3334
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/registry"
3435
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
3536
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
@@ -104,11 +105,6 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil
104105
return nil, err
105106
}
106107

107-
// Extend the scheme if the InferencePool CRD exists.
108-
if _, err := glooschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme); err != nil {
109-
return nil, err
110-
}
111-
112108
mgrOpts := ctrl.Options{
113109
BaseContext: func() context.Context { return ctx },
114110
Scheme: scheme,
@@ -146,6 +142,16 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil
146142
cli,
147143
setupLog,
148144
)
145+
146+
// Extend the scheme and add the EPP plugin if the InferencePool CRD exists.
147+
exists, err := glooschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme)
148+
switch {
149+
case err != nil:
150+
return nil, err
151+
case exists:
152+
cfg.ExtraPlugins = append(cfg.ExtraPlugins, endpointpicker.NewPlugin(ctx, commoncol))
153+
}
154+
149155
gwClasses := sets.New(append(cfg.SetupOpts.ExtraGatewayClasses, wellknown.GatewayClassName)...)
150156
isOurGw := func(gw *apiv1.Gateway) bool {
151157
return gwClasses.Has(string(gw.Spec.GatewayClassName))
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package endpointpicker
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
8+
"google.golang.org/protobuf/types/known/anypb"
9+
"google.golang.org/protobuf/types/known/durationpb"
10+
"google.golang.org/protobuf/types/known/wrapperspb"
11+
"istio.io/istio/pkg/kube/kclient"
12+
"istio.io/istio/pkg/kube/krt"
13+
"k8s.io/apimachinery/pkg/runtime/schema"
14+
infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
15+
16+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common"
17+
extensionsplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin"
18+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
19+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
20+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
21+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils/krtutil"
22+
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown"
23+
)
24+
25+
func NewPlugin(ctx context.Context, commoncol *common.CommonCollections) extensionsplug.Plugin {
26+
poolClient := kclient.New[*infextv1a1.InferencePool](commoncol.Client)
27+
pools := krt.WrapClient(poolClient, commoncol.KrtOpts.ToOptions("InferencePools")...)
28+
return NewPluginFromCollections(ctx, commoncol.KrtOpts, pools, commoncol.Pods, commoncol.Settings)
29+
}
30+
31+
func NewPluginFromCollections(
32+
ctx context.Context,
33+
krtOpts krtutil.KrtOptions,
34+
pools krt.Collection[*infextv1a1.InferencePool],
35+
pods krt.Collection[krtcollections.LocalityPod],
36+
stngs settings.Settings,
37+
) extensionsplug.Plugin {
38+
gk := schema.GroupKind{
39+
Group: infextv1a1.GroupVersion.Group,
40+
Kind: wellknown.InferencePoolKind,
41+
}
42+
43+
// TODO [danehans]: Filter InferencePools based one's that are referenced by an HTTPRoute
44+
// with a status.parents[].controllerName that matches our Gateway controllerName.
45+
infPoolUpstream := krt.NewCollection(pools, func(kctx krt.HandlerContext, pool *infextv1a1.InferencePool) *ir.Upstream {
46+
return &ir.Upstream{
47+
ObjectSource: ir.ObjectSource{
48+
Kind: gk.Kind,
49+
Group: gk.Group,
50+
Namespace: pool.Namespace,
51+
Name: pool.Name,
52+
},
53+
Obj: pool,
54+
Port: pool.Spec.TargetPortNumber,
55+
GvPrefix: "endpoint-picker",
56+
CanonicalHostname: "",
57+
}
58+
}, krtOpts.ToOptions("EndpointPickerUpstreams")...)
59+
60+
// Create the endpoints collection
61+
inputs := krtcollections.NewInfPoolEndpointsInputs(krtOpts, infPoolUpstream, pods)
62+
infPoolEndpoints := krtcollections.NewInfPoolEndpoints(ctx, inputs)
63+
64+
return extensionsplug.Plugin{
65+
ContributesUpstreams: map[schema.GroupKind]extensionsplug.UpstreamPlugin{
66+
gk: {
67+
UpstreamInit: ir.UpstreamInit{
68+
InitUpstream: processUpstream,
69+
},
70+
Endpoints: infPoolEndpoints,
71+
Upstreams: infPoolUpstream,
72+
},
73+
},
74+
}
75+
}
76+
77+
func processUpstream(ctx context.Context, in ir.Upstream, out *envoy_config_cluster_v3.Cluster) {
78+
// Set cluster type to ORIGINAL_DST
79+
out.ClusterDiscoveryType = &envoy_config_cluster_v3.Cluster_Type{
80+
Type: envoy_config_cluster_v3.Cluster_ORIGINAL_DST,
81+
}
82+
83+
// Set connect timeout to 1000 seconds.
84+
// TODO [danehans]: Figure out an API that can be used to set this value.
85+
out.ConnectTimeout = durationpb.New(1000 * time.Second)
86+
87+
// Use CLUSTER_PROVIDED load balancing.
88+
out.LbPolicy = envoy_config_cluster_v3.Cluster_CLUSTER_PROVIDED
89+
90+
// Configure circuit breakers with a single threshold.
91+
// TODO [danehans]: Figure out an API that can be used to set these values.
92+
out.CircuitBreakers = &envoy_config_cluster_v3.CircuitBreakers{
93+
Thresholds: []*envoy_config_cluster_v3.CircuitBreakers_Thresholds{
94+
{
95+
MaxConnections: wrapperspb.UInt32(40000),
96+
MaxPendingRequests: wrapperspb.UInt32(40000),
97+
MaxRequests: wrapperspb.UInt32(40000),
98+
},
99+
},
100+
}
101+
102+
// If OriginalDstLbConfig is not available on Cluster,
103+
// encode the configuration as a typed extension.
104+
// Note: The type URL will be "type.googleapis.com/envoy.config.cluster.v3.Cluster_OriginalDstLbConfig".
105+
lbConfig := &envoy_config_cluster_v3.Cluster_OriginalDstLbConfig{
106+
UseHttpHeader: true,
107+
HttpHeaderName: "x-gateway-destination-endpoint",
108+
}
109+
anyLbConfig, err := anypb.New(lbConfig)
110+
if err != nil {
111+
// handle error appropriately
112+
return
113+
}
114+
out.TypedExtensionProtocolOptions = map[string]*anypb.Any{
115+
"envoy.lb": anyLbConfig,
116+
}
117+
}

internal/kgateway/krtcollections/endpoints.go

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
corev1 "k8s.io/api/core/v1"
1515
discoveryv1 "k8s.io/api/discovery/v1"
1616
"k8s.io/apimachinery/pkg/types"
17+
infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"
1718

1819
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
1920
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
@@ -297,3 +298,113 @@ func findPortInEndpointSlice(endpointSlice *discoveryv1.EndpointSlice, singlePor
297298
}
298299
return port
299300
}
301+
302+
type InfPoolEndpointsInputs struct {
303+
Upstreams krt.Collection[ir.Upstream]
304+
Pods krt.Collection[LocalityPod]
305+
KrtOpts krtutil.KrtOptions
306+
}
307+
308+
func NewInfPoolEndpointsInputs(
309+
krtopts krtutil.KrtOptions,
310+
infPoolUpstreams krt.Collection[ir.Upstream],
311+
pods krt.Collection[LocalityPod],
312+
) InfPoolEndpointsInputs {
313+
return InfPoolEndpointsInputs{
314+
Upstreams: infPoolUpstreams,
315+
Pods: pods,
316+
KrtOpts: krtopts,
317+
}
318+
}
319+
320+
func NewInfPoolEndpoints(ctx context.Context, inputs InfPoolEndpointsInputs) krt.Collection[ir.EndpointsForUpstream] {
321+
return krt.NewCollection(inputs.Upstreams, transformInfPoolEndpoints(ctx, inputs), inputs.KrtOpts.ToOptions("InfPoolEndpoints")...)
322+
}
323+
324+
func transformInfPoolEndpoints(ctx context.Context, inputs InfPoolEndpointsInputs) func(kctx krt.HandlerContext, us ir.Upstream) *ir.EndpointsForUpstream {
325+
logger := contextutils.LoggerFrom(ctx).Desugar()
326+
327+
return func(kctx krt.HandlerContext, us ir.Upstream) *ir.EndpointsForUpstream {
328+
infPool, ok := us.Obj.(*infextv1a1.InferencePool)
329+
if !ok {
330+
logger.Debug("not an InferencePool upstream")
331+
return nil
332+
}
333+
334+
logger.Debug("building endpoints for inference pool", zap.String("pool", infPool.Name))
335+
336+
// Convert `spec.selector` from custom type to `map[string]string`
337+
labelSelector := convertSelector(infPool.Spec.Selector)
338+
339+
// Use `FilterGeneric()` to match `LocalityPod` based on AugmentedLabels and Namespace
340+
podMatches := krt.Fetch(kctx, inputs.Pods, krt.FilterGeneric(func(obj any) bool {
341+
pod, ok := obj.(LocalityPod)
342+
if !ok {
343+
return false
344+
}
345+
// Ensure Pod is in the same namespace as the InferencePool
346+
if pod.Namespace != infPool.Namespace {
347+
return false
348+
}
349+
// Ensure the pod labels match the InferencePool selector
350+
return labelsMatch(labelSelector, pod.AugmentedLabels)
351+
}))
352+
353+
// Always return a valid EndpointsForUpstream instance, even if no matching pods
354+
ret := ir.NewEndpointsForUpstream(us)
355+
356+
if len(podMatches) == 0 {
357+
logger.Debug("no matching pods found for inference pool", zap.String("pool", infPool.Name))
358+
return ret // Return an empty but valid EndpointsForUpstream
359+
}
360+
361+
// Deduplicate Pod IPs
362+
seenAddresses := make(map[string]struct{})
363+
364+
// Process matching Pods
365+
for _, pod := range podMatches {
366+
// Get the primary pod address
367+
podIP := pod.IP()
368+
if podIP == "" {
369+
continue
370+
}
371+
372+
// Deduplicate addresses
373+
if _, exists := seenAddresses[podIP]; exists {
374+
continue
375+
}
376+
seenAddresses[podIP] = struct{}{}
377+
378+
// Create Envoy LB Endpoint
379+
ep := CreateLBEndpoint(podIP, uint32(infPool.Spec.TargetPortNumber), pod.AugmentedLabels, true)
380+
381+
// Add endpoint
382+
ret.Add(pod.Locality, ir.EndpointWithMd{
383+
LbEndpoint: ep,
384+
EndpointMd: ir.EndpointMetadata{
385+
Labels: pod.AugmentedLabels,
386+
},
387+
})
388+
}
389+
390+
logger.Debug("created endpoints", zap.Int("numAddresses", len(ret.LbEps)))
391+
return ret
392+
}
393+
}
394+
395+
func convertSelector(selector map[infextv1a1.LabelKey]infextv1a1.LabelValue) map[string]string {
396+
result := make(map[string]string, len(selector))
397+
for k, v := range selector {
398+
result[string(k)] = string(v)
399+
}
400+
return result
401+
}
402+
403+
func labelsMatch(selector map[string]string, podLabels map[string]string) bool {
404+
for k, v := range selector {
405+
if podLabels[k] != v {
406+
return false
407+
}
408+
}
409+
return true
410+
}

0 commit comments

Comments
 (0)