Skip to content

Commit e89782d

Browse files
Adding support for DRA
Signed-off-by: Vishesh Tanksale <vtanksale@nvidia.com>
1 parent 72a326d commit e89782d

File tree

8 files changed

+245
-48
lines changed

8 files changed

+245
-48
lines changed

api/apps/v1alpha1/nimservice_types.go

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -223,16 +223,6 @@ func (n *NIMService) GetLWSName() string {
223223
return fmt.Sprintf("%s-lws", n.GetName())
224224
}
225225

226-
// GetMultiNodeGPUsPerPod returns the number of GPUs per pod for the multi-node NIMService.
227-
func (n *NIMService) GetMultiNodeGPUsPerPod() int {
228-
gpuQuantity, ok := n.Spec.Resources.Requests["nvidia.com/gpu"]
229-
if !ok {
230-
// return 0 if no GPU limit is specified because auto determine base on tp*pp/(.spec.multiNode.size) is a TODO
231-
return 0
232-
}
233-
return int(gpuQuantity.Value())
234-
}
235-
236226
// GetPVCName returns the name to be used for the PVC based on the custom spec
237227
// Prefers pvc.Name if explicitly set by the user in the NIMService instance.
238228
func (n *NIMService) GetPVCName(pvc PersistentVolumeClaim) string {
@@ -324,7 +314,7 @@ func (n *NIMService) GetStandardEnv() []corev1.EnvVar {
324314
return envVars
325315
}
326316

327-
func (n *NIMService) getLWSCommonEnv() []corev1.EnvVar {
317+
func (n *NIMService) getLWSCommonEnv(multiNodeGPUsPerPod int) []corev1.EnvVar {
328318
env := n.GetEnv()
329319

330320
env = utils.MergeEnvVars([]corev1.EnvVar{
@@ -342,7 +332,7 @@ func (n *NIMService) getLWSCommonEnv() []corev1.EnvVar {
342332
},
343333
{
344334
Name: "NIM_TENSOR_PARALLEL_SIZE",
345-
Value: fmt.Sprintf("%d", n.GetMultiNodeGPUsPerPod()),
335+
Value: fmt.Sprintf("%d", multiNodeGPUsPerPod),
346336
},
347337
{
348338
Name: "NIM_PIPELINE_PARALLEL_SIZE",
@@ -360,8 +350,8 @@ func (n *NIMService) getLWSCommonEnv() []corev1.EnvVar {
360350
return env
361351
}
362352

363-
func (n *NIMService) GetLWSLeaderEnv() []corev1.EnvVar {
364-
env := n.getLWSCommonEnv()
353+
func (n *NIMService) GetLWSLeaderEnv(multiNodeGPUsPerPod int) []corev1.EnvVar {
354+
env := n.getLWSCommonEnv(multiNodeGPUsPerPod)
365355

366356
mpiTimeout := DefaultMPITimeout
367357
if n.Spec.MultiNode.MPI != nil && n.Spec.MultiNode.MPI.MPIStartTimeout != 0 {
@@ -383,7 +373,7 @@ func (n *NIMService) GetLWSLeaderEnv() []corev1.EnvVar {
383373
},
384374
{
385375
Name: "GPUS_PER_NODE",
386-
Value: fmt.Sprintf("%d", n.GetMultiNodeGPUsPerPod()),
376+
Value: fmt.Sprintf("%d", multiNodeGPUsPerPod),
387377
},
388378
{
389379
Name: "CLUSTER_START_TIMEOUT",
@@ -409,8 +399,8 @@ func (n *NIMService) GetLWSLeaderEnv() []corev1.EnvVar {
409399
return env
410400
}
411401

412-
func (n *NIMService) GetLWSWorkerEnv() []corev1.EnvVar {
413-
env := n.getLWSCommonEnv()
402+
func (n *NIMService) GetLWSWorkerEnv(multiNodeGPUsPerPod int) []corev1.EnvVar {
403+
env := n.getLWSCommonEnv(multiNodeGPUsPerPod)
414404
env = utils.MergeEnvVars([]corev1.EnvVar{
415405
{
416406
Name: "NIM_LEADER_ROLE",
@@ -1093,7 +1083,7 @@ func (n *NIMService) GetDeploymentParams() *rendertypes.DeploymentParams {
10931083
return params
10941084
}
10951085

1096-
func (n *NIMService) GetLWSParams() *rendertypes.LeaderWorkerSetParams {
1086+
func (n *NIMService) GetLWSParams(multiNodeGPUsPerPod int) *rendertypes.LeaderWorkerSetParams {
10971087
params := &rendertypes.LeaderWorkerSetParams{}
10981088

10991089
// Set metadata
@@ -1110,8 +1100,8 @@ func (n *NIMService) GetLWSParams() *rendertypes.LeaderWorkerSetParams {
11101100
params.ContainerName = n.GetContainerName()
11111101
params.Args = n.GetArgs()
11121102
params.Command = n.GetCommand()
1113-
params.LeaderEnvs = n.GetLWSLeaderEnv()
1114-
params.WorkerEnvs = n.GetLWSWorkerEnv()
1103+
params.LeaderEnvs = n.GetLWSLeaderEnv(multiNodeGPUsPerPod)
1104+
params.WorkerEnvs = n.GetLWSWorkerEnv(multiNodeGPUsPerPod)
11151105
params.UserID = n.GetUserID()
11161106
params.GroupID = n.GetGroupID()
11171107
params.Image = n.GetImage()
@@ -1200,14 +1190,14 @@ func (n *NIMService) GetStatefulSetParams() *rendertypes.StatefulSetParams {
12001190
return params
12011191
}
12021192

1203-
func (n *NIMService) generateMPIConfigData() map[string]string {
1193+
func (n *NIMService) generateMPIConfigData(multiNodeGPUsPerPod int) map[string]string {
12041194
// Construct ConfigMap data
12051195
data := make(map[string]string)
12061196
for i := 0; i < n.Spec.Replicas; i++ {
1207-
hostfile := fmt.Sprintf("localhost slots=%d\n", n.GetMultiNodeGPUsPerPod())
1197+
hostfile := fmt.Sprintf("localhost slots=%d\n", multiNodeGPUsPerPod)
12081198
for j := 1; j < n.Spec.MultiNode.Size; j++ {
12091199
workerHostname := fmt.Sprintf("%s-%d-%d.%s.%s.svc slots=%d",
1210-
n.GetLWSName(), i, j, n.GetLWSName(), n.GetNamespace(), n.GetMultiNodeGPUsPerPod())
1200+
n.GetLWSName(), i, j, n.GetLWSName(), n.GetNamespace(), multiNodeGPUsPerPod)
12111201
hostfile += workerHostname + "\n"
12121202
}
12131203
dataKey := fmt.Sprintf("hostfile-%d", i)
@@ -1216,14 +1206,14 @@ func (n *NIMService) generateMPIConfigData() map[string]string {
12161206
return data
12171207
}
12181208

1219-
func (n *NIMService) GetMPIConfigParams() *rendertypes.ConfigMapParams {
1209+
func (n *NIMService) GetMPIConfigParams(multiNodeGPUsPerPod int) *rendertypes.ConfigMapParams {
12201210
if n.Spec.MultiNode == nil {
12211211
return nil
12221212
}
12231213
return &rendertypes.ConfigMapParams{
12241214
Name: fmt.Sprintf("%s-mpi-config", n.GetName()),
12251215
Namespace: n.GetNamespace(),
1226-
ConfigMapData: n.generateMPIConfigData(),
1216+
ConfigMapData: n.generateMPIConfigData(multiNodeGPUsPerPod),
12271217
Labels: n.GetLabels(),
12281218
Annotations: n.GetAnnotations(),
12291219
}

internal/controller/platform/kserve/nimservice.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ func (r *NIMServiceReconciler) renderAndSyncInferenceService(ctx context.Context
411411
var conType, failedCon string
412412
var renderObj client.Object
413413

414+
multiNodeGPUsPerPod, err := shared.GetMultiNodeGPUsPerPod(ctx, r.Client, nimService)
415+
if err != nil {
416+
logger.Error(err, "Failed to get multi-node GPUs per pod")
417+
return err
418+
}
419+
414420
// Setup env for explicit override profile is specified
415421
if modelProfile != "" {
416422
profileEnv = &[]corev1.EnvVar{
@@ -432,7 +438,7 @@ func (r *NIMServiceReconciler) renderAndSyncInferenceService(ctx context.Context
432438

433439
// Auto assign GPU resources in case of the optimized profile
434440
if profile != nil {
435-
gpuResources, err = r.addGPUResources(ctx, nimService, profile)
441+
gpuResources, err = r.addGPUResources(ctx, nimService, profile, multiNodeGPUsPerPod)
436442
if err != nil {
437443
logger.Error(err, "Failed to get GPU resources")
438444
return err
@@ -445,7 +451,7 @@ func (r *NIMServiceReconciler) renderAndSyncInferenceService(ctx context.Context
445451

446452
initContainers = nimService.GetInitContainers()
447453
namedDraResources := shared.GenerateNamedDRAResources(nimService)
448-
err := r.reconcileDRAResources(ctx, nimService, namedDraResources)
454+
err = r.reconcileDRAResources(ctx, nimService, namedDraResources)
449455
if err != nil {
450456
logger.Error(err, "Failed to reconcile DRAResources")
451457
return err
@@ -548,7 +554,7 @@ func (r *NIMServiceReconciler) getNIMCacheProfile(ctx context.Context, nimServic
548554
// If the TP value is not present, the function defaults to allocating 1 GPU.
549555
//
550556
// In case of multi-node NIMs, this function assigns the number of GPUs equal to .spec.multiNode.gpuPerWorker.
551-
func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *appsv1alpha1.NIMService, profile *appsv1alpha1.NIMProfile) (*corev1.ResourceRequirements, error) {
557+
func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *appsv1alpha1.NIMService, profile *appsv1alpha1.NIMProfile, multiNodeGPUsPerPod int) (*corev1.ResourceRequirements, error) {
552558
logger := log.FromContext(ctx)
553559

554560
// TODO: Refine this to detect GPU claims and only assign GPU resources if no GPU claims are present.
@@ -577,7 +583,7 @@ func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *
577583
// if deployed as multi-node, use the GPU per worker value to assign GPU resources to each worker
578584
// TODO auto determine base on tp*pp/(.spec.multiNode.size)
579585
if nimService.Spec.MultiNode != nil {
580-
gpuQuantity, err = apiResource.ParseQuantity(fmt.Sprintf("%d", nimService.GetMultiNodeGPUsPerPod()))
586+
gpuQuantity, err = apiResource.ParseQuantity(fmt.Sprintf("%d", multiNodeGPUsPerPod))
581587
if err != nil {
582588
logger.Error(err, "Failed to parse GPU per worker value")
583589
return nil, err

internal/controller/platform/kserve/nimservice_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1691,7 +1691,7 @@ var _ = Describe("NIMServiceReconciler for a KServe platform", func() {
16911691
},
16921692
}
16931693

1694-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
1694+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 8)
16951695
Expect(err).ToNot(HaveOccurred())
16961696
Expect(resources).ToNot(BeNil())
16971697

@@ -1706,7 +1706,7 @@ var _ = Describe("NIMServiceReconciler for a KServe platform", func() {
17061706
Config: map[string]string{"tp": "4"},
17071707
}
17081708

1709-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
1709+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
17101710
Expect(err).ToNot(HaveOccurred())
17111711
Expect(resources).ToNot(BeNil())
17121712

@@ -1720,7 +1720,7 @@ var _ = Describe("NIMServiceReconciler for a KServe platform", func() {
17201720
Config: map[string]string{},
17211721
}
17221722

1723-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
1723+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 8)
17241724
Expect(err).ToNot(HaveOccurred())
17251725
Expect(resources).ToNot(BeNil())
17261726

@@ -1736,7 +1736,7 @@ var _ = Describe("NIMServiceReconciler for a KServe platform", func() {
17361736
Config: map[string]string{},
17371737
}
17381738

1739-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
1739+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
17401740
Expect(err).ToNot(HaveOccurred())
17411741
Expect(resources).ToNot(BeNil())
17421742

@@ -1750,7 +1750,7 @@ var _ = Describe("NIMServiceReconciler for a KServe platform", func() {
17501750
Config: map[string]string{"tp": "invalid"},
17511751
}
17521752

1753-
_, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
1753+
_, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
17541754
Expect(err).To(HaveOccurred())
17551755
})
17561756
})

internal/controller/platform/standalone/nimservice.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,12 @@ func (r *NIMServiceReconciler) reconcileNIMService(ctx context.Context, nimServi
349349
var conType, failedCon string
350350
var renderObj client.Object
351351

352+
multiNodeGPUsPerPod, err := shared.GetMultiNodeGPUsPerPod(ctx, r.GetClient(), nimService)
353+
if err != nil {
354+
logger.Error(err, "failed to get GPU count for DRA resources")
355+
return ctrl.Result{}, err
356+
}
357+
352358
if modelProfile != "" {
353359
profileEnv = &[]corev1.EnvVar{{
354360
Name: "NIM_MODEL_PROFILE",
@@ -366,7 +372,7 @@ func (r *NIMServiceReconciler) reconcileNIMService(ctx context.Context, nimServi
366372

367373
// Auto assign GPU resources in case of the optimized profile
368374
if profile != nil {
369-
gpuResources, err = r.addGPUResources(ctx, nimService, profile)
375+
gpuResources, err = r.addGPUResources(ctx, nimService, profile, multiNodeGPUsPerPod)
370376
if err != nil {
371377
logger.Error(err, "Failed to get GPU resources")
372378
return ctrl.Result{}, err
@@ -386,7 +392,7 @@ func (r *NIMServiceReconciler) reconcileNIMService(ctx context.Context, nimServi
386392
}
387393

388394
if nimService.Spec.MultiNode != nil && nimService.Spec.MultiNode.BackendType == appsv1alpha1.NIMBackendTypeLWS {
389-
lwsParams := nimService.GetLWSParams()
395+
lwsParams := nimService.GetLWSParams(multiNodeGPUsPerPod)
390396
lwsParams.PodResourceClaims = shared.GetPodResourceClaims(namedDraResources)
391397
lwsParams.OrchestratorType = string(r.GetOrchestratorType())
392398
lwsParams.LeaderVolumes = nimService.GetLeaderVolumes(*modelPVC)
@@ -428,7 +434,7 @@ func (r *NIMServiceReconciler) reconcileNIMService(ctx context.Context, nimServi
428434
renderObj = &lws.LeaderWorkerSet{}
429435

430436
// Create configmap for MPI
431-
err = r.createMultiNodeVolumeObjects(ctx, nimService)
437+
err = r.createMultiNodeVolumeObjects(ctx, nimService, multiNodeGPUsPerPod)
432438
if err != nil {
433439
return ctrl.Result{}, fmt.Errorf("failed to create multi-node volumes: %v", err)
434440
}
@@ -523,8 +529,8 @@ func (r *NIMServiceReconciler) reconcileNIMService(ctx context.Context, nimServi
523529
return ctrl.Result{}, nil
524530
}
525531

526-
func (r *NIMServiceReconciler) createMultiNodeVolumeObjects(ctx context.Context, nimService *appsv1alpha1.NIMService) error {
527-
if err := r.createMultiNodeConfigMap(ctx, nimService, nimService.GetMPIConfigParams()); err != nil {
532+
func (r *NIMServiceReconciler) createMultiNodeVolumeObjects(ctx context.Context, nimService *appsv1alpha1.NIMService, multiNodeGPUsPerPod int) error {
533+
if err := r.createMultiNodeConfigMap(ctx, nimService, nimService.GetMPIConfigParams(multiNodeGPUsPerPod)); err != nil {
528534
return fmt.Errorf("failed to create MPI configmap for %s: %v", nimService.Name, err)
529535
}
530536

@@ -975,7 +981,7 @@ func (r *NIMServiceReconciler) getNIMCacheProfile(ctx context.Context, nimServic
975981
// If the TP value is not present, the function defaults to allocating 1 GPU.
976982
//
977983
// In case of multi-node NIMs, this function assigns the number of GPUs equal to .spec.multiNode.gpuPerWorker.
978-
func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *appsv1alpha1.NIMService, profile *appsv1alpha1.NIMProfile) (*corev1.ResourceRequirements, error) {
984+
func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *appsv1alpha1.NIMService, profile *appsv1alpha1.NIMProfile, multiNodeGPUsPerPod int) (*corev1.ResourceRequirements, error) {
979985
logger := log.FromContext(ctx)
980986

981987
// TODO: Refine this to detect GPU claims and only assign GPU resources if no GPU claims are present.
@@ -1004,7 +1010,7 @@ func (r *NIMServiceReconciler) addGPUResources(ctx context.Context, nimService *
10041010
// if deployed as multi-node, use the GPU per worker value to assign GPU resources to each worker
10051011
// TODO auto determine base on tp*pp/(.spec.multiNode.size)
10061012
if nimService.Spec.MultiNode != nil {
1007-
gpuQuantity, err = apiResource.ParseQuantity(fmt.Sprintf("%d", nimService.GetMultiNodeGPUsPerPod()))
1013+
gpuQuantity, err = apiResource.ParseQuantity(fmt.Sprintf("%d", multiNodeGPUsPerPod))
10081014
if err != nil {
10091015
logger.Error(err, "Failed to parse GPU per worker value")
10101016
return nil, err
@@ -1083,5 +1089,6 @@ func (r *NIMServiceReconciler) reconcileDRAResources(ctx context.Context, nimSer
10831089
return fmt.Errorf("failed to reconcile DRAResource %s: %w", namedDraResource.ResourceName, err)
10841090
}
10851091
}
1092+
10861093
return nil
10871094
}

internal/controller/platform/standalone/nimservice_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,8 +1038,8 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
10381038
},
10391039
}
10401040

1041-
leaderEnv := utils.SortKeys(nimService.GetLWSLeaderEnv())
1042-
workerEnv := utils.SortKeys(nimService.GetLWSWorkerEnv())
1041+
leaderEnv := utils.SortKeys(nimService.GetLWSLeaderEnv(8))
1042+
workerEnv := utils.SortKeys(nimService.GetLWSWorkerEnv(8))
10431043

10441044
Expect(reflect.DeepEqual(leaderEnv, []corev1.EnvVar{
10451045
{
@@ -2317,7 +2317,7 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
23172317
},
23182318
}
23192319

2320-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
2320+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 8)
23212321
Expect(err).ToNot(HaveOccurred())
23222322
Expect(resources).ToNot(BeNil())
23232323

@@ -2332,7 +2332,7 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
23322332
Config: map[string]string{"tp": "4"},
23332333
}
23342334

2335-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
2335+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
23362336
Expect(err).ToNot(HaveOccurred())
23372337
Expect(resources).ToNot(BeNil())
23382338

@@ -2346,7 +2346,7 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
23462346
Config: map[string]string{},
23472347
}
23482348

2349-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
2349+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
23502350
Expect(err).ToNot(HaveOccurred())
23512351
Expect(resources).ToNot(BeNil())
23522352

@@ -2362,7 +2362,7 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
23622362
Config: map[string]string{},
23632363
}
23642364

2365-
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
2365+
resources, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
23662366
Expect(err).ToNot(HaveOccurred())
23672367
Expect(resources).ToNot(BeNil())
23682368

@@ -2376,7 +2376,7 @@ var _ = Describe("NIMServiceReconciler for a standalone platform", func() {
23762376
Config: map[string]string{"tp": "invalid"},
23772377
}
23782378

2379-
_, err := reconciler.addGPUResources(context.TODO(), nimService, profile)
2379+
_, err := reconciler.addGPUResources(context.TODO(), nimService, profile, 0)
23802380
Expect(err).To(HaveOccurred())
23812381
})
23822382
})

internal/k8sutil/resourceclaimsutil.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,11 @@ func GetResourceClaimState(claim *resourcev1beta2.ResourceClaim) string {
7474
}
7575
return strings.Join(states, ",")
7676
}
77+
78+
func GetResourceClaimTemplate(ctx context.Context, k8sclient client.Client, name string, namespace string) (*resourcev1beta2.ResourceClaimTemplate, error) {
79+
claimTemplate := &resourcev1beta2.ResourceClaimTemplate{}
80+
if err := k8sclient.Get(ctx, client.ObjectKey{Name: name, Namespace: namespace}, claimTemplate); err != nil {
81+
return nil, err
82+
}
83+
return claimTemplate, nil
84+
}

0 commit comments

Comments
 (0)