Skip to content

Commit

Permalink
Adds initial infer ext epp plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Daneyon Hansen <[email protected]>
  • Loading branch information
danehans committed Feb 22, 2025
1 parent 8947ce1 commit b981485
Show file tree
Hide file tree
Showing 4 changed files with 632 additions and 6 deletions.
16 changes: 11 additions & 5 deletions internal/kgateway/controller/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common"
extensionsplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/registry"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
Expand Down Expand Up @@ -104,11 +105,6 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil
return nil, err
}

// Extend the scheme if the InferencePool CRD exists.
if _, err := glooschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme); err != nil {
return nil, err
}

mgrOpts := ctrl.Options{
BaseContext: func() context.Context { return ctx },
Scheme: scheme,
Expand Down Expand Up @@ -146,6 +142,16 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil
cli,
setupLog,
)

// Extend the scheme and add the EPP plugin if the InferencePool CRD exists.
exists, err := glooschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme)
switch {
case err != nil:
return nil, err
case exists:
cfg.ExtraPlugins = append(cfg.ExtraPlugins, endpointpicker.NewPlugin(ctx, commoncol))
}

gwClasses := sets.New(append(cfg.SetupOpts.ExtraGatewayClasses, wellknown.GatewayClassName)...)
isOurGw := func(gw *apiv1.Gateway) bool {
return gwClasses.Has(string(gw.Spec.GatewayClassName))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package endpointpicker

import (
"context"
"time"

envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/wrapperspb"
"istio.io/istio/pkg/kube/kclient"
"istio.io/istio/pkg/kube/krt"
"k8s.io/apimachinery/pkg/runtime/schema"
infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"

"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common"
extensionsplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/krtcollections"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils/krtutil"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown"
)

func NewPlugin(ctx context.Context, commoncol *common.CommonCollections) extensionsplug.Plugin {
poolClient := kclient.New[*infextv1a1.InferencePool](commoncol.Client)
pools := krt.WrapClient(poolClient, commoncol.KrtOpts.ToOptions("InferencePools")...)
return NewPluginFromCollections(ctx, commoncol.KrtOpts, pools, commoncol.Pods, commoncol.Settings)
}

func NewPluginFromCollections(
ctx context.Context,
krtOpts krtutil.KrtOptions,
pools krt.Collection[*infextv1a1.InferencePool],
pods krt.Collection[krtcollections.LocalityPod],
stngs settings.Settings,
) extensionsplug.Plugin {
gk := schema.GroupKind{
Group: infextv1a1.GroupVersion.Group,
Kind: wellknown.InferencePoolKind,
}

// TODO [danehans]: Filter InferencePools based one's that are referenced by an HTTPRoute
// with a status.parents[].controllerName that matches our Gateway controllerName.
infPoolUpstream := krt.NewCollection(pools, func(kctx krt.HandlerContext, pool *infextv1a1.InferencePool) *ir.Upstream {
return &ir.Upstream{
ObjectSource: ir.ObjectSource{
Kind: gk.Kind,
Group: gk.Group,
Namespace: pool.Namespace,
Name: pool.Name,
},
Obj: pool,
Port: pool.Spec.TargetPortNumber,
GvPrefix: "endpoint-picker",
CanonicalHostname: "",
}
}, krtOpts.ToOptions("EndpointPickerUpstreams")...)

// Create the endpoints collection
inputs := krtcollections.NewInfPoolEndpointsInputs(krtOpts, infPoolUpstream, pods)
infPoolEndpoints := krtcollections.NewInfPoolEndpoints(ctx, inputs)

return extensionsplug.Plugin{
ContributesUpstreams: map[schema.GroupKind]extensionsplug.UpstreamPlugin{
gk: {
UpstreamInit: ir.UpstreamInit{
InitUpstream: processUpstream,
},
Endpoints: infPoolEndpoints,
Upstreams: infPoolUpstream,
},
},
}
}

func processUpstream(ctx context.Context, in ir.Upstream, out *envoy_config_cluster_v3.Cluster) {
// Set cluster type to ORIGINAL_DST
out.ClusterDiscoveryType = &envoy_config_cluster_v3.Cluster_Type{
Type: envoy_config_cluster_v3.Cluster_ORIGINAL_DST,
}

// Set connect timeout to 1000 seconds.
// TODO [danehans]: Figure out an API that can be used to set this value.
out.ConnectTimeout = durationpb.New(1000 * time.Second)

// Use CLUSTER_PROVIDED load balancing.
out.LbPolicy = envoy_config_cluster_v3.Cluster_CLUSTER_PROVIDED

// Configure circuit breakers with a single threshold.
// TODO [danehans]: Figure out an API that can be used to set these values.
out.CircuitBreakers = &envoy_config_cluster_v3.CircuitBreakers{
Thresholds: []*envoy_config_cluster_v3.CircuitBreakers_Thresholds{
{
MaxConnections: wrapperspb.UInt32(40000),
MaxPendingRequests: wrapperspb.UInt32(40000),
MaxRequests: wrapperspb.UInt32(40000),
},
},
}

// If OriginalDstLbConfig is not available on Cluster,
// encode the configuration as a typed extension.
// Note: The type URL will be "type.googleapis.com/envoy.config.cluster.v3.Cluster_OriginalDstLbConfig".
lbConfig := &envoy_config_cluster_v3.Cluster_OriginalDstLbConfig{
UseHttpHeader: true,
HttpHeaderName: "x-gateway-destination-endpoint",
}
anyLbConfig, err := anypb.New(lbConfig)
if err != nil {
// handle error appropriately
return
}
out.TypedExtensionProtocolOptions = map[string]*anypb.Any{
"envoy.lb": anyLbConfig,
}
}
111 changes: 111 additions & 0 deletions internal/kgateway/krtcollections/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
corev1 "k8s.io/api/core/v1"
discoveryv1 "k8s.io/api/discovery/v1"
"k8s.io/apimachinery/pkg/types"
infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1"

"github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings"
"github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir"
Expand Down Expand Up @@ -297,3 +298,113 @@ func findPortInEndpointSlice(endpointSlice *discoveryv1.EndpointSlice, singlePor
}
return port
}

type InfPoolEndpointsInputs struct {
Upstreams krt.Collection[ir.Upstream]
Pods krt.Collection[LocalityPod]
KrtOpts krtutil.KrtOptions
}

func NewInfPoolEndpointsInputs(
krtopts krtutil.KrtOptions,
infPoolUpstreams krt.Collection[ir.Upstream],
pods krt.Collection[LocalityPod],
) InfPoolEndpointsInputs {
return InfPoolEndpointsInputs{
Upstreams: infPoolUpstreams,
Pods: pods,
KrtOpts: krtopts,
}
}

func NewInfPoolEndpoints(ctx context.Context, inputs InfPoolEndpointsInputs) krt.Collection[ir.EndpointsForUpstream] {
return krt.NewCollection(inputs.Upstreams, transformInfPoolEndpoints(ctx, inputs), inputs.KrtOpts.ToOptions("InfPoolEndpoints")...)
}

func transformInfPoolEndpoints(ctx context.Context, inputs InfPoolEndpointsInputs) func(kctx krt.HandlerContext, us ir.Upstream) *ir.EndpointsForUpstream {
logger := contextutils.LoggerFrom(ctx).Desugar()

return func(kctx krt.HandlerContext, us ir.Upstream) *ir.EndpointsForUpstream {
infPool, ok := us.Obj.(*infextv1a1.InferencePool)
if !ok {
logger.Debug("not an InferencePool upstream")
return nil
}

logger.Debug("building endpoints for inference pool", zap.String("pool", infPool.Name))

// Convert `spec.selector` from custom type to `map[string]string`
labelSelector := convertSelector(infPool.Spec.Selector)

// Use `FilterGeneric()` to match `LocalityPod` based on AugmentedLabels and Namespace
podMatches := krt.Fetch(kctx, inputs.Pods, krt.FilterGeneric(func(obj any) bool {
pod, ok := obj.(LocalityPod)
if !ok {
return false
}
// Ensure Pod is in the same namespace as the InferencePool
if pod.Namespace != infPool.Namespace {
return false
}
// Ensure the pod labels match the InferencePool selector
return labelsMatch(labelSelector, pod.AugmentedLabels)
}))

// Always return a valid EndpointsForUpstream instance, even if no matching pods
ret := ir.NewEndpointsForUpstream(us)

if len(podMatches) == 0 {
logger.Debug("no matching pods found for inference pool", zap.String("pool", infPool.Name))
return ret // Return an empty but valid EndpointsForUpstream
}

// Deduplicate Pod IPs
seenAddresses := make(map[string]struct{})

// Process matching Pods
for _, pod := range podMatches {
// Get the primary pod address
podIP := pod.IP()
if podIP == "" {
continue
}

// Deduplicate addresses
if _, exists := seenAddresses[podIP]; exists {
continue
}
seenAddresses[podIP] = struct{}{}

// Create Envoy LB Endpoint
ep := CreateLBEndpoint(podIP, uint32(infPool.Spec.TargetPortNumber), pod.AugmentedLabels, true)

// Add endpoint
ret.Add(pod.Locality, ir.EndpointWithMd{
LbEndpoint: ep,
EndpointMd: ir.EndpointMetadata{
Labels: pod.AugmentedLabels,
},
})
}

logger.Debug("created endpoints", zap.Int("numAddresses", len(ret.LbEps)))
return ret
}
}

func convertSelector(selector map[infextv1a1.LabelKey]infextv1a1.LabelValue) map[string]string {
result := make(map[string]string, len(selector))
for k, v := range selector {
result[string(k)] = string(v)
}
return result
}

func labelsMatch(selector map[string]string, podLabels map[string]string) bool {
for k, v := range selector {
if podLabels[k] != v {
return false
}
}
return true
}
Loading

0 comments on commit b981485

Please sign in to comment.