diff --git a/deploy/config/pd-epp-config.yaml b/deploy/config/pd-epp-config.yaml index 9732be2f3..35a94a3be 100644 --- a/deploy/config/pd-epp-config.yaml +++ b/deploy/config/pd-epp-config.yaml @@ -1,13 +1,25 @@ # Sample EPP configuration for tunning with P/D apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer + parameters: + maxPrefixBlocksToMatch: 256 + lruCapacityPerServer: 31250 +- type: queue-scorer - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler + parameters: + primaryPort: ${PRIMARY_PORT} + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: @@ -15,9 +27,13 @@ schedulingProfiles: - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 - name: decode plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 diff --git a/deploy/config/sim-pd-epp-config.yaml b/deploy/config/sim-pd-epp-config.yaml index 2d6a85dd9..2f93504a1 100644 --- a/deploy/config/sim-pd-epp-config.yaml +++ b/deploy/config/sim-pd-epp-config.yaml @@ -2,21 +2,27 @@ # Use with small hash block size for simulation purposes apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer parameters: - hashBlockSize: 5 + blockSizeTokens: 16 + autoTune: false maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 31250 +- type: queue-scorer - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler parameters: - threshold: 10 - hashBlockSize: 5 primaryPort: ${PRIMARY_PORT} + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: @@ -24,9 +30,13 @@ schedulingProfiles: - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 - name: decode plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 diff --git a/docs/architecture.md b/docs/architecture.md index efdb92597..c6ab17198 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -206,15 +206,41 @@ Selects the profiles to use when running with disaggregated prefill/decode - **Type**: `pd-profile-handler` - **Parameters**: - - `threshold`: specifies the threshold at which there are enough new input tokens to send the request to prefill and then decode, vs just to decode. - - `hashBlockSize`: specifies the length of the prompt chunk that a block is keyed by. This must the same value used for the PrefixCachePlugin. - `decodeProfile`: specifies the name of the profile used for the decode scheduling. Only needed if the decode profile is not named `decode`. - `prefillProfile`: specifies the name of the profile used for the prefill scheduling. Only needed if the prefill profile is not named `prefill`. + - `deciderPluginName`: specifies the name of the decider plugin. Decider determines whether disaggregated PD should be executed + - `primaryPort`: the base port number used for data parallel communication. **Note:** When using this plugin you must also have a PrefixCachePlugin configured in the prefill and decode scheduling profiles. --- +#### Prefix Based Decider Plugin + +Type: `prefix-based-pd-decider` + +**Parameters** +- `nonCachedTokens`: length, in token, of the uncached part of the user input above which disaggregated PD is triggered. + +Note: `prepareDataPlugins` feature gate should be enabled + +**Example** +```yaml +kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins +plugins: +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 4 +- type: pd-profile-handler + parameters: + primaryPort: 8000 + deciderPluginName: prefix-based-pd-decider +``` + +--- + #### ByLabelSelector Filters out pods using a standard Kubernetes label selector. diff --git a/docs/disagg_pd.md b/docs/disagg_pd.md index ad459bcb0..4c7cabd23 100644 --- a/docs/disagg_pd.md +++ b/docs/disagg_pd.md @@ -155,6 +155,8 @@ Below is a minimal `EndpointPickerConfig` that enables integration with workload ```yaml apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: # Prefill selection: match Pods with label role=prefill - type: by-label @@ -176,10 +178,12 @@ plugins: lruCapacityPerServer: 31250 - type: max-score-picker - type: prefill-header-handler - - type: pd-profile-handler + - type: prefix-based-pd-decider parameters: - threshold: 0 - hashBlockSize: 5 + nonCachedTokens: 8 + - type: pd-profile-handler + parameters: + deciderPluginName: prefix-based-pd-decider primaryPort: 8000 schedulingProfiles: - name: prefill @@ -200,6 +204,59 @@ schedulingProfiles: ![Disaggregated Prefill/Decode Architecture](./images/dp_architecture.png) +--- +## PD Deciders + +PD deciders are pd handler plugins responsible for determining whether disaggregated P/D should be executed for a given request, based on the properties of the request prompt. + + +### Prefix-Based PD Decider + +The `prefix-based-pd-decider` plugin makes the disaggregation decision according to the length of the non-cached suffix of the prompt relative to tokens already cached on the selected decode pod. + +**How It Works** +- Once a decode pod is selected, the decider checks how many tokens from the incoming prompt have already been sent to this pod + +- If the remaining non-cached suffix length is longer than the configured threshold (nonCachedTokens), disaggregation is triggered — the prefill will run remotely on a prefill pod, and decode locally on the decode pod + +- If the non-cached suffix is shorter or equal to the threshold, the full request runs locally on the decode worker without remote prefill + +**Configuration** +```yaml +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 8 +``` + +**Parameter:** + +- `nonCachedTokens`: Number of non-cached tokens that trigger disaggregation + - If set to 0, disaggregation always occurs for all requests + +**Feature Gate Requirement** +To activate this decider, ensure the following feature gate is enabled in your EndpointPickerConfig + +```yaml +featureGates: +- prepareDataPlugins +``` + + +### Always-Disagg PD Decider +The `always-disagg-pd-decider` is a simpler alternative used mainly for testing or benchmarking. +It always triggers disaggregation, regardless of prefix cache state or prompt characteristics. + +**Configuration example:** + +```yaml +- type: always-disagg-pd-decider +``` + +**Notes:** +This plugin accepts no parameters. + +It’s useful for validating end-to-end prefill/decode splitting and comparing system performance under forced disaggregation. + --- ## References diff --git a/pkg/plugins/profile/always_disagg_decider.go b/pkg/plugins/profile/always_disagg_decider.go new file mode 100644 index 000000000..1fefd47fe --- /dev/null +++ b/pkg/plugins/profile/always_disagg_decider.go @@ -0,0 +1,48 @@ +package profile + +import ( + "context" + "encoding/json" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" +) + +const ( + // AlwaysDisaggDeciderPluginType is the type-name of the alwaysDisaggPDDecider plugin. + AlwaysDisaggDeciderPluginType = "always-disagg-pd-decider" +) + +// compile-time type assertion +var _ pdDeciderPlugin = &AlwaysDisaggPDDecider{} + +// AlwaysDisaggPDDecider is a PD decider plugin which always decide to disaggregate PD +type AlwaysDisaggPDDecider struct { + typedName plugin.TypedName +} + +// AlwaysDisaggPDDeciderPluginFactory defines the factory function for creating +// a new instance of the AlwaysDisaggPDDecider. +func AlwaysDisaggPDDeciderPluginFactory(name string, _ json.RawMessage, + _ plugin.Handle) (plugin.Plugin, error) { + return newAlwaysDisaggPDDecider().WithName(name), nil +} + +func newAlwaysDisaggPDDecider() *AlwaysDisaggPDDecider { + return &AlwaysDisaggPDDecider{} +} + +// TypedName returns the typed name of the plugin. +func (d *AlwaysDisaggPDDecider) TypedName() plugin.TypedName { + return d.typedName +} + +// WithName sets the name of the plugin. +func (d *AlwaysDisaggPDDecider) WithName(name string) *AlwaysDisaggPDDecider { + d.typedName.Name = name + return d +} + +func (d *AlwaysDisaggPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool { + return true +} diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index 4b86b471f..3f5a4b932 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -23,36 +23,42 @@ const ( // PdProfileHandlerType is the type of the PdProfileHandler PdProfileHandlerType = "pd-profile-handler" - defaultDecodeProfile = "decode" - defaultPrefillProfile = "prefill" - defaultPrefixPluginType = prefix.PrefixCachePluginType + defaultDecodeProfile = "decode" + defaultPrefillProfile = "prefill" + defaultPrefixPluginType = prefix.PrefixCachePluginType + defaultDeciderPluginName = AlwaysDisaggDeciderPluginType - // An estimated average characters per token, used since the request we cached is not tokenized. - averageCharactersPerToken = 4 + // AverageCharactersPerToken is an estimated average characters per token, used since the request we cached is not tokenized. + AverageCharactersPerToken = 4 ) +// pdDeciderPlugin interface for pd decider plugins +type pdDeciderPlugin interface { + plugin.Plugin + // disaggregate checks if disaggregated PD is required for the given request and endpoint. + disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool +} + type pdProfileHandlerParameters struct { - Threshold int `json:"threshold"` - DecodeProfile string `json:"decodeProfile"` - PrefillProfile string `json:"prefillProfile"` - PrefixPluginType string `json:"prefixPluginType"` - PrefixPluginName string `json:"prefixPluginName"` - HashBlockSize int `json:"hashBlockSize"` - PrimaryPort int `json:"primaryPort"` + DecodeProfile string `json:"decodeProfile"` + PrefillProfile string `json:"prefillProfile"` + PrefixPluginType string `json:"prefixPluginType"` + PrefixPluginName string `json:"prefixPluginName"` + PrimaryPort int `json:"primaryPort"` + DeciderPluginName string `json:"deciderPluginName"` } // compile-time type assertion var _ scheduling.ProfileHandler = &PdProfileHandler{} // PdProfileHandlerFactory defines the factory function for the PdProfileHandler -func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { +func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := pdProfileHandlerParameters{ - Threshold: 0, - DecodeProfile: defaultDecodeProfile, - PrefillProfile: defaultPrefillProfile, - PrefixPluginType: defaultPrefixPluginType, - HashBlockSize: prefix.DefaultBlockSizeTokens * averageCharactersPerToken, - PrimaryPort: 0, + DecodeProfile: defaultDecodeProfile, + PrefillProfile: defaultPrefillProfile, + PrefixPluginType: defaultPrefixPluginType, + PrimaryPort: 0, + DeciderPluginName: defaultDeciderPluginName, } if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -64,39 +70,52 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi parameters.PrefixPluginName = parameters.PrefixPluginType } - if parameters.Threshold < 0 { - return nil, fmt.Errorf("invalid threshold: must be >= 0, got %d", parameters.Threshold) - } - - if parameters.HashBlockSize <= 0 { - return nil, fmt.Errorf("invalid hashBlockSize: must be > 0, got %d", parameters.HashBlockSize) - } - if parameters.PrimaryPort != 0 { if parameters.PrimaryPort < 1 || parameters.PrimaryPort > 65535 { return nil, fmt.Errorf("invalid primaryPort: must be between 1 and 65535, got %d", parameters.PrimaryPort) } } - return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName, - parameters.Threshold, parameters.HashBlockSize, parameters.PrimaryPort).WithName(name), nil + if parameters.DeciderPluginName == "" { + return nil, errors.New("decider plugin name is not defined") + } + + plugin := handle.Plugin(parameters.DeciderPluginName) + if plugin == nil { + return nil, fmt.Errorf("invalid decider plugin type: %s", parameters.DeciderPluginName) + } + + deciderPlugin, ok := plugin.(pdDeciderPlugin) + if !ok { + return nil, fmt.Errorf("decider plugin of type: %s does not implement pdDeciderPlugin", parameters.DeciderPluginName) + } + + handler, err := NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName, + parameters.PrimaryPort, deciderPlugin) + + if err != nil { + return nil, err + } + + return handler.WithName(name), nil + } // NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer. -func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler { +func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, + primaryPort int, deciderPlugin pdDeciderPlugin) (*PdProfileHandler, error) { result := &PdProfileHandler{ typedName: plugin.TypedName{Type: PdProfileHandlerType}, prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName}, decodeProfile: decodeProfile, prefillProfile: prefillProfile, - pdThreshold: pdThreshold, - hashBlockSize: hashBlockSize, + decider: deciderPlugin, } if primaryPort != 0 { result.primaryPort = strconv.Itoa(primaryPort) } - return result + return result, nil } // PdProfileHandler handles scheduler profiles for PD. @@ -105,9 +124,8 @@ type PdProfileHandler struct { prefixPluginTypedName plugin.TypedName decodeProfile string prefillProfile string - pdThreshold int - hashBlockSize int primaryPort string + decider pdDeciderPlugin } // TypedName returns the typed name of the plugin. @@ -123,7 +141,7 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler { // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, +func (h *PdProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { if _, executed := profileResults[h.decodeProfile]; !executed { // if decode profile was not executed yet, first let the scheduler run the decode profile @@ -139,41 +157,22 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *scheduling.Cycl return map[string]scheduling.SchedulerProfile{} } - if h.pdThreshold > 0 { - userInput, err := getUserInputBytes(request) - if err != nil { - log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input bytes") - return nil - } - - // if we're here that means decode profile ran successfully, and we have additional profile configured that didn't run yet, - // which means PD is enabled (otherwise, prefill profile is not configured at all and this profile handler is not used). - // inspect decode execution result to decide if prefill should run or not. - // if the request is short enough, use decode results only and don't run the prefill profile. - hitPercentagePrefix := 0.0 // default to 0, meaning no prefix cache hit - prefixState, err := scheduling.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugin.StateKey(h.prefixPluginTypedName.String())) - if err != nil { - log.FromContext(ctx).Error(err, "unable to read prefix state") - } else { - decodeEndpoint := profileResults[h.decodeProfile].TargetEndpoints[0].GetMetadata().NamespacedName - hitPrefix := max(prefixState.PrefixCacheServers[prefix.ServerID(decodeEndpoint)]-1, 0) // The first hit is always the model name - hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize*averageCharactersPerToken) / float64(len(userInput)) - log.FromContext(ctx).V(logutil.DEBUG).Info("Computed hit percentage for prefix cache", "hitPercentage", hitPercentagePrefix, - "promptLength", len(userInput)) - } + inputTokens, err := getUserInputLenInTokens(request) + if err != nil { + log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input") + return nil + } - if (1.0-hitPercentagePrefix)*float64(len(userInput)) < float64(h.pdThreshold) { - log.FromContext(ctx).Info("Non-cached suffix is smaller than threshold, using decode profile only", "hitPercentage", hitPercentagePrefix) - metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) - return map[string]scheduling.SchedulerProfile{} // do not run prefill + if h.decider != nil && h.decider.disaggregate(ctx, inputTokens, profileResults[h.decodeProfile].TargetEndpoints[0]) { + metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) + // run the prefill profile + return map[string]scheduling.SchedulerProfile{ + h.prefillProfile: profiles[h.prefillProfile], } } - metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) - // run the prefill profile - return map[string]scheduling.SchedulerProfile{ - h.prefillProfile: profiles[h.prefillProfile], - } + metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) + return map[string]scheduling.SchedulerProfile{} // do not run prefill } // ProcessResults handles the outcome of the profile runs after the selected profiles ran. @@ -193,17 +192,17 @@ func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *scheduling.Cycle if h.primaryPort != "" { // Data Parallel is active - targetPod := decodeRunResults.TargetEndpoints[0].GetMetadata() - request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port) + targetEndpoint := decodeRunResults.TargetEndpoints[0].GetMetadata() + request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetEndpoint.Address, targetEndpoint.Port) updatedResult := scheduling.ProfileRunResult{ TargetEndpoints: []scheduling.Endpoint{}, } for _, target := range decodeRunResults.TargetEndpoints { - updatedPodInfo := target.GetMetadata().Clone() - updatedPodInfo.Port = h.primaryPort - targetEndpoint := scheduling.NewEndpoint(updatedPodInfo, target.GetMetrics().Clone(), nil) + updatedEndpointInfo := target.GetMetadata().Clone() + updatedEndpointInfo.Port = h.primaryPort + targetEndpoint := scheduling.NewEndpoint(updatedEndpointInfo, target.GetMetrics().Clone(), nil) updatedResult.TargetEndpoints = append(updatedResult.TargetEndpoints, targetEndpoint) } updatedResults[h.decodeProfile] = &updatedResult @@ -223,11 +222,18 @@ func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *scheduling.Cycle }, nil } -func getUserInputBytes(request *scheduling.LLMRequest) ([]byte, error) { +// returns length of user input in tokens +func getUserInputLenInTokens(request *scheduling.LLMRequest) (int, error) { if request.Body.Completions != nil { // assumed to be valid if not nil - return []byte(request.Body.Completions.Prompt), nil + return len([]byte(request.Body.Completions.Prompt)) / AverageCharactersPerToken, nil } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + prompt, err := json.Marshal(request.Body.ChatCompletions.Messages) + + if err != nil { + return 0, err + } + + return len(prompt) / AverageCharactersPerToken, nil } diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go index d138fe384..c932c8064 100644 --- a/pkg/plugins/profile/pd_profile_handler_test.go +++ b/pkg/plugins/profile/pd_profile_handler_test.go @@ -8,6 +8,7 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" @@ -18,118 +19,109 @@ import ( ) func TestPdProfileHandlerFactory(t *testing.T) { + ctx := utils.NewTestContext(t) tests := []struct { name string pluginName string - jsonParams string + params map[string]any expectErr bool }{ { name: "valid configuration with all defaults", pluginName: "default-handler", - jsonParams: "{}", + params: map[string]any{}, expectErr: false, }, { name: "valid configuration with custom values", pluginName: "custom-handler", - jsonParams: `{ - "threshold": 100, - "decodeProfile": "my-decode", - "prefillProfile": "my-prefill", - "prefixPluginName": "my-prefix-cache", - "hashBlockSize": 32, - "primaryPort": 8080 - }`, + params: map[string]any{ + "decodeProfile": "my-decode", + "prefillProfile": "my-prefill", + "prefixPluginName": "my-prefix-cache", + "primaryPort": 8080, + "deciderPluginName": PrefixBasedPDDeciderPluginType, + }, expectErr: false, }, { name: "zero primaryPort is allowed", pluginName: "zero-port", - jsonParams: `{"primaryPort": 0}`, - expectErr: false, - }, - { - name: "threshold = 0 is allowed", - pluginName: "zero-threshold", - jsonParams: `{"threshold": 0}`, - expectErr: false, - }, - { - name: "negative threshold should error", - pluginName: "neg-threshold", - jsonParams: `{"threshold": -1}`, - expectErr: true, - }, - { - name: "hashBlockSize = 0 should error", - pluginName: "zero-block-size", - jsonParams: `{"hashBlockSize": 0}`, - expectErr: true, + params: map[string]any{ + "primaryPort": 0, + }, + expectErr: false, }, { - name: "negative hashBlockSize should error", - pluginName: "neg-block-size", - jsonParams: `{"hashBlockSize": -5}`, - expectErr: true, + name: "nonCachedTokens = 0 is allowed", + pluginName: "zero-non-cached-tokens", + params: map[string]any{ + "deciderPluginName": PrefixBasedPDDeciderPluginType, + }, + expectErr: false, }, { name: "primaryPort below range should error", pluginName: "port-too-low", - jsonParams: `{"primaryPort": 0}`, // OK + params: map[string]any{"primaryPort": 0}, // OK expectErr: false, }, { name: "primaryPort = 1 is valid", pluginName: "port-min", - jsonParams: `{"primaryPort": 1}`, + params: map[string]any{"primaryPort": 1}, expectErr: false, }, { name: "primaryPort = 65535 is valid", pluginName: "port-max", - jsonParams: `{"primaryPort": 65535}`, + params: map[string]any{"primaryPort": 65535}, expectErr: false, }, { name: "empty decodeProfile is valid", pluginName: "empty-decode", - jsonParams: `{"decodeProfile": ""}`, + params: map[string]any{"decodeProfile": ""}, expectErr: false, }, { name: "empty prefillProfile is valid", pluginName: "empty-prefill", - jsonParams: `{"prefillProfile": ""}`, + params: map[string]any{"prefillProfile": ""}, expectErr: false, }, { name: "empty prefixPluginName is valid", pluginName: "empty-prefix-plugin", - jsonParams: `{"prefixPluginName": ""}`, + params: map[string]any{"prefixPluginName": ""}, expectErr: false, }, { name: "primaryPort = 65536 should error", pluginName: "port-too-high", - jsonParams: `{"primaryPort": 65536}`, + params: map[string]any{"primaryPort": 65536}, expectErr: true, }, { name: "primaryPort = -10 should error", pluginName: "port-negative", - jsonParams: `{"primaryPort": -10}`, + params: map[string]any{"primaryPort": -10}, expectErr: true, }, } + handle, err := createHandleWithDeciderPlugins(ctx) + assert.NoError(t, err) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var rawParams json.RawMessage - if tt.jsonParams != "" { - rawParams = json.RawMessage(tt.jsonParams) + if tt.params != nil { + bytes, err := json.Marshal(tt.params) + assert.NoError(t, err) + rawParams = json.RawMessage(bytes) } - plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, nil) + plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, handle) if tt.expectErr { assert.Error(t, err) @@ -143,21 +135,19 @@ func TestPdProfileHandlerFactory(t *testing.T) { } func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { + ctx := utils.NewTestContext(t) + invalidTests := []struct { name string jsonParams string }{ { name: "malformed JSON", - jsonParams: `{"threshold": 100, "hashBlockSize":`, // incomplete - }, - { - name: "threshold as string instead of int", - jsonParams: `{"threshold": "100"}`, + jsonParams: `{"deciderPluginName": `, // incomplete }, { - name: "hashBlockSize as boolean", - jsonParams: `{"hashBlockSize": true}`, + name: "invalid decider plugin type", + jsonParams: `{"deciderPluginName": "INVALID"}`, }, { name: "primaryPort as float", @@ -165,10 +155,13 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { }, } + handle, err := createHandleWithDeciderPlugins(ctx) + assert.NoError(t, err) + for _, tt := range invalidTests { t.Run(tt.name, func(t *testing.T) { rawParams := json.RawMessage(tt.jsonParams) - plugin, err := PdProfileHandlerFactory("test", rawParams, nil) + plugin, err := PdProfileHandlerFactory("test", rawParams, handle) assert.Error(t, err) assert.Nil(t, plugin) @@ -178,7 +171,7 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { const DefaultTestPodPort = "8000" -// createEndpoint creates a mock Pod with customizable IP and port. +// createEndpoint creates a mock Endpoint with customizable IP and port. func createEndpoint(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) scheduling.Endpoint { return scheduling.NewEndpoint( &fwkdl.EndpointMetadata{ @@ -187,8 +180,8 @@ func createEndpoint(nsn k8stypes.NamespacedName, ipaddr, port string, labels map Port: port, Labels: labels, }, - &fwkdl.Metrics{}, nil, + fwkdl.NewAttributes(), ) } @@ -219,15 +212,33 @@ func (p *mockSchedulerProfile) Run(_ context.Context, _ *scheduling.LLMRequest, return &scheduling.ProfileRunResult{}, nil } -func TestPdProfileHandler_Pick(t *testing.T) { - ctx := utils.NewTestContext(t) - request := &scheduling.LLMRequest{ +// creates and returns llm completion request forthe given prompt +func createRequest(prompt string) *scheduling.LLMRequest { + return &scheduling.LLMRequest{ Body: &scheduling.LLMRequestBody{ Completions: &scheduling.CompletionsRequest{ - Prompt: "hello world", + Prompt: prompt, }, }, } +} + +// returns array of profile names in the given profile pick result +func getProfilesFromResult(result map[string]scheduling.SchedulerProfile) []string { + profiles := make([]string, len(result)) + index := 0 + + for name := range result { + profiles[index] = name + index++ + } + + return profiles +} + +func TestPdProfileHandler_Pick(t *testing.T) { + ctx := utils.NewTestContext(t) + request := createRequest("hello world hello world hello world") profiles := map[string]scheduling.SchedulerProfile{ "decode": newMockSchedulerProfile(), @@ -235,114 +246,205 @@ func TestPdProfileHandler_Pick(t *testing.T) { } tests := []struct { - name string - pdThreshold int - hashBlockSize int - prefixPluginType string - prefixPluginName string - setupPrefixState func(*scheduling.CycleState) - profileResults map[string]*scheduling.ProfileRunResult - expectedProfiles []string + name string + nonCachedTokensLimit int + prefixPluginType string + prefixPluginName string + cachedTokens int + profileResults map[string]*scheduling.ProfileRunResult + expectedProfiles []string }{ { - name: "decode not executed yet → run decode", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*scheduling.ProfileRunResult{}, - expectedProfiles: []string{"decode"}, + name: "decode not executed yet → run decode", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{}, + expectedProfiles: []string{defaultDecodeProfile}, }, { - name: "decode failed (nil result) → run nothing", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, + name: "decode failed (nil result) → run nothing", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": nil, + defaultDecodeProfile: nil, }, expectedProfiles: []string{}, }, { - name: "all profiles already executed → run nothing", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, + name: "all profiles already executed → run nothing", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), - "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, expectedProfiles: []string{}, }, { - name: "pd threshold NOT triggered → run prefill", - pdThreshold: 5, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *scheduling.CycleState) { - state := &prefix.SchedulingContextState{ - PrefixCacheServers: map[prefix.ServerID]int{ - prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 1, - }, - } - key := plugin.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) - cs.Write(key, state) - }, + name: "has enough not-cached tokens → run prefill", + // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill + // In this case: prompt length is 35 chars (8 tokens), cached length is 2 tokens -> disaggregated prefill should trigger + nonCachedTokensLimit: 4, + cachedTokens: 2, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, - expectedProfiles: []string{"prefill"}, + expectedProfiles: []string{defaultPrefillProfile}, }, { - name: "pd threshold triggered (short non-cached suffix) → skip prefill", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *scheduling.CycleState) { - state := &prefix.SchedulingContextState{ - PrefixCacheServers: map[prefix.ServerID]int{ - prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 5, - }, - } - key := plugin.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) - cs.Write(key, state) - }, + name: "short non-cached suffix → skip prefill", + // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill + // In this case: prompt length is 35 chars (8 tokens), cached length is 5 tokens -> skip prefill + nonCachedTokensLimit: 4, + cachedTokens: 5, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectedProfiles: []string{}, }, } for _, tt := range tests { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit}) + assert.NoError(t, err) + t.Run(tt.name, func(t *testing.T) { - handler := NewPdProfileHandler( - "prefill", - "decode", + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, tt.prefixPluginType, tt.prefixPluginName, - tt.pdThreshold, - tt.hashBlockSize, 0, - ).WithName("test-handler") + deciderPlugin, + ) + assert.NoError(t, err) + + // set prefix to the given cached tokens number for pod "pod1" in decode profile results + inputTokens := len(request.Body.Completions.Prompt) / AverageCharactersPerToken - cs := &scheduling.CycleState{} - if tt.setupPrefixState != nil { - tt.setupPrefixState(cs) + for profileName, profileRes := range tt.profileResults { + if profileName == defaultDecodeProfile && profileRes != nil { + for _, pod := range profileRes.TargetEndpoints { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, + approximateprefix.NewPrefixCacheMatchInfo(tt.cachedTokens, inputTokens, 1)) + } + } } + result := handler.Pick(ctx, nil, request, profiles, tt.profileResults) + assert.ElementsMatch(t, tt.expectedProfiles, getProfilesFromResult(result)) + }) + } +} - result := handler.Pick(ctx, cs, request, profiles, tt.profileResults) +func TestPdProfileHandler_PickSeries(t *testing.T) { + ctx := context.Background() + prompt := "hello world, hello world, hello world, hello world, hello world, hello world, hello world!" + request := createRequest(prompt) + longerRequest := createRequest(prompt + "123") + longRequest := createRequest(prompt + prompt) - var actual []string - for name := range result { - actual = append(actual, name) - } + profiles := map[string]scheduling.SchedulerProfile{ + defaultDecodeProfile: newMockSchedulerProfile(), + defaultPrefillProfile: newMockSchedulerProfile(), + } + profileResults := map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + } + + type testData struct { + request *scheduling.LLMRequest + cachedTokens int + expectedProfiles []string + } + tests := []struct { + name string + nonCachedTokensLimit int + tests []testData + }{ + { + name: "same request twice", + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: request, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{}, + }}, + }, { + name: "short request and a little bit longer after it", + // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill + // In this case: longer request is longer in 4 chars than the request -> no disaggregated prefill + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: longerRequest, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{}, + }}, + }, { + name: "short request and a long one after it", + // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill + // In this case: long request is longer enough than the request -> should have disaggregated prefill + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: longRequest, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{defaultPrefillProfile}, + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit}) + assert.NoError(t, err) + + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, + prefix.PrefixCachePluginType, + prefix.PrefixCachePluginType, + 0, + deciderPlugin, + ) + assert.NoError(t, err) + + // run sequences of request + for _, innerTest := range tt.tests { + cs := &scheduling.CycleState{} + + // set prefix to the given cached tokens number for pod "pod1" in decode profile results + inputTokens := len(innerTest.request.Body.Completions.Prompt) / AverageCharactersPerToken - assert.ElementsMatch(t, tt.expectedProfiles, actual) + for profileName, profileRes := range profileResults { + if profileName == defaultDecodeProfile && profileRes != nil { + for _, endpoint := range profileRes.TargetEndpoints { + endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey, + approximateprefix.NewPrefixCacheMatchInfo(innerTest.cachedTokens, inputTokens, 1)) + } + } + } + + result := handler.Pick(ctx, cs, innerTest.request, profiles, profileResults) + assert.ElementsMatch(t, innerTest.expectedProfiles, getProfilesFromResult(result)) + } }) } } @@ -358,7 +460,7 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { { name: "decode failed → error", profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": nil, + defaultDecodeProfile: nil, }, expectError: true, }, @@ -366,14 +468,14 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { name: "decode success, no prefill, no primaryPort", primaryPort: 0, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { - assert.Equal(t, "decode", res.PrimaryProfileName) - assert.Contains(t, res.ProfileResults, "decode") - assert.NotContains(t, res.ProfileResults, "prefill") - metadata := res.ProfileResults["decode"].TargetEndpoints[0].GetMetadata() + assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName) + assert.Contains(t, res.ProfileResults, defaultDecodeProfile) + assert.NotContains(t, res.ProfileResults, defaultPrefillProfile) + metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata() assert.Equal(t, DefaultTestPodPort, metadata.Port) assert.Empty(t, headers[common.DataParallelPodHeader]) }, @@ -382,25 +484,25 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { name: "decode success, with prefill", primaryPort: 0, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), - "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, expectError: false, checkResult: func(t *testing.T, res *scheduling.SchedulingResult, _ map[string]string) { - assert.Equal(t, "decode", res.PrimaryProfileName) - assert.Contains(t, res.ProfileResults, "decode") - assert.Contains(t, res.ProfileResults, "prefill") + assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName) + assert.Contains(t, res.ProfileResults, defaultDecodeProfile) + assert.Contains(t, res.ProfileResults, defaultPrefillProfile) }, }, { name: "with primaryPort → port updated and header set", primaryPort: 9000, profileResults: map[string]*scheduling.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { - metadata := res.ProfileResults["decode"].TargetEndpoints[0].GetMetadata() + metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata() assert.Equal(t, "9000", metadata.Port) hostPort := headers[common.DataParallelPodHeader] @@ -410,16 +512,19 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { } for _, tt := range tests { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 0}) + assert.NoError(t, err) + t.Run(tt.name, func(t *testing.T) { - handler := NewPdProfileHandler( - "prefill", - "decode", + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, prefix.PrefixCachePluginType, prefix.PrefixCachePluginType, - 0, - prefix.DefaultBlockSizeTokens*averageCharactersPerToken, tt.primaryPort, - ).WithName("test-handler") + deciderPlugin, + ) + assert.NoError(t, err) headers := make(map[string]string) req := &scheduling.LLMRequest{ @@ -438,3 +543,16 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { }) } } + +func createHandleWithDeciderPlugins(ctx context.Context) (plugin.Handle, error) { + handle := plugin.NewEppHandle(ctx, nil) + plugin1, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 4}) + if err != nil { + return nil, err + } + handle.AddPlugin(PrefixBasedPDDeciderPluginType, plugin1) + plugin2 := newAlwaysDisaggPDDecider() + handle.AddPlugin(AlwaysDisaggDeciderPluginType, plugin2) + + return handle, nil +} diff --git a/pkg/plugins/profile/prefix_based_pd_decider.go b/pkg/plugins/profile/prefix_based_pd_decider.go new file mode 100644 index 000000000..8948b4d37 --- /dev/null +++ b/pkg/plugins/profile/prefix_based_pd_decider.go @@ -0,0 +1,137 @@ +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" +) + +const ( + // PrefixBasedPDDeciderPluginType is the type-name of the prefixBasedPDDecider plugin. + PrefixBasedPDDeciderPluginType = "prefix-based-pd-decider" +) + +// PrefixBasedPDDeciderConfig holds the configuration for the prefixBasedPDDecider plugin. +type PrefixBasedPDDeciderConfig struct { + // NonCachedTokens non cached minimum tokens that triggers disaggregated PD + NonCachedTokens int `json:"nonCachedTokens"` +} + +func (p PrefixBasedPDDeciderConfig) validate() error { + if p.NonCachedTokens < 0 { + return errors.New("nonCachedTokens parameter of prefix disaggregation decider cannot be negative") + } + + return nil +} + +// compile-time type assertion +var _ pdDeciderPlugin = &PrefixBasedPDDecider{} + +// PrefixBasedPDDecider is a PD decider plugin which decision is based prefix aware +type PrefixBasedPDDecider struct { + typedName plugin.TypedName + config PrefixBasedPDDeciderConfig +} + +// PrefixBasedPDDeciderPluginFactory defines the factory function for creating +// a new instance of the prefixBasedPDDecider. +func PrefixBasedPDDeciderPluginFactory(name string, rawParameters json.RawMessage, + handle plugin.Handle) (plugin.Plugin, error) { + config := PrefixBasedPDDeciderConfig{ + NonCachedTokens: 0, + } + + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, &config); err != nil { + return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrefixBasedPDDeciderPluginType, err) + } + } + + decider, err := NewPrefixBasedPDDecider(config) + if err != nil { + return nil, fmt.Errorf("failed to create %s plugin: %w", PrefixBasedPDDeciderPluginType, err) + } + + return decider.WithName(name), nil +} + +// NewPrefixBasedPDDecider initializes a NewPrefixBasedPDDecider prefix based PD decider Plugin and returns its pointer. +// If the configuration is invalid an error is returned. +func NewPrefixBasedPDDecider(config PrefixBasedPDDeciderConfig) (*PrefixBasedPDDecider, error) { + if err := config.validate(); err != nil { + return nil, err + } + + return &PrefixBasedPDDecider{ + config: config, + }, nil +} + +// TypedName returns the typed name of the plugin. +func (d *PrefixBasedPDDecider) TypedName() plugin.TypedName { + return d.typedName +} + +// WithName sets the name of the plugin. +func (d *PrefixBasedPDDecider) WithName(name string) *PrefixBasedPDDecider { + d.typedName.Name = name + return d +} + +func (d *PrefixBasedPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool { + logger := log.FromContext(ctx) + debugLogger := log.FromContext(ctx).V(logutil.DEBUG) + + if d.config.NonCachedTokens <= 0 { // always use disaggregation in case of non cached tokens number is 0 + return true + } + if endpoint == nil { + logger.Error(nil, "prefix decider: endpoint is nil") + return false + } + if inputTokens < d.config.NonCachedTokens { + debugLogger.Info("Input is shorter than the nonCachedToken, no disaggregated PD") + return false + } + // inspect the decode endpoint to decide if prefill should run or not. + // if the non-cached part is short enough - no disaggregation. + prefixInfoRaw, ok := endpoint.Get(approximateprefix.PrefixCacheMatchInfoKey) + if !ok || prefixInfoRaw == nil { + logger.Error(nil, "unable to read prefix cache state") + return false + } + prefixCacheMatchInfo, ok := prefixInfoRaw.(*approximateprefix.PrefixCacheMatchInfo) + if !ok { + logger.Error(nil, "wrong type of prefix cache match info") + return false + } + + // number of cached tokens + hitPrefixTokens := prefixCacheMatchInfo.MatchBlocks() * prefixCacheMatchInfo.BlockSizeTokens() + // length of non-cached suffix in tokens + nonCachedTokens := inputTokens - hitPrefixTokens + + debugLogger.Info("Computed hit percentage for prefix cache", + "absolute hit prefix len (tokens)", hitPrefixTokens, + "prompt length (token)", inputTokens) + + if nonCachedTokens < d.config.NonCachedTokens { + debugLogger.Info("Non-cached suffix is smaller than threshold, using decode profile only") + return false // do not run prefill + } + + return true +} + +// Consumes defines data types consumed by this plugin +func (*PrefixBasedPDDecider) Consumes() map[string]any { + return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}} +} diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index e978de60c..d3c50fd31 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -25,4 +25,7 @@ func RegisterAllPlugins() { plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory) plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory) + // pd decider plugins + plugin.Register(profile.PrefixBasedPDDeciderPluginType, profile.PrefixBasedPDDeciderPluginFactory) + plugin.Register(profile.AlwaysDisaggDeciderPluginType, profile.AlwaysDisaggPDDeciderPluginFactory) } diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go index 06efcc1f0..e068ade12 100644 --- a/pkg/scheduling/pd/scheduler_test.go +++ b/pkg/scheduling/pd/scheduler_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/log" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" fwkschd "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/picker" @@ -37,7 +38,7 @@ func TestPDSchedule(t *testing.T) { Labels: map[string]string{filter.RoleLabel: filter.RolePrefill}, }, &fwkdl.Metrics{WaitingQueueSize: 0}, - nil, + fwkdl.NewAttributes(), ) endpoint2 := fwkschd.NewEndpoint( &fwkdl.EndpointMetadata{ @@ -46,7 +47,7 @@ func TestPDSchedule(t *testing.T) { Labels: map[string]string{filter.RoleLabel: filter.RoleDecode}, }, &fwkdl.Metrics{WaitingQueueSize: 0}, - nil, + fwkdl.NewAttributes(), ) noRoleEndpoint1 := fwkschd.NewEndpoint( &fwkdl.EndpointMetadata{ @@ -54,7 +55,7 @@ func TestPDSchedule(t *testing.T) { Address: "1.1.1.1", }, &fwkdl.Metrics{WaitingQueueSize: 2}, - nil, + fwkdl.NewAttributes(), ) prefillDecodeResult := &fwkschd.SchedulingResult{ @@ -214,7 +215,7 @@ func TestPDSchedule(t *testing.T) { TargetModel: "critical", Body: &fwkschd.LLMRequestBody{ Completions: &fwkschd.CompletionsRequest{ - Prompt: "12345678906", + Prompt: "1234567890123456789012345678901234567890", }, }, }, @@ -233,12 +234,13 @@ func TestPDSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // initialize scheduler with config - prefixScorer, _ := prefix.New(ctx, prefix.Config{BlockSizeTokens: 1, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) + prefixScorer, err := prefix.New(ctx, prefix.Config{AutoTune: false, BlockSizeTokens: 2, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) + assert.NoError(t, err, "Prefix plugin creation returned unexpected error") prefillSchedulerProfile := scheduling.NewSchedulerProfile(). WithFilters(filter.NewPrefillRole()). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - err := prefillSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 50)) + err = prefillSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 50)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") decodeSchedulerProfile := scheduling.NewSchedulerProfile(). @@ -248,13 +250,23 @@ func TestPDSchedule(t *testing.T) { err = decodeSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 0)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 1, 0) + deciderPlugin, err := profile.NewPrefixBasedPDDecider(profile.PrefixBasedPDDeciderConfig{NonCachedTokens: 2}) + assert.NoError(t, err) + + profileHandle, err := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, + 0, deciderPlugin) + assert.NoError(t, err) schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]fwkschd.SchedulerProfile{ prefill: prefillSchedulerProfile, decode: decodeSchedulerProfile, }) scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig) + + inputTokens := len(test.req.Body.Completions.Prompt) / profile.AverageCharactersPerToken + for _, pod := range test.input { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(0, inputTokens, 1)) + } got, err := scheduler.Schedule(ctx, test.req, test.input) if test.err != (err != nil) { @@ -264,12 +276,16 @@ func TestPDSchedule(t *testing.T) { if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } - if test.wantRes2 != nil { // Checking the prefix match in the decode pod. // make sure prefix plugin stores the prefix hit in cache, so we can test it in the following schedule call prefixScorer.PreRequest(ctx, test.req, got) time.Sleep(time.Second) + // update number of cached tokens "stored" in the first schedule execution + for _, pod := range test.input { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(inputTokens, inputTokens, 1)) + } + got, err = scheduler.Schedule(ctx, test.req, test.input) if test.err != (err != nil) { t.Errorf("Unexpected error in schedule call, got %v, want %v", err, test.err) diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index 02da62944..ff400872e 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -293,10 +293,11 @@ func createInferencePool(numTargetPorts int, toDelete bool) []string { } infPoolYaml := testutils.ReadYaml(inferExtManifest) - targetPorts := "" + var b strings.Builder for idx := range numTargetPorts { - targetPorts += fmt.Sprintf("\n - number: %d", 8000+idx) + fmt.Fprintf(&b, "\n - number: %d", 8000+idx) } + targetPorts := b.String() infPoolYaml = substituteMany(infPoolYaml, map[string]string{ "${POOL_NAME}": poolName, diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index 344cf8752..44671ed71 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -126,8 +126,8 @@ var _ = ginkgo.Describe("Run end to end tests", ginkgo.Ordered, func() { labelFilter2 := fmt.Sprintf(`decision_type="decode-only",model_name="%s"`, modelName) decodeOnlyCount := getCounterMetric(metricsURL, "llm_d_inference_scheduler_pd_decision_total", labelFilter2) - gomega.Expect(prefillDecodeCount).Should(gomega.Equal(6)) - gomega.Expect(decodeOnlyCount).Should(gomega.Equal(0)) + gomega.Expect(prefillDecodeCount).Should(gomega.Equal(4)) + gomega.Expect(decodeOnlyCount).Should(gomega.Equal(2)) testutils.DeleteObjects(testConfig, epp) testutils.DeleteObjects(testConfig, modelServers) @@ -843,20 +843,24 @@ schedulingProfiles: // EPP configuration for running with P/D const pdConfig = `apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer parameters: - blockSizeTokens: 10 + blockSizeTokens: 16 maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 256 - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler parameters: - hashBlockSize: 10 - threshold: 40 + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: