diff --git a/go.mod b/go.mod index 23c9dac7e..2eae3523b 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/llm-d/llm-d-inference-scheduler -go 1.24.1 +go 1.24.9 -toolchain go1.24.2 +toolchain go1.24.12 require ( github.com/go-logr/logr v1.4.3 @@ -10,7 +10,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jellydator/ttlcache/v3 v3.4.0 - github.com/llm-d/llm-d-kv-cache v0.5.0-RC1 + github.com/llm-d/llm-d-kv-cache v0.5.0-rc1 github.com/onsi/ginkgo/v2 v2.27.5 github.com/onsi/gomega v1.39.0 github.com/openai/openai-go v1.12.0 @@ -23,10 +23,10 @@ require ( k8s.io/apimachinery v0.34.3 k8s.io/client-go v0.34.3 k8s.io/component-base v0.34.3 - k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d + k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 sigs.k8s.io/controller-runtime v0.22.5 sigs.k8s.io/gateway-api v1.4.1 - sigs.k8s.io/gateway-api-inference-extension v1.3.0 + sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a ) require ( @@ -46,7 +46,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect - github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect @@ -54,23 +54,32 @@ require ( github.com/go-errors/errors v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect - github.com/go-openapi/jsonpointer v0.21.2 // indirect - github.com/go-openapi/jsonreference v0.21.0 // indirect - github.com/go-openapi/swag v0.23.1 // indirect + github.com/go-openapi/jsonpointer v0.22.1 // indirect + github.com/go-openapi/jsonreference v0.21.3 // indirect + github.com/go-openapi/swag v0.25.4 // indirect + github.com/go-openapi/swag/cmdutils v0.25.4 // indirect + github.com/go-openapi/swag/conv v0.25.4 // indirect + github.com/go-openapi/swag/fileutils v0.25.4 // indirect + github.com/go-openapi/swag/jsonname v0.25.4 // indirect + github.com/go-openapi/swag/jsonutils v0.25.4 // indirect + github.com/go-openapi/swag/loading v0.25.4 // indirect + github.com/go-openapi/swag/mangling v0.25.4 // indirect + github.com/go-openapi/swag/netutils v0.25.4 // indirect + github.com/go-openapi/swag/stringutils v0.25.4 // indirect + github.com/go-openapi/swag/typeutils v0.25.4 // indirect + github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/cel-go v0.26.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect - github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 // indirect + github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/mailru/easyjson v0.9.0 // indirect github.com/moby/spdystream v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect @@ -83,7 +92,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.17.0 // indirect - github.com/prometheus/prometheus v0.308.1 // indirect + github.com/prometheus/prometheus v0.309.1 // indirect github.com/redis/go-redis/v9 v9.11.0 // indirect github.com/spf13/cobra v1.9.1 // indirect github.com/spf13/pflag v1.0.10 // indirect @@ -97,7 +106,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 // indirect @@ -118,10 +127,10 @@ require ( golang.org/x/sys v0.39.0 // indirect golang.org/x/term v0.38.0 // indirect golang.org/x/text v0.32.0 // indirect - golang.org/x/time v0.13.0 // indirect + golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.39.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect diff --git a/go.sum b/go.sum index 91f7b75a6..98149f336 100644 --- a/go.sum +++ b/go.sum @@ -6,14 +6,14 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1 h1:5YTBM8QDVIBN3sxBil89WfdAAqDZbyJTgh688DSxX5w= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0 h1:wL5IEG5zb7BVv1Kv0Xm92orq+5hB5Nipn3B5tn4Rqfk= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= -github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI= -github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0= @@ -24,32 +24,34 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= -github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= -github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= -github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y= -github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c= -github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA= -github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= +github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= -github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs= -github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= -github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= -github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3 h1:6df1vn4bBlDDo4tARvBm7l6KA9iVMnE3NWizDeWSrps= github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3/go.mod h1:CIWtjkly68+yqLPbvwwR/fjNJA/idrtULjZWh2v1ys0= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -87,8 +89,8 @@ github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bF github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= -github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= -github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -114,12 +116,40 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= -github.com/go-openapi/jsonpointer v0.21.2 h1:AqQaNADVwq/VnkCmQg6ogE+M3FOsKTytwges0JdwVuA= -github.com/go-openapi/jsonpointer v0.21.2/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= -github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= -github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= -github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= -github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= +github.com/go-openapi/jsonpointer v0.22.1 h1:sHYI1He3b9NqJ4wXLoJDKmUmHkWy/L7rtEo92JUxBNk= +github.com/go-openapi/jsonpointer v0.22.1/go.mod h1:pQT9OsLkfz1yWoMgYFy4x3U5GY5nUlsOn1qSBH5MkCM= +github.com/go-openapi/jsonreference v0.21.3 h1:96Dn+MRPa0nYAR8DR1E03SblB5FJvh7W6krPI0Z7qMc= +github.com/go-openapi/jsonreference v0.21.3/go.mod h1:RqkUP0MrLf37HqxZxrIAtTWW4ZJIK1VzduhXYBEeGc4= +github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU= +github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ= +github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4= +github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0= +github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4= +github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU= +github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y= +github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk= +github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= +github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= +github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA= +github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM= +github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s= +github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE= +github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48= +github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg= +github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0= +github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg= +github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8= +github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0= +github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw= +github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE= +github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw= +github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= +github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= +github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= @@ -143,14 +173,14 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 h1:ZI8gCoCjGzPsum4L21jHdQs8shFBIQih1TM9Rd/c+EQ= -github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f h1:HU1RgM6NALf/KW9HEY6zry3ADbDKcmpQ+hJedoNGQYQ= +github.com/google/pprof v0.0.0-20251213031049-b05bdaca462f/go.mod h1:67FPmZWbr+KDT/VlpWtw6sO9XSjpJmLuHpoLmWiTGgY= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= -github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= +github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= @@ -165,8 +195,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/joshdk/go-junit v1.0.0 h1:S86cUKIdwBHWwA6xCmFlf3RTLfVXYQfvanM5Uh+K6GE= github.com/joshdk/go-junit v1.0.0/go.mod h1:TiiV0PqkaNfFXjEiyjWM3XXrhVyCa1K4Zfga6W52ung= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= @@ -183,12 +211,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache v0.4.1-0.20260121180456-e3fafddd09f4 h1:3LHSnDQ2tLsSIbh4BgN+7RYz/Wi+KjvIigcxVHb3mkE= -github.com/llm-d/llm-d-kv-cache v0.4.1-0.20260121180456-e3fafddd09f4/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A= -github.com/llm-d/llm-d-kv-cache v0.5.0-RC1 h1:qlZeZw43CsvO8XnDaaZEKVWUJRRwb9AEwl4OcV833Bc= -github.com/llm-d/llm-d-kv-cache v0.5.0-RC1/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/llm-d/llm-d-kv-cache v0.5.0-rc1 h1:UkJZU8hGRdZKPeCiXnuGjLivqIS6yeFAl9pv4QDQcWY= +github.com/llm-d/llm-d-kv-cache v0.5.0-rc1/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE= @@ -241,8 +265,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= -github.com/prometheus/prometheus v0.308.1 h1:ApMNI/3/es3Ze90Z7CMb+wwU2BsSYur0m5VKeqHj7h4= -github.com/prometheus/prometheus v0.308.1/go.mod h1:aHjYCDz9zKRyoUXvMWvu13K9XHOkBB12XrEqibs3e0A= +github.com/prometheus/prometheus v0.309.1 h1:jutK6eCYDpWdPTUbVbkcQsNCMO9CCkSwjQRMLds4jSo= +github.com/prometheus/prometheus v0.309.1/go.mod h1:d+dOGiVhuNDa4MaFXHVdnUBy/CzqlcNTooR8oM1wdTU= github.com/prometheus/sigv4 v0.3.0 h1:QIG7nTbu0JTnNidGI1Uwl5AGVIChWUACxn2B/BQ1kms= github.com/prometheus/sigv4 v0.3.0/go.mod h1:fKtFYDus2M43CWKMNtGvFNHGXnAJJEGZbiYCmVp/F8I= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= @@ -295,8 +319,8 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0= @@ -362,8 +386,8 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= -golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= @@ -378,10 +402,10 @@ gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI= -google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= +google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= +google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 h1:7LRqPCEdE4TP4/9psdaB7F2nhZFfBiGJomA5sojLWdU= +google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= @@ -414,16 +438,18 @@ k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 h1:liMHz39T5dJO1aOKHLvwaCjDbf07wVh6yaUlTpunnkE= k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts= -k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d h1:wAhiDyZ4Tdtt7e46e9M5ZSAJ/MnPGPs+Ki1gHw4w1R0= -k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck= +k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 h1:jpcvIRr3GLoUoEKRkHKSmGjxb6lWwrBlJsXc+eUYQHM= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= sigs.k8s.io/controller-runtime v0.22.5 h1:v3nfSUMowX/2WMp27J9slwGFyAt7IV0YwBxAkrUr0GE= sigs.k8s.io/controller-runtime v0.22.5/go.mod h1:pc5SoYWnWI6I+cBHYYdZ7B6YHZVY5xNfll88JB+vniI= sigs.k8s.io/gateway-api v1.4.1 h1:NPxFutNkKNa8UfLd2CMlEuhIPMQgDQ6DXNKG9sHbJU8= sigs.k8s.io/gateway-api v1.4.1/go.mod h1:AR5RSqciWP98OPckEjOjh2XJhAe2Na4LHyXD2FUY7Qk= -sigs.k8s.io/gateway-api-inference-extension v1.3.0 h1:Ng2Qs1Oum4WycuWyi3rOkAC7pZ2aDqgN2ku6Lr/mryQ= -sigs.k8s.io/gateway-api-inference-extension v1.3.0/go.mod h1:Cyex0AlEzhuXFklzl0y5Hdf5zVY8PUtSKhzMvHh5D9M= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128073548-aea9ebe8cea3 h1:sobxO5HxXOd9RdhIUbUP0p+rZyn3ZFJAL6NolaHx1ZQ= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128073548-aea9ebe8cea3/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a h1:Ce5CZ0R3c5H475uEuJ92FMgux3j99wDrSsI4ivTBEXQ= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/kustomize/api v0.21.0 h1:I7nry5p8iDJbuRdYS7ez8MUvw7XVNPcIP5GkzzuXIIQ= diff --git a/pkg/plugins/filter/by_label.go b/pkg/plugins/filter/by_label.go index 070b6a627..464bc81a7 100644 --- a/pkg/plugins/filter/by_label.go +++ b/pkg/plugins/filter/by_label.go @@ -5,9 +5,8 @@ import ( "encoding/json" "fmt" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -21,10 +20,10 @@ type byLabelParameters struct { AllowsNoLabel bool `json:"allowsNoLabel"` } -var _ framework.Filter = &ByLabel{} // validate interface conformance +var _ scheduling.Filter = &ByLabel{} // validate interface conformance // ByLabelFactory defines the factory function for the ByLabel filter. -func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := byLabelParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -47,7 +46,7 @@ func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle // NewByLabel creates and returns an instance of the RoleBasedFilter based on the input parameters // name - the filter name // labelName - the name of the label to use -// allowsNoLabel - if true pods without given label will be considered as valid (not filtered out) +// allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out) // validValuesApp - list of valid values func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues ...string) *ByLabel { validValuesMap := map[string]struct{}{} @@ -57,27 +56,27 @@ func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues . } return &ByLabel{ - typedName: plugins.TypedName{Type: ByLabelType, Name: name}, + typedName: plugin.TypedName{Type: ByLabelType, Name: name}, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValuesMap, } } -// ByLabel - filters out pods based on the values defined by the given label +// ByLabel - filters out endpoints based on the values defined by the given label type ByLabel struct { // name defines the filter typed name - typedName plugins.TypedName + typedName plugin.TypedName // labelName defines the name of the label to be checked labelName string // validValues defines list of valid label values validValues map[string]struct{} - // allowsNoLabel - if true pods without given label will be considered as valid (not filtered out) + // allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out) allowsNoLabel bool } // TypedName returns the typed name of the plugin -func (f *ByLabel) TypedName() plugins.TypedName { +func (f *ByLabel) TypedName() plugin.TypedName { return f.typedName } @@ -87,19 +86,19 @@ func (f *ByLabel) WithName(name string) *ByLabel { return f } -// Filter filters out all pods that are not marked with one of roles from the validRoles collection +// Filter filters out all endpoints that are not marked with one of roles from the validRoles collection // or has no role label in case allowsNoRolesLabel is true -func (f *ByLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { - filteredPods := []types.Pod{} +func (f *ByLabel) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + filteredEndpoints := []scheduling.Endpoint{} - for _, pod := range pods { - val, labelDefined := pod.GetPod().Labels[f.labelName] + for _, endpoint := range endpoints { + val, labelDefined := endpoint.GetMetadata().Labels[f.labelName] _, valueExists := f.validValues[val] if (!labelDefined && f.allowsNoLabel) || valueExists { - filteredPods = append(filteredPods, pod) + filteredEndpoints = append(filteredEndpoints, endpoint) } } - return filteredPods + return filteredEndpoints } diff --git a/pkg/plugins/filter/by_label_selector.go b/pkg/plugins/filter/by_label_selector.go index 98b95d418..ceb53a0e3 100644 --- a/pkg/plugins/filter/by_label_selector.go +++ b/pkg/plugins/filter/by_label_selector.go @@ -8,9 +8,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -19,10 +18,10 @@ const ( ) // compile-time type assertion -var _ framework.Filter = &ByLabelSelector{} +var _ scheduling.Filter = &ByLabelSelector{} // ByLabelSelectorFactory defines the factory function for the ByLabelSelector filter -func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := metav1.LabelSelector{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -44,30 +43,30 @@ func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSe } return &ByLabelSelector{ - typedName: plugins.TypedName{Type: ByLabelSelectorType, Name: name}, + typedName: plugin.TypedName{Type: ByLabelSelectorType, Name: name}, selector: labelSelector, }, nil } -// ByLabelSelector filters out pods that do not match its label selector criteria +// ByLabelSelector filters out endpoints that do not match its label selector criteria type ByLabelSelector struct { - typedName plugins.TypedName + typedName plugin.TypedName selector labels.Selector } // TypedName returns the typed name of the plugin -func (blf *ByLabelSelector) TypedName() plugins.TypedName { +func (blf *ByLabelSelector) TypedName() plugin.TypedName { return blf.typedName } -// Filter filters out all pods that do not satisfy the label selector -func (blf *ByLabelSelector) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { - filtered := []types.Pod{} +// Filter filters out all endpoints that do not satisfy the label selector +func (blf *ByLabelSelector) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + filtered := []scheduling.Endpoint{} - for _, pod := range pods { - labels := labels.Set(pod.GetPod().Labels) + for _, endpoint := range endpoints { + labels := labels.Set(endpoint.GetMetadata().Labels) if blf.selector.Matches(labels) { - filtered = append(filtered, pod) + filtered = append(filtered, endpoint) } } return filtered diff --git a/pkg/plugins/filter/by_label_selector_test.go b/pkg/plugins/filter/by_label_selector_test.go index 3dc16e9d0..0d7947f09 100644 --- a/pkg/plugins/filter/by_label_selector_test.go +++ b/pkg/plugins/filter/by_label_selector_test.go @@ -9,9 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" "github.com/llm-d/llm-d-inference-scheduler/test/utils" @@ -147,35 +146,35 @@ func TestByLabelSelectorFactoryWithInvalidJSON(t *testing.T) { } func TestByLabelSelectorFiltering(t *testing.T) { - pods := []types.Pod{ - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, + endpoints := []scheduling.Endpoint{ + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, "10.0.0.1", map[string]string{ "app": "nginx", "version": "v1.0", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, "10.0.0.2", map[string]string{ "app": "nginx", "version": "v1.1", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, "10.0.0.3", map[string]string{ "app": "coredns", "tier": "system", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, "10.0.0.4", map[string]string{ "app": "redis", "tier": "backend", "deprecated": "true", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, "10.0.0.5", map[string]string{ "app": "web", @@ -301,17 +300,17 @@ func TestByLabelSelectorFiltering(t *testing.T) { ctx := utils.NewTestContext(t) - filteredPods := blf.Filter(ctx, nil, nil, pods) + filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints) - var actualPodNames []string - for _, pod := range filteredPods { - actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name) + var actualEndpointNames []string + for _, endpoint := range filteredEndpoints { + actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name) } - assert.ElementsMatch(t, tt.expectedPods, actualPodNames, - "filtered pods should match expected pods") - assert.Len(t, filteredPods, len(tt.expectedPods), - "filtered pods count should match expected count") + assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames, + "filtered endpoints should match expected endpoints") + assert.Len(t, filteredEndpoints, len(tt.expectedPods), + "filtered endpoints count should match expected count") }) } } @@ -326,26 +325,26 @@ func TestByLabelSelectorFilterEdgeCases(t *testing.T) { ctx := utils.NewTestContext(t) - t.Run("empty pods slice", func(t *testing.T) { - result := blf.Filter(ctx, nil, nil, []types.Pod{}) + t.Run("empty endpoints slice", func(t *testing.T) { + result := blf.Filter(ctx, nil, nil, []scheduling.Endpoint{}) assert.Empty(t, result) }) - t.Run("nil pods slice", func(t *testing.T) { + t.Run("nil endpoints slice", func(t *testing.T) { result := blf.Filter(ctx, nil, nil, nil) assert.Empty(t, result) }) - t.Run("pods with nil labels", func(t *testing.T) { - pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)} - result := blf.Filter(ctx, nil, nil, pods) - assert.Empty(t, result, "pod with nil labels should not match") + t.Run("endpoints with nil labels", func(t *testing.T) { + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)} + result := blf.Filter(ctx, nil, nil, endpoints) + assert.Empty(t, result, "endpoint with nil labels should not match") }) - t.Run("pods with empty labels", func(t *testing.T) { - pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})} - result := blf.Filter(ctx, nil, nil, pods) - assert.Empty(t, result, "pod with empty labels should not match") + t.Run("endpoints with empty labels", func(t *testing.T) { + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})} + result := blf.Filter(ctx, nil, nil, endpoints) + assert.Empty(t, result, "endpoint with empty labels should not match") }) } @@ -371,7 +370,7 @@ func ExamplePrefillDecodeRolesInLWS() { plugin, _ = filter.ByLabelSelectorFactory("prefill-role", prefillWorkerJSON, nil) prefillworker, _ := plugin.(*filter.ByLabelSelector) - pods := []types.Pod{createPod(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"}, + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"}, "10.0.0.1", map[string]string{ "app.kubernetes.io/component": "vllm-worker", @@ -383,7 +382,7 @@ func ExamplePrefillDecodeRolesInLWS() { name := "" for _, blf := range []*filter.ByLabelSelector{decodeLeader, decodeFollower, prefillworker} { - filtered := PrefillDecodeRolesInLWS(blf, pods) + filtered := PrefillDecodeRolesInLWS(blf, endpoints) if len(filtered) > 0 { name = blf.TypedName().Name } @@ -395,17 +394,18 @@ func ExamplePrefillDecodeRolesInLWS() { } // Helper functions -func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + &fwkdl.Metrics{}, + nil, + ) } -func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, pods []types.Pod) []types.Pod { - return blf.Filter(context.Background(), nil, nil, pods) +func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + return blf.Filter(context.Background(), nil, nil, endpoints) } diff --git a/pkg/plugins/filter/by_label_test.go b/pkg/plugins/filter/by_label_test.go index a3af75a4b..933871b61 100644 --- a/pkg/plugins/filter/by_label_test.go +++ b/pkg/plugins/filter/by_label_test.go @@ -8,9 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) @@ -139,54 +138,55 @@ func TestByLabelFactoryInvalidJSON(t *testing.T) { } // Helper functions -func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + &fwkdl.Metrics{}, + nil, + ) } func TestByLabelFiltering(t *testing.T) { - pods := []types.Pod{ - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, + endpoints := []scheduling.Endpoint{ + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, "10.0.0.1", map[string]string{ "app": "nginx", "version": "v1.0", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, "10.0.0.2", map[string]string{ "app": "nginx", "version": "v1.1", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, "10.0.0.3", map[string]string{ "app": "coredns", "tier": "system", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, "10.0.0.4", map[string]string{ "app": "redis", "tier": "backend", "deprecated": "true", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, "10.0.0.5", map[string]string{ "app": "web", "tier": "frontend", "environment": "production", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"}, "10.0.0.6", map[string]string{ "app": "unknown", @@ -247,17 +247,17 @@ func TestByLabelFiltering(t *testing.T) { ctx := utils.NewTestContext(t) - filteredPods := blf.Filter(ctx, nil, nil, pods) + filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints) - var actualPodNames []string - for _, pod := range filteredPods { - actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name) + var actualEndpointNames []string + for _, endpoint := range filteredEndpoints { + actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name) } - assert.ElementsMatch(t, tt.expectedPods, actualPodNames, - "filtered pods should match expected pods") - assert.Len(t, filteredPods, len(tt.expectedPods), - "filtered pods count should match expected count") + assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames, + "filtered endpoints should match expected endpoints") + assert.Len(t, filteredEndpoints, len(tt.expectedPods), + "filtered endpoints count should match expected count") }) } } diff --git a/pkg/plugins/filter/pd_role.go b/pkg/plugins/filter/pd_role.go index cc4cf7448..da96e7893 100644 --- a/pkg/plugins/filter/pd_role.go +++ b/pkg/plugins/filter/pd_role.go @@ -3,7 +3,7 @@ package filter import ( "encoding/json" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" ) const ( @@ -23,7 +23,7 @@ const ( ) // PrefillRoleFactory defines the factory function for the Prefill filter. -func PrefillRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PrefillRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewPrefillRole().WithName(name), nil } @@ -33,7 +33,7 @@ func NewPrefillRole() *ByLabel { } // DecodeRoleFactory defines the factory function for the Decode filter. -func DecodeRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func DecodeRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewDecodeRole().WithName(name), nil } diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go index beebbe46c..c77fc700f 100644 --- a/pkg/plugins/pre-request/pd_prerequest.go +++ b/pkg/plugins/pre-request/pd_prerequest.go @@ -7,9 +7,9 @@ import ( "fmt" "net" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -29,7 +29,7 @@ type prefillHeaderHandlerParameters struct { var _ requestcontrol.PreRequest = &PrefillHeaderHandler{} // PrefillHeaderHandlerFactory defines the factory function for the PrefillHeaderHandler -func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := prefillHeaderHandlerParameters{ PrefillProfile: defaultPrefillProfile, } @@ -44,19 +44,19 @@ func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ p // NewPrefillHeaderHandler initializes a new PrefillHeaderHandler and returns its pointer. func NewPrefillHeaderHandler(prefillProfile string) *PrefillHeaderHandler { return &PrefillHeaderHandler{ - typedName: plugins.TypedName{Type: PrefillHeaderHandlerType}, + typedName: plugin.TypedName{Type: PrefillHeaderHandlerType}, prefillProfile: prefillProfile, } } // PrefillHeaderHandler PreRequest plugin type PrefillHeaderHandler struct { - typedName plugins.TypedName + typedName plugin.TypedName prefillProfile string } // TypedName returns the typed name of the plugin. -func (p *PrefillHeaderHandler) TypedName() plugins.TypedName { +func (p *PrefillHeaderHandler) TypedName() plugin.TypedName { return p.typedName } @@ -67,7 +67,7 @@ func (p *PrefillHeaderHandler) WithName(name string) *PrefillHeaderHandler { } // PreRequest wires prefill SchedulerProfile result into a header to indicate prefill worker -func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { +func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { if _, found := request.Headers[common.PrefillPodHeader]; found { request.Headers[common.PrefillPodHeader] = "" // clear header, if already set } @@ -77,7 +77,7 @@ func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMR return // prefill profile failed to run or we chose not to run it, no-op in this case } - targetPod := prefillProfileRunResult.TargetPods[0].GetPod() + targetPod := prefillProfileRunResult.TargetEndpoints[0].GetMetadata() prefillHostPort := net.JoinHostPort(targetPod.Address, targetPod.Port) request.Headers[common.PrefillPodHeader] = prefillHostPort // in the form of } diff --git a/pkg/plugins/profile/dp_profile_handler.go b/pkg/plugins/profile/dp_profile_handler.go index ff6555a56..836872584 100644 --- a/pkg/plugins/profile/dp_profile_handler.go +++ b/pkg/plugins/profile/dp_profile_handler.go @@ -9,9 +9,8 @@ import ( "strconv" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -26,10 +25,10 @@ type dataParallelProfileHandlerParameters struct { } // compile-time type assertion -var _ framework.ProfileHandler = &DataParallelProfileHandler{} +var _ scheduling.ProfileHandler = &DataParallelProfileHandler{} // DataParallelProfileHandlerFactory defines the factory function for the DataParallelProfileHandler -func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := dataParallelProfileHandlerParameters{ PrimaryPort: 8000, } @@ -51,19 +50,19 @@ func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessag // NewDataParallelProfileHandler initializes a new PdProfileHandler and returns its pointer. func NewDataParallelProfileHandler(primaryPort int) *DataParallelProfileHandler { return &DataParallelProfileHandler{ - typedName: plugins.TypedName{Type: DataParallelProfileHandlerType}, + typedName: plugin.TypedName{Type: DataParallelProfileHandlerType}, primaryPort: strconv.Itoa(primaryPort), } } // DataParallelProfileHandler handles scheduler profiles for Data Parallel. type DataParallelProfileHandler struct { - typedName plugins.TypedName + typedName plugin.TypedName primaryPort string } // TypedName returns the typed name of the plugin. -func (h *DataParallelProfileHandler) TypedName() plugins.TypedName { +func (h *DataParallelProfileHandler) TypedName() plugin.TypedName { return h.typedName } @@ -75,17 +74,17 @@ func (h *DataParallelProfileHandler) WithName(name string) *DataParallelProfileH // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *DataParallelProfileHandler) Pick(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, - profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { +func (h *DataParallelProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, + profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call - return map[string]*framework.SchedulerProfile{} + return map[string]scheduling.SchedulerProfile{} } // Validate that only one profile is configured for Data Parallel mode if len(profiles) != 1 { log.FromContext(ctx).Error(nil, "Data Parallel profile handler requires exactly one scheduling profile", "profileCount", len(profiles), ) - return map[string]*framework.SchedulerProfile{} // return empty map for fast exit in later steps + return map[string]scheduling.SchedulerProfile{} // return empty map for fast exit in later steps } // return only one profile return profiles @@ -95,8 +94,8 @@ func (h *DataParallelProfileHandler) Pick(ctx context.Context, _ *types.CycleSta // It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the // key of the primary profile that should be used to get the request selected destination. // When a profile run fails, its result in the profileResults map is nil. -func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest, - profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { +func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, + profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) { if len(profileResults) != 1 { return nil, errors.New("data parallel profile handler is intended to be used with a single profile, failed to process multiple profiles") } @@ -112,23 +111,23 @@ func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types. return nil, fmt.Errorf("failed to run scheduler profile '%s'", singleProfileName) } - newResult := types.ProfileRunResult{ - TargetPods: []types.Pod{}, + newResult := scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{}, } - targetPod := profileResult.TargetPods[0].GetPod() + targetPod := profileResult.TargetEndpoints[0].GetMetadata() request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port) - for _, target := range profileResult.TargetPods { - newPodInfo := target.GetPod().Clone() - newPodInfo.Port = h.primaryPort - targetPod := &types.PodMetrics{Pod: newPodInfo, MetricsState: target.GetMetrics().Clone()} - newResult.TargetPods = append(newResult.TargetPods, targetPod) + for _, target := range profileResult.TargetEndpoints { + newMetadata := target.GetMetadata().Clone() + newMetadata.Port = h.primaryPort + targetEndpoint := scheduling.NewEndpoint(newMetadata, target.GetMetrics().Clone(), nil) + newResult.TargetEndpoints = append(newResult.TargetEndpoints, targetEndpoint) } - modifiedResults := map[string]*types.ProfileRunResult{singleProfileName: &newResult} + modifiedResults := map[string]*scheduling.ProfileRunResult{singleProfileName: &newResult} - return &types.SchedulingResult{ + return &scheduling.SchedulingResult{ ProfileResults: modifiedResults, PrimaryProfileName: singleProfileName, }, nil diff --git a/pkg/plugins/profile/dp_profile_handler_test.go b/pkg/plugins/profile/dp_profile_handler_test.go index 49ee1c58a..9bb942327 100644 --- a/pkg/plugins/profile/dp_profile_handler_test.go +++ b/pkg/plugins/profile/dp_profile_handler_test.go @@ -8,8 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -124,28 +123,28 @@ func TestDataParallelProfileHandlerFactoryInvalidJSON(t *testing.T) { func Test_DataParallelProfileHandler_Pick(t *testing.T) { tests := []struct { name string - profiles map[string]*framework.SchedulerProfile - profileResults map[string]*types.ProfileRunResult + profiles map[string]scheduling.SchedulerProfile + profileResults map[string]*scheduling.ProfileRunResult expectEmptyResult bool expectLogError bool description string }{ { name: "success: single profile, first call", - profiles: map[string]*framework.SchedulerProfile{ - "default": {}, + profiles: map[string]scheduling.SchedulerProfile{ + "default": newMockSchedulerProfile(), }, - profileResults: map[string]*types.ProfileRunResult{}, + profileResults: map[string]*scheduling.ProfileRunResult{}, expectEmptyResult: false, expectLogError: false, description: "Should return the single profile to run", }, { name: "success: single profile, second call (all already executed)", - profiles: map[string]*framework.SchedulerProfile{ - "default": {}, + profiles: map[string]scheduling.SchedulerProfile{ + "default": newMockSchedulerProfile(), }, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "default": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectEmptyResult: true, @@ -154,19 +153,19 @@ func Test_DataParallelProfileHandler_Pick(t *testing.T) { }, { name: "error: multiple profiles configured in EPP", - profiles: map[string]*framework.SchedulerProfile{ - "profile1": {}, - "profile2": {}, + profiles: map[string]scheduling.SchedulerProfile{ + "profile1": newMockSchedulerProfile(), + "profile2": newMockSchedulerProfile(), }, - profileResults: map[string]*types.ProfileRunResult{}, + profileResults: map[string]*scheduling.ProfileRunResult{}, expectEmptyResult: true, expectLogError: true, description: "Should return empty map and log error for multiple profiles", }, { name: "error: zero profiles configured in EPP", - profiles: map[string]*framework.SchedulerProfile{}, - profileResults: map[string]*types.ProfileRunResult{}, + profiles: map[string]scheduling.SchedulerProfile{}, + profileResults: map[string]*scheduling.ProfileRunResult{}, expectEmptyResult: true, expectLogError: true, description: "Should return empty map and log error for zero profiles", @@ -178,7 +177,7 @@ func Test_DataParallelProfileHandler_Pick(t *testing.T) { handler := NewDataParallelProfileHandler(8000).WithName("test-handler") ctx := context.Background() - result := handler.Pick(ctx, &types.CycleState{}, &types.LLMRequest{}, tt.profiles, tt.profileResults) + result := handler.Pick(ctx, &scheduling.CycleState{}, &scheduling.LLMRequest{}, tt.profiles, tt.profileResults) if tt.expectEmptyResult { assert.Empty(t, result, tt.description) @@ -194,13 +193,13 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { tests := []struct { name string primaryPort int - profileResults map[string]*types.ProfileRunResult + profileResults map[string]*scheduling.ProfileRunResult expectError bool - checkResult func(*testing.T, *types.SchedulingResult, map[string]string) + checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string) }{ { name: "error: multiple profiles not supported", - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "profile1": newMockProfileRunResult(DefaultTestPodPort, "pod1"), "profile2": newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, @@ -208,7 +207,7 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { }, { name: "error: single profile but result is nil", - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "nil-profile": nil, }, expectError: true, @@ -216,16 +215,16 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { { name: "success: single profile with primaryPort → port overridden, header set", primaryPort: 9000, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { assert.Equal(t, "dp-profile", res.PrimaryProfileName) - pods := res.ProfileResults["dp-profile"].TargetPods + pods := res.ProfileResults["dp-profile"].TargetEndpoints require.Len(t, pods, 1) - assert.Equal(t, "9000", pods[0].GetPod().Port) // overridden + assert.Equal(t, "9000", pods[0].GetMetadata().Port) // overridden expectedHeader := net.JoinHostPort("10.0.0.1", DefaultTestPodPort) // original assert.Equal(t, expectedHeader, headers[common.DataParallelPodHeader]) }, @@ -233,28 +232,28 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { { name: "success: primaryPort=0 → port becomes '0'", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp": newMockProfileRunResult("8080", "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pod := res.ProfileResults["dp"].TargetPods[0] - assert.Equal(t, "0", pod.GetPod().Port) + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + pod := res.ProfileResults["dp"].TargetEndpoints[0] + assert.Equal(t, "0", pod.GetMetadata().Port) assert.Equal(t, "10.0.0.1:8080", headers[common.DataParallelPodHeader]) }, }, { name: "success: multiple target pods → all ports overridden", primaryPort: 8080, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1", "pod2"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pods := res.ProfileResults["dp-profile"].TargetPods + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + pods := res.ProfileResults["dp-profile"].TargetEndpoints assert.Len(t, pods, 2) for _, p := range pods { - assert.Equal(t, "8080", p.GetPod().Port) + assert.Equal(t, "8080", p.GetMetadata().Port) } assert.Equal(t, net.JoinHostPort("10.0.0.1", DefaultTestPodPort), headers[common.DataParallelPodHeader]) }, @@ -265,8 +264,8 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { t.Run(tt.name, func(t *testing.T) { handler := NewDataParallelProfileHandler(tt.primaryPort).WithName("test-handler") headers := make(map[string]string) - req := &types.LLMRequest{Headers: headers} - result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults) + req := &scheduling.LLMRequest{Headers: headers} + result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults) if tt.expectError { assert.Error(t, err) diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index 8dff33e43..4b86b471f 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -9,16 +9,14 @@ import ( "net" "strconv" - "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" ) const ( @@ -28,6 +26,9 @@ const ( defaultDecodeProfile = "decode" defaultPrefillProfile = "prefill" defaultPrefixPluginType = prefix.PrefixCachePluginType + + // An estimated average characters per token, used since the request we cached is not tokenized. + averageCharactersPerToken = 4 ) type pdProfileHandlerParameters struct { @@ -41,16 +42,16 @@ type pdProfileHandlerParameters struct { } // compile-time type assertion -var _ framework.ProfileHandler = &PdProfileHandler{} +var _ scheduling.ProfileHandler = &PdProfileHandler{} // PdProfileHandlerFactory defines the factory function for the PdProfileHandler -func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := pdProfileHandlerParameters{ Threshold: 0, DecodeProfile: defaultDecodeProfile, PrefillProfile: defaultPrefillProfile, PrefixPluginType: defaultPrefixPluginType, - HashBlockSize: prefix.DefaultBlockSize, + HashBlockSize: prefix.DefaultBlockSizeTokens * averageCharactersPerToken, PrimaryPort: 0, } if rawParameters != nil { @@ -84,8 +85,8 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi // NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer. func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler { result := &PdProfileHandler{ - typedName: plugins.TypedName{Type: PdProfileHandlerType}, - prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, + typedName: plugin.TypedName{Type: PdProfileHandlerType}, + prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName}, decodeProfile: decodeProfile, prefillProfile: prefillProfile, pdThreshold: pdThreshold, @@ -100,8 +101,8 @@ func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefix // PdProfileHandler handles scheduler profiles for PD. type PdProfileHandler struct { - typedName plugins.TypedName - prefixPluginTypedName plugins.TypedName + typedName plugin.TypedName + prefixPluginTypedName plugin.TypedName decodeProfile string prefillProfile string pdThreshold int @@ -110,7 +111,7 @@ type PdProfileHandler struct { } // TypedName returns the typed name of the plugin. -func (h *PdProfileHandler) TypedName() plugins.TypedName { +func (h *PdProfileHandler) TypedName() plugin.TypedName { return h.typedName } @@ -122,11 +123,11 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler { // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, - profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { +func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, + profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { if _, executed := profileResults[h.decodeProfile]; !executed { // if decode profile was not executed yet, first let the scheduler run the decode profile - return map[string]*framework.SchedulerProfile{ + return map[string]scheduling.SchedulerProfile{ h.decodeProfile: profiles[h.decodeProfile], } } @@ -135,7 +136,7 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat // when a profile run fails its result value is nil. we need to check decode result before continuing to prefill // check if all configured profiles have been executed, or if decode failed, no need to run more profiles. if len(profiles) == len(profileResults) || profileResults[h.decodeProfile] == nil { - return map[string]*framework.SchedulerProfile{} + return map[string]scheduling.SchedulerProfile{} } if h.pdThreshold > 0 { @@ -150,13 +151,13 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat // inspect decode execution result to decide if prefill should run or not. // if the request is short enough, use decode results only and don't run the prefill profile. hitPercentagePrefix := 0.0 // default to 0, meaning no prefix cache hit - prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(h.prefixPluginTypedName.String())) + prefixState, err := scheduling.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugin.StateKey(h.prefixPluginTypedName.String())) if err != nil { log.FromContext(ctx).Error(err, "unable to read prefix state") } else { - decodePod := profileResults[h.decodeProfile].TargetPods[0].GetPod().NamespacedName - hitPrefix := max(prefixState.PrefixCacheServers[prefix.ServerID(decodePod)]-1, 0) // The first hit is always the model name - hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize) / float64(len(userInput)) + decodeEndpoint := profileResults[h.decodeProfile].TargetEndpoints[0].GetMetadata().NamespacedName + hitPrefix := max(prefixState.PrefixCacheServers[prefix.ServerID(decodeEndpoint)]-1, 0) // The first hit is always the model name + hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize*averageCharactersPerToken) / float64(len(userInput)) log.FromContext(ctx).V(logutil.DEBUG).Info("Computed hit percentage for prefix cache", "hitPercentage", hitPercentagePrefix, "promptLength", len(userInput)) } @@ -164,13 +165,13 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat if (1.0-hitPercentagePrefix)*float64(len(userInput)) < float64(h.pdThreshold) { log.FromContext(ctx).Info("Non-cached suffix is smaller than threshold, using decode profile only", "hitPercentage", hitPercentagePrefix) metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) - return map[string]*framework.SchedulerProfile{} // do not run prefill + return map[string]scheduling.SchedulerProfile{} // do not run prefill } } metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) // run the prefill profile - return map[string]*framework.SchedulerProfile{ + return map[string]scheduling.SchedulerProfile{ h.prefillProfile: profiles[h.prefillProfile], } } @@ -178,32 +179,32 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat // ProcessResults handles the outcome of the profile runs after the selected profiles ran. // In case of an error in any of the profiles, the matching entry in the profileResults will contain nil, to indicate there was // an error while running the profile. -func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest, - profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { +func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, + profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) { decodeRunResults := profileResults[h.decodeProfile] if decodeRunResults == nil { // if decode profile failed to run, we should fail return nil, errors.New("failed to find available decode workers") } // otherwise, decode ran successfully - updatedResults := map[string]*types.ProfileRunResult{} + updatedResults := map[string]*scheduling.ProfileRunResult{} // Add decode profile to result if h.primaryPort != "" { // Data Parallel is active - targetPod := decodeRunResults.TargetPods[0].GetPod() + targetPod := decodeRunResults.TargetEndpoints[0].GetMetadata() request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port) - updatedResult := types.ProfileRunResult{ - TargetPods: []types.Pod{}, + updatedResult := scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{}, } - for _, target := range decodeRunResults.TargetPods { - updatedPodInfo := target.GetPod().Clone() + for _, target := range decodeRunResults.TargetEndpoints { + updatedPodInfo := target.GetMetadata().Clone() updatedPodInfo.Port = h.primaryPort - targetPod := &types.PodMetrics{Pod: updatedPodInfo, MetricsState: target.GetMetrics().Clone()} - updatedResult.TargetPods = append(updatedResult.TargetPods, targetPod) + targetEndpoint := scheduling.NewEndpoint(updatedPodInfo, target.GetMetrics().Clone(), nil) + updatedResult.TargetEndpoints = append(updatedResult.TargetEndpoints, targetEndpoint) } updatedResults[h.decodeProfile] = &updatedResult } else { @@ -216,13 +217,13 @@ func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState updatedResults[h.prefillProfile] = prefillRunResult } - return &types.SchedulingResult{ + return &scheduling.SchedulingResult{ PrimaryProfileName: h.decodeProfile, ProfileResults: updatedResults, }, nil } -func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { +func getUserInputBytes(request *scheduling.LLMRequest) ([]byte, error) { if request.Body.Completions != nil { // assumed to be valid if not nil return []byte(request.Body.Completions.Prompt), nil } diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go index 09068104b..3f62fb640 100644 --- a/pkg/plugins/profile/pd_profile_handler_test.go +++ b/pkg/plugins/profile/pd_profile_handler_test.go @@ -8,12 +8,10 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" "github.com/llm-d/llm-d-inference-scheduler/test/utils" @@ -180,51 +178,58 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { const DefaultTestPodPort = "8000" -// createPod creates a mock Pod with customizable IP and port. -func createPod(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +// createEndpoint creates a mock Pod with customizable IP and port. +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Port: port, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + &fwkdl.Metrics{}, + nil, + ) } // newMockProfileRunResult creates a ProfileRunResult with Pods using the given port. -func newMockProfileRunResult(port string, podNames ...string) *types.ProfileRunResult { - pods := make([]types.Pod, 0, len(podNames)) - for i, name := range podNames { +func newMockProfileRunResult(port string, endpointNames ...string) *scheduling.ProfileRunResult { + endpoints := make([]scheduling.Endpoint, 0, len(endpointNames)) + for i, name := range endpointNames { ip := fmt.Sprintf("10.0.0.%d", i+1) - pods = append(pods, createPod( + endpoints = append(endpoints, createEndpoint( k8stypes.NamespacedName{Namespace: "default", Name: name}, ip, port, map[string]string{}, )) } - return &types.ProfileRunResult{ - TargetPods: pods, + return &scheduling.ProfileRunResult{ + TargetEndpoints: endpoints, } } -func newMockSchedulerProfile() *framework.SchedulerProfile { - return &framework.SchedulerProfile{} +func newMockSchedulerProfile() scheduling.SchedulerProfile { + return &mockSchedulerProfile{} +} + +type mockSchedulerProfile struct{} + +func (p *mockSchedulerProfile) Run(_ context.Context, _ *scheduling.LLMRequest, _ *scheduling.CycleState, _ []scheduling.Endpoint) (*scheduling.ProfileRunResult, error) { + return nil, nil } func TestPdProfileHandler_Pick(t *testing.T) { ctx := utils.NewTestContext(t) - request := &types.LLMRequest{ - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + request := &scheduling.LLMRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: "hello world", }, }, } - profiles := map[string]*framework.SchedulerProfile{ + profiles := map[string]scheduling.SchedulerProfile{ "decode": newMockSchedulerProfile(), "prefill": newMockSchedulerProfile(), } @@ -235,8 +240,8 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize int prefixPluginType string prefixPluginName string - setupPrefixState func(*types.CycleState) - profileResults map[string]*types.ProfileRunResult + setupPrefixState func(*scheduling.CycleState) + profileResults map[string]*scheduling.ProfileRunResult expectedProfiles []string }{ { @@ -245,7 +250,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize: 16, prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{}, + profileResults: map[string]*scheduling.ProfileRunResult{}, expectedProfiles: []string{"decode"}, }, { @@ -254,7 +259,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize: 16, prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": nil, }, expectedProfiles: []string{}, @@ -265,7 +270,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize: 16, prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, @@ -277,16 +282,16 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize: 16, prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *types.CycleState) { + setupPrefixState: func(cs *scheduling.CycleState) { state := &prefix.SchedulingContextState{ PrefixCacheServers: map[prefix.ServerID]int{ prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 1, }, } - key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) + key := plugin.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) cs.Write(key, state) }, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectedProfiles: []string{"prefill"}, @@ -297,16 +302,16 @@ func TestPdProfileHandler_Pick(t *testing.T) { hashBlockSize: 16, prefixPluginType: prefix.PrefixCachePluginType, prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *types.CycleState) { + setupPrefixState: func(cs *scheduling.CycleState) { state := &prefix.SchedulingContextState{ PrefixCacheServers: map[prefix.ServerID]int{ prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 5, }, } - key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) + key := plugin.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) cs.Write(key, state) }, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectedProfiles: []string{}, @@ -325,7 +330,7 @@ func TestPdProfileHandler_Pick(t *testing.T) { 0, ).WithName("test-handler") - cs := &types.CycleState{} + cs := &scheduling.CycleState{} if tt.setupPrefixState != nil { tt.setupPrefixState(cs) } @@ -346,13 +351,13 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { tests := []struct { name string primaryPort int - profileResults map[string]*types.ProfileRunResult + profileResults map[string]*scheduling.ProfileRunResult expectError bool - checkResult func(*testing.T, *types.SchedulingResult, map[string]string) + checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string) }{ { name: "decode failed → error", - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": nil, }, expectError: true, @@ -360,28 +365,28 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { { name: "decode success, no prefill, no primaryPort", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { assert.Equal(t, "decode", res.PrimaryProfileName) assert.Contains(t, res.ProfileResults, "decode") assert.NotContains(t, res.ProfileResults, "prefill") - pod := res.ProfileResults["decode"].TargetPods[0].GetPod() - assert.Equal(t, DefaultTestPodPort, pod.Port) + metadata := res.ProfileResults["decode"].TargetEndpoints[0].GetMetadata() + assert.Equal(t, DefaultTestPodPort, metadata.Port) assert.Empty(t, headers[common.DataParallelPodHeader]) }, }, { name: "decode success, with prefill", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, _ map[string]string) { + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, _ map[string]string) { assert.Equal(t, "decode", res.PrimaryProfileName) assert.Contains(t, res.ProfileResults, "decode") assert.Contains(t, res.ProfileResults, "prefill") @@ -390,13 +395,13 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { { name: "with primaryPort → port updated and header set", primaryPort: 9000, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pod := res.ProfileResults["decode"].TargetPods[0].GetPod() - assert.Equal(t, "9000", pod.Port) + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + metadata := res.ProfileResults["decode"].TargetEndpoints[0].GetMetadata() + assert.Equal(t, "9000", metadata.Port) hostPort := headers[common.DataParallelPodHeader] assert.Equal(t, "10.0.0.1:8000", hostPort) @@ -412,15 +417,15 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { prefix.PrefixCachePluginType, prefix.PrefixCachePluginType, 0, - prefix.DefaultBlockSize, + prefix.DefaultBlockSizeTokens*averageCharactersPerToken, tt.primaryPort, ).WithName("test-handler") headers := make(map[string]string) - req := &types.LLMRequest{ + req := &scheduling.LLMRequest{ Headers: headers, } - result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults) + result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults) if tt.expectError { assert.Error(t, err) diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index f78c4afc7..08bb7b0b3 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -5,21 +5,21 @@ import ( prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" ) // RegisterAllPlugins registers the factory functions of all plugins in this repository. func RegisterAllPlugins() { - plugins.Register(filter.ByLabelType, filter.ByLabelFactory) - plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory) - plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory) - plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory) - plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) - plugins.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory) - plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) - plugins.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory) - plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory) - plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) - plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) - plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) + plugin.Register(filter.ByLabelType, filter.ByLabelFactory) + plugin.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory) + plugin.Register(filter.DecodeRoleType, filter.DecodeRoleFactory) + plugin.Register(filter.PrefillRoleType, filter.PrefillRoleFactory) + plugin.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) + plugin.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory) + plugin.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) + plugin.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory) + plugin.Register(scorer.LoadAwareType, scorer.LoadAwareFactory) + plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) + plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) + plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) } diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go index 2b0a0d559..33aaff853 100644 --- a/pkg/plugins/scorer/active_request.go +++ b/pkg/plugins/scorer/active_request.go @@ -10,12 +10,11 @@ import ( "github.com/jellydator/ttlcache/v3" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -49,12 +48,12 @@ func (r requestEntry) String() string { } // compile-time type assertion -var _ framework.Scorer = &ActiveRequest{} +var _ scheduling.Scorer = &ActiveRequest{} var _ requestcontrol.PreRequest = &ActiveRequest{} var _ requestcontrol.ResponseComplete = &ActiveRequest{} // ActiveRequestFactory defines the factory function for the ActiveRequest scorer. -func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := ActiveRequestParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -87,19 +86,19 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act ) scorer := &ActiveRequest{ - typedName: plugins.TypedName{Type: ActiveRequestType}, - requestCache: requestCache, - podCounts: make(map[string]int), - mutex: &sync.RWMutex{}, + typedName: plugin.TypedName{Type: ActiveRequestType}, + requestCache: requestCache, + endpointCounts: make(map[string]int), + mutex: &sync.RWMutex{}, } // callback to decrement count when requests expire // most requests will be removed in ResponseComplete, but this ensures - // that we don't leak pod counts if ResponseComplete is not called + // that we don't leak endpoint counts if ResponseComplete is not called requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason, item *ttlcache.Item[string, *requestEntry]) { if reason == ttlcache.EvictionReasonExpired { - for _, podName := range item.Value().PodNames { - scorer.decrementPodCount(podName) + for _, endpointName := range item.Value().PodNames { + scorer.decrementPodCount(endpointName) } } }) @@ -110,20 +109,20 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act } // ActiveRequest keeps track of individual requests being served -// per pod. +// per endpoint. type ActiveRequest struct { - typedName plugins.TypedName + typedName plugin.TypedName - // requestCache stores individual request entries with unique composite keys (podName.requestID) + // requestCache stores individual request entries with unique composite keys (endpointName.requestID) requestCache *ttlcache.Cache[string, *requestEntry] - // podCounts maintains fast lookup for request counts per pod - podCounts map[string]int - mutex *sync.RWMutex + // endpointCounts maintains fast lookup for request counts per endpoint + endpointCounts map[string]int + mutex *sync.RWMutex } // TypedName returns the typed name of the plugin. -func (s *ActiveRequest) TypedName() plugins.TypedName { +func (s *ActiveRequest) TypedName() plugin.TypedName { return s.typedName } @@ -133,78 +132,83 @@ func (s *ActiveRequest) WithName(name string) *ActiveRequest { return s } -// Score scores the given pods based on the number of active requests -// being served by each pod. The score is normalized to a range of 0-1. -func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, - pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[string]int) +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *ActiveRequest) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + +// Score scores the given endpoints based on the number of active requests +// being served by each endpoint. The score is normalized to a range of 0-1. +func (s *ActiveRequest) Score(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, + endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[string]int) maxCount := 0 s.mutex.RLock() - for podName, count := range s.podCounts { - scoredPods[podName] = count + for endpointName, count := range s.endpointCounts { + scoredEndpoints[endpointName] = count if count >= maxCount { maxCount = count } } s.mutex.RUnlock() - scoredPodsMap := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - podName := pod.GetPod().NamespacedName.String() - if count, exists := scoredPods[podName]; exists { + scoredEndpointsMap := make(map[scheduling.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + endpointName := endpoint.GetMetadata().NamespacedName.String() + if count, exists := scoredEndpoints[endpointName]; exists { if count == 0 || maxCount == 0 { - scoredPodsMap[pod] = 1.0 // no requests means highest score + scoredEndpointsMap[endpoint] = 1.0 // no requests means highest score } else { - scoredPodsMap[pod] = float64(maxCount-count) / float64(maxCount) + scoredEndpointsMap[endpoint] = float64(maxCount-count) / float64(maxCount) } } else { - scoredPodsMap[pod] = 1.0 + scoredEndpointsMap[endpoint] = 1.0 } } - log.FromContext(ctx).V(logutil.DEBUG).Info("Scored pods", "scores", scoredPodsMap) - return scoredPodsMap + log.FromContext(ctx).V(logutil.DEBUG).Info("Scored endpoints", "scores", scoredEndpointsMap) + return scoredEndpointsMap } -// PreRequest is called before a request is sent to the target pod. +// PreRequest is called before a request is sent to the target endpoint. // It creates a new request entry in the cache with its own TTL and -// increments the pod count for fast lookup. +// increments the endpoint count for fast lookup. func (s *ActiveRequest) PreRequest( ctx context.Context, - request *types.LLMRequest, - schedulingResult *types.SchedulingResult, + request *scheduling.LLMRequest, + schedulingResult *scheduling.SchedulingResult, ) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG) - podNames := make([]string, 0, len(schedulingResult.ProfileResults)) + endpointNames := make([]string, 0, len(schedulingResult.ProfileResults)) for profileName, profileResult := range schedulingResult.ProfileResults { - if profileResult == nil || len(profileResult.TargetPods) == 0 { + if profileResult == nil || len(profileResult.TargetEndpoints) == 0 { continue } - podName := profileResult.TargetPods[0].GetPod().NamespacedName.String() - podNames = append(podNames, podName) - s.incrementPodCount(podName) + endpointName := profileResult.TargetEndpoints[0].GetMetadata().NamespacedName.String() + endpointNames = append(endpointNames, endpointName) + s.incrementPodCount(endpointName) debugLogger.Info( "Added request to cache", "requestId", request.RequestId, - "podName", podName, + "endpointName", endpointName, "profileName", profileName, ) } // add to request cache - s.requestCache.Set(request.RequestId, &requestEntry{PodNames: podNames, RequestID: request.RequestId}, 0) // Use default TTL + s.requestCache.Set(request.RequestId, &requestEntry{PodNames: endpointNames, RequestID: request.RequestId}, 0) // Use default TTL } // ResponseComplete is called after a response is sent to the client. // It removes the specific request entry from the cache and decrements -// the pod count. +// the endpoint count. func (s *ActiveRequest) ResponseComplete( ctx context.Context, - request *types.LLMRequest, + request *scheduling.LLMRequest, _ *requestcontrol.Response, - targetPod *backend.Pod, + targetPod *datalayer.EndpointMetadata, ) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.ResponseComplete") if targetPod == nil { @@ -215,8 +219,8 @@ func (s *ActiveRequest) ResponseComplete( if item, found := s.requestCache.GetAndDelete(request.RequestId); found { entry := item.Value() if entry != nil { - for _, podName := range entry.PodNames { - s.decrementPodCount(podName) + for _, endpointName := range entry.PodNames { + s.decrementPodCount(endpointName) } debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) } else { @@ -227,25 +231,25 @@ func (s *ActiveRequest) ResponseComplete( } } -// incrementPodCount increments the request count for a pod. -func (s *ActiveRequest) incrementPodCount(podName string) { +// incrementPodCount increments the request count for a endpoint. +func (s *ActiveRequest) incrementPodCount(endpointName string) { s.mutex.Lock() defer s.mutex.Unlock() - s.podCounts[podName]++ + s.endpointCounts[endpointName]++ } -// decrementPodCount decrements the request count for a pod and removes +// decrementPodCount decrements the request count for a endpoint and removes // the entry if count reaches zero. -func (s *ActiveRequest) decrementPodCount(podName string) { +func (s *ActiveRequest) decrementPodCount(endpointName string) { s.mutex.Lock() defer s.mutex.Unlock() - if count, exists := s.podCounts[podName]; exists { + if count, exists := s.endpointCounts[endpointName]; exists { if count <= 1 { - delete(s.podCounts, podName) + delete(s.endpointCounts, endpointName) } else { - s.podCounts[podName] = count - 1 + s.endpointCounts[endpointName] = count - 1 } } } diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go index 1009aaa59..12ab609aa 100644 --- a/pkg/plugins/scorer/active_request_test.go +++ b/pkg/plugins/scorer/active_request_test.go @@ -7,109 +7,109 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) // Test helper functions -func newTestPod(name string, queueSize int) *types.PodMetrics { - return &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: name, Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ +func newTestEndpoint(name string, queueSize int) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: name, Namespace: "default"}}, + &fwkdl.Metrics{ WaitingQueueSize: queueSize, }, - } + nil, + ) } -func newTestRequest(id string) *types.LLMRequest { - return &types.LLMRequest{ +func newTestRequest(id string) *scheduling.LLMRequest { + return &scheduling.LLMRequest{ RequestId: id, } } -func newTestSchedulingResult(profilePods map[string]types.Pod) *types.SchedulingResult { - profileResults := make(map[string]*types.ProfileRunResult) - for profile, pod := range profilePods { - profileResults[profile] = &types.ProfileRunResult{ - TargetPods: []types.Pod{pod}, +func newTestSchedulingResult(profileEndpoints map[string]scheduling.Endpoint) *scheduling.SchedulingResult { + profileResults := make(map[string]*scheduling.ProfileRunResult) + for profile, endpoint := range profileEndpoints { + profileResults[profile] = &scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{endpoint}, } } - return &types.SchedulingResult{ + return &scheduling.SchedulingResult{ ProfileResults: profileResults, } } -func (s *ActiveRequest) getPodCount(podName string) int { +func (s *ActiveRequest) getPodCount(endpointName string) int { s.mutex.RLock() defer s.mutex.RUnlock() - return s.podCounts[podName] + return s.endpointCounts[endpointName] } -func (s *ActiveRequest) hasPodCount(podName string) bool { +func (s *ActiveRequest) hasPodCount(endpointName string) bool { s.mutex.RLock() defer s.mutex.RUnlock() - _, exists := s.podCounts[podName] + _, exists := s.endpointCounts[endpointName] return exists } func TestActiveRequestScorer_Score(t *testing.T) { - podA := newTestPod("pod-a", 2) - podB := newTestPod("pod-b", 0) - podC := newTestPod("pod-c", 15) + endpointA := newTestEndpoint("pod-a", 2) + endpointB := newTestEndpoint("pod-b", 0) + endpointC := newTestEndpoint("pod-c", 15) tests := []struct { name string setupCache func(*ActiveRequest) - input []types.Pod - wantScores map[types.Pod]float64 + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { - name: "no pods in cache", + name: "no endpoints in cache", setupCache: func(_ *ActiveRequest) { // Cache is empty }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 1, - podB: 1, - podC: 1, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1, + endpointB: 1, + endpointC: 1, }, }, { - name: "all pods in cache with different request counts", + name: "all endpoints in cache with different request counts", setupCache: func(s *ActiveRequest) { s.mutex.Lock() - s.podCounts["default/pod-a"] = 3 - s.podCounts["default/pod-b"] = 0 - s.podCounts["default/pod-c"] = 6 + s.endpointCounts["default/pod-a"] = 3 + s.endpointCounts["default/pod-b"] = 0 + s.endpointCounts["default/pod-c"] = 6 s.mutex.Unlock() }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 0.5, - podB: 1.0, - podC: 0.0, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.5, + endpointB: 1.0, + endpointC: 0.0, }, }, { - name: "some pods in cache", + name: "some endpoints in cache", setupCache: func(s *ActiveRequest) { s.mutex.Lock() - s.podCounts["default/pod-a"] = 4 - s.podCounts["default/pod-c"] = 1 + s.endpointCounts["default/pod-a"] = 4 + s.endpointCounts["default/pod-c"] = 1 // pod-b not in cache s.mutex.Unlock() }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 1.0, - podC: 0.75, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 1.0, + endpointC: 0.75, }, }, } @@ -132,35 +132,35 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) { ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) - podA := newTestPod("pod-a", 2) - podB := newTestPod("pod-b", 0) + endpointA := newTestEndpoint("pod-a", 2) + endpointB := newTestEndpoint("pod-b", 0) testProfile := "test-profile" t.Run("First request", func(t *testing.T) { request := newTestRequest("test-request-1") - schedulingResult := newTestSchedulingResult(map[string]types.Pod{ - testProfile: podA, + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + testProfile: endpointA, }) scorer.PreRequest(ctx, request, schedulingResult) assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") - assert.Equal(t, 1, scorer.getPodCount(podA.GetPod().NamespacedName.String())) + assert.Equal(t, 1, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String())) }) - t.Run("Second request to multiple pods", func(t *testing.T) { + t.Run("Second request to multiple endpoints", func(t *testing.T) { request := newTestRequest("test-request-2") - schedulingResult := newTestSchedulingResult(map[string]types.Pod{ - testProfile: podA, - "prefill": podB, + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + testProfile: endpointA, + "prefill": endpointB, }) scorer.PreRequest(ctx, request, schedulingResult) assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") - assert.Equal(t, 2, scorer.getPodCount(podA.GetPod().NamespacedName.String())) - assert.Equal(t, 1, scorer.getPodCount(podB.GetPod().NamespacedName.String())) + assert.Equal(t, 2, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String())) + assert.Equal(t, 1, scorer.getPodCount(endpointB.GetMetadata().NamespacedName.String())) }) } @@ -168,20 +168,20 @@ func TestActiveRequestScorer_ResponseComplete(t *testing.T) { ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) - podA := newTestPod("pod-a", 2) + endpointA := newTestEndpoint("pod-a", 2) request := newTestRequest("test-request-1") // Setup initial state: add request through PreRequest - schedulingResult := newTestSchedulingResult(map[string]types.Pod{ - "test-profile": podA, + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + "test-profile": endpointA, }) scorer.PreRequest(ctx, request, schedulingResult) // Call ResponseComplete - scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, podA.GetPod()) + scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, endpointA.GetMetadata()) assert.False(t, scorer.requestCache.Has(request.RequestId)) - assert.False(t, scorer.hasPodCount(podA.GetPod().NamespacedName.String()), + assert.False(t, scorer.hasPodCount(endpointA.GetMetadata().NamespacedName.String()), "Pod count should be removed after decrement to zero") } @@ -192,10 +192,10 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { params := &ActiveRequestParameters{RequestTimeout: "1s"} scorer := NewActiveRequest(ctx, params) - podA := newTestPod("pod-a", 0) + endpointA := newTestEndpoint("pod-a", 0) request := newTestRequest("test-request-ttl") - schedulingResult := newTestSchedulingResult(map[string]types.Pod{ - "test-profile": podA, + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + "test-profile": endpointA, }) // Add request @@ -210,9 +210,9 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { // Trigger cleanup scorer.requestCache.DeleteExpired() - // Check that pod count is decremented due to TTL expiration + // Check that endpoint count is decremented due to TTL expiration assert.False(t, scorer.hasPodCount("default/pod-a"), - "Pod should be removed from podCounts after TTL expiration") + "Pod should be removed from endpointCounts after TTL expiration") } func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { diff --git a/pkg/plugins/scorer/load_aware.go b/pkg/plugins/scorer/load_aware.go index c4f86d0bb..4f3ef918b 100644 --- a/pkg/plugins/scorer/load_aware.go +++ b/pkg/plugins/scorer/load_aware.go @@ -6,10 +6,9 @@ import ( "fmt" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -25,10 +24,10 @@ type loadAwareParameters struct { } // compile-time type assertion -var _ framework.Scorer = &LoadAware{} +var _ scheduling.Scorer = &LoadAware{} // LoadAwareFactory defines the factory function for the LoadAware -func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := loadAwareParameters{Threshold: QueueThresholdDefault} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -47,19 +46,19 @@ func NewLoadAware(ctx context.Context, queueThreshold int) *LoadAware { } return &LoadAware{ - typedName: plugins.TypedName{Type: LoadAwareType}, + typedName: plugin.TypedName{Type: LoadAwareType}, queueThreshold: float64(queueThreshold), } } // LoadAware scorer that is based on load type LoadAware struct { - typedName plugins.TypedName + typedName plugin.TypedName queueThreshold float64 } // TypedName returns the typed name of the plugin. -func (s *LoadAware) TypedName() plugins.TypedName { +func (s *LoadAware) TypedName() plugin.TypedName { return s.typedName } @@ -69,6 +68,11 @@ func (s *LoadAware) WithName(name string) *LoadAware { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *LoadAware) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + // Score scores the given pod in range of 0-1 // Currently metrics contains number of requests waiting in the queue, there is no information about number of requests // that can be processed in the given pod immediately. @@ -76,20 +80,20 @@ func (s *LoadAware) WithName(name string) *LoadAware { // Pod with requests in the queue will get score between 0.5 and 0. // Score 0 will get pod with number of requests in the queue equal to the threshold used in load-based filter // In the future, pods with additional capacity will get score higher than 0.5 -func (s *LoadAware) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func (s *LoadAware) Score(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) - for _, pod := range pods { - waitingRequests := float64(pod.GetMetrics().WaitingQueueSize) + for _, endpoint := range endpoints { + waitingRequests := float64(endpoint.GetMetrics().WaitingQueueSize) if waitingRequests == 0 { - scoredPods[pod] = 0.5 + scoredEndpoints[endpoint] = 0.5 } else { if waitingRequests > s.queueThreshold { waitingRequests = s.queueThreshold } - scoredPods[pod] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold)) + scoredEndpoints[endpoint] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold)) } } - return scoredPods + return scoredEndpoints } diff --git a/pkg/plugins/scorer/load_aware_test.go b/pkg/plugins/scorer/load_aware_test.go index e693e99b5..c3e43a94b 100644 --- a/pkg/plugins/scorer/load_aware_test.go +++ b/pkg/plugins/scorer/load_aware_test.go @@ -6,56 +6,57 @@ import ( "github.com/google/go-cmp/cmp" - k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + k8stypes "k8s.io/apimachinery/pkg/types" // Import config for thresholds + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestLoadBasedScorer(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{ + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{ + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, - MetricsState: &backendmetrics.MetricsState{ + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, + &fwkdl.Metrics{ WaitingQueueSize: 15, }, - } + nil, + ) tests := []struct { name string - scorer framework.Scorer - req *types.LLMRequest - input []types.Pod - wantScores map[types.Pod]float64 + scorer scheduling.Scorer + req *scheduling.LLMRequest + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { name: "load based scorer", scorer: scorer.NewLoadAware(utils.NewTestContext(t), 10), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "critical", }, - input: []types.Pod{ - podA, podB, podC, + input: []scheduling.Endpoint{ + endpointA, endpointB, endpointC, }, - wantScores: map[types.Pod]float64{ - podA: 0.4, - podB: 0.5, - podC: 0, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.4, + endpointB: 0.5, + endpointC: 0, }, }, } diff --git a/pkg/plugins/scorer/no_hit_lru.go b/pkg/plugins/scorer/no_hit_lru.go index a367caa34..65182199a 100644 --- a/pkg/plugins/scorer/no_hit_lru.go +++ b/pkg/plugins/scorer/no_hit_lru.go @@ -7,19 +7,18 @@ import ( lru "github.com/hashicorp/golang-lru/v2" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" ) const ( // NoHitLRUType is the type of the NoHitLRU scorer NoHitLRUType = "no-hit-lru-scorer" - // defaultLRUSize is the maximum number of pods we'll consider in the cache + // defaultLRUSize is the maximum number of endpoints we'll consider in the cache defaultLRUSize = 1024 // defaultPrefillProfile is the name of the prefill profile @@ -30,7 +29,7 @@ const ( ) // compile-time type assertions -var _ framework.Scorer = &NoHitLRU{} +var _ scheduling.Scorer = &NoHitLRU{} var _ requestcontrol.PreRequest = &NoHitLRU{} // NoHitLRUParameters defines the parameters for the NoHitLRU scorer. @@ -42,7 +41,7 @@ type NoHitLRUParameters struct { // Defaults to "prefix-cache-scorer". PrefixPluginName string `json:"prefixPluginName"` - // LRUSize defines the maximum number of pods to track in the LRU cache. + // LRUSize defines the maximum number of endpoints to track in the LRU cache. LRUSize int `json:"lruSize"` } @@ -52,13 +51,13 @@ type coldRequestState struct { isCold bool } -// Clone implements the plugins.StateData interface -func (c *coldRequestState) Clone() plugins.StateData { +// Clone implements the plugin.StateData interface +func (c *coldRequestState) Clone() plugin.StateData { return &coldRequestState{isCold: c.isCold} } // NoHitLRUFactory defines the factory function for the NoHitLRU -func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := NoHitLRUParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -101,25 +100,25 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU { } return &NoHitLRU{ - typedName: plugins.TypedName{Type: NoHitLRUType}, + typedName: plugin.TypedName{Type: NoHitLRUType}, lruCache: lruCache, - prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, - pluginState: plugins.NewPluginState(ctx), + prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName}, + pluginState: plugin.NewPluginState(ctx), } } -// NoHitLRU scorer that favors pods that were least recently used for cold requests. +// NoHitLRU scorer that favors endpoints that were least recently used for cold requests. // This can help evenly distribute cache growth, since cold requests result in more // new KV blocks. type NoHitLRU struct { - typedName plugins.TypedName - lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order) - prefixPluginTypedName plugins.TypedName - pluginState *plugins.PluginState + typedName plugin.TypedName + lruCache *lru.Cache[string, struct{}] // endpoint name -> dummy value (we only care about order) + prefixPluginTypedName plugin.TypedName + pluginState *plugin.PluginState } // TypedName returns the typed name of the plugin. -func (s *NoHitLRU) TypedName() plugins.TypedName { +func (s *NoHitLRU) TypedName() plugin.TypedName { return s.typedName } @@ -129,14 +128,19 @@ func (s *NoHitLRU) WithName(name string) *NoHitLRU { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *NoHitLRU) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + // isColdRequest determines if a request is cold by reading the prefix cache state. // Returns true if no prefix cache hits were found, or if prefix cache state is unavailable. -func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleState) bool { +func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *scheduling.CycleState) bool { logger := log.FromContext(ctx).V(logutil.DEBUG) // Read prefix cache state to determine if this is a cold request // This is treated as an optimization - if the state isn't available, we assume cold request - prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginTypedName.String())) + prefixState, err := scheduling.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugin.StateKey(s.prefixPluginTypedName.String())) if err != nil { logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err) @@ -147,17 +151,17 @@ func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleSta return len(prefixState.PrefixCacheServers) == 0 } -// scoreNeutral returns neutral scores (0.5) for all pods. +// scoreNeutral returns neutral scores (0.5) for all endpoints. // Used when a request has cache hits and LRU optimization should not apply. -func (s *NoHitLRU) scoreNeutral(pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scoredPods[pod] = 0.5 +func (s *NoHitLRU) scoreNeutral(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scoredEndpoints[endpoint] = 0.5 } - return scoredPods + return scoredEndpoints } -// getLRUPositions returns a map of pod names to their LRU position. +// getLRUPositions returns a map of endpoint names to their LRU position. // Position 0 represents the oldest (least recently used) entry. func (s *NoHitLRU) getLRUPositions() map[string]int { // Get all keys from LRU cache in order (oldest first) @@ -171,105 +175,105 @@ func (s *NoHitLRU) getLRUPositions() map[string]int { return lruPosition } -// partitionPodsByUsage separates pods into those that have received cold requests +// partitionPodsByUsage separates endpoints into those that have received cold requests // (usedPods) and those that have never received cold requests (neverUsedPods). -func (s *NoHitLRU) partitionPodsByUsage(pods []types.Pod, lruPosition map[string]int) (usedPods, neverUsedPods []types.Pod) { - for _, pod := range pods { - podName := pod.GetPod().NamespacedName.String() - if _, exists := lruPosition[podName]; exists { - usedPods = append(usedPods, pod) +func (s *NoHitLRU) partitionPodsByUsage(endpoints []scheduling.Endpoint, lruPosition map[string]int) (usedEndpoints, neverUsedEndpoints []scheduling.Endpoint) { + for _, endpoint := range endpoints { + endpointName := endpoint.GetMetadata().NamespacedName.String() + if _, exists := lruPosition[endpointName]; exists { + usedEndpoints = append(usedEndpoints, endpoint) } else { - neverUsedPods = append(neverUsedPods, pod) + neverUsedEndpoints = append(neverUsedEndpoints, endpoint) } } - return usedPods, neverUsedPods + return usedEndpoints, neverUsedEndpoints } -// scoreNeverUsedPods assigns scores to pods that have never received a cold request. -// The first never-used pod gets the highest score (1.0), with subsequent pods +// scoreNeverUsedEndpoints assigns scores to endpoints that have never received a cold request. +// The first never-used endpoint gets the highest score (1.0), with subsequent endpoints // receiving progressively lower scores. -func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[types.Pod]float64, neverUsedPods []types.Pod, totalPods int) { +func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[scheduling.Endpoint]float64, neverUsedPods []scheduling.Endpoint, totalEndpoints int) { // Avoid possibility of dividing by zero. - if totalPods <= 1 { + if totalEndpoints <= 1 { return } - for i, pod := range neverUsedPods { - score := 1.0 - float64(i)/float64(totalPods-1) - scoredPods[pod] = score + for i, endpoint := range neverUsedPods { + score := 1.0 - float64(i)/float64(totalEndpoints-1) + scoredPods[endpoint] = score } } -// scoreUsedPods assigns scores to pods based on their LRU position. +// scoreUsedPods assigns scores to endpoints based on their LRU position. // Pods that were least recently used for cold requests receive higher scores. -func (s *NoHitLRU) scoreUsedPods(scoredPods map[types.Pod]float64, usedPods []types.Pod, lruPosition map[string]int, neverUsedCount, totalPods int) { +func (s *NoHitLRU) scoreUsedPods(scoredEndpoints map[scheduling.Endpoint]float64, usedPods []scheduling.Endpoint, lruPosition map[string]int, neverUsedCount, totalEndpoints int) { // Avoid possibility of dividing by zero. - if totalPods <= 1 { + if totalEndpoints <= 1 { return } - for _, pod := range usedPods { - podName := pod.GetPod().NamespacedName.String() - lruPos := lruPosition[podName] + for _, endpoint := range usedPods { + endpointName := endpoint.GetMetadata().NamespacedName.String() + lruPos := lruPosition[endpointName] // LRU keys are oldest to newest so rank 0 = oldest - // The never used pod count is added to the rank so that - // a never-used pod will always have the highest score. + // The never used endpoint count is added to the rank so that + // a never-used endpoint will always have the highest score. rank := neverUsedCount + lruPos - score := 1.0 - float64(rank)/float64(totalPods-1) + score := 1.0 - float64(rank)/float64(totalEndpoints-1) if score < 0 { score = 0 } - scoredPods[pod] = score + scoredEndpoints[endpoint] = score } } -// scoreColdRequestByLRU scores pods based on their LRU position for cold requests. +// scoreColdRequestByLRU scores endpoints based on their LRU position for cold requests. // Pods that have never received a cold request get the highest scores. -// Among previously used pods, least recently used ones get higher scores. -func (s *NoHitLRU) scoreColdRequestByLRU(pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64, len(pods)) - totalPods := len(pods) +// Among previously used endpoints, least recently used ones get higher scores. +func (s *NoHitLRU) scoreColdRequestByLRU(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints)) + totalEndpoints := len(endpoints) // Avoid possibility of dividing by zero. - if totalPods == 1 { - scoredPods[pods[0]] = 1.0 - return scoredPods + if totalEndpoints == 1 { + scoredEndpoints[endpoints[0]] = 1.0 + return scoredEndpoints } lruPosition := s.getLRUPositions() - usedPods, neverUsedPods := s.partitionPodsByUsage(pods, lruPosition) + usedEndpoints, neverUsedEndpoints := s.partitionPodsByUsage(endpoints, lruPosition) - s.scoreNeverUsedPods(scoredPods, neverUsedPods, totalPods) - s.scoreUsedPods(scoredPods, usedPods, lruPosition, len(neverUsedPods), totalPods) + s.scoreNeverUsedPods(scoredEndpoints, neverUsedEndpoints, totalEndpoints) + s.scoreUsedPods(scoredEndpoints, usedEndpoints, lruPosition, len(neverUsedEndpoints), totalEndpoints) - return scoredPods + return scoredEndpoints } -// Score scores the given pods based on LRU for cold requests. -// For cache hits, returns neutral scores (0.5) for all pods. -// For cache misses, ranks pods by their LRU order. -// - LRU ordering is with respect to when a pod last received a cold request. -// - Least recently used (or never used) pods get highest score (1.0) -// - Most recently used pods get lowest score (approaching 0.0) -func (s *NoHitLRU) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +// Score scores the given endpoints based on LRU for cold requests. +// For cache hits, returns neutral scores (0.5) for all endpoints. +// For cache misses, ranks endpoints by their LRU order. +// - LRU ordering is with respect to when a endpoint last received a cold request. +// - Least recently used (or never used) endpoints get highest score (1.0) +// - Most recently used endpoints get lowest score (approaching 0.0) +func (s *NoHitLRU) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { logger := log.FromContext(ctx).V(logutil.DEBUG) isCold := s.isColdRequest(ctx, cycleState) // Store the cold request state in plugin state for PreRequest to use coldState := &coldRequestState{isCold: isCold} - s.pluginState.Write(request.RequestId, plugins.StateKey(s.typedName.String()), coldState) + s.pluginState.Write(request.RequestId, plugin.StateKey(s.typedName.String()), coldState) if !isCold { logger.Info("Cache hit detected, returning neutral scores") - return s.scoreNeutral(pods) + return s.scoreNeutral(endpoints) } - logger.Info("Cold request detected, scoring pods by LRU") - return s.scoreColdRequestByLRU(pods) + logger.Info("Cold request detected, scoring endpoints by LRU") + return s.scoreColdRequestByLRU(endpoints) } -// PreRequest is called before a request is sent to the target pod. -// For cold requests, it updates the LRU cache to track which pods have been used recently. -func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { +// PreRequest is called before a request is sent to the target endpoint. +// For cold requests, it updates the LRU cache to track which endpoints have been used recently. +func (s *NoHitLRU) PreRequest(ctx context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { logger := log.FromContext(ctx).V(logutil.DEBUG) if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { @@ -278,7 +282,7 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc } // Read the cold request state we stored in Score - coldState, err := plugins.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugins.StateKey(s.typedName.String())) + coldState, err := plugin.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugin.StateKey(s.typedName.String())) // After fetching the cold state, drop it from the plugin state immediately (otherwise it will hang around until it becomes stale). s.pluginState.Delete(request.RequestId) @@ -292,23 +296,23 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc return } - if targetProfile, ok := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]; ok && targetProfile != nil && len(targetProfile.TargetPods) != 0 { + if targetProfile, ok := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 { s.moveTargetPodToFront(ctx, request, targetProfile, schedulingResult.PrimaryProfileName) } - if targetProfile, ok := schedulingResult.ProfileResults[defaultPrefillProfile]; ok && targetProfile != nil && len(targetProfile.TargetPods) != 0 { + if targetProfile, ok := schedulingResult.ProfileResults[defaultPrefillProfile]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 { s.moveTargetPodToFront(ctx, request, targetProfile, defaultPrefillProfile) } } -func (s *NoHitLRU) moveTargetPodToFront(ctx context.Context, request *types.LLMRequest, targetProfile *types.ProfileRunResult, profileName string) { +func (s *NoHitLRU) moveTargetPodToFront(ctx context.Context, request *scheduling.LLMRequest, targetProfile *scheduling.ProfileRunResult, profileName string) { logger := log.FromContext(ctx).V(logutil.DEBUG) - targetPod := targetProfile.TargetPods[0] - podName := targetPod.GetPod().NamespacedName.String() + targetPod := targetProfile.TargetEndpoints[0] + endpointName := targetPod.GetMetadata().NamespacedName.String() - // Move the pod to the front of the LRU. + // Move the endpoint to the front of the LRU. var present struct{} // dummy value - s.lruCache.Add(podName, present) + s.lruCache.Add(endpointName, present) - logger.Info("Updated LRU cache for cold request", "profile", profileName, "pod", podName, "requestId", request.RequestId) + logger.Info("Updated LRU cache for cold request", "profile", profileName, "endpoint", endpointName, "requestId", request.RequestId) } diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go index a03dd3aae..2af62b021 100644 --- a/pkg/plugins/scorer/no_hit_lru_test.go +++ b/pkg/plugins/scorer/no_hit_lru_test.go @@ -9,49 +9,47 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) -var _ plugins.Handle = &fakeHandle{} +var _ plugin.Handle = &fakeHandle{} type fakeHandle struct { ctx context.Context - plugins map[string]plugins.Plugin + plugins map[string]plugin.Plugin } func newFakeHandle(ctx context.Context) *fakeHandle { - return &fakeHandle{ctx: ctx, plugins: map[string]plugins.Plugin{}} + return &fakeHandle{ctx: ctx, plugins: map[string]plugin.Plugin{}} } func (h *fakeHandle) Context() context.Context { return h.ctx } -func (h *fakeHandle) Plugin(name string) plugins.Plugin { +func (h *fakeHandle) Plugin(name string) plugin.Plugin { return h.plugins[name] } -func (h *fakeHandle) AddPlugin(name string, plugin plugins.Plugin) { +func (h *fakeHandle) AddPlugin(name string, plugin plugin.Plugin) { h.plugins[name] = plugin } -func (h *fakeHandle) GetAllPlugins() []plugins.Plugin { - result := make([]plugins.Plugin, 0, len(h.plugins)) +func (h *fakeHandle) GetAllPlugins() []plugin.Plugin { + result := make([]plugin.Plugin, 0, len(h.plugins)) for _, plugin := range h.plugins { result = append(result, plugin) } return result } -func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugins.Plugin { +func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugin.Plugin { return h.plugins } @@ -60,10 +58,10 @@ func (h *fakeHandle) PodList() []k8stypes.NamespacedName { } type stubPlugin struct { - name plugins.TypedName + name plugin.TypedName } -func (p *stubPlugin) TypedName() plugins.TypedName { +func (p *stubPlugin) TypedName() plugin.TypedName { return p.name } @@ -84,7 +82,7 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) { name: "prefix plugin present - should work", handle: func() *fakeHandle { h := newFakeHandle(utils.NewTestContext(t)) - h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}}) + h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}}) return h }(), expectError: false, @@ -123,87 +121,90 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) { } func TestNoHitLRUScorer(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, + &fwkdl.Metrics{}, + nil, + ) tests := []struct { name string - scorer framework.Scorer - req *types.LLMRequest - input []types.Pod + scorer scheduling.Scorer + req *scheduling.LLMRequest + input []scheduling.Endpoint prefixState *prefix.SchedulingContextState - wantScores map[types.Pod]float64 + wantScores map[scheduling.Endpoint]float64 description string }{ { - name: "cold request - all pods never used", + name: "cold request - all endpoints never used", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA, podB, podC}, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request }, - wantScores: map[types.Pod]float64{ - podA: 1.0, // All never-used pods get high scores - podB: 0.5, - podC: 0.0, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1.0, // All never-used endpoints get high scores + endpointB: 0.5, + endpointC: 0.0, }, - description: "Never-used pods should get high scores for cold requests", + description: "Never-used endpoints should get high scores for cold requests", }, { name: "cache hit - neutral scores", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA, podB, podC}, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: map[prefix.ServerID]int{ {Name: "server1", Namespace: "default"}: 5, // non-empty = cache hit }, }, - wantScores: map[types.Pod]float64{ - podA: 0.5, // All pods get neutral scores for cache hits - podB: 0.5, - podC: 0.5, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.5, // All endpoints get neutral scores for cache hits + endpointB: 0.5, + endpointC: 0.5, }, description: "Cache hits should return neutral scores", }, { - name: "single pod - max score", + name: "single endpoint - max score", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA}, + input: []scheduling.Endpoint{endpointA}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request }, - wantScores: map[types.Pod]float64{ - podA: 1.0, // Single pod gets max score + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1.0, // Single endpoint gets max score }, - description: "Single pod should get maximum score", + description: "Single endpoint should get maximum score", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Create cycle state and set prefix state - cycleState := &types.CycleState{} + cycleState := &scheduling.CycleState{} if test.prefixState != nil { - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), test.prefixState) } @@ -221,42 +222,44 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) { scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) - pods := []types.Pod{podA, podB} + endpoints := []scheduling.Endpoint{endpointA, endpointB} // Test basic scoring for cold request (no crashes, returns valid scores) coldPrefixState := &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request } - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), coldPrefixState) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) - // Should return scores for all pods + // Should return scores for all endpoints if len(scores) != 2 { t.Errorf("Expected 2 scores, got %d", len(scores)) } // All scores should be valid (between 0 and 1) - for pod, score := range scores { + for endpoint, score := range scores { if score < 0 || score > 1 { - t.Errorf("Invalid score %f for pod %s", score, pod.GetPod().NamespacedName.String()) + t.Errorf("Invalid score %f for endpoint %s", score, endpoint.GetMetadata().NamespacedName.String()) } } - // For never-used pods, should have different scores (to provide ordering) - if scores[podA] == scores[podB] { - t.Errorf("Expected different scores for different pods, both got %f", scores[podA]) + // For never-used endpoints, should have different scores (to provide ordering) + if scores[endpointA] == scores[endpointB] { + t.Errorf("Expected different scores for different endpoints, both got %f", scores[endpointA]) } } @@ -264,16 +267,17 @@ func TestNoPrefixCacheStateFound(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - pods := []types.Pod{podA} - cycleState := &types.CycleState{} + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpoints := []scheduling.Endpoint{endpointA} + cycleState := &scheduling.CycleState{} - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) - if scores[podA] != 1.0 { + if scores[endpointA] != 1.0 { t.Errorf("Failure to find a prefix cache should result in scoring as a cold request.") } } @@ -282,120 +286,123 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - pods := []types.Pod{podA, podB, podC} + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpoints := []scheduling.Endpoint{endpointA, endpointB, endpointC} primaryProfile := "primary-profile" - toPrefixState := func(entries map[prefix.ServerID]int) *types.CycleState { - cycle := &types.CycleState{} - cycle.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + toPrefixState := func(entries map[prefix.ServerID]int) *scheduling.CycleState { + cycle := &scheduling.CycleState{} + cycle.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{PrefixCacheServers: entries}) return cycle } - requestToPod := func(target types.Pod) *types.SchedulingResult { - return &types.SchedulingResult{ + requestToEndpoint := func(target scheduling.Endpoint) *scheduling.SchedulingResult { + return &scheduling.SchedulingResult{ PrimaryProfileName: primaryProfile, - ProfileResults: map[string]*types.ProfileRunResult{ + ProfileResults: map[string]*scheduling.ProfileRunResult{ primaryProfile: { - TargetPods: []types.Pod{target}, + TargetEndpoints: []scheduling.Endpoint{target}, }, }, } } // Test LRU behavior indirectly through scoring rather than internal state - assertHighestScoredPod := func(expectedPod types.Pod, testName string) { + assertHighestScoredPod := func(expectedEndpoint scheduling.Endpoint, testName string) { t.Helper() - coldReq := &types.LLMRequest{RequestId: testName + "-scoring-check"} - scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, pods) + coldReq := &scheduling.LLMRequest{RequestId: testName + "-scoring-check"} + scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, endpoints) highestScore := -1.0 - var highestPod types.Pod - for pod, score := range scores { + var highestEndpoint scheduling.Endpoint + for endpoint, score := range scores { if score > highestScore { highestScore = score - highestPod = pod + highestEndpoint = endpoint } } - if highestPod.GetPod().NamespacedName.String() != expectedPod.GetPod().NamespacedName.String() { + if highestEndpoint.GetMetadata().NamespacedName.String() != expectedEndpoint.GetMetadata().NamespacedName.String() { t.Fatalf("expected %s to have highest score for LRU behavior, but %s had highest score (%f). All scores: %+v", - expectedPod.GetPod().NamespacedName.String(), - highestPod.GetPod().NamespacedName.String(), + expectedEndpoint.GetMetadata().NamespacedName.String(), + highestEndpoint.GetMetadata().NamespacedName.String(), highestScore, scores) } } t.Run("initial cold request seeds cache", func(_ *testing.T) { - coldReqA := &types.LLMRequest{RequestId: "cold-1"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, pods) - scorer.PreRequest(ctx, coldReqA, requestToPod(podA)) - // After podA handles a cold request, other pods should score higher for new cold requests - assertHighestScoredPod(podB, "after-podA-used") + coldReqA := &scheduling.LLMRequest{RequestId: "cold-1"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, endpoints) + scorer.PreRequest(ctx, coldReqA, requestToEndpoint(endpointA)) + // After endpointA handles a cold request, other endpoints should score higher for new cold requests + assertHighestScoredPod(endpointB, "after-endpointA-used") }) - t.Run("unused pods rank above existing ones", func(t *testing.T) { - coldReqCheck := &types.LLMRequest{RequestId: "cold-check"} - coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, pods) - if coldScores[podB] <= coldScores[podA] { - t.Fatalf("expected pod-b to outrank pod-a after pod-a handled previous cold request, scores=%+v", coldScores) + t.Run("unused endpoints rank above existing ones", func(t *testing.T) { + coldReqCheck := &scheduling.LLMRequest{RequestId: "cold-check"} + coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, endpoints) + if coldScores[endpointB] <= coldScores[endpointA] { + t.Fatalf("expected endpoint-b to outrank endpoint-a after endpoint-a handled previous cold request, scores=%+v", coldScores) } - if coldScores[podB] != 1.0 { - t.Fatalf("expected pod-b to score 1.0, scores=%+v", coldScores) + if coldScores[endpointB] != 1.0 { + t.Fatalf("expected endpoint-b to score 1.0, scores=%+v", coldScores) } - if coldScores[podC] != 0.5 { - t.Fatalf("expected pod-c to score 0.5, scores=%+v", coldScores) + if coldScores[endpointC] != 0.5 { + t.Fatalf("expected endpoint-c to score 0.5, scores=%+v", coldScores) } }) t.Run("warm request leaves LRU untouched", func(t *testing.T) { - warmReq := &types.LLMRequest{RequestId: "warm-1"} + warmReq := &scheduling.LLMRequest{RequestId: "warm-1"} warmState := map[prefix.ServerID]int{ {Name: "server1", Namespace: "default"}: 1, } - warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, pods) + warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, endpoints) for _, score := range warmScores { if score != 0.5 { t.Fatalf("expected neutral score for warm request, got %f", score) } } - scorer.PreRequest(ctx, warmReq, requestToPod(podB)) - postWarmReq := &types.LLMRequest{RequestId: "cold-after-warm"} - postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, pods) - if postWarmScores[podB] <= postWarmScores[podA] { + scorer.PreRequest(ctx, warmReq, requestToEndpoint(endpointB)) + postWarmReq := &scheduling.LLMRequest{RequestId: "cold-after-warm"} + postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, endpoints) + if postWarmScores[endpointB] <= postWarmScores[endpointA] { t.Fatalf("expected warm request to leave ordering unchanged, scores=%+v", postWarmScores) } }) - t.Run("second cold request rotates to podB", func(_ *testing.T) { - // Simulate podB handling a cold request - coldReqB := &types.LLMRequest{RequestId: "cold-2"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, pods) - scorer.PreRequest(ctx, coldReqB, requestToPod(podB)) - // Now podC should score highest since both podA and podB have been used - assertHighestScoredPod(podC, "after-podB-used") + t.Run("second cold request rotates to endpointB", func(_ *testing.T) { + // Simulate endpointB handling a cold request + coldReqB := &scheduling.LLMRequest{RequestId: "cold-2"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, endpoints) + scorer.PreRequest(ctx, coldReqB, requestToEndpoint(endpointB)) + // Now endpointC should score highest since both endpointA and endpointB have been used + assertHighestScoredPod(endpointC, "after-endpointB-used") }) - t.Run("third cold request rotates back to podA", func(_ *testing.T) { - // Simulate podC handling a cold request - coldReqC := &types.LLMRequest{RequestId: "cold-3"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, pods) - scorer.PreRequest(ctx, coldReqC, requestToPod(podC)) - // Now podA should score highest again (LRU rotation) - assertHighestScoredPod(podA, "after-podC-used") + t.Run("third cold request rotates back to endpointA", func(_ *testing.T) { + // Simulate endpointC handling a cold request + coldReqC := &scheduling.LLMRequest{RequestId: "cold-3"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, endpoints) + scorer.PreRequest(ctx, coldReqC, requestToEndpoint(endpointC)) + // Now endpointA should score highest again (LRU rotation) + assertHighestScoredPod(endpointA, "after-endpointC-used") }) } @@ -403,85 +410,90 @@ func TestNoHitLRUEdgeCases(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) - t.Run("empty pods list", func(t *testing.T) { - emptyPods := []types.Pod{} - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("empty endpoints list", func(t *testing.T) { + emptyEndpoints := []scheduling.Endpoint{} + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, emptyPods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, emptyEndpoints) if len(scores) != 0 { - t.Errorf("Expected empty scores for empty pods list, got %d scores", len(scores)) + t.Errorf("Expected empty scores for empty endpoints list, got %d scores", len(scores)) } }) - t.Run("nil pods list", func(t *testing.T) { - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("nil endpoints list", func(t *testing.T) { + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, nil) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, nil) if scores == nil { - t.Errorf("Expected non-nil scores map for nil pods list") + t.Errorf("Expected non-nil scores map for nil endpoints list") } if len(scores) != 0 { - t.Errorf("Expected empty scores for nil pods list, got %d scores", len(scores)) + t.Errorf("Expected empty scores for nil endpoints list, got %d scores", len(scores)) } }) - t.Run("single pod returns 1.0", func(t *testing.T) { - pods := []types.Pod{podA} - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("single endpoint returns 1.0", func(t *testing.T) { + endpoints := []scheduling.Endpoint{endpointA} + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) - if scores[podA] != 1.0 { - t.Errorf("Expected single pod to get score 1.0, got %f", scores[podA]) + if scores[endpointA] != 1.0 { + t.Errorf("Expected single endpoint to get score 1.0, got %f", scores[endpointA]) } }) } func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { - // Prefill worker pods - prefillPodA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "prefill-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - prefillPodB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "prefill-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - - // Decode worker pods - decodePodA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "decode-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - decodePodB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "decode-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - - prefillPods := []types.Pod{prefillPodA, prefillPodB} - decodePods := []types.Pod{decodePodA, decodePodB} - - coldPrefixState := &types.CycleState{} - coldPrefixState.Write(plugins.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + // Prefill worker endpoints + prefillEndpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + prefillEndpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + + // Decode worker endpoints + decodeEndpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + decodeEndpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + + prefillEndpoints := []scheduling.Endpoint{prefillEndpointA, prefillEndpointB} + decodeEndpoints := []scheduling.Endpoint{decodeEndpointA, decodeEndpointB} + + coldPrefixState := &scheduling.CycleState{} + coldPrefixState.Write(plugin.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request }) @@ -491,18 +503,18 @@ func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { scorer := scorer.NewNoHitLRU(ctx, nil) // First cold request with P/D - req1 := &types.LLMRequest{RequestId: "pd-request-1"} - scorer.Score(ctx, coldPrefixState, req1, append(prefillPods, decodePods...)) + req1 := &scheduling.LLMRequest{RequestId: "pd-request-1"} + scorer.Score(ctx, coldPrefixState, req1, append(prefillEndpoints, decodeEndpoints...)) // Simulate scheduling result with both prefill and decode profiles - pdResult := &types.SchedulingResult{ + pdResult := &scheduling.SchedulingResult{ PrimaryProfileName: "decode", - ProfileResults: map[string]*types.ProfileRunResult{ + ProfileResults: map[string]*scheduling.ProfileRunResult{ "prefill": { - TargetPods: []types.Pod{prefillPodA}, + TargetEndpoints: []scheduling.Endpoint{prefillEndpointA}, }, "decode": { - TargetPods: []types.Pod{decodePodA}, + TargetEndpoints: []scheduling.Endpoint{decodeEndpointA}, }, }, } @@ -510,30 +522,30 @@ func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { // Second cold request - both prefillPodB and decodePodB should score higher // since prefillPodA and decodePodA were just used - req2 := &types.LLMRequest{RequestId: "pd-request-2"} - prefillScores := scorer.Score(ctx, coldPrefixState, req2, prefillPods) - decodeScores := scorer.Score(ctx, coldPrefixState, req2, decodePods) + req2 := &scheduling.LLMRequest{RequestId: "pd-request-2"} + prefillScores := scorer.Score(ctx, coldPrefixState, req2, prefillEndpoints) + decodeScores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints) - if prefillScores[prefillPodB] <= prefillScores[prefillPodA] { + if prefillScores[prefillEndpointB] <= prefillScores[prefillEndpointA] { t.Errorf("Expected prefill-b to score higher than prefill-a after prefill-a was used: %+v", prefillScores) } - if decodeScores[decodePodB] <= decodeScores[decodePodA] { + if decodeScores[decodeEndpointB] <= decodeScores[decodeEndpointA] { t.Errorf("Expected decode-b to score higher than decode-a after decode-a was used: %+v", decodeScores) } }) t.Run("non-P/D scenario - only primary profile exists", func(t *testing.T) { - req := &types.LLMRequest{RequestId: "non-pd-request"} + req := &scheduling.LLMRequest{RequestId: "non-pd-request"} scorer := scorer.NewNoHitLRU(ctx, nil) - scorer.Score(ctx, coldPrefixState, req, decodePods) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) // Scheduling result with only decode profile (no prefill) - result := &types.SchedulingResult{ + result := &scheduling.SchedulingResult{ PrimaryProfileName: "decode", - ProfileResults: map[string]*types.ProfileRunResult{ + ProfileResults: map[string]*scheduling.ProfileRunResult{ "decode": { - TargetPods: []types.Pod{decodePodA}, + TargetEndpoints: []scheduling.Endpoint{decodeEndpointA}, }, // No "prefill" profile in results }, @@ -542,31 +554,31 @@ func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { scorer.PreRequest(ctx, req, result) // Verify decodePodA was tracked - req2 := &types.LLMRequest{RequestId: "non-pd-request-2"} - scores := scorer.Score(ctx, coldPrefixState, req2, decodePods) + req2 := &scheduling.LLMRequest{RequestId: "non-pd-request-2"} + scores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints) - if scores[decodePodB] <= scores[decodePodA] { + if scores[decodeEndpointB] <= scores[decodeEndpointA] { t.Errorf("Expected decode-b to score higher than decode-a: %+v", scores) } }) t.Run("nil scheduling result - graceful handling", func(_ *testing.T) { - req := &types.LLMRequest{RequestId: "nil-result"} + req := &scheduling.LLMRequest{RequestId: "nil-result"} scorer := scorer.NewNoHitLRU(ctx, nil) - scorer.Score(ctx, coldPrefixState, req, decodePods) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) // Should not panic with nil result scorer.PreRequest(ctx, req, nil) }) t.Run("empty profile results - graceful handling", func(_ *testing.T) { - req := &types.LLMRequest{RequestId: "empty-results"} + req := &scheduling.LLMRequest{RequestId: "empty-results"} scorer := scorer.NewNoHitLRU(ctx, nil) - scorer.Score(ctx, coldPrefixState, req, decodePods) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) - result := &types.SchedulingResult{ + result := &scheduling.SchedulingResult{ PrimaryProfileName: "decode", - ProfileResults: map[string]*types.ProfileRunResult{}, + ProfileResults: map[string]*scheduling.ProfileRunResult{}, } // Should not panic with empty profile results scorer.PreRequest(ctx, req, result) diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 39a2ac844..f839625d8 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -14,11 +14,10 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" ) const ( @@ -33,7 +32,7 @@ type PrecisePrefixCachePluginConfig struct { // used to process tokens into KV-block keys. TokenProcessorConfig *kvblock.TokenProcessorConfig `json:"tokenProcessorConfig"` // IndexerConfig holds the configuration for the `kvcache.Indexer` which is - // used to score pods based on the KV-cache index state. + // used to score endpoints based on the KV-cache index state. IndexerConfig *kvcache.Config `json:"indexerConfig"` // KVEventsConfig holds the configuration for the `kvevents.Pool` which is // used to subscribe to KV-cache events and update the internal KV-cache @@ -42,12 +41,12 @@ type PrecisePrefixCachePluginConfig struct { } // compile-time type assertion -var _ framework.Scorer = &PrecisePrefixCacheScorer{} +var _ scheduling.Scorer = &PrecisePrefixCacheScorer{} // PrecisePrefixCachePluginFactory defines the factory function for creating // a new instance of the PrefixCacheTrackingPlugin. func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, - handle plugins.Handle) (plugins.Plugin, error) { + handle plugin.Handle) (plugin.Plugin, error) { indexerConfig, err := kvcache.NewDefaultConfig() if err != nil { return nil, fmt.Errorf("failed to initialize indexer config: %w", err) @@ -91,7 +90,7 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, // based on the provided configuration. The `kvevents.Pool` is started // in a goroutine to listen for KV-cache events and update the internal // KV-cache index state. The `kvcache.Indexer` is also started in a goroutine -// to score pods based on the KV-cache index state. +// to score endpoints based on the KV-cache index state. // // If the configuration is invalid or if the indexer fails to initialize, // an error is returned. @@ -117,7 +116,7 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr subscribersManager := kvevents.NewSubscriberManager(pool) var subscribersCache *ttlcache.Cache[string, struct{}] - // initialize the subscribers cache only if pod discovery is enabled + // initialize the subscribers cache only if endpoint discovery is enabled if config.KVEventsConfig.DiscoverPods { // initialize the subscribers TTL cache subscriptionTimeout := 10 * time.Minute @@ -142,7 +141,7 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr } return &PrecisePrefixCacheScorer{ - typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, + typedName: plugin.TypedName{Type: PrecisePrefixCachePluginType}, kvCacheIndexer: kvCacheIndexer, subscribersCache: subscribersCache, subscribersManager: subscribersManager, @@ -152,16 +151,16 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr // PrecisePrefixCacheScorer implements the framework.Scorer interface. // The scorer implements precise prefix-cache KV-block locality scoring. -// It uses the `kvcache.Indexer` to score pods based on the KV-cache index +// It uses the `kvcache.Indexer` to score endpoints based on the KV-cache index // state, and the `kvevents.Pool` to subscribe to KV-cache events // to keep the internal KV-cache index state up-to-date. type PrecisePrefixCacheScorer struct { - typedName plugins.TypedName + typedName plugin.TypedName kvCacheIndexer *kvcache.Indexer // until the IGW data-layer is ready to provide endpoint events, - // we maintain a TTL cache of known pods that are discovered through - // the scoring process. If a pod is not in the received endpoints list + // we maintain a TTL cache of known endpoints that are discovered through + // the scoring process. If a endpoint is not in the received endpoints list // during scoring for a certain period, we consider it gone and // stop its KV events subscription. subscribersCache *ttlcache.Cache[string, struct{}] @@ -170,7 +169,7 @@ type PrecisePrefixCacheScorer struct { } // TypedName returns the typed name of the plugin. -func (s *PrecisePrefixCacheScorer) TypedName() plugins.TypedName { +func (s *PrecisePrefixCacheScorer) TypedName() plugin.TypedName { return s.typedName } @@ -180,27 +179,32 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor return s } -// Score scores the provided pod based on the KVCache index state. +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *PrecisePrefixCacheScorer) Category() scheduling.ScorerCategory { + return scheduling.Affinity +} + +// Score scores the provided endpoint based on the KVCache index state. // The returned scores are normalized to a range of 0-1. -func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) if s.kvEventsConfig.DiscoverPods { // update subscribers here temporarily - for _, pod := range pods { - podObj := pod.GetPod() - if podObj == nil { + for _, endpoint := range endpoints { + endpointObj := endpoint.GetMetadata() + if endpointObj == nil { continue } - podKey := podObj.NamespacedName.String() - s.subscribersCache.Set(podKey, struct{}{}, 0) // use default TTL + endpointKey := endpointObj.NamespacedName.String() + s.subscribersCache.Set(endpointKey, struct{}{}, 0) // use default TTL - if err := s.subscribersManager.EnsureSubscriber(context.Background(), podKey, // dont use request ctx - fmt.Sprintf("tcp://%s:%d", podObj.Address, s.kvEventsConfig.PodDiscoveryConfig.SocketPort), + if err := s.subscribersManager.EnsureSubscriber(context.Background(), endpointKey, // dont use request ctx + fmt.Sprintf("tcp://%s:%d", endpointObj.Address, s.kvEventsConfig.PodDiscoveryConfig.SocketPort), s.kvEventsConfig.TopicFilter, true); err != nil { - logger.Error(err, "Failed to ensure KV-events subscriber for pod", "pod", podKey, - "endpoint", podObj.Address) + logger.Error(err, "Failed to ensure KV-events subscriber for endpoint", "endpoint", endpointKey, + "endpoint", endpointObj.Address) continue } } @@ -213,41 +217,41 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. scores, err := s.getScores(ctx, request) if err != nil { - logger.Error(err, "Failed to get pod scores") + logger.Error(err, "Failed to get endpoint scores") return nil } - debugLogger.Info("Got pod scores", "scores", scores) + debugLogger.Info("Got endpoint scores", "scores", scores) - podToKey := func(pod types.Pod) (string, bool) { - metricsPod := pod.GetPod() - if metricsPod == nil { + endpointToKey := func(endpoint scheduling.Endpoint) (string, bool) { + metadata := endpoint.GetMetadata() + if metadata == nil { return "", false } - return metricsPod.Address, true + return metadata.Address, true } state := &prefix.SchedulingContextState{ PrefixHashes: []prefix.BlockHash{}, PrefixCacheServers: map[prefix.ServerID]int{}, } - for _, pod := range pods { - key, ok := podToKey(pod) + for _, endpoint := range endpoints { + key, ok := endpointToKey(endpoint) if !ok { continue } - state.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] = int(scores[key]) + state.PrefixCacheServers[prefix.ServerID(endpoint.GetMetadata().NamespacedName)] = int(scores[key]) } - cycleState.Write(plugins.StateKey(s.typedName.String()), state) + cycleState.Write(plugin.StateKey(s.typedName.String()), state) - return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) + return indexedScoresToNormalizedScoredPods(endpoints, endpointToKey, scores) } -// getScores retrieves the pod scores from the KV-cache indexer +// getScores retrieves the endpoint scores from the KV-cache indexer // based on the provided LLM request. // If the request contains chat completions, it processes them accordingly. // If the request contains regular completions, it uses the prompt directly. -func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) { +func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *scheduling.LLMRequest) (map[string]float64, error) { logger := log.FromContext(ctx).WithName(s.typedName.String()) traceLogger := logger.V(logutil.TRACE) @@ -289,7 +293,7 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil) if err != nil { - return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err) + return nil, fmt.Errorf("failed to get endpoint scores for chat/completions: %w", err) } return scores, nil } @@ -301,7 +305,7 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil) if err != nil { - return nil, fmt.Errorf("failed to get pod scores for completions: %w", err) + return nil, fmt.Errorf("failed to get endpoint scores for completions: %w", err) } return scores, nil } diff --git a/pkg/plugins/scorer/precise_prefix_cache_test.go b/pkg/plugins/scorer/precise_prefix_cache_test.go index 1a8bf9eec..968497798 100644 --- a/pkg/plugins/scorer/precise_prefix_cache_test.go +++ b/pkg/plugins/scorer/precise_prefix_cache_test.go @@ -13,9 +13,8 @@ import ( "github.com/llm-d/llm-d-kv-cache/pkg/tokenization" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) @@ -37,34 +36,38 @@ func TestPrefixCacheTracking_Score(t *testing.T) { testcases := []struct { name string - pods []types.Pod - request *types.LLMRequest - kvBlockData func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry + endpoints []scheduling.Endpoint + request *scheduling.LLMRequest + kvBlockData func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry wantScoresByAddress map[string]float64 }{ { name: "nil request", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, + nil, + nil, + ), }, wantScoresByAddress: map[string]float64{}, // empty map }, { name: "empty request body", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", Body: nil, @@ -73,45 +76,48 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "longest prefix scorer (default scorer)", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") prompt := req.Completions.Prompt @@ -129,7 +135,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test") // populate kvblock.Index to test longest prefix matching: - // - chunk0 (first chunk): all pods have it (common prefix start) + // - chunk0 (first chunk): all endpoints have it (common prefix start) // - chunk1: pod-a and pod-b have it (pod-c drops off after chunk0) // - chunk2: only pod-a has it (pod-b drops off after chunk1) // LongestPrefixScorer uses intersection, so: @@ -161,51 +167,53 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "chat completions request", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - ChatCompletions: &types.ChatCompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + ChatCompletions: &scheduling.ChatCompletionsRequest{ ChatTemplate: `{% for message in messages %}{{ message.role }}: {{ message.content }} {% endfor %}`, - Messages: []types.Message{ + Messages: []scheduling.Message{ { Role: "user", - Content: types.Content{Raw: "Hello, how are you?"}, + Content: scheduling.Content{Raw: "Hello, how are you?"}, }, { Role: "assistant", - Content: types.Content{Raw: "I'm doing well, thank you for asking!"}, + Content: scheduling.Content{Raw: "I'm doing well, thank you for asking!"}, }, { Role: "user", - Content: types.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"}, + Content: scheduling.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"}, }, }, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.ChatCompletions, "req expected to use ChatCompletions API") // convert to preprocessing format @@ -263,45 +271,48 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "partial prefix", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) @@ -316,7 +327,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test") // Test partial prefix cache scenario: - // - chunk0: all pods (common prefix start) + // - chunk0: all endpoints (common prefix start) // - chunk1: only pod-a (creates a gap for pod-b and pod-c) // - chunk2: pod-a and pod-b (pod-b has this but missing chunk1) // @@ -349,28 +360,29 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, }, { - name: "single pod", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + name: "single endpoint", + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) @@ -384,7 +396,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test") - // Single pod has 2 chunks cached + // Single endpoint has 2 chunks cached return map[kvblock.BlockHash][]kvblock.PodEntry{ chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, @@ -395,81 +407,93 @@ func TestPrefixCacheTracking_Score(t *testing.T) { } }, wantScoresByAddress: map[string]float64{ - // with only one pod, minScore == maxScore, so normalization returns 1.0 + // with only one endpoint, minScore == maxScore, so normalization returns 1.0 "10.0.0.1:8080": 1.0, }, }, { name: "no cache hits (empty index)", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ - Prompt: "This prompt has never been cached before on any pod.", + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ + Prompt: "This prompt has never been cached before on any endpoint.", }, }, }, kvBlockData: nil, // no cached data wantScoresByAddress: map[string]float64{ - // when no pods have any cache hits, all should get equal scores (0.0) + // when no endpoints have any cache hits, all should get equal scores (0.0) "10.0.0.1:8080": 0.0, "10.0.0.2:8080": 0.0, "10.0.0.3:8080": 0.0, }, }, { - name: "all pods have equal prefix length", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + name: "all endpoints have equal prefix length", + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) @@ -483,7 +507,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test") - // all pods have the same 2 chunks cached + // all endpoints have the same 2 chunks cached return map[kvblock.BlockHash][]kvblock.PodEntry{ chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, @@ -498,8 +522,8 @@ func TestPrefixCacheTracking_Score(t *testing.T) { } }, wantScoresByAddress: map[string]float64{ - // when all pods have equal cache (minScore == maxScore), the implementation - // returns 1.0 for all pods to avoid division by zero + // when all endpoints have equal cache (minScore == maxScore), the implementation + // returns 1.0 for all endpoints to avoid division by zero "10.0.0.1:8080": 1.0, "10.0.0.2:8080": 1.0, "10.0.0.3:8080": 1.0, @@ -537,12 +561,12 @@ func TestPrefixCacheTracking_Score(t *testing.T) { } } - got := prefixCacheScorer.Score(ctx, types.NewCycleState(), tt.request, tt.pods) + got := prefixCacheScorer.Score(ctx, scheduling.NewCycleState(), tt.request, tt.endpoints) gotByAddress := make(map[string]float64) - for pod, score := range got { - if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil { - gotByAddress[podMetrics.GetPod().Address] = score + for endpoint, score := range got { + if endpoint.GetMetadata() != nil { + gotByAddress[endpoint.GetMetadata().Address] = score } } diff --git a/pkg/plugins/scorer/session_affinity.go b/pkg/plugins/scorer/session_affinity.go index 3ac9230c6..87e9d2be9 100644 --- a/pkg/plugins/scorer/session_affinity.go +++ b/pkg/plugins/scorer/session_affinity.go @@ -6,12 +6,11 @@ import ( "encoding/json" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -22,18 +21,18 @@ const ( ) // compile-time type assertion -var _ framework.Scorer = &SessionAffinity{} +var _ scheduling.Scorer = &SessionAffinity{} var _ requestcontrol.ResponseComplete = &SessionAffinity{} // SessionAffinityFactory defines the factory function for SessionAffinity scorer. -func SessionAffinityFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func SessionAffinityFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewSessionAffinity().WithName(name), nil } // NewSessionAffinity returns a scorer func NewSessionAffinity() *SessionAffinity { return &SessionAffinity{ - typedName: plugins.TypedName{Type: SessionAffinityType}, + typedName: plugin.TypedName{Type: SessionAffinityType}, } } @@ -42,11 +41,11 @@ func NewSessionAffinity() *SessionAffinity { // session was sent to, by giving that pod the specified weight and assigning // zero score to the rest of the targets type SessionAffinity struct { - typedName plugins.TypedName + typedName plugin.TypedName } // TypedName returns the typed name of the plugin. -func (s *SessionAffinity) TypedName() plugins.TypedName { +func (s *SessionAffinity) TypedName() plugin.TypedName { return s.typedName } @@ -56,9 +55,14 @@ func (s *SessionAffinity) WithName(name string) *SessionAffinity { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *SessionAffinity) Category() scheduling.ScorerCategory { + return scheduling.Affinity +} + // Score assign a high score to the pod used in previous requests and zero to others -func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func (s *SessionAffinity) Score(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) sessionToken := request.Headers[sessionTokenHeader] podName := "" @@ -70,21 +74,21 @@ func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, reques podName = string(decodedBytes) } } - for _, pod := range pods { - scoredPods[pod] = 0.0 // initial value - if pod.GetPod().NamespacedName.String() == podName { - scoredPods[pod] = 1.0 + for _, endpoint := range endpoints { + scoredEndpoints[endpoint] = 0.0 // initial value + if endpoint.GetMetadata().NamespacedName.String() == podName { + scoredEndpoints[endpoint] = 1.0 } } - return scoredPods + return scoredEndpoints } // ResponseComplete sets the session header on the response sent to the client // TODO: this should be using a cookie and ensure not overriding any other // cookie values if present. // Tracked in https://github.com/llm-d/llm-d-inference-scheduler/issues/28 -func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { +func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *scheduling.LLMRequest, response *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) { if response == nil || targetPod == nil { reqID := "undefined" if response != nil { diff --git a/pkg/plugins/scorer/session_affinity_test.go b/pkg/plugins/scorer/session_affinity_test.go index 943b06eb4..d7acf3468 100644 --- a/pkg/plugins/scorer/session_affinity_test.go +++ b/pkg/plugins/scorer/session_affinity_test.go @@ -8,79 +8,80 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestSessionAffinity_Score(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - - inputPods := []types.Pod{podA, podB} - - // valid session token for podB - validSessionTokenForPodB := base64.StdEncoding.EncodeToString([]byte(podB.GetPod().NamespacedName.String())) + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) + + inputEndpoints := []scheduling.Endpoint{endpointA, endpointB} + + // valid session token for endpointB + validSessionTokenForEndpointB := base64.StdEncoding.EncodeToString([]byte(endpointB.GetMetadata().NamespacedName.String())) sessionAffinityScorer := scorer.NewSessionAffinity() tests := []struct { name string - req *types.LLMRequest - input []types.Pod - wantScores map[types.Pod]float64 + req *scheduling.LLMRequest + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { - name: "selects correct pod : podB", - req: &types.LLMRequest{ - Headers: map[string]string{"x-session-token": validSessionTokenForPodB}, + name: "selects correct endpoint : endpointB", + req: &scheduling.LLMRequest{ + Headers: map[string]string{"x-session-token": validSessionTokenForEndpointB}, }, - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 1.0, + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 1.0, }, }, { name: "no session token", - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ Headers: map[string]string{}, }, - // both pods get score 0.0 - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 0.0, + // both endpoints get score 0.0 + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 0.0, }, }, { name: "invalid session token", - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ Headers: map[string]string{"x-session-token": "garbage-token"}, }, // expect same behavior as no session token - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 0.0, + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 0.0, }, }, { - name: "no pods available", - req: &types.LLMRequest{}, - input: []types.Pod{}, + name: "no endpoints available", + req: &scheduling.LLMRequest{}, + input: []scheduling.Endpoint{}, // returns empty score map - wantScores: map[types.Pod]float64{}, + wantScores: map[scheduling.Endpoint]float64{}, }, } @@ -97,30 +98,30 @@ func TestSessionAffinity_Score(t *testing.T) { func TestSessionAffinity_ResponseComplete(t *testing.T) { - targetPod := &backend.Pod{ + targetEndpoint := &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, Address: "1.2.3.4", } // expected token to be set in response header - wantToken := base64.StdEncoding.EncodeToString([]byte(targetPod.NamespacedName.String())) + wantToken := base64.StdEncoding.EncodeToString([]byte(targetEndpoint.NamespacedName.String())) tests := []struct { name string initialResponse *requestcontrol.Response - targetPod *backend.Pod + targetPod *fwkdl.EndpointMetadata wantHeaders map[string]string }{ { name: "standard case with existing headers map", initialResponse: &requestcontrol.Response{RequestId: "req-1", Headers: make(map[string]string)}, - targetPod: targetPod, + targetPod: targetEndpoint, wantHeaders: map[string]string{"x-session-token": wantToken}, }, { name: "response with nil headers map", initialResponse: &requestcontrol.Response{RequestId: "req-2", Headers: nil}, - targetPod: targetPod, + targetPod: targetEndpoint, wantHeaders: map[string]string{"x-session-token": wantToken}, }, { diff --git a/pkg/plugins/scorer/utils.go b/pkg/plugins/scorer/utils.go index 31a721b71..4d4b3c741 100644 --- a/pkg/plugins/scorer/utils.go +++ b/pkg/plugins/scorer/utils.go @@ -3,41 +3,41 @@ package scorer import ( "math" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) -// podToKey is a function type that converts a Pod to a string key. +// endpointToKey is a function type that converts a Pod to a string key. // It returns the key and a boolean indicating success. -type podToKeyFunc func(pod types.Pod) (string, bool) +type endpointToKeyFunc func(endpoint scheduling.Endpoint) (string, bool) // indexedScoresToNormalizedScoredPods converts a map of pod scores to a map of // normalized scores. The function takes a list of pods, a function to convert // a pod to a key, and a map of scores indexed by those keys. It returns a map // of pods to their normalized scores. -func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, - scores map[string]float64) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func indexedScoresToNormalizedScoredPods(endpoints []scheduling.Endpoint, endpointToKey endpointToKeyFunc, + scores map[string]float64) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) minScore, maxScore := getMinMax(scores) - for _, pod := range pods { - key, ok := podToKey(pod) + for _, endpoint := range endpoints { + key, ok := endpointToKey(endpoint) if !ok { continue } if score, ok := scores[key]; ok { if minScore == maxScore { - scoredPods[pod] = 1.0 + scoredEndpoints[endpoint] = 1.0 continue } - scoredPods[pod] = (score - minScore) / (maxScore - minScore) + scoredEndpoints[endpoint] = (score - minScore) / (maxScore - minScore) } else { - scoredPods[pod] = 0.0 + scoredEndpoints[endpoint] = 0.0 } } - return scoredPods + return scoredEndpoints } func getMinMax(scores map[string]float64) (float64, float64) { diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go index bd1f6b1c1..06efcc1f0 100644 --- a/pkg/scheduling/pd/scheduler_test.go +++ b/pkg/scheduling/pd/scheduler_test.go @@ -11,14 +11,12 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/controller-runtime/pkg/log" // Import config for thresholds + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + fwkschd "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" @@ -32,42 +30,45 @@ const ( // Tests the scheduler expected behavior. func TestPDSchedule(t *testing.T) { - pod1 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + endpoint1 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "endpoint1"}, Address: "1.2.3.4", Labels: map[string]string{filter.RoleLabel: filter.RolePrefill}, }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}, - } - pod2 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + &fwkdl.Metrics{WaitingQueueSize: 0}, + nil, + ) + endpoint2 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "endpoint2"}, Address: "5.6.7.8", Labels: map[string]string{filter.RoleLabel: filter.RoleDecode}, }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}, - } - noRolePod1 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "noRolePod1"}, + &fwkdl.Metrics{WaitingQueueSize: 0}, + nil, + ) + noRoleEndpoint1 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "noRoleEndpoint1"}, Address: "1.1.1.1", }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 2}, - } + &fwkdl.Metrics{WaitingQueueSize: 2}, + nil, + ) - prefillDecodeResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - decode: {TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod2, + prefillDecodeResult := &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ + decode: {TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint2, }, }, }, prefill: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint1, }, }, }, @@ -76,12 +77,12 @@ func TestPDSchedule(t *testing.T) { PrimaryProfileName: decode, } - decodeResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ + decodeResult := &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ decode: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod2, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint2, }, }, }, @@ -91,114 +92,114 @@ func TestPDSchedule(t *testing.T) { tests := []struct { name string - req *types.LLMRequest - input []types.Pod - wantRes *types.SchedulingResult - wantRes2 *types.SchedulingResult // a subsequent call to check prefix cache and how it affects PD + req *fwkschd.LLMRequest + input []fwkschd.Endpoint + wantRes *fwkschd.SchedulingResult + wantRes2 *fwkschd.SchedulingResult // a subsequent call to check prefix cache and how it affects PD err bool }{ { - name: "no candidate pods", - req: &types.LLMRequest{ + name: "no candidate endpoints", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "any-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - input: []types.Pod{}, + input: []fwkschd.Endpoint{}, err: true, }, { - name: "one decode pod, long prompt", - req: &types.LLMRequest{ + name: "one decode endpoint, long prompt", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - // pod2 will be picked because it is the only pod with Decode role - input: []types.Pod{pod2}, + // endpoint2 will be picked because it is the only endpoint with Decode role + input: []fwkschd.Endpoint{endpoint2}, wantRes: decodeResult, }, { - name: "one prefill pod, long prompt", - req: &types.LLMRequest{ + name: "one prefill endpoint, long prompt", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - // no Decode pod - input: []types.Pod{pod1}, + // no Decode endpoint + input: []fwkschd.Endpoint{endpoint1}, err: true, }, { name: "1P1D - long prompt", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678906", }, }, }, - // pod2 will be picked in the decode profile result, pod1 will be in the prefill profile result - input: []types.Pod{pod1, pod2}, + // endpoint2 will be picked in the decode profile result, endpoint1 will be in the prefill profile result + input: []fwkschd.Endpoint{endpoint1, endpoint2}, wantRes: prefillDecodeResult, wantRes2: decodeResult, }, { name: "1P1Dshort", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345", }, }, }, - // pod2 will be picked because it is the decode pod, pod1 shouldn't be picked, + // endpoint2 will be picked because it is the decode endpoint, endpoint1 shouldn't be picked, // because the prompt is too short - input: []types.Pod{pod1, pod2}, + input: []fwkschd.Endpoint{endpoint1, endpoint2}, wantRes: decodeResult, wantRes2: decodeResult, }, { name: "TestRolesWithNoDecode", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - input: []types.Pod{pod1, noRolePod1}, - wantRes: &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ + input: []fwkschd.Endpoint{endpoint1, noRoleEndpoint1}, + wantRes: &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ decode: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: noRolePod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: noRoleEndpoint1, }, }, }, prefill: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint1, }, }, }, @@ -208,18 +209,18 @@ func TestPDSchedule(t *testing.T) { }, { name: "1P2D - long prompt", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678906", }, }, }, - // pod2 will be picked in the decode profile result cause it has higher score than noRolePod1 - // pod1 will be in the prefill profile result - input: []types.Pod{pod1, pod2, noRolePod1}, + // endpoint2 will be picked in the decode profile result cause it has higher score than noRoleEndpoint1 + // endpoint1 will be in the prefill profile result + input: []fwkschd.Endpoint{endpoint1, endpoint2, noRoleEndpoint1}, wantRes: prefillDecodeResult, wantRes2: decodeResult, }, @@ -232,24 +233,24 @@ func TestPDSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // initialize scheduler with config - prefixScorer := prefix.New(ctx, prefix.Config{BlockSize: 5, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) + prefixScorer, _ := prefix.New(ctx, prefix.Config{BlockSizeTokens: 1, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) - prefillSchedulerProfile := framework.NewSchedulerProfile(). + prefillSchedulerProfile := scheduling.NewSchedulerProfile(). WithFilters(filter.NewPrefillRole()). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - err := prefillSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 50)) + err := prefillSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 50)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - decodeSchedulerProfile := framework.NewSchedulerProfile(). + decodeSchedulerProfile := scheduling.NewSchedulerProfile(). WithFilters(filter.NewDecodeRole()). - WithScorers(framework.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)). + WithScorers(scheduling.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0)) + err = decodeSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 0)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 5, 0) + profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 1, 0) - schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{ + schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]fwkschd.SchedulerProfile{ prefill: prefillSchedulerProfile, decode: decodeSchedulerProfile, }) @@ -260,7 +261,7 @@ func TestPDSchedule(t *testing.T) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" { + if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } @@ -274,7 +275,7 @@ func TestPDSchedule(t *testing.T) { t.Errorf("Unexpected error in schedule call, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" { + if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" { t.Errorf("Unexpected output in subsequent schedule call (-want +got): %v", diff) } } diff --git a/test/config/prefix_cache_mode_test.go b/test/config/prefix_cache_mode_test.go index 5a13d0087..3965f9246 100644 --- a/test/config/prefix_cache_mode_test.go +++ b/test/config/prefix_cache_mode_test.go @@ -7,7 +7,7 @@ import ( "github.com/go-logr/logr" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/config/loader" - giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" "sigs.k8s.io/gateway-api-inference-extension/test/utils" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 027a49686..f339866f0 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -457,7 +457,7 @@ plugins: - type: prefill-header-handler - type: prefix-cache-scorer parameters: - hashBlockSize: 10 + blockSizeTokens: 10 maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 256 - type: prefill-filter @@ -465,7 +465,8 @@ plugins: - type: max-score-picker - type: pd-profile-handler parameters: - threshold: 10 + hashBlockSize: 10 + threshold: 40 schedulingProfiles: - name: prefill plugins: