diff --git a/go.mod b/go.mod index 45cb63db096..66ba0295e2f 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/kelseyhightower/envconfig v1.4.0 github.com/mitchellh/hashstructure v1.0.0 - github.com/onsi/ginkgo/v2 v2.22.1 + github.com/onsi/ginkgo/v2 v2.22.2 github.com/onsi/gomega v1.36.2 github.com/pkg/errors v0.9.1 github.com/rotisserie/eris v0.5.4 @@ -52,9 +52,10 @@ require ( k8s.io/kube-openapi v0.0.0-20241212222426-2c72e554b1e7 k8s.io/utils v0.0.0-20241210054802-24370beab758 knative.dev/pkg v0.0.0-20211206113427-18589ac7627e - sigs.k8s.io/controller-runtime v0.20.0 + sigs.k8s.io/controller-runtime v0.20.2 sigs.k8s.io/controller-tools v0.16.5 sigs.k8s.io/gateway-api v1.2.1 + sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c sigs.k8s.io/structured-merge-diff/v4 v4.5.0 sigs.k8s.io/yaml v1.4.0 ) @@ -98,7 +99,7 @@ require ( github.com/emicklei/go-restful/v3 v3.12.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/evanphx/json-patch v5.9.0+incompatible // indirect - github.com/evanphx/json-patch/v5 v5.9.0 // indirect + github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/go.sum b/go.sum index 1f01996868b..ea9cecee8b3 100644 --- a/go.sum +++ b/go.sum @@ -284,8 +284,8 @@ github.com/evanphx/json-patch v4.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLi github.com/evanphx/json-patch v5.9.0+incompatible h1:fBXyNpNMuTTDdquAq/uisOr2lShz4oaXpDTX2bLe7ls= github.com/evanphx/json-patch v5.9.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQLiYLvXMP4fmwYFNcr97nuDLSk= github.com/evanphx/json-patch/v5 v5.6.0/go.mod h1:G79N1coSVB93tBe7j6PhzjmR3/2VvlbKOFpnXhI9Bw4= -github.com/evanphx/json-patch/v5 v5.9.0 h1:kcBlZQbplgElYIlo/n1hJbls2z/1awpXxpRi0/FOJfg= -github.com/evanphx/json-patch/v5 v5.9.0/go.mod h1:VNkHZ/282BpEyt/tObQO8s5CMPmYYq14uClGH4abBuQ= +github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= +github.com/evanphx/json-patch/v5 v5.9.11/go.mod h1:3j+LviiESTElxA4p3EMKAB9HXj3/XEtnUf6OZxqIQTM= github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f h1:Wl78ApPPB2Wvf/TIe2xdyJxTlb6obmF18d8QdkxNDu4= github.com/exponent-io/jsonpath v0.0.0-20210407135951-1de76d718b3f/go.mod h1:OSYXu++VVOHnXeitef/D8n/6y4QV8uLHSFXX4NeXMGc= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= @@ -738,8 +738,8 @@ github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108 github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= -github.com/onsi/ginkgo/v2 v2.22.1 h1:QW7tbJAUDyVDVOM5dFa7qaybo+CRfR7bemlQUN6Z8aM= -github.com/onsi/ginkgo/v2 v2.22.1/go.mod h1:S6aTpoRsSq2cZOd+pssHAlKW/Q/jZt6cPrPlnj4a1xM= +github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU= +github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk= github.com/onsi/gomega v0.0.0-20170829124025-dcabb60a477c/go.mod h1:C1qb7wdrVGGVU+Z6iS04AVkA3Q65CEZX59MT0QO5uiA= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.7.0/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= @@ -1627,12 +1627,14 @@ rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.0.22/go.mod h1:LEScyzhFmoF5pso/YSeBstl57mOzx9xlU9n85RGrDQg= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.1 h1:uOuSLOMBWkJH0TWa9X6l+mj5nZdm6Ay6Bli8HL8rNfk= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.1/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= -sigs.k8s.io/controller-runtime v0.20.0 h1:jjkMo29xEXH+02Md9qaVXfEIaMESSpy3TBWPrsfQkQs= -sigs.k8s.io/controller-runtime v0.20.0/go.mod h1:BrP3w158MwvB3ZbNpaAcIKkHQ7YGpYnzpoSTZ8E14WU= +sigs.k8s.io/controller-runtime v0.20.2 h1:/439OZVxoEc02psi1h4QO3bHzTgu49bb347Xp4gW1pc= +sigs.k8s.io/controller-runtime v0.20.2/go.mod h1:xg2XB0K5ShQzAgsoujxuKN4LNXR2LfwwHsPj7Iaw+XY= sigs.k8s.io/controller-tools v0.16.5 h1:5k9FNRqziBPwqr17AMEPPV/En39ZBplLAdOwwQHruP4= sigs.k8s.io/controller-tools v0.16.5/go.mod h1:8vztuRVzs8IuuJqKqbXCSlXcw+lkAv/M2sTpg55qjMY= sigs.k8s.io/gateway-api v1.2.1 h1:fZZ/+RyRb+Y5tGkwxFKuYuSRQHu9dZtbjenblleOLHM= sigs.k8s.io/gateway-api v1.2.1/go.mod h1:EpNfEXNjiYfUJypf0eZ0P5iXA9ekSGWaS1WgPaM42X0= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c h1:YyTNvnfjzdiHXFQdRzouvQO9SKFwZkgQffnbr9YADFE= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20250219213427-2577f63f6a1c/go.mod h1:H2DbSVDbCxG2cNTTgYC+V3RiotW077Xkx3fA3mRAwXs= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8 h1:gBQPwqORJ8d8/YNZWEjoZs7npUVDpVXUUOFfW6CgAqE= sigs.k8s.io/json v0.0.0-20241014173422-cfa47c3a1cc8/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/kustomize/api v0.18.0 h1:hTzp67k+3NEVInwz5BHyzc9rGxIauoXferXyjv5lWPo= diff --git a/hack/utils/oss_compliance/osa_provided.md b/hack/utils/oss_compliance/osa_provided.md index f3c22b6977b..613ed00d114 100644 --- a/hack/utils/oss_compliance/osa_provided.md +++ b/hack/utils/oss_compliance/osa_provided.md @@ -18,7 +18,7 @@ Name|Version|License [grpc-ecosystem/go-grpc-middleware](https://github.com/grpc-ecosystem/go-grpc-middleware)|v1.4.0|Apache License 2.0 [kelseyhightower/envconfig](https://github.com/kelseyhightower/envconfig)|v1.4.0|MIT License [mitchellh/hashstructure](https://github.com/mitchellh/hashstructure)|v1.0.0|MIT License -[ginkgo/v2](https://github.com/onsi/ginkgo)|v2.22.1|MIT License +[ginkgo/v2](https://github.com/onsi/ginkgo)|v2.22.2|MIT License [onsi/gomega](https://github.com/onsi/gomega)|v1.36.2|MIT License [pkg/errors](https://github.com/pkg/errors)|v0.9.1|BSD 2-clause "Simplified" License [rotisserie/eris](https://github.com/rotisserie/eris)|v0.5.4|MIT License @@ -45,9 +45,10 @@ Name|Version|License [k8s.io/kube-openapi](https://k8s.io/kube-openapi)|v0.0.0-20241212222426-2c72e554b1e7|Apache License 2.0 [k8s.io/utils](https://k8s.io/utils)|v0.0.0-20241210054802-24370beab758|Apache License 2.0 [knative.dev/pkg](https://knative.dev/pkg)|v0.0.0-20211206113427-18589ac7627e|Apache License 2.0 -[sigs.k8s.io/controller-runtime](https://sigs.k8s.io/controller-runtime)|v0.20.0|Apache License 2.0 +[sigs.k8s.io/controller-runtime](https://sigs.k8s.io/controller-runtime)|v0.20.2|Apache License 2.0 [sigs.k8s.io/controller-tools](https://sigs.k8s.io/controller-tools)|v0.16.5|Apache License 2.0 [sigs.k8s.io/gateway-api](https://sigs.k8s.io/gateway-api)|v1.2.1|Apache License 2.0 +[sigs.k8s.io/gateway-api-inference-extension](https://sigs.k8s.io/gateway-api-inference-extension)|v0.0.0-20250219213427-2577f63f6a1c|Apache License 2.0 [structured-merge-diff/v4](https://sigs.k8s.io/structured-merge-diff/v4)|v4.5.0|Apache License 2.0 [sigs.k8s.io/yaml](https://sigs.k8s.io/yaml)|v1.4.0|MIT License [cmd/goimports](https://golang.org/x/tools/cmd/goimports)|latest|MIT License diff --git a/install/helm/kgateway/templates/role.yaml b/install/helm/kgateway/templates/role.yaml index 917af45d1bb..7b87de9d380 100644 --- a/install/helm/kgateway/templates/role.yaml +++ b/install/helm/kgateway/templates/role.yaml @@ -126,3 +126,47 @@ rules: - get - list - watch +- apiGroups: + - inference.networking.x-k8s.io + resources: + - inferencemodels + verbs: + - get + - list + - watch +- apiGroups: + - inference.networking.x-k8s.io + resources: + - inferencepools + verbs: + - get + - list + - watch + - update +- apiGroups: + - rbac.authorization.k8s.io + # TODO [danehans]: EPP should use Role and RoleBinding resources: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 + resources: + - clusterroles + - clusterrolebindings + verbs: + - create + - delete + - get + - list + - patch + - update + - watch +# TODO [danehans]: Unsure why the following rules are needed: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 +- apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create +- apiGroups: + - authorization.k8s.io + resources: + - subjectaccessreviews + verbs: + - create diff --git a/internal/kgateway/controller/controller.go b/internal/kgateway/controller/controller.go index d1c6a281547..d7d1e31546c 100644 --- a/internal/kgateway/controller/controller.go +++ b/internal/kgateway/controller/controller.go @@ -17,6 +17,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/manager" "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/reconcile" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -25,10 +26,13 @@ import ( ) const ( - // field name used for indexing + // GatewayParamsField is the field name used for indexing Gateway objects. GatewayParamsField = "gateway-params" + // InferencePoolField is the field name used for indexing HTTPRoute objects. + InferencePoolField = "inferencepool-index" ) +// TODO [danehans]: Refactor so controller config is organized into shared and Gateway/InferencePool-specific controllers. type GatewayConfig struct { Mgr manager.Manager @@ -45,7 +49,7 @@ type GatewayConfig struct { func NewBaseGatewayController(ctx context.Context, cfg GatewayConfig) error { log := log.FromContext(ctx) - log.V(5).Info("starting controller", "controllerName", cfg.ControllerName) + log.V(5).Info("starting gateway controller", "controllerName", cfg.ControllerName) controllerBuilder := &controllerBuilder{ cfg: cfg, @@ -62,6 +66,29 @@ func NewBaseGatewayController(ctx context.Context, cfg GatewayConfig) error { ) } +type InferencePoolConfig struct { + Mgr manager.Manager + ControllerName string + InferenceExt *deployer.InferenceExtInfo +} + +func NewBaseInferencePoolController(ctx context.Context, poolCfg *InferencePoolConfig, gwCfg *GatewayConfig) error { + log := log.FromContext(ctx) + log.V(5).Info("starting inferencepool controller", "controllerName", poolCfg.ControllerName) + + // TODO [danehans]: Make GatewayConfig optional since Gateway and InferencePool are independent controllers. + controllerBuilder := &controllerBuilder{ + cfg: *gwCfg, + poolCfg: poolCfg, + reconciler: &controllerReconciler{ + cli: poolCfg.Mgr.GetClient(), + scheme: poolCfg.Mgr.GetScheme(), + }, + } + + return run(ctx, controllerBuilder.watchInferencePool) +} + func run(ctx context.Context, funcs ...func(ctx context.Context) error) error { for _, f := range funcs { if err := f(ctx); err != nil { @@ -72,8 +99,8 @@ func run(ctx context.Context, funcs ...func(ctx context.Context) error) error { } type controllerBuilder struct { - cfg GatewayConfig - + cfg GatewayConfig + poolCfg *InferencePoolConfig reconciler *controllerReconciler } @@ -98,7 +125,7 @@ func (c *controllerBuilder) watchGw(ctx context.Context) error { // setup a deployer log := log.FromContext(ctx) - log.Info("creating deployer", "ctrlname", c.cfg.ControllerName, "server", c.cfg.ControlPlane.XdsHost, "port", c.cfg.ControlPlane.XdsPort) + log.Info("creating gateway deployer", "ctrlname", c.cfg.ControllerName, "server", c.cfg.ControlPlane.XdsHost, "port", c.cfg.ControlPlane.XdsPort) d, err := deployer.NewDeployer(c.cfg.Mgr.GetClient(), &deployer.Inputs{ ControllerName: c.cfg.ControllerName, Dev: c.cfg.Dev, @@ -181,6 +208,152 @@ func (c *controllerBuilder) watchGw(ctx context.Context) error { return nil } +func (c *controllerBuilder) addHTTPRouteIndexes(ctx context.Context) error { + return c.cfg.Mgr.GetFieldIndexer().IndexField(ctx, new(apiv1.HTTPRoute), InferencePoolField, httpRouteInferencePoolIndex) +} + +func httpRouteInferencePoolIndex(obj client.Object) []string { + route, ok := obj.(*apiv1.HTTPRoute) + if !ok { + // Should never happen, but return empty slice in case of unexpected type. + return nil + } + + var poolNames []string + for _, rule := range route.Spec.Rules { + for _, ref := range rule.BackendRefs { + if ref.Kind != nil && *ref.Kind == wellknown.InferencePoolKind { + poolNames = append(poolNames, string(ref.Name)) + } + } + } + return poolNames +} + +// watchInferencePool adds a watch on InferencePool and HTTPRoute objects (that reference an InferencePool) +// to trigger reconciliation. +func (c *controllerBuilder) watchInferencePool(ctx context.Context) error { + log := log.FromContext(ctx) + log.Info("creating inference extension deployer", "controller", c.cfg.ControllerName) + + // Register the HTTPRoute index. + if err := c.addHTTPRouteIndexes(ctx); err != nil { + return fmt.Errorf("failed to register HTTPRoute index: %w", err) + } + + // Create a deployer using the controllerBuilder as inputs. + d, err := deployer.NewDeployer(c.cfg.Mgr.GetClient(), &deployer.Inputs{ + ControllerName: c.cfg.ControllerName, + InferenceExtension: c.poolCfg.InferenceExt, + }) + if err != nil { + return err + } + + buildr := ctrl.NewControllerManagedBy(c.cfg.Mgr). + For(&infextv1a1.InferencePool{}, builder.WithPredicates( + predicate.Or( + predicate.AnnotationChangedPredicate{}, + predicate.GenerationChangedPredicate{}, + ), + )). + // Watch HTTPRoute objects so that changes there trigger a reconcile for referenced InferencePools. + Watches(&apiv1.HTTPRoute{}, handler.EnqueueRequestsFromMapFunc(func(ctx context.Context, obj client.Object) []reconcile.Request { + route, ok := obj.(*apiv1.HTTPRoute) + if !ok { + return nil + } + + // Use the index function to get the inference pool names. + poolNames := httpRouteInferencePoolIndex(route) + if len(poolNames) == 0 { + return nil + } + + hasOurGateway := false + for _, parentRef := range route.Spec.ParentRefs { + // We only care about references to Gateways. + if parentRef.Group != nil && string(*parentRef.Group) == apiv1.GroupName && + parentRef.Kind != nil && *parentRef.Kind == wellknown.GatewayKind { + + // Determine the namespace of the Gateway. If parentRef.Namespace is nil/empty, + // it defaults to the route's namespace. + gwNamespace := route.Namespace + if parentRef.Namespace != nil && *parentRef.Namespace != "" { + gwNamespace = string(*parentRef.Namespace) + } + gwName := string(parentRef.Name) + + // Fetch the Gateway + var gw apiv1.Gateway + if err := c.cfg.Mgr.GetClient().Get(ctx, client.ObjectKey{ + Namespace: gwNamespace, + Name: gwName, + }, &gw); err != nil { + // If we cannot get it, skip this parentRef + continue + } + + // Check if the Gateway is recognized as "ours" + if c.cfg.OurGateway(&gw) { + hasOurGateway = true + break + } + } + } + if !hasOurGateway { + // If no parentRef references one of our Gateways, skip it. + return nil + } + + // The HTTPRoute references an InferencePool and one of our Gateways. + // Enqueue each referenced InferencePool for reconciliation. + var reqs []reconcile.Request + for _, poolName := range poolNames { + reqs = append(reqs, reconcile.Request{ + NamespacedName: client.ObjectKey{ + Namespace: route.Namespace, + Name: poolName, + }, + }) + } + return reqs + })) + + // Watch child objects, e.g. Deployments, created by the inference pool deployer. + gvks, err := d.GetGvksToWatch(ctx) + if err != nil { + return err + } + for _, gvk := range gvks { + obj, err := c.cfg.Mgr.GetScheme().New(gvk) + if err != nil { + return err + } + clientObj, ok := obj.(client.Object) + if !ok { + return fmt.Errorf("object %T is not a client.Object", obj) + } + log.Info("watching gvk as inferencepool child", "gvk", gvk) + var opts []builder.OwnsOption + if shouldIgnoreStatusChild(gvk) { + opts = append(opts, builder.WithPredicates(predicate.GenerationChangedPredicate{})) + } + buildr.Owns(clientObj, opts...) + } + + r := &inferencePoolReconciler{ + cli: c.cfg.Mgr.GetClient(), + scheme: c.cfg.Mgr.GetScheme(), + deployer: d, + } + if err := buildr.Complete(r); err != nil { + return err + } + + return nil +} + func shouldIgnoreStatusChild(gvk schema.GroupVersionKind) bool { // avoid triggering on pod changes that update deployment status return gvk.Kind == "Deployment" diff --git a/internal/kgateway/controller/controller_suite_test.go b/internal/kgateway/controller/controller_suite_test.go index 8a0977b8166..502b4429093 100644 --- a/internal/kgateway/controller/controller_suite_test.go +++ b/internal/kgateway/controller/controller_suite_test.go @@ -15,7 +15,9 @@ import ( . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/rest" @@ -28,11 +30,12 @@ import ( "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/controller-runtime/pkg/log/zap" "sigs.k8s.io/controller-runtime/pkg/webhook" - api "sigs.k8s.io/gateway-api/apis/v1" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/controller" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" ) @@ -51,7 +54,7 @@ var ( const ( gatewayClassName = "clsname" altGatewayClassName = "clsname-alt" - gatewayControllerName = "controller/name" + gatewayControllerName = "kgateway.dev/kgateway" ) func getAssetsDir() string { @@ -72,6 +75,14 @@ var _ = BeforeSuite(func() { ctx, cancel = context.WithCancel(context.TODO()) By("bootstrapping test environment") + // Create a scheme and add both Gateway and InferencePool types. + scheme := schemes.GatewayScheme() + err := infextv1a1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + // Required to deploy endpoint picker RBAC resources. + err = rbacv1.AddToScheme(scheme) + Expect(err).NotTo(HaveOccurred()) + testEnv = &envtest.Environment{ CRDDirectoryPaths: []string{ filepath.Join("..", "crds"), @@ -80,14 +91,12 @@ var _ = BeforeSuite(func() { ErrorIfCRDPathMissing: true, // set assets dir so we can run without the makefile BinaryAssetsDirectory: getAssetsDir(), - // web hook to add cluster ips to services - } - var err error - cfg, err = testEnv.Start() - Expect(err).NotTo(HaveOccurred()) + var err2 error + cfg, err2 = testEnv.Start() + Expect(err2).NotTo(HaveOccurred()) Expect(cfg).NotTo(BeNil()) - scheme := schemes.GatewayScheme() + k8sClient, err = client.New(cfg, client.Options{Scheme: scheme}) Expect(err).NotTo(HaveOccurred()) Expect(k8sClient).NotTo(BeNil()) @@ -114,9 +123,8 @@ var _ = BeforeSuite(func() { kubeconfig = generateKubeConfiguration(cfg) mgr.GetLogger().Info("starting manager", "kubeconfig", kubeconfig) - Expect(err).ToNot(HaveOccurred()) - - cfg := controller.GatewayConfig{ + // Start the Gateway controller. + gwCfg := controller.GatewayConfig{ Mgr: mgr, ControllerName: gatewayControllerName, OurGateway: func(gw *apiv1.Gateway) bool { @@ -124,27 +132,19 @@ var _ = BeforeSuite(func() { }, AutoProvision: true, } - err = controller.NewBaseGatewayController(ctx, cfg) + err = controller.NewBaseGatewayController(ctx, gwCfg) Expect(err).ToNot(HaveOccurred()) - for class := range gwClasses { - err = k8sClient.Create(ctx, &api.GatewayClass{ - ObjectMeta: metav1.ObjectMeta{ - Name: class, - }, - Spec: api.GatewayClassSpec{ - ControllerName: api.GatewayController(gatewayControllerName), - ParametersRef: &api.ParametersReference{ - Group: api.Group(v1alpha1.GroupVersion.Group), - Kind: api.Kind("GatewayParameters"), - Name: wellknown.DefaultGatewayParametersName, - Namespace: ptr.To(api.Namespace("default")), - }, - }, - }) - Expect(err).NotTo(HaveOccurred()) + // Start the inference pool controller. + poolCfg := &controller.InferencePoolConfig{ + Mgr: mgr, + ControllerName: wellknown.GatewayControllerName, + InferenceExt: new(deployer.InferenceExtInfo), } + err = controller.NewBaseInferencePoolController(ctx, poolCfg, &gwCfg) + Expect(err).ToNot(HaveOccurred()) + // Create the default GatewayParameters and GatewayClass. err = k8sClient.Create(ctx, &v1alpha1.GatewayParameters{ ObjectMeta: metav1.ObjectMeta{ Name: wellknown.DefaultGatewayParametersName, @@ -161,6 +161,25 @@ var _ = BeforeSuite(func() { }) Expect(err).NotTo(HaveOccurred()) + for class := range gwClasses { + err = k8sClient.Create(ctx, &apiv1.GatewayClass{ + ObjectMeta: metav1.ObjectMeta{ + Name: class, + }, + Spec: apiv1.GatewayClassSpec{ + ControllerName: apiv1.GatewayController(gatewayControllerName), + ParametersRef: &apiv1.ParametersReference{ + Group: apiv1.Group(v1alpha1.GroupVersion.Group), + Kind: "GatewayParameters", + Name: wellknown.DefaultGatewayParametersName, + Namespace: ptr.To(apiv1.Namespace("default")), + }, + }, + }) + Expect(err).NotTo(HaveOccurred()) + } + + // Start the manager. go func() { defer GinkgoRecover() err = mgr.Start(ctx) @@ -204,11 +223,10 @@ func generateKubeConfiguration(restconfig *rest.Config) string { } clientConfig := clientcmdapi.Config{ - Kind: "Config", - APIVersion: "v1", - Clusters: clusters, - Contexts: contexts, - // current context must be mgmt cluster for now, as the api server doesn't have context configurable. + Kind: "Config", + APIVersion: "v1", + Clusters: clusters, + Contexts: contexts, CurrentContext: "cluster", AuthInfos: authinfos, } @@ -220,3 +238,143 @@ func generateKubeConfiguration(restconfig *rest.Config) string { Expect(err).NotTo(HaveOccurred()) return tmpfile.Name() } + +var _ = Describe("InferencePool controller", func() { + const defaultNamespace = "default" + + It("should reconcile an InferencePool referenced by an HTTPRoute managed by our controller", func() { + // Create a test Gateway that will be referenced by the HTTPRoute. + testGw := &apiv1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-gateway", + Namespace: defaultNamespace, + }, + Spec: apiv1.GatewaySpec{ + GatewayClassName: gatewayClassName, + Listeners: []apiv1.Listener{ + { + Name: "listener-1", + Protocol: apiv1.HTTPProtocolType, + Port: 80, + }, + }, + }, + } + err := k8sClient.Create(ctx, testGw) + Expect(err).NotTo(HaveOccurred()) + + // Create an HTTPRoute without a status. + httpRoute := &apiv1.HTTPRoute{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-route", + Namespace: defaultNamespace, + }, + Spec: apiv1.HTTPRouteSpec{ + Rules: []apiv1.HTTPRouteRule{ + { + BackendRefs: []apiv1.HTTPBackendRef{ + { + BackendRef: apiv1.BackendRef{ + BackendObjectReference: apiv1.BackendObjectReference{ + Group: ptr.To(apiv1.Group(infextv1a1.GroupVersion.Group)), + Kind: ptr.To(apiv1.Kind("InferencePool")), + Name: "pool1", + }, + }, + }, + }, + }, + }, + }, + } + err = k8sClient.Create(ctx, httpRoute) + Expect(err).NotTo(HaveOccurred()) + + // Now update the status to include a valid Parents field. + httpRoute.Status = apiv1.HTTPRouteStatus{ + RouteStatus: apiv1.RouteStatus{ + Parents: []apiv1.RouteParentStatus{ + { + ParentRef: apiv1.ParentReference{ + Group: ptr.To(apiv1.Group("gateway.networking.k8s.io")), + Kind: ptr.To(apiv1.Kind("Gateway")), + Name: apiv1.ObjectName(testGw.Name), + Namespace: ptr.To(apiv1.Namespace(defaultNamespace)), + }, + ControllerName: gatewayControllerName, + }, + }, + }, + } + err = k8sClient.Status().Update(ctx, httpRoute) + Expect(err).NotTo(HaveOccurred()) + + // Create an InferencePool resource that is referenced by the HTTPRoute. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: "InferencePool", + APIVersion: infextv1a1.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool1", + Namespace: defaultNamespace, + UID: "pool-uid", + }, + Spec: infextv1a1.InferencePoolSpec{ + Selector: map[infextv1a1.LabelKey]infextv1a1.LabelValue{}, + TargetPortNumber: 1234, + EndpointPickerConfig: infextv1a1.EndpointPickerConfig{ + ExtensionRef: &infextv1a1.Extension{ + ExtensionReference: infextv1a1.ExtensionReference{ + Name: "doesnt-matter", + }, + }, + }, + }, + } + err = k8sClient.Create(ctx, pool) + Expect(err).NotTo(HaveOccurred()) + + // The secondary watch on HTTPRoute should now trigger reconciliation of pool "pool1". + // We expect the deployer to render and deploy an endpoint picker Deployment with name "pool1-endpoint-picker". + expectedName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + var deploy appsv1.Deployment + Eventually(func() error { + return k8sClient.Get(ctx, client.ObjectKey{Namespace: defaultNamespace, Name: expectedName}, &deploy) + }, "10s", "1s").Should(Succeed()) + }) + + It("should ignore an InferencePool not referenced by any HTTPRoute", func() { + // Create an InferencePool that is not referenced by any HTTPRoute. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: "InferencePool", + APIVersion: infextv1a1.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool2", + Namespace: defaultNamespace, + UID: "pool2-uid", + }, + Spec: infextv1a1.InferencePoolSpec{ + Selector: map[infextv1a1.LabelKey]infextv1a1.LabelValue{}, + TargetPortNumber: 1234, + EndpointPickerConfig: infextv1a1.EndpointPickerConfig{ + ExtensionRef: &infextv1a1.Extension{ + ExtensionReference: infextv1a1.ExtensionReference{ + Name: "doesnt-matter", + }, + }, + }, + }, + } + err := k8sClient.Create(ctx, pool) + Expect(err).NotTo(HaveOccurred()) + + // Consistently check that no endpoint picker deployment is created. + Consistently(func() error { + var dep appsv1.Deployment + return k8sClient.Get(ctx, client.ObjectKey{Namespace: defaultNamespace, Name: fmt.Sprintf("%s-endpoint-picker", pool.Name)}, &dep) + }, "5s", "1s").ShouldNot(Succeed()) + }) +}) diff --git a/internal/kgateway/controller/inferencepool_controller.go b/internal/kgateway/controller/inferencepool_controller.go new file mode 100644 index 00000000000..6858e86c32c --- /dev/null +++ b/internal/kgateway/controller/inferencepool_controller.go @@ -0,0 +1,88 @@ +package controller + +import ( + "context" + "slices" + + "k8s.io/apimachinery/pkg/runtime" + ctrl "sigs.k8s.io/controller-runtime" + "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/log" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" + gwv1 "sigs.k8s.io/gateway-api/apis/v1" + + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" +) + +type inferencePoolReconciler struct { + cli client.Client + scheme *runtime.Scheme + deployer *deployer.Deployer +} + +func (r *inferencePoolReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { + log := log.FromContext(ctx).WithValues("inferencepool", req.NamespacedName) + log.V(1).Info("reconciling request", "request", req) + + pool := new(infextv1a1.InferencePool) + if err := r.cli.Get(ctx, req.NamespacedName, pool); err != nil { + return ctrl.Result{}, client.IgnoreNotFound(err) + } + + if pool.GetDeletionTimestamp() != nil { + log.Info("Removing endpoint picker for InferencePool", "name", pool.Name, "namespace", pool.Namespace) + // TODO [danehans]: EPP should use role and rolebinding RBAC: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 + if err := r.deployer.CleanupClusterScopedResources(ctx, pool); err != nil { + return ctrl.Result{}, err + } + // Remove the finalizer. + pool.Finalizers = slices.DeleteFunc(pool.Finalizers, func(s string) bool { + return s == wellknown.InferencePoolFinalizer + }) + + if err := r.cli.Update(ctx, pool); err != nil { + return ctrl.Result{}, err + } + return ctrl.Result{}, nil + } + + // Ensure the finalizer is present for the InferencePool. + if err := r.deployer.EnsureFinalizer(ctx, pool); err != nil { + return ctrl.Result{}, err + } + + // Use the registered index to list HTTPRoutes that reference this pool. + var routeList gwv1.HTTPRouteList + if err := r.cli.List(ctx, &routeList, + client.InNamespace(pool.Namespace), + client.MatchingFields{InferencePoolField: pool.Name}, + ); err != nil { + log.Error(err, "failed to list HTTPRoutes referencing InferencePool", "name", pool.Name, "namespace", pool.Namespace) + return ctrl.Result{}, err + } + + // If no HTTPRoutes reference the pool, skip reconciliation. + if len(routeList.Items) == 0 { + log.Info("No HTTPRoutes reference this InferencePool; skipping reconcile", "pool", pool.Name) + return ctrl.Result{}, nil + } + + objs, err := r.deployer.GetEndpointPickerObjs(pool) + if err != nil { + return ctrl.Result{}, err + } + + // TODO [danehans]: Manage inferencepool status conditions. + + // Deploy the endpoint picker resources. + log.Info("Deploying endpoint picker for InferencePool", "name", pool.Name, "namespace", pool.Namespace) + err = r.deployer.DeployObjs(ctx, objs) + if err != nil { + return ctrl.Result{}, err + } + + log.V(1).Info("reconciled request", "request", req) + + return ctrl.Result{}, nil +} diff --git a/internal/kgateway/controller/start.go b/internal/kgateway/controller/start.go index 424ffa8004c..4eec0e1090f 100644 --- a/internal/kgateway/controller/start.go +++ b/internal/kgateway/controller/start.go @@ -18,12 +18,14 @@ import ( "sigs.k8s.io/controller-runtime/pkg/healthz" czap "sigs.k8s.io/controller-runtime/pkg/log/zap" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" apiv1 "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/deployer" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common" extensionsplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/registry" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/settings" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir" @@ -32,7 +34,7 @@ import ( "github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils/krtutil" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" "github.com/kgateway-dev/kgateway/v2/pkg/client/clientset/versioned" - glooschemes "github.com/kgateway-dev/kgateway/v2/pkg/schemes" + kgtwschemes "github.com/kgateway-dev/kgateway/v2/pkg/schemes" "github.com/kgateway-dev/kgateway/v2/pkg/utils/kubeutils" "github.com/kgateway-dev/kgateway/v2/pkg/utils/namespaces" ) @@ -100,7 +102,7 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil scheme := DefaultScheme() // Extend the scheme if the TCPRoute CRD exists. - if err := glooschemes.AddGatewayV1A2Scheme(cfg.RestConfig, scheme); err != nil { + if err := kgtwschemes.AddGatewayV1A2Scheme(cfg.RestConfig, scheme); err != nil { return nil, err } @@ -142,6 +144,22 @@ func NewControllerBuilder(ctx context.Context, cfg StartConfig) (*ControllerBuil setupLog, *cfg.SetupOpts.GlobalSettings, ) + + // Extend the scheme and add the EPP plugin if the InferencePool CRD exists. + exists, err := kgtwschemes.AddInferExtV1A1Scheme(cfg.RestConfig, scheme) + setupLog.Info("checking inference extension CRDs exist", "result", exists) + + switch { + case err != nil: + return nil, err + case exists: + setupLog.Info("adding inference extension endpoint picker plugin") + if cfg.ExtraPlugins == nil { + cfg.ExtraPlugins = []extensionsplug.Plugin{} + } + cfg.ExtraPlugins = append(cfg.ExtraPlugins, endpointpicker.NewPlugin(ctx, commoncol)) + } + gwClasses := sets.New(append(cfg.SetupOpts.ExtraGatewayClasses, wellknown.GatewayClassName)...) isOurGw := func(gw *apiv1.Gateway) bool { return gwClasses.Has(string(gw.Spec.GatewayClassName)) @@ -229,9 +247,22 @@ func (c *ControllerBuilder) Start(ctx context.Context) error { } if err := NewBaseGatewayController(ctx, gwCfg); err != nil { - setupLog.Error(err, "unable to create controller") + setupLog.Error(err, "unable to create gateway controller") return err } + // Create the InferencePool controller if the inference extension API group is registered. + if c.mgr.GetScheme().IsGroupRegistered(infextv1a1.GroupVersion.Group) { + poolCfg := &InferencePoolConfig{ + Mgr: c.mgr, + ControllerName: wellknown.GatewayControllerName, + InferenceExt: new(deployer.InferenceExtInfo), + } + if err := NewBaseInferencePoolController(ctx, poolCfg, &gwCfg); err != nil { + setupLog.Error(err, "unable to create inferencepool controller") + return err + } + } + return c.mgr.Start(ctx) } diff --git a/internal/kgateway/crds/inferencepools.yaml b/internal/kgateway/crds/inferencepools.yaml new file mode 100644 index 00000000000..9e6473b9e20 --- /dev/null +++ b/internal/kgateway/crds/inferencepools.yaml @@ -0,0 +1,206 @@ +--- +apiVersion: apiextensions.k8s.io/v1 +kind: CustomResourceDefinition +metadata: + annotations: + controller-gen.kubebuilder.io/version: v0.16.1 + name: inferencepools.inference.networking.x-k8s.io +spec: + group: inference.networking.x-k8s.io + names: + kind: InferencePool + listKind: InferencePoolList + plural: inferencepools + singular: inferencepool + scope: Namespaced + versions: + - name: v1alpha1 + schema: + openAPIV3Schema: + description: InferencePool is the Schema for the InferencePools API. + properties: + apiVersion: + description: |- + APIVersion defines the versioned schema of this representation of an object. + Servers should convert recognized schemas to the latest internal value, and + may reject unrecognized values. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#resources + type: string + kind: + description: |- + Kind is a string value representing the REST resource this object represents. + Servers may infer this from the endpoint the client submits requests to. + Cannot be updated. + In CamelCase. + More info: https://git.k8s.io/community/contributors/devel/sig-architecture/api-conventions.md#types-kinds + type: string + metadata: + type: object + spec: + description: InferencePoolSpec defines the desired state of InferencePool + properties: + extensionRef: + description: Extension configures an endpoint picker as an extension + service. + properties: + failureMode: + default: FailClose + description: |- + Configures how the gateway handles the case when the extension is not responsive. + Defaults to failClose. + enum: + - FailOpen + - FailClose + type: string + group: + default: "" + description: |- + Group is the group of the referent. + When unspecified or empty string, core API group is inferred. + type: string + kind: + default: Service + description: |- + Kind is the Kubernetes resource kind of the referent. For example + "Service". + + Defaults to "Service" when not specified. + + ExternalName services can refer to CNAME DNS records that may live + outside of the cluster and as such are difficult to reason about in + terms of conformance. They also may not be safe to forward to (see + CVE-2021-25740 for more information). Implementations MUST NOT + support ExternalName Services. + type: string + name: + description: Name is the name of the referent. + type: string + targetPortNumber: + description: |- + The port number on the pods running the extension. When unspecified, implementations SHOULD infer a + default value of 9002 when the Kind is Service. + format: int32 + maximum: 65535 + minimum: 1 + type: integer + required: + - name + type: object + selector: + additionalProperties: + description: |- + LabelValue is the value of a label. This is used for validation + of maps. This matches the Kubernetes label validation rules: + * must be 63 characters or less (can be empty), + * unless empty, must begin and end with an alphanumeric character ([a-z0-9A-Z]), + * could contain dashes (-), underscores (_), dots (.), and alphanumerics between. + + Valid values include: + + * MyValue + * my.name + * 123-my-value + maxLength: 63 + minLength: 0 + pattern: ^(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?$ + type: string + description: |- + Selector defines a map of labels to watch model server pods + that should be included in the InferencePool. + In some cases, implementations may translate this field to a Service selector, so this matches the simple + map used for Service selectors instead of the full Kubernetes LabelSelector type. + type: object + targetPortNumber: + description: |- + TargetPortNumber defines the port number to access the selected model servers. + The number must be in the range 1 to 65535. + format: int32 + maximum: 65535 + minimum: 1 + type: integer + required: + - extensionRef + - selector + - targetPortNumber + type: object + status: + description: InferencePoolStatus defines the observed state of InferencePool + properties: + conditions: + default: + - lastTransitionTime: "1970-01-01T00:00:00Z" + message: Waiting for controller + reason: Pending + status: Unknown + type: Ready + description: |- + Conditions track the state of the InferencePool. + + Known condition types are: + + * "Ready" + items: + description: Condition contains details for one aspect of the current + state of this API Resource. + properties: + lastTransitionTime: + description: |- + lastTransitionTime is the last time the condition transitioned from one status to another. + This should be when the underlying condition changed. If that is not known, then using the time when the API field changed is acceptable. + format: date-time + type: string + message: + description: |- + message is a human readable message indicating details about the transition. + This may be an empty string. + maxLength: 32768 + type: string + observedGeneration: + description: |- + observedGeneration represents the .metadata.generation that the condition was set based upon. + For instance, if .metadata.generation is currently 12, but the .status.conditions[x].observedGeneration is 9, the condition is out of date + with respect to the current state of the instance. + format: int64 + minimum: 0 + type: integer + reason: + description: |- + reason contains a programmatic identifier indicating the reason for the condition's last transition. + Producers of specific condition types may define expected values and meanings for this field, + and whether the values are considered a guaranteed API. + The value should be a CamelCase string. + This field may not be empty. + maxLength: 1024 + minLength: 1 + pattern: ^[A-Za-z]([A-Za-z0-9_,:]*[A-Za-z0-9_])?$ + type: string + status: + description: status of the condition, one of True, False, Unknown. + enum: + - "True" + - "False" + - Unknown + type: string + type: + description: type of condition in CamelCase or in foo.example.com/CamelCase. + maxLength: 316 + pattern: ^([a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*/)?(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])$ + type: string + required: + - lastTransitionTime + - message + - reason + - status + - type + type: object + maxItems: 8 + type: array + x-kubernetes-list-map-keys: + - type + x-kubernetes-list-type: map + type: object + type: object + served: true + storage: true + subresources: + status: {} diff --git a/internal/kgateway/deployer/deployer.go b/internal/kgateway/deployer/deployer.go index 3daa6c4184e..7c95a36de6e 100644 --- a/internal/kgateway/deployer/deployer.go +++ b/internal/kgateway/deployer/deployer.go @@ -8,15 +8,16 @@ import ( "io" "io/fs" "path/filepath" + "slices" "github.com/rotisserie/eris" - "golang.org/x/exp/slices" "helm.sh/helm/v3/pkg/action" "helm.sh/helm/v3/pkg/chart" "helm.sh/helm/v3/pkg/chart/loader" "helm.sh/helm/v3/pkg/storage" "helm.sh/helm/v3/pkg/storage/driver" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" "k8s.io/apimachinery/pkg/runtime" @@ -25,6 +26,7 @@ import ( "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/log" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" api "sigs.k8s.io/gateway-api/apis/v1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -62,25 +64,39 @@ type AwsInfo struct { StsUri string } -// Inputs is the set of options used to configure the gateway deployer deployment +// InferenceExtInfo defines the runtime state of Gateway API inference extensions. +type InferenceExtInfo struct{} + +// Inputs is the set of options used to configure the deployer deployment type Inputs struct { ControllerName string Dev bool IstioIntegrationEnabled bool ControlPlane ControlPlaneInfo Aws *AwsInfo + InferenceExtension *InferenceExtInfo } -// NewDeployer creates a new gateway deployer +// NewDeployer creates a new gateway deployer. +// TODO [danehans]: Reloading the chart for every reconciliation is inefficient. +// See https://github.com/kgateway-dev/kgateway/issues/10672 for details. func NewDeployer(cli client.Client, inputs *Inputs) (*Deployer, error) { if inputs == nil { return nil, NilDeployerInputsErr } - helmChart, err := loadFs(helm.KgatewayHelmChart) - if err != nil { - return nil, err + var err error + helmChart := new(chart.Chart) + if inputs.InferenceExtension == nil { + if helmChart, err = loadFs(helm.KgatewayHelmChart); err != nil { + return nil, err + } + } else { + if helmChart, err = loadFs(helm.InferenceExtensionHelmChart); err != nil { + return nil, err + } } + // simulate what `helm package` in the Makefile does if version.Version != version.UndefinedVersion { helmChart.Metadata.AppVersion = version.Version @@ -109,12 +125,7 @@ func (d *Deployer) GetGvksToWatch(ctx context.Context) ([]schema.GroupVersionKin // _slightly_ more dynamic way of getting the GVKs. It isn't a perfect solution since if // we add more resources to the helm chart that are gated by a flag, we may forget to // update the values here to enable them. - emptyGw := &api.Gateway{ - ObjectMeta: metav1.ObjectMeta{ - Name: "default", - Namespace: "default", - }, - } + // TODO(Law): these must be set explicitly as we don't have defaults for them // and the internal template isn't robust enough. // This should be empty eventually -- the template must be resilient against nil-pointers @@ -128,7 +139,16 @@ func (d *Deployer) GetGvksToWatch(ctx context.Context) ([]schema.GroupVersionKin }, } - objs, err := d.renderChartToObjects(emptyGw, vals) + if d.inputs.InferenceExtension != nil { + vals = map[string]any{ + "inferenceExtension": map[string]any{ + "endpointPicker": map[string]any{}, + }, + } + } + + // The namespace and name do not matter since we only care about the GVKs of the rendered resources. + objs, err := d.renderChartToObjects("default", "default", vals) if err != nil { return nil, err } @@ -152,14 +172,14 @@ func jsonConvert(in *helmConfig, out interface{}) error { return json.Unmarshal(b, out) } -func (d *Deployer) renderChartToObjects(gw *api.Gateway, vals map[string]any) ([]client.Object, error) { - objs, err := d.Render(gw.Name, gw.Namespace, vals) +func (d *Deployer) renderChartToObjects(ns, name string, vals map[string]any) ([]client.Object, error) { + objs, err := d.Render(name, ns, vals) if err != nil { return nil, err } for _, obj := range objs { - obj.SetNamespace(gw.Namespace) + obj.SetNamespace(ns) } return objs, nil @@ -365,6 +385,24 @@ func (d *Deployer) getValues(gw *api.Gateway, gwParam *v1alpha1.GatewayParameter return vals, nil } +func (d *Deployer) getInferExtVals(pool *infextv1a1.InferencePool) (*helmConfig, error) { + if d.inputs.InferenceExtension == nil { + return nil, fmt.Errorf("inference extension input not defined for deployer") + } + + // construct the default values + vals := &helmConfig{ + InferenceExtension: &helmInferenceExtension{ + EndpointPicker: &helmEndpointPickerExtension{ + PoolName: pool.Name, + PoolNamespace: pool.Namespace, + }, + }, + } + + return vals, nil +} + // Render relies on a `helm install` to render the Chart with the injected values // It returns the list of Objects that are rendered, and an optional error if rendering failed, // or converting the rendered manifests to objects failed. @@ -384,14 +422,19 @@ func (d *Deployer) Render(name, ns string, vals map[string]any) ([]client.Object install.ClientOnly = true installCtx := context.Background() + chartType := "gateway" + if d.inputs.InferenceExtension != nil { + chartType = "inference extension" + } + release, err := install.RunWithContext(installCtx, d.chart, vals) if err != nil { - return nil, fmt.Errorf("failed to render helm chart for gateway %s.%s: %w", ns, name, err) + return nil, fmt.Errorf("failed to render helm chart for %s %s.%s: %w", chartType, ns, name, err) } objs, err := ConvertYAMLToObjects(d.cli.Scheme(), []byte(release.Manifest)) if err != nil { - return nil, fmt.Errorf("failed to convert helm manifest yaml to objects for gateway %s.%s: %w", ns, name, err) + return nil, fmt.Errorf("failed to convert helm manifest yaml to objects for %s %s.%s: %w", chartType, ns, name, err) } return objs, nil } @@ -432,7 +475,7 @@ func (d *Deployer) GetObjsToDeploy(ctx context.Context, gw *api.Gateway) ([]clie if err != nil { return nil, fmt.Errorf("failed to convert helm values for gateway %s.%s: %w", gw.GetNamespace(), gw.GetName(), err) } - objs, err := d.renderChartToObjects(gw, convertedVals) + objs, err := d.renderChartToObjects(gw.Namespace, gw.Name, convertedVals) if err != nil { return nil, fmt.Errorf("failed to get objects to deploy for gateway %s.%s: %w", gw.GetNamespace(), gw.GetName(), err) } @@ -451,6 +494,54 @@ func (d *Deployer) GetObjsToDeploy(ctx context.Context, gw *api.Gateway) ([]clie return objs, nil } +// GetEndpointPickerObjs renders endpoint picker objects using the helm chart. +// It builds helm values from the Gateway and its associated GatewayParameters and +// sets a flag so that the chart renders only the endpoint picker objects. +func (d *Deployer) GetEndpointPickerObjs(pool *infextv1a1.InferencePool) ([]client.Object, error) { + // Build the helm values for the inference extension. + vals, err := d.getInferExtVals(pool) + if err != nil { + return nil, err + } + + // Convert the helm values struct. + var convertedVals map[string]any + if err := jsonConvert(vals, &convertedVals); err != nil { + return nil, fmt.Errorf("failed to convert inference extension helm values: %w", err) + } + + // Use a unique release name for the endpoint picker child objects. + releaseName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + objs, err := d.Render(releaseName, pool.Namespace, convertedVals) + if err != nil { + return nil, fmt.Errorf("failed to render inference extension objects: %w", err) + } + + // Ensure that each namespaced rendered object has its namespace and ownerRef set. + for _, obj := range objs { + gvk := obj.GetObjectKind().GroupVersionKind() + if IsNamespaced(gvk) { + if obj.GetNamespace() == "" { + obj.SetNamespace(pool.Namespace) + } + obj.SetOwnerReferences([]metav1.OwnerReference{{ + APIVersion: pool.APIVersion, + Kind: pool.Kind, + Name: pool.Name, + UID: pool.UID, + Controller: ptr.To(true), + }}) + } else { + // TODO [danehans]: Not sure why a ns must be set for cluster-scoped objects: + // failed to apply object rbac.authorization.k8s.io/v1, Kind=ClusterRoleBinding + // vllm-llama2-7b-pool-endpoint-picker: Namespace parameter required. + obj.SetNamespace("") + } + } + + return objs, nil +} + func (d *Deployer) DeployObjs(ctx context.Context, objs []client.Object) error { logger := log.FromContext(ctx) for _, obj := range objs { @@ -462,6 +553,48 @@ func (d *Deployer) DeployObjs(ctx context.Context, objs []client.Object) error { return nil } +// EnsureFinalizer adds the InferencePool finalizer to the given pool if it’s not already present. +func (d *Deployer) EnsureFinalizer(ctx context.Context, pool *infextv1a1.InferencePool) error { + if slices.Contains(pool.Finalizers, wellknown.InferencePoolFinalizer) { + return nil + } + pool.Finalizers = append(pool.Finalizers, wellknown.InferencePoolFinalizer) + return d.cli.Update(ctx, pool) +} + +// CleanupClusterScopedResources deletes the ClusterRole and ClusterRoleBinding for the given pool. +// TODO [danehans]: EPP should use role and rolebinding RBAC: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 +func (d *Deployer) CleanupClusterScopedResources(ctx context.Context, pool *infextv1a1.InferencePool) error { + // The same release name as in the Helm template. + releaseName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + + // Delete the ClusterRole. + var cr rbacv1.ClusterRole + if err := d.cli.Get(ctx, client.ObjectKey{Name: releaseName}, &cr); err == nil { + if err := d.cli.Delete(ctx, &cr); err != nil { + return fmt.Errorf("failed to delete ClusterRole %s: %w", releaseName, err) + } + } + + // Delete the ClusterRoleBinding. + var crb rbacv1.ClusterRoleBinding + if err := d.cli.Get(ctx, client.ObjectKey{Name: releaseName}, &crb); err == nil { + if err := d.cli.Delete(ctx, &crb); err != nil { + return fmt.Errorf("failed to delete ClusterRoleBinding %s: %w", releaseName, err) + } + } + + return nil +} + +// IsNamespaced returns true if the resource is namespaced. +func IsNamespaced(gvk schema.GroupVersionKind) bool { + if gvk == wellknown.ClusterRoleGVK || gvk == wellknown.ClusterRoleBindingGVK { + return false + } + return true +} + func loadFs(filesystem fs.FS) (*chart.Chart, error) { var bufferedFiles []*loader.BufferedFile entries, err := fs.ReadDir(filesystem, ".") diff --git a/internal/kgateway/deployer/deployer_test.go b/internal/kgateway/deployer/deployer_test.go index 00ef56ee71c..b238a428eda 100644 --- a/internal/kgateway/deployer/deployer_test.go +++ b/internal/kgateway/deployer/deployer_test.go @@ -15,12 +15,14 @@ import ( "google.golang.org/protobuf/proto" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/util/intstr" "k8s.io/utils/ptr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" api "sigs.k8s.io/gateway-api/apis/v1" gw2_v1alpha1 "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" @@ -1439,12 +1441,133 @@ var _ = Describe("Deployer", func() { }), ) }) + + Context("Inference Extension endpoint picker", func() { + const defaultNamespace = "default" + + It("should deploy endpoint picker resources for an InferencePool", func() { + // Create a fake InferencePool resource. + pool := &infextv1a1.InferencePool{ + TypeMeta: metav1.TypeMeta{ + Kind: wellknown.InferencePoolKind, + APIVersion: fmt.Sprintf("%s/%s", infextv1a1.GroupVersion.Group, infextv1a1.GroupVersion.Version), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: "pool1", + Namespace: defaultNamespace, + UID: "pool-uid", + }, + } + + // Initialize a new deployer with InferenceExtension inputs. + d, err := deployer.NewDeployer(newFakeClientWithObjs(pool), &deployer.Inputs{ + ControllerName: wellknown.GatewayControllerName, + InferenceExtension: &deployer.InferenceExtInfo{}, + }) + Expect(err).NotTo(HaveOccurred()) + + // Simulate reconciliation so that the pool gets its finalizer added. + err = d.EnsureFinalizer(context.Background(), pool) + Expect(err).NotTo(HaveOccurred()) + + // Check that the pool itself has the finalizer set. + Expect(pool.GetFinalizers()).To(ContainElement(wellknown.InferencePoolFinalizer)) + + // Get the endpoint picker objects for the InferencePool. + objs, err := d.GetEndpointPickerObjs(pool) + Expect(err).NotTo(HaveOccurred()) + Expect(objs).NotTo(BeEmpty(), "expected non-empty objects for endpoint picker deployment") + Expect(objs).To(HaveLen(5)) + + // Find the child objects. + var sa *corev1.ServiceAccount + var clusterRole *rbacv1.ClusterRole + var crb *rbacv1.ClusterRoleBinding + var dep *appsv1.Deployment + var svc *corev1.Service + for _, obj := range objs { + switch t := obj.(type) { + case *corev1.ServiceAccount: + sa = t + case *rbacv1.ClusterRole: + clusterRole = t + case *rbacv1.ClusterRoleBinding: + crb = t + case *appsv1.Deployment: + dep = t + case *corev1.Service: + svc = t + } + } + Expect(sa).NotTo(BeNil(), "expected a ServiceAccount to be rendered") + Expect(clusterRole).NotTo(BeNil(), "expected a Role to be rendered") + Expect(crb).NotTo(BeNil(), "expected a RoleBinding to be rendered") + Expect(dep).NotTo(BeNil(), "expected a Deployment to be rendered") + Expect(svc).NotTo(BeNil(), "expected a Service to be rendered") + + // Check that owner references are set on all rendered objects to the InferencePool. + for _, obj := range objs { + gvk := obj.GetObjectKind().GroupVersionKind() + if deployer.IsNamespaced(gvk) { + ownerRefs := obj.GetOwnerReferences() + Expect(ownerRefs).To(HaveLen(1)) + ref := ownerRefs[0] + Expect(ref.Name).To(Equal(pool.Name)) + Expect(ref.UID).To(Equal(pool.UID)) + Expect(ref.Kind).To(Equal(pool.Kind)) + Expect(ref.APIVersion).To(Equal(pool.APIVersion)) + Expect(*ref.Controller).To(BeTrue()) + } + } + + // Validate that the rendered Deployment and Service have the expected names. + // (The template hardcodes the names to "inference-gateway-ext-proc".) + expectedName := fmt.Sprintf("%s-endpoint-picker", pool.Name) + Expect(sa.Name).To(Equal(expectedName)) + Expect(clusterRole.Name).To(Equal(expectedName)) + Expect(crb.Name).To(Equal(expectedName)) + Expect(dep.Name).To(Equal(expectedName)) + Expect(svc.Name).To(Equal(expectedName)) + + // Check the container args for the expected poolName. + Expect(dep.Spec.Template.Spec.Containers).To(HaveLen(1)) + pickerContainer := dep.Spec.Template.Spec.Containers[0] + Expect(pickerContainer.Args).To(Equal([]string{ + "-poolName", + pool.Name, + "-v", + "3", + "-grpcPort", + "9002", + "-grpcHealthPort", + "9003", + })) + }) + }) }) // initialize a fake controller-runtime client with the given list of objects func newFakeClientWithObjs(objs ...client.Object) client.Client { + scheme := schemes.GatewayScheme() + + // Ensure the rbac types are registered. + if err := rbacv1.AddToScheme(scheme); err != nil { + panic(fmt.Sprintf("failed to add rbacv1 scheme: %v", err)) + } + + // Check if any object is an InferencePool, and add its scheme if needed. + for _, obj := range objs { + gvk := obj.GetObjectKind().GroupVersionKind() + if gvk.Kind == wellknown.InferencePoolKind { + if err := infextv1a1.AddToScheme(scheme); err != nil { + panic(fmt.Sprintf("failed to add InferenceExtension scheme: %v", err)) + } + break + } + } + return fake.NewClientBuilder(). - WithScheme(schemes.GatewayScheme()). + WithScheme(scheme). WithObjects(objs...). Build() } diff --git a/internal/kgateway/deployer/values.go b/internal/kgateway/deployer/values.go index c3e28e8cb77..a3d19beed6e 100644 --- a/internal/kgateway/deployer/values.go +++ b/internal/kgateway/deployer/values.go @@ -8,7 +8,8 @@ import ( // The top-level helm values used by the deployer. type helmConfig struct { - Gateway *helmGateway `json:"gateway,omitempty"` + Gateway *helmGateway `json:"gateway,omitempty"` + InferenceExtension *helmInferenceExtension `json:"inferenceExtension,omitempty"` } type helmGateway struct { @@ -161,3 +162,12 @@ type helmAws struct { StsClusterName *string `json:"stsClusterName,omitempty"` StsUri *string `json:"stsUri,omitempty"` } + +type helmInferenceExtension struct { + EndpointPicker *helmEndpointPickerExtension `json:"endpointPicker,omitempty"` +} + +type helmEndpointPickerExtension struct { + PoolName string `json:"poolName"` + PoolNamespace string `json:"poolNamespace"` +} diff --git a/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/plugin.go b/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/plugin.go new file mode 100644 index 00000000000..95752f39a99 --- /dev/null +++ b/internal/kgateway/extensions2/plugins/inferenceextension/endpointpicker/plugin.go @@ -0,0 +1,400 @@ +package endpointpicker + +import ( + "context" + "fmt" + "time" + + clusterv3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + corev3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" + endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + listenerv3 "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + routev3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + extprocv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/http/ext_proc/v3" + hcmv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/filters/network/http_connection_manager/v3" + tlsv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/transport_sockets/tls/v3" + upstreamsv3 "github.com/envoyproxy/go-control-plane/envoy/extensions/upstreams/http/v3" + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/wrapperspb" + "istio.io/istio/pkg/kube/krt" + "k8s.io/apimachinery/pkg/runtime/schema" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" + + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/common" + extplug "github.com/kgateway-dev/kgateway/v2/internal/kgateway/extensions2/plugin" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/plugins" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils/krtutil" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" +) + +func NewPlugin(ctx context.Context, commonCol *common.CommonCollections) extplug.Plugin { + poolGVR := schema.GroupVersionResource{ + Group: infextv1a1.GroupVersion.Group, + Version: infextv1a1.GroupVersion.Version, + Resource: "inferencepools", + } + + poolCol := krtutil.SetupCollectionDynamic[infextv1a1.InferencePool]( + ctx, + commonCol.Client, + poolGVR, + commonCol.KrtOpts.ToOptions("InferencePools")..., + ) + + return NewPluginFromCollections(ctx, commonCol, poolCol) +} + +func NewPluginFromCollections( + ctx context.Context, + commonCol *common.CommonCollections, + poolCol krt.Collection[*infextv1a1.InferencePool], +) extplug.Plugin { + // The InferencePool group kind used by the BackendObjectIR and the ContributesBackendObjectIRs plugin. + gk := schema.GroupKind{ + Group: infextv1a1.GroupVersion.Group, + Kind: wellknown.InferencePoolKind, + } + + backendCol := krt.NewCollection(poolCol, func(kctx krt.HandlerContext, pool *infextv1a1.InferencePool) *ir.BackendObjectIR { + // Create a BackendObjectIR IR representation from the given InferencePool. + return &ir.BackendObjectIR{ + ObjectSource: ir.ObjectSource{ + Kind: gk.Kind, + Group: gk.Group, + Namespace: pool.Namespace, + Name: pool.Name, + }, + Obj: pool, + Port: pool.Spec.TargetPortNumber, + GvPrefix: "endpoint-picker", + CanonicalHostname: "", + ObjIr: ir.NewInferencePool(pool), + } + }, commonCol.KrtOpts.ToOptions("InferencePoolIR")...) + + policyCol := krt.NewCollection(poolCol, func(krtctx krt.HandlerContext, i *infextv1a1.InferencePool) *ir.PolicyWrapper { + // Create a PolicyWrapper IR representation from the given InferencePool. + return &ir.PolicyWrapper{ + ObjectSource: ir.ObjectSource{ + Group: gk.Group, + Kind: gk.Kind, + Namespace: i.Namespace, + Name: i.Name, + }, + Policy: i, + PolicyIR: ir.NewInferencePool(i), + } + }) + + // Return a plugin that contributes a policy and backend. + return extplug.Plugin{ + ContributesBackends: map[schema.GroupKind]extplug.BackendPlugin{ + gk: { + Backends: backendCol, + BackendInit: ir.BackendInit{ + InitBackend: processBackendObjectIR, + }, + }, + }, + ContributesPolicies: map[schema.GroupKind]extplug.PolicyPlugin{ + gk: { + Name: "endpointpicker", + Policies: policyCol, + NewGatewayTranslationPass: newEndpointPickerPass, + }, + }, + } +} + +// processBackendObjectIR processes the given BackendObjectIR into an Envoy cluster. +func processBackendObjectIR(ctx context.Context, in ir.BackendObjectIR, out *clusterv3.Cluster) { + // Large timeout based on upstream working config. + out.ConnectTimeout = durationpb.New(1000 * time.Second) + + // Set the cluster type to ORIGINAL_DST. + out.ClusterDiscoveryType = &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_ORIGINAL_DST, + } + out.LbPolicy = clusterv3.Cluster_CLUSTER_PROVIDED + + // Use the headers added by the endpoint picker extension. + out.LbConfig = &clusterv3.Cluster_OriginalDstLbConfig_{ + OriginalDstLbConfig: &clusterv3.Cluster_OriginalDstLbConfig{ + UseHttpHeader: true, + HttpHeaderName: "x-gateway-destination-endpoint", + }, + } + + // Circuit breakers based on upstream working config. + out.CircuitBreakers = &clusterv3.CircuitBreakers{ + Thresholds: []*clusterv3.CircuitBreakers_Thresholds{ + { + MaxConnections: wrapperspb.UInt32(40000), + MaxPendingRequests: wrapperspb.UInt32(40000), + MaxRequests: wrapperspb.UInt32(40000), + }, + }, + } + + out.Name = clusterNameOriginalDst(in.Name, in.Namespace) +} + +// endpointPickerPass implements ir.ProxyTranslationPass. It collects any references to InferencePools, +// then in ResourcesToAdd() returns both the “ext_proc” cluster (STRICT_DNS) and “original_dst” cluster (ORIGINAL_DST). +type endpointPickerPass struct { + usedPool *ir.InferencePool +} + +func newEndpointPickerPass(ctx context.Context, tctx ir.GwTranslationCtx) ir.ProxyTranslationPass { + return &endpointPickerPass{ + usedPool: new(ir.InferencePool), + } +} + +func (p *endpointPickerPass) Name() string { + return "endpoint-picker" +} + +// No-op for these standard pass methods +func (p *endpointPickerPass) ApplyListenerPlugin(ctx context.Context, lctx *ir.ListenerContext, out *listenerv3.Listener) { +} +func (p *endpointPickerPass) ApplyHCM(ctx context.Context, hctx *ir.HcmContext, out *hcmv3.HttpConnectionManager) error { + return nil +} +func (p *endpointPickerPass) NetworkFilters(ctx context.Context) ([]plugins.StagedNetworkFilter, error) { + return nil, nil +} +func (p *endpointPickerPass) UpstreamHttpFilters(ctx context.Context) ([]plugins.StagedUpstreamHttpFilter, error) { + return nil, nil +} +func (p *endpointPickerPass) ApplyVhostPlugin(ctx context.Context, vctx *ir.VirtualHostContext, out *routev3.VirtualHost) { +} +func (p *endpointPickerPass) ApplyForRoute(ctx context.Context, rctx *ir.RouteContext, out *routev3.Route) error { + return nil +} +func (p *endpointPickerPass) ApplyRouteConfigPlugin( + ctx context.Context, + pCtx *ir.RouteConfigContext, + out *routev3.RouteConfiguration, +) { +} +func (p *endpointPickerPass) ApplyForRouteBackend( + ctx context.Context, + policy ir.PolicyIR, + pCtx *ir.RouteBackendContext, +) error { + return nil +} + +func (p *endpointPickerPass) ApplyForBackend( + ctx context.Context, + pCtx *ir.RouteBackendContext, + in ir.HttpBackend, + out *routev3.Route, +) error { + // Check if the backend’s Group/Kind matches your InferencePool GVK. + if pCtx.Backend == nil { + return fmt.Errorf("unexpected nil route backend") + } + if pCtx.Backend.Group != infextv1a1.GroupVersion.Group { + return fmt.Errorf("unexpected group for route backend; expected %s and found %s", infextv1a1.GroupVersion.Group, pCtx.Backend.Group) + } + if pCtx.Backend.Kind != wellknown.InferencePoolKind { + return fmt.Errorf("unexpected kind for route backend; expected %s and found %s", wellknown.InferencePoolKind, pCtx.Backend.Kind) + } + + // Cast the underlying backend object. + pool, ok := pCtx.Backend.Obj.(*infextv1a1.InferencePool) + if !ok || pool == nil { + return fmt.Errorf("unexpected backend object") + } + + // Store the pool to build clusters in ResourcesToAdd. + irPool := ir.NewInferencePool(pool) + p.usedPool = irPool + + // Add things which require basic EPP backend. + if out.GetRoute() == nil { + out.Action = &routev3.Route_Route{Route: &routev3.RouteAction{}} + } + + // Create the ext_proc per-route override + override := &extprocv3.ExtProcPerRoute{ + Override: &extprocv3.ExtProcPerRoute_Overrides{ + Overrides: &extprocv3.ExtProcOverrides{ + GrpcService: &corev3.GrpcService{ + Timeout: durationpb.New(10 * time.Second), + TargetSpecifier: &corev3.GrpcService_EnvoyGrpc_{ + EnvoyGrpc: &corev3.GrpcService_EnvoyGrpc{ + ClusterName: clusterNameExtProc( + irPool.ObjMeta.GetName(), + irPool.ObjMeta.GetNamespace(), + ), + Authority: fmt.Sprintf("%s.%s.svc.cluster.local:%d", + irPool.ConfigRef.Name, + irPool.ObjMeta.GetNamespace(), + irPool.ConfigRef.Ports[0].PortNum), + }, + }, + }, + }, + }, + } + + // Attach to typed_per_filter_config, referencing the same filter name used in the HCM. + pCtx.AddTypedConfig(wellknown.InfPoolBackendTransformationFilterName, override) + + // Override the route's cluster to point to the ORIGINAL_DST cluster + originalDstClusterName := clusterNameOriginalDst(irPool.ObjMeta.GetName(), irPool.ObjMeta.GetNamespace()) + out.GetRoute().ClusterSpecifier = &routev3.RouteAction_Cluster{ + Cluster: originalDstClusterName, + } + + return nil +} + +// HttpFilters inserts one ext_proc filter at the top-level. +func (p *endpointPickerPass) HttpFilters(ctx context.Context, fc ir.FilterChainCommon) ([]plugins.StagedHttpFilter, error) { + if p.usedPool == nil { + return nil, fmt.Errorf("unexpected nil usedPools") + } + if p.usedPool.ConfigRef == nil { + return nil, fmt.Errorf("unexpected nil usedPool ConfigRef") + } + + pool := p.usedPool + clusterName := clusterNameExtProc(pool.ObjMeta.GetName(), pool.ObjMeta.GetNamespace()) + authority := fmt.Sprintf("%s.%s:%d", pool.ConfigRef.Name, pool.ObjMeta.Namespace, pool.ConfigRef.Ports[0].PortNum) + + return AddEndpointPickerHTTPFilter(clusterName, authority) +} + +// AddEndpointPickerHTTPFilter returns a top-level ext_proc filter that references +// the cluster built in ResourcesToAdd(). This filter gets placed in the HCM's http_filters array. +func AddEndpointPickerHTTPFilter(clusterName, authority string) ([]plugins.StagedHttpFilter, error) { + var filters []plugins.StagedHttpFilter + + // This is the top-level ext_proc filter config + extProcSettings := &extprocv3.ExternalProcessor{ + GrpcService: &corev3.GrpcService{ + TargetSpecifier: &corev3.GrpcService_EnvoyGrpc_{ + EnvoyGrpc: &corev3.GrpcService_EnvoyGrpc{ + ClusterName: clusterName, + Authority: authority, + }, + }, + }, + ProcessingMode: &extprocv3.ProcessingMode{ + RequestHeaderMode: extprocv3.ProcessingMode_SEND, + RequestBodyMode: extprocv3.ProcessingMode_BUFFERED, + ResponseHeaderMode: extprocv3.ProcessingMode_SKIP, + RequestTrailerMode: extprocv3.ProcessingMode_SKIP, + ResponseTrailerMode: extprocv3.ProcessingMode_SKIP, + }, + MessageTimeout: durationpb.New(5 * time.Second), + FailureModeAllow: false, + } + + stagedFilter, err := plugins.NewStagedFilter( + wellknown.InfPoolBackendTransformationFilterName, // Filters must have a unique name. + extProcSettings, + plugins.BeforeStage(plugins.RouteStage), + ) + if err != nil { + return nil, err + } + filters = append(filters, stagedFilter) + + return filters, nil +} + +// ResourcesToAdd is called one time (per envoy proxy) and replaces GeneratedResources +// with the returned cluster resources. +func (p *endpointPickerPass) ResourcesToAdd(ctx context.Context) ir.Resources { + // Build an ext-proc cluster per InferencePool + return ir.Resources{Clusters: []*clusterv3.Cluster{buildExtProcCluster(p.usedPool)}} +} + +// buildExtProcCluster returns a “STRICT_DNS” cluster using the host/port from InferencePool.Spec.ExtensionRef +func buildExtProcCluster(pool *ir.InferencePool) *clusterv3.Cluster { + if pool.ConfigRef == nil || len(pool.ConfigRef.Ports) != 1 { + return nil + } + + name := clusterNameExtProc(pool.ObjMeta.GetName(), pool.ObjMeta.GetNamespace()) + c := &clusterv3.Cluster{ + Name: name, + ConnectTimeout: durationpb.New(10 * time.Second), + ClusterDiscoveryType: &clusterv3.Cluster_Type{ + Type: clusterv3.Cluster_STRICT_DNS, + }, + LbPolicy: clusterv3.Cluster_LEAST_REQUEST, + LoadAssignment: &endpointv3.ClusterLoadAssignment{ + ClusterName: name, + Endpoints: []*endpointv3.LocalityLbEndpoints{{ + LbEndpoints: []*endpointv3.LbEndpoint{{ + HealthStatus: corev3.HealthStatus_HEALTHY, + HostIdentifier: &endpointv3.LbEndpoint_Endpoint{ + Endpoint: &endpointv3.Endpoint{ + Address: &corev3.Address{ + Address: &corev3.Address_SocketAddress{ + SocketAddress: &corev3.SocketAddress{ + Address: fmt.Sprintf("%s.%s.svc.cluster.local", pool.ConfigRef.Name, pool.ObjMeta.Namespace), + Protocol: corev3.SocketAddress_TCP, + PortSpecifier: &corev3.SocketAddress_PortValue{ + PortValue: uint32(pool.ConfigRef.Ports[0].PortNum), + }, + }, + }, + }, + }, + }, + }}, + }}, + }, + // Ensure Envoy accepts untrusted certificates. + TransportSocket: &corev3.TransportSocket{ + Name: "envoy.transport_sockets.tls", + ConfigType: &corev3.TransportSocket_TypedConfig{ + TypedConfig: func() *anypb.Any { + tlsCtx := &tlsv3.UpstreamTlsContext{ + CommonTlsContext: &tlsv3.CommonTlsContext{ + ValidationContextType: &tlsv3.CommonTlsContext_ValidationContext{}, + }, + } + anyTLS, _ := anypb.New(tlsCtx) + return anyTLS + }(), + }, + }, + } + + http2Opts := &upstreamsv3.HttpProtocolOptions{ + UpstreamProtocolOptions: &upstreamsv3.HttpProtocolOptions_ExplicitHttpConfig_{ + ExplicitHttpConfig: &upstreamsv3.HttpProtocolOptions_ExplicitHttpConfig{ + ProtocolConfig: &upstreamsv3.HttpProtocolOptions_ExplicitHttpConfig_Http2ProtocolOptions{ + Http2ProtocolOptions: &corev3.Http2ProtocolOptions{}, + }, + }, + }, + } + + // Marshall the HttpProtocolOptions proto message. + anyHttp2, _ := utils.MessageToAny(http2Opts) + c.TypedExtensionProtocolOptions = map[string]*anypb.Any{ + "envoy.extensions.upstreams.http.v3.HttpProtocolOptions": anyHttp2, + } + + return c +} + +func clusterNameExtProc(name, ns string) string { + return fmt.Sprintf("endpointpicker_%s_%s_ext_proc", name, ns) +} + +func clusterNameOriginalDst(name, ns string) string { + return fmt.Sprintf("endpointpicker_%s_%s_original_dst", name, ns) +} diff --git a/internal/kgateway/helm/embed.go b/internal/kgateway/helm/embed.go index 56b2589d39d..c00e874bdaa 100644 --- a/internal/kgateway/helm/embed.go +++ b/internal/kgateway/helm/embed.go @@ -4,5 +4,10 @@ import ( "embed" ) -//go:embed all:kgateway -var KgatewayHelmChart embed.FS +var ( + //go:embed all:kgateway + KgatewayHelmChart embed.FS + + //go:embed all:inference-extension + InferenceExtensionHelmChart embed.FS +) diff --git a/internal/kgateway/helm/inference-extension/.helmignore b/internal/kgateway/helm/inference-extension/.helmignore new file mode 100644 index 00000000000..ede6884aa6b --- /dev/null +++ b/internal/kgateway/helm/inference-extension/.helmignore @@ -0,0 +1,30 @@ +# Patterns to ignore when building packages. +# This supports shell glob matching, relative path matching, and +# negation (prefixed with !). Only one pattern per line. +.DS_Store +# Common VCS dirs +.git/ +.gitignore +.bzr/ +.bzrignore +.hg/ +.hgignore +.svn/ +# Common backup files +*.swp +*.bak +*.tmp +*.orig +*~ +# Various IDEs +.project +.idea/ +*.tmproj +.vscode/ + +# template files +*-template.yaml + +# generator files +*.go +generate/ diff --git a/internal/kgateway/helm/inference-extension/Chart.yaml b/internal/kgateway/helm/inference-extension/Chart.yaml new file mode 100644 index 00000000000..4476b5e761f --- /dev/null +++ b/internal/kgateway/helm/inference-extension/Chart.yaml @@ -0,0 +1,24 @@ +apiVersion: v2 +name: inference-extension +description: A Helm chart for managing Gateway API Inference Extensions + +# A chart can be either an 'application' or a 'library' chart. +# +# Application charts are a collection of templates that can be packaged into versioned archives +# to be deployed. +# +# Library charts provide useful utilities or functions for the chart developer. They're included as +# a dependency of application charts to inject those utilities and functions into the rendering +# pipeline. Library charts do not define any templates and therefore cannot be deployed. +type: application + +# This is the chart version. This version number should be incremented each time you make changes +# to the chart and its templates, including the app version. +# Versions are expected to follow Semantic Versioning (https://semver.org/) +version: 0.0.1-alpha1 + +# This is the version number of the application being deployed. This version number should be +# incremented each time you make changes to the application. Versions are not expected to +# follow Semantic Versioning. They should reflect the version the application is using. +# It is recommended to use it with quotes. +appVersion: "0.1.0" diff --git a/internal/kgateway/helm/inference-extension/templates/_helpers.tpl b/internal/kgateway/helm/inference-extension/templates/_helpers.tpl new file mode 100644 index 00000000000..3282a1b2a3c --- /dev/null +++ b/internal/kgateway/helm/inference-extension/templates/_helpers.tpl @@ -0,0 +1,6 @@ +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "inference-extension.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} diff --git a/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml b/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml new file mode 100644 index 00000000000..6052f4bd14f --- /dev/null +++ b/internal/kgateway/helm/inference-extension/templates/endpoint-picker/resources.yaml @@ -0,0 +1,116 @@ +{{- $endpointPicker := .Values.inferenceExtension.endpointPicker }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ .Release.Name }} +--- + # TODO [danehans]: EPP should use Role and RoleBinding resources: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRole +metadata: + name: {{ .Release.Name }} +rules: +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencemodels"] + verbs: ["get", "watch", "list"] +- apiGroups: [""] + resources: ["pods"] + verbs: ["get", "watch", "list"] +- apiGroups: ["inference.networking.x-k8s.io"] + resources: ["inferencepools"] + verbs: ["get", "watch", "list"] +- apiGroups: ["discovery.k8s.io"] + resources: ["endpointslices"] + verbs: ["get", "watch", "list"] + # TODO [danehans]: Unsure why the following rules are needed: https://github.com/kubernetes-sigs/gateway-api-inference-extension/issues/224 +- apiGroups: + - authentication.k8s.io + resources: + - tokenreviews + verbs: + - create +- apiGroups: + - authorization.k8s.io + resources: + - subjectaccessreviews + verbs: + - create +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: ClusterRoleBinding +metadata: + name: {{ .Release.Name }} +subjects: +- kind: ServiceAccount + name: {{ .Release.Name }} + namespace: {{ $endpointPicker.poolNamespace }} +roleRef: + kind: ClusterRole + name: {{ .Release.Name }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ .Release.Name }} + labels: + app.kubernetes.io/component: endpoint-picker + app.kubernetes.io/name: {{ .Release.Name }} + app.kubernetes.io/instance: kgateway +spec: + replicas: 1 + selector: + matchLabels: + app: {{ .Release.Name }} + template: + metadata: + labels: + app: {{ .Release.Name }} + spec: + serviceAccountName: {{ .Release.Name }} + containers: + - name: endpoint-picker + args: + - -poolName + - {{ $endpointPicker.poolName }} + - -v + - "3" + - -grpcPort + - "9002" + - -grpcHealthPort + - "9003" + image: "registry.k8s.io/gateway-api-inference-extension/epp:v0.1.0" + imagePullPolicy: IfNotPresent + ports: + - containerPort: 9002 + - containerPort: 9003 + - name: metrics + containerPort: 9090 + livenessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 + readinessProbe: + grpc: + port: 9003 + service: inference-extension + initialDelaySeconds: 5 + periodSeconds: 10 +--- +apiVersion: v1 +kind: Service +metadata: + name: {{ .Release.Name }} + labels: + app.kubernetes.io/component: endpoint-picker + app.kubernetes.io/name: {{ .Release.Name }} + app.kubernetes.io/instance: kgateway +spec: + selector: + app: {{ .Release.Name }} + ports: + - protocol: TCP + port: 9002 + targetPort: 9002 + type: ClusterIP diff --git a/internal/kgateway/helm/inference-extension/values.yaml b/internal/kgateway/helm/inference-extension/values.yaml new file mode 100644 index 00000000000..6cb93c4c50a --- /dev/null +++ b/internal/kgateway/helm/inference-extension/values.yaml @@ -0,0 +1,12 @@ +# These values represent configurable values for the dynamic inference extension chart +# They are not intended to be actual "defaults," rather they are just placeholder values +# meant to allow rendering of the chart/template, as the real values will come from: +# * The `InferencePool` resource driving the inference extension provisioning +# * A (possibly merged) GatewayParameters object translated to helm values +# The actual defaults for these values should come from the "default GatewayParameters" object +# See: (install/helm/kgateway/templates/gatewayparameters.yaml) + +inferenceExtension: + endpointPicker: + poolName: default + poolNamespace: default diff --git a/internal/kgateway/ir/inferencepool.go b/internal/kgateway/ir/inferencepool.go new file mode 100644 index 00000000000..f4df4f8c0a1 --- /dev/null +++ b/internal/kgateway/ir/inferencepool.go @@ -0,0 +1,79 @@ +package ir + +import ( + "maps" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" +) + +type InferencePool struct { + ObjMeta metav1.ObjectMeta + // PodSelector is a label selector to select Pods that are members of the InferencePool. + PodSelector map[string]string + // TargetPort is the port number that should be targeted for Pods selected by Selector. + TargetPort int32 + // ConfigRef is a reference to the extension configuration. A ConfigRef is typically implemented + // as a Kubernetes Service resource. + ConfigRef *Service +} + +func NewInferencePool(pool *infextv1a1.InferencePool) *InferencePool { + if pool == nil || pool.Spec.ExtensionRef == nil { + return nil + } + + port := ServicePort{Name: "grpc", PortNum: (int32(9002))} + if pool.Spec.ExtensionRef.TargetPortNumber != nil { + port.PortNum = *pool.Spec.ExtensionRef.TargetPortNumber + } + + svcIR := &Service{ + ObjectSource: ObjectSource{ + Group: "", + Kind: "Service", + Namespace: pool.Namespace, + Name: pool.Spec.ExtensionRef.Name, + }, + Obj: pool, + Ports: []ServicePort{port}, + } + + return &InferencePool{ + ObjMeta: pool.ObjectMeta, + PodSelector: convertSelector(pool.Spec.Selector), + TargetPort: pool.Spec.TargetPortNumber, + ConfigRef: svcIR, + } +} + +// In case multiple pools attached to the same resource, we sort by creation time. +func (ir *InferencePool) CreationTime() time.Time { + return ir.ObjMeta.CreationTimestamp.Time +} + +func (ir *InferencePool) Selector() map[string]string { + if ir.PodSelector == nil { + return nil + } + return ir.PodSelector +} + +func (ir *InferencePool) Equals(other any) bool { + otherPool, ok := other.(*InferencePool) + if !ok { + return false + } + return maps.EqualFunc(ir.Selector(), otherPool.Selector(), func(a, b string) bool { + return a == b + }) +} + +func convertSelector(selector map[infextv1a1.LabelKey]infextv1a1.LabelValue) map[string]string { + result := make(map[string]string, len(selector)) + for k, v := range selector { + result[string(k)] = string(v) + } + return result +} diff --git a/internal/kgateway/ir/service.go b/internal/kgateway/ir/service.go new file mode 100644 index 00000000000..29caae2ae3c --- /dev/null +++ b/internal/kgateway/ir/service.go @@ -0,0 +1,62 @@ +package ir + +import ( + "encoding/json" + + "istio.io/istio/pkg/kube/krt" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + // GRPCPort is the default port number for the gRPC service. + GRPCPort = 9002 +) + +// Service defines an internal representation of a service. +type Service struct { + // ObjectSource is a reference to the source object. Sometimes the group and kind are not + // populated from api-server, so set them explicitly here, and pass this around as the reference. + ObjectSource `json:",inline"` + + // Obj is the original object. Opaque to us other than metadata. + Obj metav1.Object + + // Ports is a list of ports exposed by the service. + Ports []ServicePort +} + +// ServicePort is an exposed post of a service. +type ServicePort struct { + // Name is the name of the port. + Name string + // PortNum is the port number used to expose the service port. + PortNum int32 +} + +func (r Service) ResourceName() string { + return r.ObjectSource.ResourceName() +} + +func (r Service) Equals(in Service) bool { + return r.ObjectSource.Equals(in.ObjectSource) && versionEquals(r.Obj, in.Obj) +} + +var _ krt.ResourceNamer = Service{} +var _ krt.Equaler[Service] = Service{} +var _ json.Marshaler = Service{} + +func (l Service) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Group string + Kind string + Name string + Namespace string + Ports []ServicePort + }{ + Group: l.Group, + Kind: l.Kind, + Namespace: l.Namespace, + Name: l.Name, + Ports: l.Ports, + }) +} diff --git a/internal/kgateway/krtcollections/policy.go b/internal/kgateway/krtcollections/policy.go index 3fc3b9e6f4f..330f0f8999f 100644 --- a/internal/kgateway/krtcollections/policy.go +++ b/internal/kgateway/krtcollections/policy.go @@ -8,6 +8,7 @@ import ( "istio.io/istio/pkg/kube/krt" "k8s.io/apimachinery/pkg/runtime/schema" "k8s.io/apimachinery/pkg/types" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" gwv1 "sigs.k8s.io/gateway-api/apis/v1" gwv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gwv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" @@ -16,6 +17,7 @@ import ( "github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/translator/backendref" "github.com/kgateway-dev/kgateway/v2/internal/kgateway/utils/krtutil" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" ) var ( @@ -646,12 +648,24 @@ func (h *RoutesIndex) resolveExtension(kctx krt.HandlerContext, ns string, ext g } func toFromBackendRef(fromns string, ref gwv1.BackendObjectReference) ir.ObjectSource { - return ir.ObjectSource{ + // Defaults to Service kind for returned ObjectSource. + ret := ir.ObjectSource{ Group: strOr(ref.Group, ""), - Kind: strOr(ref.Kind, "Service"), + Kind: strOr(ref.Kind, wellknown.ServiceKind), Namespace: strOr(ref.Namespace, fromns), Name: string(ref.Name), } + + // Change to the the InferencePool group/kind if needed. + if ref.Group != nil && + *ref.Group == gwv1.Group(infextv1a1.GroupVersion.Group) && + ref.Kind != nil && + *ref.Kind == wellknown.InferencePoolKind { + ret.Group = infextv1a1.GroupVersion.Group + ret.Kind = wellknown.InferencePoolKind + } + + return ret } func (h *RoutesIndex) getBackends(kctx krt.HandlerContext, src ir.ObjectSource, backendRefs []gwv1.HTTPBackendRef) []ir.HttpBackendOrDelegate { diff --git a/internal/kgateway/krtcollections/services.go b/internal/kgateway/krtcollections/services.go new file mode 100644 index 00000000000..0c13ade73be --- /dev/null +++ b/internal/kgateway/krtcollections/services.go @@ -0,0 +1,71 @@ +package krtcollections + +import ( + "istio.io/istio/pkg/kube/krt" + "k8s.io/apimachinery/pkg/runtime/schema" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" + + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/ir" + "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" +) + +type ServiceIndex struct { + services map[schema.GroupKind]krt.Collection[ir.Service] +} + +func NewServiceIndex(services map[schema.GroupKind]krt.Collection[ir.Service]) *ServiceIndex { + return &ServiceIndex{services: services} +} + +func (s *ServiceIndex) HasSynced() bool { + for _, col := range s.services { + if !col.HasSynced() { + return false + } + } + return true +} + +func (s *ServiceIndex) GetSvcForInferPool(kctx krt.HandlerContext, pool *infextv1a1.InferencePool) *ir.Service { + if pool == nil || pool.Spec.ExtensionRef == nil { + // TODO [danehans]: Add logging + return nil + } + + refGroup := "" + ref := *pool.Spec.ExtensionRef + if ref.Group != nil && *ref.Group != "" { + // TODO [danehans]: Add logging + return nil + } + refKind := wellknown.ServiceKind + if ref.Kind != nil && *ref.Kind != "" { + // TODO [danehans]: Add logging + return nil + } + + // Get the krt Service collection + gk := schema.GroupKind{Group: refGroup, Kind: *ref.Kind} + col := s.services[gk] + if col == nil { + // TODO [danehans]: Add logging + return nil + } + + // Create the object source used for filtering by name when fetching Services from the collection. + src := ir.ObjectSource{ + Group: refGroup, + Kind: refKind, + Namespace: pool.Namespace, + Name: pool.Name, + } + + // Fetch the Service from the krt collection, filtering based on object source name. + ret := krt.FetchOne(kctx, col, krt.FilterKey(src.ResourceName())) + if ret == nil { + // TODO [danehans]: Add logging + return nil + } + + return ret +} diff --git a/internal/kgateway/wellknown/constants.go b/internal/kgateway/wellknown/constants.go index 781f7bcee99..ceb9fb09d37 100644 --- a/internal/kgateway/wellknown/constants.go +++ b/internal/kgateway/wellknown/constants.go @@ -28,8 +28,9 @@ const ( ) const ( - AIBackendTransformationFilterName = "ai.backend.transformation.kgateway.io" - AIPolicyTransformationFilterName = "ai.policy.transformation.kgateway.io" - AIExtProcFilterName = "ai.extproc.kgateway.io" - SetMetadataFilterName = "envoy.filters.http.set_filter_state" + InfPoolBackendTransformationFilterName = "inferencepool.backend.transformation.kgateway.io" + AIBackendTransformationFilterName = "ai.backend.transformation.kgateway.io" + AIPolicyTransformationFilterName = "ai.policy.transformation.kgateway.io" + AIExtProcFilterName = "ai.extproc.kgateway.io" + SetMetadataFilterName = "envoy.filters.http.set_filter_state" ) diff --git a/internal/kgateway/wellknown/controller.go b/internal/kgateway/wellknown/controller.go index 0fe97af4cc3..00a707d3fed 100644 --- a/internal/kgateway/wellknown/controller.go +++ b/internal/kgateway/wellknown/controller.go @@ -17,4 +17,8 @@ const ( // DefaultGatewayParametersName is the name of the GatewayParameters which is attached by // parametersRef to the GatewayClass. DefaultGatewayParametersName = "kgateway" + + // InferencePoolFinalizer is the InferencePool finalizer name to ensure cluster-scoped + // objects are cleaned up. + InferencePoolFinalizer = "kgateway/inferencepool-cleanup" ) diff --git a/internal/kgateway/wellknown/gwapi.go b/internal/kgateway/wellknown/gwapi.go index 1a5f2f6353c..8c09c5c9831 100644 --- a/internal/kgateway/wellknown/gwapi.go +++ b/internal/kgateway/wellknown/gwapi.go @@ -35,6 +35,9 @@ const ( // Kind string for ReferenceGrant resource ReferenceGrantKind = "ReferenceGrant" + // Kind string for InferencePool resource + InferencePoolKind = "InferencePool" + // Kind strings for Gateway API list types HTTPRouteListKind = "HTTPRouteList" GatewayListKind = "GatewayList" diff --git a/internal/kgateway/wellknown/kgw.go b/internal/kgateway/wellknown/kgw.go index 823edb539d5..966bb42eca9 100644 --- a/internal/kgateway/wellknown/kgw.go +++ b/internal/kgateway/wellknown/kgw.go @@ -2,6 +2,7 @@ package wellknown import ( "k8s.io/apimachinery/pkg/runtime/schema" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" "github.com/kgateway-dev/kgateway/v2/api/v1alpha1" ) @@ -14,6 +15,14 @@ func buildKgatewayGvk(kind string) schema.GroupVersionKind { } } +func buildInferExtGvk(kind string) schema.GroupVersionKind { + return schema.GroupVersionKind{ + Group: infextv1a1.GroupVersion.Group, + Version: v1alpha1.GroupVersion.Version, + Kind: kind, + } +} + // TODO: consider generating these? // manually updated GVKs of the kgateway API types; for convenience var ( @@ -23,4 +32,5 @@ var ( RoutePolicyGVK = buildKgatewayGvk("RoutePolicy") ListenerPolicyGVK = buildKgatewayGvk("ListenerPolicy") HTTPListenerPolicyGVK = buildKgatewayGvk("HTTPListenerPolicy") + InferencePoolGVK = buildInferExtGvk("InferencePool") ) diff --git a/internal/kgateway/wellknown/kube.go b/internal/kgateway/wellknown/kube.go index 34ba6fde3f9..c831841e8ea 100644 --- a/internal/kgateway/wellknown/kube.go +++ b/internal/kgateway/wellknown/kube.go @@ -3,6 +3,7 @@ package wellknown import ( appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + rbacv1 "k8s.io/api/rbac/v1" ) var ( @@ -11,5 +12,9 @@ var ( ServiceGVK = corev1.SchemeGroupVersion.WithKind("Service") ServiceAccountGVK = corev1.SchemeGroupVersion.WithKind("ServiceAccount") + // RBAC GVKs + ClusterRoleGVK = rbacv1.SchemeGroupVersion.WithKind("ClusterRoleBinding") + ClusterRoleBindingGVK = rbacv1.SchemeGroupVersion.WithKind("ClusterRole") + DeploymentGVK = appsv1.SchemeGroupVersion.WithKind("Deployment") ) diff --git a/pkg/schemes/extended_scheme.go b/pkg/schemes/extended_scheme.go index 26a2938d04d..8be19f7329f 100644 --- a/pkg/schemes/extended_scheme.go +++ b/pkg/schemes/extended_scheme.go @@ -3,6 +3,7 @@ package schemes import ( "fmt" + rbacv1 "k8s.io/api/rbac/v1" "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" "k8s.io/apimachinery/pkg/runtime" @@ -11,6 +12,7 @@ import ( "github.com/kgateway-dev/kgateway/v2/internal/kgateway/wellknown" + infextv1a1 "sigs.k8s.io/gateway-api-inference-extension/api/v1alpha1" gwv1a2 "sigs.k8s.io/gateway-api/apis/v1alpha2" ) @@ -30,6 +32,26 @@ func AddGatewayV1A2Scheme(restConfig *rest.Config, scheme *runtime.Scheme) error return nil } +// AddInferExtV1A1Scheme adds the Inference Extension v1alpha1 scheme to the provided scheme if the InferencePool CRD exists. +func AddInferExtV1A1Scheme(restConfig *rest.Config, scheme *runtime.Scheme) (bool, error) { + exists, err := CRDExists(restConfig, infextv1a1.GroupVersion.Group, infextv1a1.GroupVersion.Version, wellknown.InferencePoolKind) + if err != nil { + return false, fmt.Errorf("error checking if %s CRD exists: %w", wellknown.InferencePoolKind, err) + } + + if exists { + // Required to deploy RBAC resources for endpoint picker extension. + if err := rbacv1.AddToScheme(scheme); err != nil { + return false, fmt.Errorf("error adding RBAC v1 to scheme: %w", err) + } + if err := infextv1a1.AddToScheme(scheme); err != nil { + return false, fmt.Errorf("error adding Gateway API Inference Extension v1alpha1 to scheme: %w", err) + } + } + + return exists, nil +} + // Helper function to check if a CRD exists func CRDExists(restConfig *rest.Config, group, version, kind string) (bool, error) { discoveryClient, err := discovery.NewDiscoveryClientForConfig(restConfig)