Skip to content

Commit b5120e1

Browse files
authored
Add metadata propagation for Kueue configurations to both Deployment and LeaderWorkerSet workloads (kserve#4747)
Signed-off-by: Hannah DeFazio <h2defazio@gmail.com>
1 parent 6b2ce35 commit b5120e1

File tree

8 files changed

+204
-44
lines changed

8 files changed

+204
-44
lines changed

pkg/constants/constants.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ var (
4242
AutoscalerConfigmapNamespace = GetEnvOrDefault("KNATIVE_CONFIG_AUTOSCALER_NAMESPACE", DefaultKnServingNamespace)
4343
)
4444

45+
// Kueue Constants
46+
const (
47+
KueueAPIGroupName = "kueue.x-k8s.io"
48+
)
49+
4550
// InferenceService Constants
4651
var (
4752
InferenceServiceName = "inferenceservice"

pkg/controller/v1alpha1/llmisvc/controller_int_multi_node_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ var _ = Describe("LLMInferenceService Multi-Node Controller", func() {
403403
})
404404
})
405405

406-
Context("Multi-Node Label Management", func() {
407-
It("should set correct labels", func(ctx SpecContext) {
406+
Context("Multi-Node Label and Annotation Management", func() {
407+
It("should set correct labels and annotation", func(ctx SpecContext) {
408408
// given
409409
svcName := "test-llm-lws-labels"
410410
nsName := kmeta.ChildName(svcName, "-test")
@@ -420,6 +420,10 @@ var _ = Describe("LLMInferenceService Multi-Node Controller", func() {
420420
envTest.DeleteAll(namespace)
421421
}()
422422

423+
localQueueName := "test-local-q"
424+
preemptPriority := "0"
425+
testValue := "test"
426+
423427
llmSvc := LLMInferenceService(svcName,
424428
InNamespace[*v1alpha1.LLMInferenceService](nsName),
425429
WithModelURI("hf://facebook/opt-125m"),
@@ -430,6 +434,16 @@ var _ = Describe("LLMInferenceService Multi-Node Controller", func() {
430434
WithWorker(&corev1.PodSpec{}),
431435
WithManagedRoute(),
432436
WithManagedGateway(),
437+
// Add a kueue label and annotation to ensure value propagation to the LWS
438+
// the kueue functionality itself will not be tested here
439+
WithAnnotations(map[string]string{
440+
PreemptionReclaimAnnotationKey: preemptPriority,
441+
testValue: testValue, // dummy value, should not be propagated
442+
}),
443+
WithLabels(map[string]string{
444+
LocalQueueNameLabelKey: localQueueName,
445+
testValue: testValue, // dummy value, should not be propagated
446+
}),
433447
)
434448

435449
// safety check
@@ -451,11 +465,29 @@ var _ = Describe("LLMInferenceService Multi-Node Controller", func() {
451465
}, expectedLWS)
452466
}).WithContext(ctx).Should(Succeed())
453467

468+
By("checking the LeaderWorkerSet's top-level metadata")
454469
Expect(expectedLWS).To(BeOwnedBy(llmSvc))
470+
Expect(expectedLWS.Labels).To(HaveKeyWithValue(LocalQueueNameLabelKey, localQueueName))
471+
Expect(expectedLWS.Labels).ToNot(HaveKeyWithValue(testValue, testValue))
472+
473+
Expect(expectedLWS.Annotations).To(HaveKeyWithValue(PreemptionReclaimAnnotationKey, preemptPriority))
474+
475+
By("checking the leader pod template metadata")
455476
Expect(expectedLWS.Spec.LeaderWorkerTemplate.Size).To(Equal(ptr.To(int32(1))))
456477
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate).To(Not(BeNil()))
457478
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate.Labels).To(HaveKeyWithValue("kserve.io/component", "workload"))
458479
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate.Labels).To(HaveKeyWithValue("llm-d.ai/role", "both"))
480+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate.Labels).To(HaveKeyWithValue(LocalQueueNameLabelKey, localQueueName))
481+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate.Labels).ToNot(HaveKeyWithValue(testValue, testValue))
482+
483+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.LeaderTemplate.Annotations).To(HaveKeyWithValue(PreemptionReclaimAnnotationKey, preemptPriority))
484+
485+
By("checking the worker pod template metadata")
486+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.WorkerTemplate).To(Not(BeNil()))
487+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels).To(HaveKeyWithValue(LocalQueueNameLabelKey, localQueueName))
488+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels).ToNot(HaveKeyWithValue(testValue, testValue))
489+
490+
Expect(expectedLWS.Spec.LeaderWorkerTemplate.WorkerTemplate.Annotations).To(HaveKeyWithValue(PreemptionReclaimAnnotationKey, preemptPriority))
459491
})
460492
})
461493
})

pkg/controller/v1alpha1/llmisvc/controller_int_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,79 @@ var _ = Describe("LLMInferenceService Controller", func() {
164164
g.Expect(current.Status).To(HaveCondition(string(v1alpha1.HTTPRoutesReady), "True"))
165165
})).WithContext(ctx).Should(Succeed())
166166
})
167+
168+
It("should propagate kueue labels and annotations to the deployment", func(ctx SpecContext) {
169+
// given
170+
svcName := "test-llm-kueue"
171+
nsName := kmeta.ChildName(svcName, "-test")
172+
namespace := &corev1.Namespace{
173+
ObjectMeta: metav1.ObjectMeta{
174+
Name: nsName,
175+
},
176+
}
177+
178+
Expect(envTest.Client.Create(ctx, namespace)).To(Succeed())
179+
Expect(envTest.Client.Create(ctx, IstioShadowService(svcName, nsName))).To(Succeed())
180+
defer func() {
181+
envTest.DeleteAll(namespace)
182+
}()
183+
184+
localQueueName := "test-local-q"
185+
preemptPriority := "0"
186+
testValue := "test"
187+
188+
llmSvc := LLMInferenceService(svcName,
189+
InNamespace[*v1alpha1.LLMInferenceService](nsName),
190+
WithModelURI("hf://facebook/opt-125m"),
191+
WithManagedRoute(),
192+
WithManagedGateway(),
193+
WithManagedScheduler(),
194+
// Add a kueue label and annotation to ensure value propagation to the deployment
195+
// the kueue functionality itself will not be tested here
196+
WithAnnotations(map[string]string{
197+
PreemptionReclaimAnnotationKey: preemptPriority,
198+
testValue: testValue, // dummy value, should not be propagated
199+
}),
200+
WithLabels(map[string]string{
201+
LocalQueueNameLabelKey: localQueueName,
202+
testValue: testValue, // dummy value, should not be propagated
203+
}),
204+
)
205+
206+
// when
207+
Expect(envTest.Create(ctx, llmSvc)).To(Succeed())
208+
defer func() {
209+
Expect(envTest.Delete(ctx, llmSvc)).To(Succeed())
210+
}()
211+
212+
// then
213+
expectedDeployment := &appsv1.Deployment{}
214+
Eventually(func(g Gomega, ctx context.Context) error {
215+
return envTest.Get(ctx, types.NamespacedName{
216+
Name: svcName + "-kserve",
217+
Namespace: nsName,
218+
}, expectedDeployment)
219+
}).WithContext(ctx).Should(Succeed())
220+
221+
Expect(expectedDeployment.Spec.Replicas).To(Equal(ptr.To[int32](1)))
222+
Expect(expectedDeployment).To(BeOwnedBy(llmSvc))
223+
224+
By("checking the Deployment's top-level metadata")
225+
// Check that the kueue label/annotation was propagated
226+
Expect(expectedDeployment.Labels).To(HaveKeyWithValue(LocalQueueNameLabelKey, localQueueName))
227+
Expect(expectedDeployment.Annotations).To(gomega.HaveKeyWithValue(PreemptionReclaimAnnotationKey, preemptPriority))
228+
// Check that the test label/annotation was not propagated as it is not in the approved prefixes for propagation
229+
Expect(expectedDeployment.Labels).ToNot(HaveKeyWithValue(testValue, testValue))
230+
Expect(expectedDeployment.Annotations).ToNot(HaveKeyWithValue(testValue, testValue))
231+
232+
By("checking the Deployment's pod template metadata")
233+
// Check that the kueue label/annotation was propagated
234+
Expect(expectedDeployment.Spec.Template.Labels).To(HaveKeyWithValue(LocalQueueNameLabelKey, localQueueName))
235+
Expect(expectedDeployment.Spec.Template.Annotations).To(gomega.HaveKeyWithValue(PreemptionReclaimAnnotationKey, preemptPriority))
236+
// Check that the test label/annotation was not propagated as it is not in the approved prefixes for propagation
237+
Expect(expectedDeployment.Spec.Template.Labels).ToNot(HaveKeyWithValue(testValue, testValue))
238+
Expect(expectedDeployment.Spec.Template.Annotations).ToNot(HaveKeyWithValue(testValue, testValue))
239+
})
167240
})
168241

169242
Context("Routing reconciliation ", func() {

pkg/controller/v1alpha1/llmisvc/fixture/llmisvc_builders.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
package fixture
1818

1919
import (
20+
"maps"
21+
2022
corev1 "k8s.io/api/core/v1"
2123
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
2224
"knative.dev/pkg/apis"
@@ -122,6 +124,24 @@ func WithManagedRoute() LLMInferenceServiceOption {
122124
}
123125
}
124126

127+
func WithAnnotations(annotationsToAdd map[string]string) LLMInferenceServiceOption {
128+
return func(llmSvc *v1alpha1.LLMInferenceService) {
129+
if llmSvc.Annotations == nil {
130+
llmSvc.Annotations = make(map[string]string)
131+
}
132+
maps.Copy(llmSvc.Annotations, annotationsToAdd)
133+
}
134+
}
135+
136+
func WithLabels(labelsToAdd map[string]string) LLMInferenceServiceOption {
137+
return func(llmSvc *v1alpha1.LLMInferenceService) {
138+
if llmSvc.Labels == nil {
139+
llmSvc.Labels = make(map[string]string)
140+
}
141+
maps.Copy(llmSvc.Labels, labelsToAdd)
142+
}
143+
}
144+
125145
func LLMGatewayRef(name, namespace string) v1alpha1.UntypedObjectReference {
126146
return v1alpha1.UntypedObjectReference{
127147
Name: gwapiv1.ObjectName(name),

pkg/controller/v1alpha1/llmisvc/suite_test.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,17 @@ import (
2222
. "github.com/onsi/ginkgo/v2"
2323
. "github.com/onsi/gomega"
2424

25+
"github.com/kserve/kserve/pkg/constants"
2526
"github.com/kserve/kserve/pkg/controller/v1alpha1/llmisvc/fixture"
2627
pkgtest "github.com/kserve/kserve/pkg/testing"
2728
)
2829

30+
// Kueue Constants
31+
var (
32+
LocalQueueNameLabelKey = constants.KueueAPIGroupName + "/queue-name"
33+
PreemptionReclaimAnnotationKey = constants.KueueAPIGroupName + "/preemption-reclaim-if-below-priority"
34+
)
35+
2936
func TestLLMInferenceServiceController(t *testing.T) {
3037
RegisterFailHandler(Fail)
3138
RunSpecs(t, "LLMInferenceService Controller Suite")

pkg/controller/v1alpha1/llmisvc/workload_multi_node.go

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ package llmisvc
1919
import (
2020
"context"
2121
"fmt"
22-
"strings"
2322

2423
corev1 "k8s.io/api/core/v1"
2524
rbacv1 "k8s.io/api/rbac/v1"
@@ -34,6 +33,8 @@ import (
3433
lwsapi "sigs.k8s.io/lws/api/leaderworkerset/v1"
3534

3635
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
36+
"github.com/kserve/kserve/pkg/constants"
37+
"github.com/kserve/kserve/pkg/utils"
3738
)
3839

3940
func (r *LLMISVCReconciler) reconcileMultiNodeWorkload(ctx context.Context, llmSvc *v1alpha1.LLMInferenceService, config *Config) error {
@@ -519,35 +520,47 @@ func (r *LLMISVCReconciler) expectedMultiNodeRoleBinding(llmSvc *v1alpha1.LLMInf
519520
}
520521

521522
func (r *LLMISVCReconciler) propagateLeaderWorkerSetMetadata(llmSvc *v1alpha1.LLMInferenceService, expected *lwsapi.LeaderWorkerSet) {
522-
ann := make(map[string]string, len(expected.Annotations))
523-
for k, v := range llmSvc.GetAnnotations() {
524-
if strings.HasPrefix(k, "leaderworkerset.sigs.k8s.io") ||
525-
strings.HasPrefix(k, "k8s.v1.cni.cncf.io") {
526-
ann[k] = v
527-
if expected.Annotations == nil {
528-
expected.Annotations = make(map[string]string, 1)
529-
}
530-
expected.Annotations[k] = v
531-
}
523+
// Define the prefixes to approve for annotations and labels
524+
approvedAnnotationPrefixes := []string{
525+
"leaderworkerset.sigs.k8s.io",
526+
"k8s.v1.cni.cncf.io",
527+
constants.KueueAPIGroupName,
528+
}
529+
approvedLabelPrefixes := []string{
530+
constants.KueueAPIGroupName,
532531
}
533532

533+
// Propagate approved annotations to the LeaderWorkerSet's top-level metadata
534+
utils.PropagatePrefixedMap(llmSvc.GetAnnotations(), &expected.Annotations, approvedAnnotationPrefixes...)
535+
534536
if expected.Spec.LeaderWorkerTemplate.LeaderTemplate != nil {
535-
if expected.Spec.LeaderWorkerTemplate.LeaderTemplate.Annotations == nil {
536-
expected.Spec.LeaderWorkerTemplate.LeaderTemplate.Annotations = ann
537-
} else {
538-
for k, v := range ann {
539-
expected.Spec.LeaderWorkerTemplate.LeaderTemplate.Annotations[k] = v
540-
}
541-
}
537+
utils.PropagatePrefixedMap(
538+
llmSvc.GetAnnotations(),
539+
&expected.Spec.LeaderWorkerTemplate.LeaderTemplate.Annotations,
540+
approvedAnnotationPrefixes...,
541+
)
542542
}
543543

544-
if expected.Spec.LeaderWorkerTemplate.WorkerTemplate.Annotations == nil {
545-
expected.Spec.LeaderWorkerTemplate.WorkerTemplate.Annotations = ann
546-
} else {
547-
for k, v := range ann {
548-
expected.Spec.LeaderWorkerTemplate.WorkerTemplate.Annotations[k] = v
549-
}
544+
utils.PropagatePrefixedMap(
545+
llmSvc.GetAnnotations(),
546+
&expected.Spec.LeaderWorkerTemplate.WorkerTemplate.Annotations,
547+
approvedAnnotationPrefixes...,
548+
)
549+
550+
// Propagate approved labels
551+
utils.PropagatePrefixedMap(llmSvc.GetLabels(), &expected.Labels, approvedLabelPrefixes...)
552+
if expected.Spec.LeaderWorkerTemplate.LeaderTemplate != nil {
553+
utils.PropagatePrefixedMap(
554+
llmSvc.GetLabels(),
555+
&expected.Spec.LeaderWorkerTemplate.LeaderTemplate.Labels,
556+
approvedLabelPrefixes...,
557+
)
550558
}
559+
utils.PropagatePrefixedMap(
560+
llmSvc.GetLabels(),
561+
&expected.Spec.LeaderWorkerTemplate.WorkerTemplate.Labels,
562+
approvedLabelPrefixes...,
563+
)
551564
}
552565

553566
func semanticLWSIsEqual(expected *lwsapi.LeaderWorkerSet, curr *lwsapi.LeaderWorkerSet) bool {

pkg/controller/v1alpha1/llmisvc/workload_single_node.go

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ import (
2020
"context"
2121
"fmt"
2222
"maps"
23-
"strings"
23+
24+
"github.com/kserve/kserve/pkg/utils"
2425

2526
appsv1 "k8s.io/api/apps/v1"
2627
corev1 "k8s.io/api/core/v1"
@@ -35,6 +36,7 @@ import (
3536
"sigs.k8s.io/controller-runtime/pkg/log"
3637

3738
"github.com/kserve/kserve/pkg/apis/serving/v1alpha1"
39+
"github.com/kserve/kserve/pkg/constants"
3840
)
3941

4042
func (r *LLMISVCReconciler) reconcileSingleNodeWorkload(ctx context.Context, llmSvc *v1alpha1.LLMInferenceService, config *Config) error {
@@ -217,24 +219,17 @@ func (r *LLMISVCReconciler) expectedPrefillMainDeployment(ctx context.Context, l
217219
}
218220

219221
func (r *LLMISVCReconciler) propagateDeploymentMetadata(llmSvc *v1alpha1.LLMInferenceService, expected *appsv1.Deployment) {
220-
ann := make(map[string]string, len(expected.Annotations))
221-
for k, v := range llmSvc.GetAnnotations() {
222-
if strings.HasPrefix(k, "k8s.v1.cni.cncf.io") {
223-
ann[k] = v
224-
if expected.Annotations == nil {
225-
expected.Annotations = make(map[string]string, 1)
226-
}
227-
expected.Annotations[k] = v
228-
}
229-
}
222+
// Define the prefixes to approve for annotations and labels
223+
approvedAnnotationPrefixes := []string{"k8s.v1.cni.cncf.io", constants.KueueAPIGroupName}
224+
approvedLabelPrefixes := []string{constants.KueueAPIGroupName}
230225

231-
if expected.Spec.Template.Annotations == nil {
232-
expected.Spec.Template.Annotations = ann
233-
} else {
234-
for k, v := range ann {
235-
expected.Spec.Template.Annotations[k] = v
236-
}
237-
}
226+
// Propagate approved annotations to the Deployment and its Pod template
227+
utils.PropagatePrefixedMap(llmSvc.GetAnnotations(), &expected.Annotations, approvedAnnotationPrefixes...)
228+
utils.PropagatePrefixedMap(llmSvc.GetAnnotations(), &expected.Spec.Template.Annotations, approvedAnnotationPrefixes...)
229+
230+
// Propagate approved labels to the Deployment and its Pod template
231+
utils.PropagatePrefixedMap(llmSvc.GetLabels(), &expected.Labels, approvedLabelPrefixes...)
232+
utils.PropagatePrefixedMap(llmSvc.GetLabels(), &expected.Spec.Template.Labels, approvedLabelPrefixes...)
238233
}
239234

240235
func (r *LLMISVCReconciler) propagateDeploymentStatus(ctx context.Context, expected *appsv1.Deployment, ready func(), notReady func(reason, messageFormat string, messageA ...interface{})) error {

pkg/utils/utils.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,21 @@ func IsPrefixSupported(input string, prefixes []string) bool {
9999
return false
100100
}
101101

102+
// PropagatePrefixedMap filters keys in source by the provided prefixes and propagates matching key-value pairs to dest.
103+
// Initializes dest if nil.
104+
func PropagatePrefixedMap(source map[string]string, dest *map[string]string, prefixes ...string) {
105+
for k, v := range source {
106+
// The nested loop and if statement are replaced with a single, clear function call.
107+
if IsPrefixSupported(k, prefixes) {
108+
// Initialize the destination map if it's nil
109+
if *dest == nil {
110+
*dest = make(map[string]string)
111+
}
112+
(*dest)[k] = v
113+
}
114+
}
115+
}
116+
102117
// MergeEnvs Merge a slice of EnvVars (`O`) into another slice of EnvVars (`B`), which does the following:
103118
// 1. If an EnvVar is present in B but not in O, value remains unchanged in the result
104119
// 2. If an EnvVar is present in `O` but not in `B`, appends to the result

0 commit comments

Comments
 (0)