Skip to content

Commit dc460b7

Browse files
committed
Update in filter and decode filters implementation, both are based on the ByRoleLabel + add related test in scheduler test
Signed-off-by: Maya Barnea <mayab@il.ibm.com>
1 parent 47c0d8c commit dc460b7

File tree

3 files changed

+103
-84
lines changed

3 files changed

+103
-84
lines changed

pkg/plugins/filter/pd_role_filter.go

Lines changed: 38 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@ import (
88
)
99

1010
const (
11-
// DecodeFilterType is the type of the DecodeFilter
12-
DecodeFilterType = "decode-filter"
13-
14-
// PrefillFilterType is the type of the PrefillFilter
15-
PrefillFilterType = "prefill-filter"
11+
// ByRoleLabelFilterType is the type of the ByLabelsFilter
12+
ByRoleLabelFilterType = "role-label"
1613

1714
// RoleLabel name
1815
RoleLabel = "llm-d.ai/role"
@@ -24,90 +21,64 @@ const (
2421
RoleBoth = "both"
2522
)
2623

27-
// compile-time type assertion
28-
var _ framework.Filter = &PrefillFilter{}
29-
30-
// NewPrefillFilter creates a new instance of the DecodeFilter
31-
func NewPrefillFilter() *PrefillFilter {
32-
return &PrefillFilter{
33-
name: PrefillFilterType,
34-
}
35-
}
36-
37-
// PrefillFilter - filters out pods that are not marked with role Prefill
38-
type PrefillFilter struct {
24+
// ByRoleLabel - filters out pods based on the role defined by RoleLabel label
25+
type ByRoleLabel struct {
26+
// name defines the filter name
3927
name string
28+
// validRoles defines list of valid role header values
29+
validRoles map[string]bool
30+
// allowsNoRolesLabel - if true pods without role label will be considered as valid (not filtered out)
31+
allowsNoRolesLabel bool
4032
}
4133

42-
// Type returns the type of the filter
43-
func (pf *PrefillFilter) Type() string {
44-
return PrefillFilterType
45-
}
46-
47-
// Name returns the name of the instance of the filter.
48-
func (pf *PrefillFilter) Name() string {
49-
return pf.name
50-
}
51-
52-
// WithName sets the name of the filter.
53-
func (pf *PrefillFilter) WithName(name string) *PrefillFilter {
54-
pf.name = name
55-
return pf
56-
}
34+
var _ framework.Filter = &ByRoleLabel{} // validate interface conformance
5735

58-
// Filter filters out all pods that are not marked as "prefill"
59-
func (pf *PrefillFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
60-
filteredPods := []types.Pod{}
36+
// NewByRoleLabel creates and returns an instance of the RoleBasedFilter based on the input parameters
37+
// name - the filter name
38+
// rolesArr - list of valid roles
39+
func NewByRoleLabel(name string, allowsNoRolesLabel bool, rolesArr ...string) *ByRoleLabel {
40+
roles := map[string]bool{}
6141

62-
for _, pod := range pods {
63-
role := pod.GetPod().Labels[RoleLabel]
64-
if role == RolePrefill { // TODO: doesn't RoleBoth also imply Prefill?
65-
filteredPods = append(filteredPods, pod)
66-
}
42+
for _, role := range rolesArr {
43+
roles[role] = true
6744
}
68-
return filteredPods
69-
}
70-
71-
// compile-time type assertion
72-
var _ framework.Filter = &DecodeFilter{}
7345

74-
// NewDecodeFilter creates a new instance of the DecodeFilter
75-
func NewDecodeFilter() *DecodeFilter {
76-
return &DecodeFilter{
77-
name: DecodeFilterType,
78-
}
46+
return &ByRoleLabel{name: name, allowsNoRolesLabel: allowsNoRolesLabel, validRoles: roles}
7947
}
8048

81-
// DecodeFilter - filters out pods that are not marked with role Decode or Both
82-
type DecodeFilter struct {
83-
name string
49+
// NewPrefillFilter creates and returns an instance of the Filter configured for prefill role
50+
func NewPrefillFilter() framework.Filter {
51+
return NewByRoleLabel("prefill-filter", false, RolePrefill)
8452
}
8553

86-
// Type returns the type of the filter
87-
func (df *DecodeFilter) Type() string {
88-
return DecodeFilterType
54+
// NewDecodeFilter creates and returns an instance of the Filter configured for decode role
55+
func NewDecodeFilter() framework.Filter {
56+
return NewByRoleLabel("decode-filter", true, RoleDecode, RoleBoth)
8957
}
9058

91-
// Name returns the name of the instance of the filter.
92-
func (df *DecodeFilter) Name() string {
93-
return df.name
59+
// Type returns the type of the filter
60+
func (f *ByRoleLabel) Type() string {
61+
return ByRoleLabelFilterType
9462
}
9563

96-
// WithName sets the name of the filter.
97-
func (df *DecodeFilter) WithName(name string) *DecodeFilter {
98-
df.name = name
99-
return df
64+
// Name returns the name of the filter
65+
func (f *ByRoleLabel) Name() string {
66+
return f.name
10067
}
10168

102-
// Filter removes all pods that are not marked as "decode" or "both"
103-
func (df *DecodeFilter) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
69+
// Filter filters out all pods that are not marked with one of roles from the validRoles collection
70+
// or has no role label in case allowsNoRolesLabel is true
71+
func (f *ByRoleLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
10472
filteredPods := []types.Pod{}
10573

10674
for _, pod := range pods {
107-
role, defined := pod.GetPod().Labels[RoleLabel]
108-
if !defined || role == RoleDecode || role == RoleBoth {
75+
role, labelDefined := pod.GetPod().Labels[RoleLabel]
76+
_, roleExists := f.validRoles[role]
77+
78+
if (!labelDefined && f.allowsNoRolesLabel) || roleExists {
10979
filteredPods = append(filteredPods, pod)
11080
}
11181
}
82+
11283
return filteredPods
11384
}

pkg/scheduling/pd/scheduler.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,21 +31,13 @@ func CreatePDSchedulerConfig(ctx context.Context, pdConfig *config.Config, prefi
3131
// otherwise, PD is enabled.
3232

3333
// create decode scheduling profile.
34-
decodeFilter, err := filter.NewDecodeFilter()
35-
if err != nil {
36-
return nil, err
37-
}
38-
decodeProfile, err := createSchedulerProfile(ctx, decodeFilter, picker.NewMaxScorePicker(), pdConfig.DecodeSchedulerPlugins, prefixScorer, true)
34+
decodeProfile, err := createSchedulerProfile(ctx, filter.NewDecodeFilter(), picker.NewMaxScorePicker(), pdConfig.DecodeSchedulerPlugins, prefixScorer, true)
3935
if err != nil {
4036
return nil, fmt.Errorf("falied to create decode scheduling profile - %w", err)
4137
}
4238

4339
// create prefil scheduling profile.
44-
prefillFilter, err := filter.NewPrefillFilter()
45-
if err != nil {
46-
return nil, err
47-
}
48-
prefilProfile, err := createSchedulerProfile(ctx, prefillFilter, picker.NewMaxScorePicker(), pdConfig.PrefillSchedulerPlugins, prefixScorer, true)
40+
prefilProfile, err := createSchedulerProfile(ctx, filter.NewPrefillFilter(), picker.NewMaxScorePicker(), pdConfig.PrefillSchedulerPlugins, prefixScorer, true)
4941
if err != nil {
5042
return nil, fmt.Errorf("falied to create prefill scheduling profile - %w", err)
5143
}

pkg/scheduling/pd/scheduler_test.go

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,30 @@ func TestPDSchedule(t *testing.T) {
6060
WaitingModels: map[string]int{},
6161
},
6262
}
63+
noRolePod1 := &backendmetrics.FakePodMetrics{
64+
Pod: &backend.Pod{
65+
NamespacedName: k8stypes.NamespacedName{Name: "noRolePod1"},
66+
Address: "1.1.1.1",
67+
},
68+
Metrics: &backendmetrics.MetricsState{},
69+
}
70+
noRolePod2 := &backendmetrics.FakePodMetrics{
71+
Pod: &backend.Pod{
72+
NamespacedName: k8stypes.NamespacedName{Name: "noRolePod2"},
73+
Address: "2.2.2.2",
74+
},
75+
Metrics: &backendmetrics.MetricsState{},
76+
}
6377

6478
tests := []struct {
65-
name string
66-
req *types.LLMRequest
67-
input []backendmetrics.PodMetrics
68-
wantRes *types.SchedulingResult
69-
err bool
79+
name string
80+
req *types.LLMRequest
81+
input []backendmetrics.PodMetrics
82+
wantRes *types.SchedulingResult
83+
wantHeaders map[string]string
84+
unwantedHeaders []string
85+
unwantedPodIDs []string
86+
err bool
7087
}{
7188
{
7289
name: "no pods in datastore",
@@ -153,6 +170,16 @@ func TestPDSchedule(t *testing.T) {
153170
PrimaryProfileName: "decode",
154171
},
155172
},
173+
{
174+
name: "TestRoles",
175+
req: &types.LLMRequest{
176+
TargetModel: "critical",
177+
Prompt: "12345678901",
178+
},
179+
input: []backendmetrics.PodMetrics{pod1, noRolePod1, noRolePod2},
180+
wantRes: nil, // doesn't mater which pod was selected
181+
unwantedPodIDs: []string{pod1.GetPod().NamespacedName.String()},
182+
},
156183
}
157184

158185
ctx := context.Background()
@@ -189,8 +216,37 @@ func TestPDSchedule(t *testing.T) {
189216
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
190217
}
191218

192-
if diff := cmp.Diff(test.wantRes, got); diff != "" {
193-
t.Errorf("Unexpected output (-want +got): %v", diff)
219+
if test.wantRes != nil {
220+
if diff := cmp.Diff(test.wantRes, got); diff != "" {
221+
t.Errorf("Unexpected output (-want +got): %v", diff)
222+
}
223+
224+
for header, value := range test.wantHeaders {
225+
gotValue, ok := test.req.Headers[header]
226+
if !ok {
227+
t.Errorf("Missing header: %s", header)
228+
} else if gotValue != value {
229+
t.Errorf("Wrong header value for %s: want %s got %s)", header, value, gotValue)
230+
}
231+
}
232+
233+
for _, header := range test.unwantedHeaders {
234+
if _, exists := test.req.Headers[header]; exists {
235+
t.Errorf("Unwanted header %s exists", header)
236+
}
237+
}
238+
}
239+
240+
if len(test.unwantedPodIDs) > 0 {
241+
// ensure that target pod is not one of the unwanted
242+
profileRes, found := got.ProfileResults[got.PrimaryProfileName]
243+
if found {
244+
for _, podID := range test.unwantedPodIDs {
245+
if podID == profileRes.TargetPod.GetPod().NamespacedName.String() {
246+
t.Errorf("Unwanted pod was selected: %s", podID)
247+
}
248+
}
249+
}
194250
}
195251
})
196252
}

0 commit comments

Comments
 (0)