@@ -33,6 +33,7 @@ import (
3333 lwsClient "sigs.k8s.io/lws/client-go/clientset/versioned"
3434
3535 "github.com/kubeflow/arena/pkg/apis/config"
36+ "github.com/kubeflow/arena/pkg/apis/types"
3637)
3738
3839var kubectlCmd = []string {"arena-kubectl" }
@@ -385,7 +386,36 @@ func UpdateLWSJob(lwsJob *lwsv1.LeaderWorkerSet) error {
385386 return err
386387}
387388
388- // PatchOwnerReferenceWithAppInfoFile patch tfjob / pytorchjob ownerReference
389+ // getTrainingJobCRDName returns the CRD resource name for a given training job type
390+ // Returns empty string if the training type is not supported
391+ func getTrainingJobCRDName (trainingType string ) string {
392+ switch trainingType {
393+ case string (types .TFTrainingJob ):
394+ return "tfjob.kubeflow.org"
395+ case string (types .PytorchTrainingJob ):
396+ return "pytorchjob.kubeflow.org"
397+ case string (types .MPITrainingJob ), string (types .HorovodTrainingJob ):
398+ return "mpijob.kubeflow.org"
399+ case string (types .ETTrainingJob ), string (types .DeepSpeedTrainingJob ):
400+ return "trainingjob.kai.alibabacloud.com"
401+ case string (types .VolcanoTrainingJob ):
402+ return "job.batch.volcano.sh"
403+ case string (types .SparkTrainingJob ):
404+ return "sparkapplication.sparkoperator.k8s.io"
405+ case string (types .RayJob ):
406+ return "rayjob.ray.io"
407+ default :
408+ return ""
409+ }
410+ }
411+
412+ // getTrainingJobCRDResourceName returns the full resource name (kind.apiVersion) for skipping owner reference patching
413+ func getTrainingJobCRDResourceName (trainingType , name string ) string {
414+ crdName := getTrainingJobCRDName (trainingType )
415+ return fmt .Sprintf ("%s/%s" , crdName , name )
416+ }
417+
418+ // PatchOwnerReferenceWithAppInfoFile patch training job ownerReference for all resources created by Arena
389419func PatchOwnerReferenceWithAppInfoFile (name , trainingType , appInfoFile , namespace string ) error {
390420 data , err := os .ReadFile (appInfoFile )
391421 if err != nil {
@@ -405,8 +435,14 @@ func PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFile, namespa
405435 }
406436 errs := []string {}
407437
438+ // get training job CRD resource name
439+ crdResourceName := getTrainingJobCRDName (trainingType )
440+ if crdResourceName == "" {
441+ return fmt .Errorf ("unsupported training job type: %s" , trainingType )
442+ }
443+
408444 // get training job
409- args := []string {binary , "get" , trainingType , name , "--namespace" , namespace , "-o json" }
445+ args := []string {binary , "get" , crdResourceName , name , "--namespace" , namespace , "-o json" }
410446 cmd := exec .Command ("bash" , "-c" , strings .Join (args , " " ))
411447 out , err := cmd .CombinedOutput ()
412448 if err != nil {
@@ -427,10 +463,12 @@ func PatchOwnerReferenceWithAppInfoFile(name, trainingType, appInfoFile, namespa
427463 configmapName := fmt .Sprintf ("%v-%v" , name , trainingType )
428464 resources = append (resources , "configmap/" + configmapName )
429465
466+ // get the training job resource name to skip
467+ trainingJobResourceName := getTrainingJobCRDResourceName (trainingType , name )
468+
430469 for _ , resource := range resources {
431- // skip tfjob / pytorchjob.
432- if resource == "tfjob.kubeflow.org/" + name ||
433- resource == "pytorchjob.kubeflow.org/" + name {
470+ // skip the training job CRD itself
471+ if resource == trainingJobResourceName {
434472 continue
435473 }
436474
0 commit comments