Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deploy/config/pd-epp-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ kind: EndpointPickerConfig
plugins:
- type: prefill-header-handler
- type: prefix-cache-scorer
- type: proximity-scorer
parameters:
sameNodeScore: 1.0
defaultScore: 0.0
- type: prefill-filter
- type: decode-filter
- type: max-score-picker
Expand All @@ -15,6 +19,8 @@ schedulingProfiles:
- pluginRef: max-score-picker
- pluginRef: prefix-cache-scorer
weight: 2
- pluginRef: proximity-scorer
weight: 1
- name: decode
plugins:
- pluginRef: decode-filter
Expand Down
23 changes: 23 additions & 0 deletions pkg/plugins/profile/pd_profile_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,16 @@
defaultDecodeProfile = "decode"
defaultPrefillProfile = "prefill"
defaultPrefixPluginType = prefix.PrefixCachePluginType

// ProximityStateKey is the CycleState key for storing decode pod proximity information
ProximityStateKey = "proximity-state"
)

// ProximityState stores the decode pod's node name for proximity scoring
type ProximityState struct {
DecodeNodeName string
}

type pdProfileHandlerParameters struct {
Threshold int `json:"threshold"`
DecodeProfile string `json:"decodeProfile"`
Expand Down Expand Up @@ -169,6 +177,21 @@
}

metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode)

// Extract decode pod's node name and store in CycleState for proximity scoring
decodePod := profileResults[h.decodeProfile].TargetPods[0]
decodePodMetadata := decodePod.GetPod()
nodeName := decodePodMetadata.NodeName

Check failure on line 184 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

decodePodMetadata.NodeName undefined (type *backend.Pod has no field or method NodeName)

Check failure on line 184 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

decodePodMetadata.NodeName undefined (type *backend.Pod has no field or method NodeName)

Check failure on line 184 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

decodePodMetadata.NodeName undefined (type *backend.Pod has no field or method NodeName)

// Store decode pod node name in CycleState for proximity scorer
proximityState := &ProximityState{
DecodeNodeName: nodeName,
}
cycleState.Write(plugins.StateKey(ProximityStateKey), proximityState)

Check failure on line 190 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

cannot use proximityState (variable of type *ProximityState) as plugins.StateData value in argument to cycleState.Write: *ProximityState does not implement plugins.StateData (missing method Clone)) (typecheck)

Check failure on line 190 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

cannot use proximityState (variable of type *ProximityState) as plugins.StateData value in argument to cycleState.Write: *ProximityState does not implement plugins.StateData (missing method Clone)) (typecheck)

Check failure on line 190 in pkg/plugins/profile/pd_profile_handler.go

View workflow job for this annotation

GitHub Actions / lint-and-test

cannot use proximityState (variable of type *ProximityState) as plugins.StateData value in argument to cycleState.Write: *ProximityState does not implement plugins.StateData (missing method Clone) (typecheck)

log.FromContext(ctx).V(logutil.DEBUG).Info("Stored decode pod node for proximity scoring",
"nodeName", nodeName)

// run the prefill profile
return map[string]*framework.SchedulerProfile{
h.prefillProfile: profiles[h.prefillProfile],
Expand Down
1 change: 1 addition & 0 deletions pkg/plugins/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import (
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"

Check failure on line 6 in pkg/plugins/register.go

View workflow job for this annotation

GitHub Actions / lint-and-test

could not import github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile (-: # github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
)
Expand All @@ -22,4 +22,5 @@
plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
plugins.Register(scorer.ProximityScorerType, scorer.ProximityScorerFactory)
}
126 changes: 126 additions & 0 deletions pkg/plugins/scorer/proximity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package scorer

import (
"context"
"encoding/json"
"fmt"

"sigs.k8s.io/controller-runtime/pkg/log"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"

"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"

Check failure on line 14 in pkg/plugins/scorer/proximity.go

View workflow job for this annotation

GitHub Actions / lint-and-test

could not import github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile (-: # github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile
)

const (
// ProximityScorerType is the type of the ProximityScorer
ProximityScorerType = "proximity-scorer"
)

// ProximityScorerParams defines configuration parameters for the proximity scorer
type ProximityScorerParams struct {
SameNodeScore float64 `json:"sameNodeScore"` // Score for same node (NVLink ~900GB/s)
DefaultScore float64 `json:"defaultScore"` // Score for different node (InfiniBand ~400GB/s)
}

// compile-time type assertion
var _ framework.Scorer = &ProximityScorer{}

// ProximityScorerFactory defines the factory function for the ProximityScorer
func ProximityScorerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
params := ProximityScorerParams{
SameNodeScore: 1.0,
DefaultScore: 0.0,
}

if rawParameters != nil {
if err := json.Unmarshal(rawParameters, &params); err != nil {
return nil, fmt.Errorf("failed to parse proximity scorer parameters: %w", err)
}
}

return NewProximityScorer(params).WithName(name), nil
}

// NewProximityScorer creates a new ProximityScorer instance
func NewProximityScorer(params ProximityScorerParams) *ProximityScorer {
return &ProximityScorer{
typedName: plugins.TypedName{Type: ProximityScorerType},
params: params,
}
}

// ProximityScorer is a scorer that prefers prefill pods on the same node as the decode pod
type ProximityScorer struct {
typedName plugins.TypedName
params ProximityScorerParams
}

// TypedName returns the typed name of the plugin
func (s *ProximityScorer) TypedName() plugins.TypedName {
return s.typedName
}

// WithName sets the name of the plugin
func (s *ProximityScorer) WithName(name string) *ProximityScorer {
s.typedName.Name = name
return s
}

// Category returns the scorer category (Affinity to prefer locality)
func (s *ProximityScorer) Category() framework.ScorerCategory {

Check failure on line 73 in pkg/plugins/scorer/proximity.go

View workflow job for this annotation

GitHub Actions / lint-and-test

undefined: framework.ScorerCategory (typecheck)
return framework.Affinity

Check failure on line 74 in pkg/plugins/scorer/proximity.go

View workflow job for this annotation

GitHub Actions / lint-and-test

undefined: framework.Affinity (typecheck)
}

// Score assigns scores to prefill pods based on their proximity to the decode pod
func (s *ProximityScorer) Score(ctx context.Context, cycleState *types.CycleState,
request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {

scores := make(map[types.Pod]float64, len(pods))

// Read decode pod node name from CycleState
proximityState, err := types.ReadCycleStateKey[*profile.ProximityState](
cycleState,
plugins.StateKey(profile.ProximityStateKey),
)

if err != nil {
// If decode pod node name not available, give all pods default score
log.FromContext(ctx).V(logutil.DEBUG).Info(
"Proximity state not found in CycleState, using default scores",
"error", err,
)
for _, pod := range pods {
scores[pod] = s.params.DefaultScore
}
return scores
}

// Score each prefill pod based on proximity to decode pod
for _, pod := range pods {
metadata := pod.GetPod()

// Get prefill pod's node name
prefillNodeName := metadata.NodeName

// Simple binary scoring: same node = 1.0, different node = 0.0
score := s.params.DefaultScore
if prefillNodeName != "" && prefillNodeName == proximityState.DecodeNodeName {
score = s.params.SameNodeScore // Same node - NVLink (~900GB/s)
}
// else: Different node - InfiniBand (~400GB/s)

scores[pod] = score

log.FromContext(ctx).V(logutil.DEBUG).Info("Proximity score calculated",
"pod", metadata.NamespacedName.String(),
"prefillNode", prefillNodeName,
"decodeNode", proximityState.DecodeNodeName,
"score", score,
)
}

return scores
}
Loading