diff --git a/pkg/scheduling/pd/scheduler.go b/pkg/scheduling/pd/scheduler.go index 124bc4440..2c0584a4c 100644 --- a/pkg/scheduling/pd/scheduler.go +++ b/pkg/scheduling/pd/scheduler.go @@ -71,13 +71,13 @@ func NewScheduler(ctx context.Context, schedCfg *config.Config, ds Datastore) (* scheduler.prefill = scheduling.NewSchedulerWithConfig( ds, scheduler.generateSchedulerConfig(ctx, schedCfg.PrefillSchedulerPlugins, - &filter.PrefillFilter{}), + filter.NewPrefillFilter()), ) scheduler.decode = scheduling.NewSchedulerWithConfig( ds, scheduler.generateSchedulerConfig(ctx, schedCfg.DecodeSchedulerPlugins, - &filter.DecodeFilter{}), + filter.NewDecodeFilter()), ) return scheduler, nil diff --git a/pkg/scheduling/plugins/filter/pd_role_filter.go b/pkg/scheduling/plugins/filter/pd_role_filter.go index 98089aa25..cd00a4d12 100644 --- a/pkg/scheduling/plugins/filter/pd_role_filter.go +++ b/pkg/scheduling/plugins/filter/pd_role_filter.go @@ -6,7 +6,7 @@ import ( ) const ( - // RoleLabel name + // RoleLabel name of the label that contains the pod's role RoleLabel = "llm-d.ai/role" // RolePrefill set for designated prefill workers RolePrefill = "prefill" @@ -16,46 +16,50 @@ const ( RoleBoth = "both" ) -// PrefillFilter - filters out pods that are not marked with role Prefill -type PrefillFilter struct{} +// RoleBasedFilter - filters out pods based on the role defined by RoleLabel +type RoleBasedFilter struct { + validRoles map[string]bool + name string +} -var _ plugins.Filter = &PrefillFilter{} // validate interface conformance +var _ plugins.Filter = &RoleBasedFilter{} -// Name returns the name of the filter -func (pf *PrefillFilter) Name() string { - return "prefill-filter" +// NewPrefillFilter creates and returns an instance of the RoleBasedFilter configured for prefill +func NewPrefillFilter() *RoleBasedFilter { + // TODO: doesn't RoleBoth also imply Prefill? + return NewRoleBasedFilter("prefill-filter", RolePrefill) } -// Filter filters out all pods that are not marked as "prefill" -func (pf *PrefillFilter) Filter(_ *types.SchedulingContext, pods []types.Pod) []types.Pod { - filteredPods := []types.Pod{} - - for _, pod := range pods { - role := pod.GetPod().Labels[RoleLabel] - if role == RolePrefill { // TODO: doesn't RoleBoth also imply Prefill? - filteredPods = append(filteredPods, pod) - } - } - return filteredPods +// NewDecodeFilter creates and returns an instance of the RoleBasedFilter configured for decode +func NewDecodeFilter() *RoleBasedFilter { + return NewRoleBasedFilter("decode-filter", RoleDecode, RoleBoth) } -// DecodeFilter - filters out pods that are not marked with role Decode or Both -type DecodeFilter struct{} +// NewRoleBasedFilter creates and returns an instance of the RoleBasedFilter based on the input parameters +// name - the filter name +// rolesArr - list of valid roles +func NewRoleBasedFilter(name string, rolesArr ...string) *RoleBasedFilter { + roles := map[string]bool{} -var _ plugins.Filter = &DecodeFilter{} // validate interface conformance + for _, role := range rolesArr { + roles[role] = true + } + + return &RoleBasedFilter{name: name, validRoles: roles} +} // Name returns the name of the filter -func (df *DecodeFilter) Name() string { - return "decode-filter" +func (f *RoleBasedFilter) Name() string { + return f.name } -// Filter removes all pods that are not marked as "decode" or "both" -func (df *DecodeFilter) Filter(_ *types.SchedulingContext, pods []types.Pod) []types.Pod { +// Filter filters out all pods that are not marked with one of roles from the validRoles collection +func (f *RoleBasedFilter) Filter(_ *types.SchedulingContext, pods []types.Pod) []types.Pod { filteredPods := []types.Pod{} for _, pod := range pods { - role, defined := pod.GetPod().Labels[RoleLabel] - if !defined || role == RoleDecode || role == RoleBoth { + role := pod.GetPod().Labels[RoleLabel] + if _, exists := f.validRoles[role]; exists { filteredPods = append(filteredPods, pod) } }