@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
55you may not use this file except in compliance with the License.
66You may obtain a copy of the License at
77
8- http://www.apache.org/licenses/LICENSE-2.0
8+ http://www.apache.org/licenses/LICENSE-2.0
99
1010Unless required by applicable law or agreed to in writing, software
1111distributed under the License is distributed on an "AS IS" BASIS,
@@ -18,32 +18,118 @@ package xgboost
1818
1919import (
2020 "context"
21- /// "fmt"
21+ "fmt"
2222
23+ "k8s.io/apimachinery/pkg/util/validation/field"
24+ corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
25+ "k8s.io/utils/ptr"
2326 "sigs.k8s.io/controller-runtime/pkg/client"
27+ "sigs.k8s.io/controller-runtime/pkg/webhook/admission"
2428
2529 trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
30+ "github.com/kubeflow/trainer/v2/pkg/apply"
31+ "github.com/kubeflow/trainer/v2/pkg/constants"
2632 "github.com/kubeflow/trainer/v2/pkg/runtime"
2733 "github.com/kubeflow/trainer/v2/pkg/runtime/framework"
2834)
2935
30- // XGBoost implements the EnforceMLPolicyPlugin interface for distributed
31- // XGBoost training using Rabit coordination.
3236type XGBoost struct {}
3337
3438var _ framework.EnforceMLPolicyPlugin = (* XGBoost )(nil )
39+ var _ framework.CustomValidationPlugin = (* XGBoost )(nil )
3540
3641const Name = "XGBoost"
3742
3843func New (context.Context , client.Client , client.FieldIndexer ) (framework.Plugin , error ) {
3944 return & XGBoost {}, nil
4045}
46+
4147func (x * XGBoost ) Name () string {
4248 return Name
4349}
4450
45- // TODO: Inject DMLC_* Rabit environment variables for
46- // distributed XGBoost training. See KEP for env var specification.
51+ func (x * XGBoost ) Validate (_ context.Context , runtimeInfo * runtime.Info , _ , newObj * trainer.TrainJob ) (admission.Warnings , field.ErrorList ) {
52+ var allErrs field.ErrorList
53+ if newObj .Spec .Trainer != nil {
54+ specPath := field .NewPath ("spec" , "trainer" , "env" )
55+ for i , env := range newObj .Spec .Trainer .Env {
56+ if constants .XGBoostReservedEnvNames .Has (env .Name ) {
57+ allErrs = append (allErrs , field .Forbidden (
58+ specPath .Index (i ),
59+ fmt .Sprintf ("%s is reserved for the XGBoost runtime" , env .Name ),
60+ ))
61+ }
62+ }
63+ }
64+ return nil , allErrs
65+ }
66+
4767func (x * XGBoost ) EnforceMLPolicy (info * runtime.Info , trainJob * trainer.TrainJob ) error {
68+ // Guard: Return early if XGBoost policy not configured.
69+ if info == nil || info .RuntimePolicy .MLPolicySource == nil ||
70+ info .RuntimePolicy .MLPolicySource .XGBoost == nil {
71+ return nil
72+ }
73+
74+ // Find the trainer PodSet.
75+ trainerPS := info .FindPodSetByAncestor (constants .AncestorTrainer )
76+
77+ // Set the number of nodes from TrainJob if specified.
78+ if trainerPS .Count != nil &&
79+ trainJob .Spec .Trainer != nil && trainJob .Spec .Trainer .NumNodes != nil {
80+ * trainerPS .Count = * trainJob .Spec .Trainer .NumNodes
81+ }
82+
83+ // Find the trainer container and inject environment variables.
84+ var trainerContainer * runtime.Container
85+ if trainJob .Spec .Trainer != nil {
86+ if trainerContainer = info .FindContainerByPodSetAncestorContainerName (
87+ constants .AncestorTrainer , constants .Node ,
88+ ); trainerContainer != nil {
89+ numNodes := ptr .Deref (ptr .Deref (trainerPS , runtime.PodSet {}).Count , 1 )
90+
91+ // Auto-derive numWorkersPerNode from GPU resources.
92+ // GPU training: 1 worker per GPU | CPU training: 1 worker per node.
93+ numWorkersPerNode := int32 (1 )
94+ if res := runtime .ExtractResourcePerNodeFromRuntime (info ); res != nil {
95+ if gpuCount := runtime .GetNumGPUPerNode (res ); gpuCount > 0 {
96+ numWorkersPerNode = int32 (gpuCount )
97+ }
98+ }
99+ totalWorkers := numNodes * numWorkersPerNode
100+
101+ // Build tracker URI: <trainjob-name>-node-0-0.<trainjob-name>
102+ trackerURI := fmt .Sprintf ("%s-%s-0-0.%s" ,
103+ trainJob .Name , constants .Node , trainJob .Name )
104+
105+ // Inject DMLC_* environment variables.
106+ apply .UpsertEnvVars (& trainerContainer .Env ,
107+ // DMLC_TRACKER_URI - DNS name for rank-0 worker running tracker.
108+ * corev1ac .EnvVar ().
109+ WithName (constants .XGBoostEnvTrackerURI ).
110+ WithValue (trackerURI ),
111+ // DMLC_TRACKER_PORT - Default tracker port.
112+ * corev1ac .EnvVar ().
113+ WithName (constants .XGBoostEnvTrackerPort ).
114+ WithValue (fmt .Sprintf ("%d" , constants .XGBoostDefaultTrackerPort )),
115+ // DMLC_TASK_ID - Worker rank from Job completion index.
116+ * corev1ac .EnvVar ().
117+ WithName (constants .XGBoostEnvTaskID ).
118+ WithValueFrom (corev1ac .EnvVarSource ().
119+ WithFieldRef (corev1ac .ObjectFieldSelector ().
120+ WithFieldPath (constants .JobCompletionIndexFieldPath ))),
121+ // DMLC_NUM_WORKER - Total number of workers.
122+ * corev1ac .EnvVar ().
123+ WithName (constants .XGBoostEnvNumWorker ).
124+ WithValue (fmt .Sprintf ("%d" , totalWorkers )),
125+ )
126+
127+ // Add container port for tracker communication.
128+ apply .UpsertPort (& trainerContainer .Ports ,
129+ * corev1ac .ContainerPort ().
130+ WithContainerPort (constants .XGBoostDefaultTrackerPort ))
131+ }
132+ }
133+
48134 return nil
49135}
0 commit comments