Skip to content

Commit 7ec359f

Browse files
committed
Add XGBoost plugin unit tests and update framework test registry
Signed-off-by: Krishna-kg732 <2405732@kiit.ac.in>
1 parent e5c552e commit 7ec359f

File tree

5 files changed

+576
-7
lines changed

5 files changed

+576
-7
lines changed

pkg/runtime/framework/core/framework_test.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ import (
5656
"github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/plainml"
5757
"github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/torch"
5858
"github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/volcano"
59+
"github.com/kubeflow/trainer/v2/pkg/runtime/framework/plugins/xgboost"
5960
index "github.com/kubeflow/trainer/v2/pkg/runtime/indexer"
6061
testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing"
6162
)
@@ -84,12 +85,14 @@ func TestNew(t *testing.T) {
8485
torch.Name: &torch.Torch{},
8586
jobset.Name: &jobset.JobSet{},
8687
jax.Name: &jax.Jax{},
88+
xgboost.Name: &xgboost.XGBoost{},
8789
},
8890
enforceMLPlugins: []framework.EnforceMLPolicyPlugin{
8991
&mpi.MPI{},
9092
&plainml.PlainML{},
9193
&torch.Torch{},
9294
&jax.Jax{},
95+
&xgboost.XGBoost{},
9396
},
9497
enforcePodGroupPolicyPlugins: []framework.EnforcePodGroupPolicyPlugin{
9598
&coscheduling.CoScheduling{},
@@ -101,6 +104,7 @@ func TestNew(t *testing.T) {
101104
&jobset.JobSet{},
102105
&volcano.Volcano{},
103106
&jax.Jax{},
107+
&xgboost.XGBoost{},
104108
},
105109
watchExtensionPlugins: []framework.WatchExtensionPlugin{
106110
&coscheduling.CoScheduling{},
@@ -137,7 +141,7 @@ func TestNew(t *testing.T) {
137141
}
138142
cmpOpts := []cmp.Option{
139143
cmp.AllowUnexported(Framework{}),
140-
cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, volcano.Volcano{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}),
144+
cmpopts.IgnoreUnexported(coscheduling.CoScheduling{}, volcano.Volcano{}, mpi.MPI{}, plainml.PlainML{}, torch.Torch{}, jobset.JobSet{}, xgboost.XGBoost{}),
141145
cmpopts.IgnoreFields(coscheduling.CoScheduling{}, "client"),
142146
cmpopts.IgnoreFields(volcano.Volcano{}, "client"),
143147
cmpopts.IgnoreFields(jobset.JobSet{}, "client", "restMapper", "scheme", "logger"),

pkg/runtime/framework/plugins/xgboost/xgboost.go

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
66
You may obtain a copy of the License at
77
8-
http://www.apache.org/licenses/LICENSE-2.0
8+
http://www.apache.org/licenses/LICENSE-2.0
99
1010
Unless required by applicable law or agreed to in writing, software
1111
distributed under the License is distributed on an "AS IS" BASIS,
@@ -18,32 +18,118 @@ package xgboost
1818

1919
import (
2020
"context"
21-
///"fmt"
21+
"fmt"
2222

23+
"k8s.io/apimachinery/pkg/util/validation/field"
24+
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
25+
"k8s.io/utils/ptr"
2326
"sigs.k8s.io/controller-runtime/pkg/client"
27+
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
2428

2529
trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
30+
"github.com/kubeflow/trainer/v2/pkg/apply"
31+
"github.com/kubeflow/trainer/v2/pkg/constants"
2632
"github.com/kubeflow/trainer/v2/pkg/runtime"
2733
"github.com/kubeflow/trainer/v2/pkg/runtime/framework"
2834
)
2935

30-
// XGBoost implements the EnforceMLPolicyPlugin interface for distributed
31-
// XGBoost training using Rabit coordination.
3236
type XGBoost struct{}
3337

3438
var _ framework.EnforceMLPolicyPlugin = (*XGBoost)(nil)
39+
var _ framework.CustomValidationPlugin = (*XGBoost)(nil)
3540

3641
const Name = "XGBoost"
3742

3843
func New(context.Context, client.Client, client.FieldIndexer) (framework.Plugin, error) {
3944
return &XGBoost{}, nil
4045
}
46+
4147
func (x *XGBoost) Name() string {
4248
return Name
4349
}
4450

45-
// TODO: Inject DMLC_* Rabit environment variables for
46-
// distributed XGBoost training. See KEP for env var specification.
51+
func (x *XGBoost) Validate(_ context.Context, runtimeInfo *runtime.Info, _, newObj *trainer.TrainJob) (admission.Warnings, field.ErrorList) {
52+
var allErrs field.ErrorList
53+
if newObj.Spec.Trainer != nil {
54+
specPath := field.NewPath("spec", "trainer", "env")
55+
for i, env := range newObj.Spec.Trainer.Env {
56+
if constants.XGBoostReservedEnvNames.Has(env.Name) {
57+
allErrs = append(allErrs, field.Forbidden(
58+
specPath.Index(i),
59+
fmt.Sprintf("%s is reserved for the XGBoost runtime", env.Name),
60+
))
61+
}
62+
}
63+
}
64+
return nil, allErrs
65+
}
66+
4767
func (x *XGBoost) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) error {
68+
// Guard: Return early if XGBoost policy not configured.
69+
if info == nil || info.RuntimePolicy.MLPolicySource == nil ||
70+
info.RuntimePolicy.MLPolicySource.XGBoost == nil {
71+
return nil
72+
}
73+
74+
// Find the trainer PodSet.
75+
trainerPS := info.FindPodSetByAncestor(constants.AncestorTrainer)
76+
77+
// Set the number of nodes from TrainJob if specified.
78+
if trainerPS.Count != nil &&
79+
trainJob.Spec.Trainer != nil && trainJob.Spec.Trainer.NumNodes != nil {
80+
*trainerPS.Count = *trainJob.Spec.Trainer.NumNodes
81+
}
82+
83+
// Find the trainer container and inject environment variables.
84+
var trainerContainer *runtime.Container
85+
if trainJob.Spec.Trainer != nil {
86+
if trainerContainer = info.FindContainerByPodSetAncestorContainerName(
87+
constants.AncestorTrainer, constants.Node,
88+
); trainerContainer != nil {
89+
numNodes := ptr.Deref(ptr.Deref(trainerPS, runtime.PodSet{}).Count, 1)
90+
91+
// Auto-derive numWorkersPerNode from GPU resources.
92+
// GPU training: 1 worker per GPU | CPU training: 1 worker per node.
93+
numWorkersPerNode := int32(1)
94+
if res := runtime.ExtractResourcePerNodeFromRuntime(info); res != nil {
95+
if gpuCount := runtime.GetNumGPUPerNode(res); gpuCount > 0 {
96+
numWorkersPerNode = int32(gpuCount)
97+
}
98+
}
99+
totalWorkers := numNodes * numWorkersPerNode
100+
101+
// Build tracker URI: <trainjob-name>-node-0-0.<trainjob-name>
102+
trackerURI := fmt.Sprintf("%s-%s-0-0.%s",
103+
trainJob.Name, constants.Node, trainJob.Name)
104+
105+
// Inject DMLC_* environment variables.
106+
apply.UpsertEnvVars(&trainerContainer.Env,
107+
// DMLC_TRACKER_URI - DNS name for rank-0 worker running tracker.
108+
*corev1ac.EnvVar().
109+
WithName(constants.XGBoostEnvTrackerURI).
110+
WithValue(trackerURI),
111+
// DMLC_TRACKER_PORT - Default tracker port.
112+
*corev1ac.EnvVar().
113+
WithName(constants.XGBoostEnvTrackerPort).
114+
WithValue(fmt.Sprintf("%d", constants.XGBoostDefaultTrackerPort)),
115+
// DMLC_TASK_ID - Worker rank from Job completion index.
116+
*corev1ac.EnvVar().
117+
WithName(constants.XGBoostEnvTaskID).
118+
WithValueFrom(corev1ac.EnvVarSource().
119+
WithFieldRef(corev1ac.ObjectFieldSelector().
120+
WithFieldPath(constants.JobCompletionIndexFieldPath))),
121+
// DMLC_NUM_WORKER - Total number of workers.
122+
*corev1ac.EnvVar().
123+
WithName(constants.XGBoostEnvNumWorker).
124+
WithValue(fmt.Sprintf("%d", totalWorkers)),
125+
)
126+
127+
// Add container port for tracker communication.
128+
apply.UpsertPort(&trainerContainer.Ports,
129+
*corev1ac.ContainerPort().
130+
WithContainerPort(constants.XGBoostDefaultTrackerPort))
131+
}
132+
}
133+
48134
return nil
49135
}

0 commit comments

Comments
 (0)