Skip to content

Commit 5ff5893

Browse files
fix: add ownerReferences to resources created by arena (#1407)
Signed-off-by: zhoujinyu <2319109590@qq.com>
1 parent a4944e2 commit 5ff5893

2 files changed

Lines changed: 47 additions & 14 deletions

File tree

pkg/util/kubectl/kubectl.go

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ import (
3333
lwsClient "sigs.k8s.io/lws/client-go/clientset/versioned"
3434

3535
"github.com/kubeflow/arena/pkg/apis/config"
36+
"github.com/kubeflow/arena/pkg/apis/types"
3637
)
3738

3839
var kubectlCmd = []string{"arena-kubectl"}
@@ -385,7 +386,36 @@ func UpdateLWSJob(lwsJob *lwsv1.LeaderWorkerSet) error {
385386
return err
386387
}
387388

388-
// PatchOwnerReferenceWithAppInfoFile patch tfjob / pytorchjob ownerReference
389+
// getTrainingJobCRDName returns the CRD resource name for a given training job type
390+
// Returns empty string if the training type is not supported
391+
func getTrainingJobCRDName(trainingType string) string {
392+
switch trainingType {
393+
case string(types.TFTrainingJob):
394+
return "tfjob.kubeflow.org"
395+
case string(types.PytorchTrainingJob):
396+
return "pytorchjob.kubeflow.org"
397+
case string(types.MPITrainingJob), string(types.HorovodTrainingJob):
398+
return "mpijob.kubeflow.org"
399+
case string(types.ETTrainingJob), string(types.DeepSpeedTrainingJob):
400+
return "trainingjob.kai.alibabacloud.com"
401+
case string(types.VolcanoTrainingJob):
402+
return "job.batch.volcano.sh"
403+
case string(types.SparkTrainingJob):
404+
return "sparkapplication.sparkoperator.k8s.io"
405+
case string(types.RayJob):
406+
return "rayjob.ray.io"
407+
default:
408+
return ""
409+
}
410+
}
411+
412+
// getTrainingJobCRDResourceName returns the full resource name (kind.apiVersion) for skipping owner reference patching
413+
func getTrainingJobCRDResourceName(trainingType, name string) string {
414+
crdName := getTrainingJobCRDName(trainingType)
415+
return fmt.Sprintf("%s/%s", crdName, name)
416+
}
417+
418+
// PatchOwnerReferenceWithAppInfoFile patch training job ownerReference for all resources created by Arena
389419
func PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFile, namespace string) error {
390420
data, err := os.ReadFile(appInfoFile)
391421
if err != nil {
@@ -405,8 +435,14 @@ func PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFile, namespa
405435
}
406436
errs := []string{}
407437

438+
// get training job CRD resource name
439+
crdResourceName := getTrainingJobCRDName(trainingType)
440+
if crdResourceName == "" {
441+
return fmt.Errorf("unsupported training job type: %s", trainingType)
442+
}
443+
408444
// get training job
409-
args := []string{binary, "get", trainingType, name, "--namespace", namespace, "-o json"}
445+
args := []string{binary, "get", crdResourceName, name, "--namespace", namespace, "-o json"}
410446
cmd := exec.Command("bash", "-c", strings.Join(args, " "))
411447
out, err := cmd.CombinedOutput()
412448
if err != nil {
@@ -427,10 +463,12 @@ func PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFile, namespa
427463
configmapName := fmt.Sprintf("%v-%v", name, trainingType)
428464
resources = append(resources, "configmap/"+configmapName)
429465

466+
// get the training job resource name to skip
467+
trainingJobResourceName := getTrainingJobCRDResourceName(trainingType, name)
468+
430469
for _, resource := range resources {
431-
// skip tfjob / pytorchjob.
432-
if resource == "tfjob.kubeflow.org/"+name ||
433-
resource == "pytorchjob.kubeflow.org/"+name {
470+
// skip the training job CRD itself
471+
if resource == trainingJobResourceName {
434472
continue
435473
}
436474

pkg/workflow/workflow.go

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ import (
2424
log "github.com/sirupsen/logrus"
2525
"github.com/spf13/viper"
2626
k8serrors "k8s.io/apimachinery/pkg/api/errors"
27-
28-
"github.com/kubeflow/arena/pkg/apis/types"
2927
)
3028

3129
/**
@@ -200,13 +198,10 @@ func SubmitJob(name string, trainingType string, namespace string, values interf
200198
return err
201199
}
202200

203-
// 6. Patch OwnerReference for tfjob / pytorchjob
204-
if trainingType == string(types.TFTrainingJob) ||
205-
trainingType == string(types.PytorchTrainingJob) {
206-
err := kubectl.PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFileName, namespace)
207-
if err != nil {
208-
log.Debugf("Failed to patch ownerReference %s due to %v`", name, err)
209-
}
201+
// 6. Patch OwnerReference for all training job types
202+
err = kubectl.PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFileName, namespace)
203+
if err != nil {
204+
log.Debugf("Failed to patch ownerReference %s due to %v", name, err)
210205
}
211206

212207
// 7. Clean up the template file

0 commit comments

Comments
 (0)