Skip to content

Commit 139f8d2

Browse files
Add "ReadExt" method and update KFTO tests
1 parent 8803090 commit 139f8d2

6 files changed

Lines changed: 20 additions & 56 deletions

File tree

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ toolchain go1.21.5
77
require (
88
github.com/kubeflow/training-operator v1.7.0
99
github.com/onsi/gomega v1.31.1
10-
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22
10+
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921
1111
github.com/prometheus/client_golang v1.20.4
1212
github.com/prometheus/common v0.57.0
1313
github.com/ray-project/kuberay/ray-operator v1.1.0-alpha.0

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
365365
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
366366
github.com/project-codeflare/appwrapper v0.8.0 h1:vWHNtXUtHutN2EzYb6rryLdESnb8iDXsCokXOuNYXvg=
367367
github.com/project-codeflare/appwrapper v0.8.0/go.mod h1:FMQ2lI3fz6LakUVXgN1FTdpsc3BBkNIZZgtMmM9J5UM=
368-
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22 h1:wzIJHoGAmNZupO3ZI7gbONuXgIUireabHsZvMt+3fqQ=
369-
github.com/project-codeflare/codeflare-common v0.0.0-20241108084652-0d76fd215a22/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
368+
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921 h1:OI9jKDW4yxbXDTpf4Y+8H4uVfdCH+jIqN0JTQfdUMYw=
369+
github.com/project-codeflare/codeflare-common v0.0.0-20241121090634-e99e941c6921/go.mod h1:v7XKwaDoCspsHQlWJNarO7gOpR+iumSS+c1bWs3kJOI=
370370
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
371371
github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo=
372372
github.com/prometheus/client_golang v1.7.1/go.mod h1:PY5Wy2awLA44sXw4AOSfFBetzPP4j5+D6mVACh+pe2M=

tests/kfto/core/environment.go

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,12 @@ const (
3434
minioCliImageEnvVar = "MINIO_CLI_IMAGE"
3535
// The environment variable for HuggingFace token to download models which require authentication
3636
huggingfaceTokenEnvVar = "HF_TOKEN"
37-
// The environment variable specifying existing namespace name to be used for tests
38-
testNamespaceEnvVar = "TEST_NAMESPACE_NAME"
3937
// The environment variable specifying name of PersistenceVolumeClaim containing GPTQ models
4038
gptqModelPvcNameEnvVar = "GPTQ_MODEL_PVC_NAME"
4139
// The environment variable referring to image simulating sleep condition in container
4240
sleepImageEnvVar = "SLEEP_IMAGE"
4341
// The environment variable specifying s3 bucket folder path used to store model
4442
storageBucketModelPath = "AWS_STORAGE_BUCKET_MODEL_PATH"
45-
// The environment variable for the CUDA training image
46-
cudaTrainingImageEnvVar = "CUDA_TRAINING_IMAGE"
47-
// The environment variable for the ROCm training image
48-
rocmTrainingImageEnvVar = "ROCM_TRAINING_IMAGE"
4943
)
5044

5145
func GetFmsHfTuningImage(t Test) string {
@@ -57,24 +51,6 @@ func GetFmsHfTuningImage(t Test) string {
5751
return image
5852
}
5953

60-
func GetCudaTrainingImage(t Test) string {
61-
t.T().Helper()
62-
image, ok := os.LookupEnv(cudaTrainingImageEnvVar)
63-
if !ok {
64-
t.T().Fatalf("Expected environment variable %s not found, please use this environment variable to specify the cuda training image to be tested.", cudaTrainingImageEnvVar)
65-
}
66-
return image
67-
}
68-
69-
func GetROCmTrainingImage(t Test) string {
70-
t.T().Helper()
71-
image, ok := os.LookupEnv(rocmTrainingImageEnvVar)
72-
if !ok {
73-
t.T().Fatalf("Expected environment variable %s not found, please use this environment variable to specify the cuda training image to be tested.", rocmTrainingImageEnvVar)
74-
}
75-
return image
76-
}
77-
7854
func GetBloomModelImage() string {
7955
return lookupEnvOrDefault(bloomModelImageEnvVar, "quay.io/ksuta/bloom-560m@sha256:f6db02bb7b5d09a8d698c04994d747bfb9e581bbb4c07d00290244d207623733")
8056
}
@@ -96,10 +72,6 @@ func GetHuggingFaceToken(t Test) string {
9672
return image
9773
}
9874

99-
func GetTestNamespaceName() (namespaceName string, exists bool) {
100-
return os.LookupEnv(testNamespaceEnvVar)
101-
}
102-
10375
func GetGptqModelPvcName() (string, error) {
10476
image, ok := os.LookupEnv(gptqModelPvcNameEnvVar)
10577
if !ok {

tests/kfto/core/kfto_pytorchjob_failed_test.go

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
package core
22

33
import (
4+
"testing"
5+
46
. "github.com/onsi/gomega"
57
. "github.com/project-codeflare/codeflare-common/support"
6-
"testing"
78

89
corev1 "k8s.io/api/core/v1"
910
"k8s.io/apimachinery/pkg/api/resource"
@@ -13,15 +14,11 @@ import (
1314
)
1415

1516
func TestPyTorchJobFailureWithCuda(t *testing.T) {
16-
test := With(t)
17-
cudaBaseImage := GetCudaTrainingImage(test)
18-
runFailedPyTorchJobTest(t, cudaBaseImage)
17+
runFailedPyTorchJobTest(t, GetCudaTrainingImage())
1918
}
2019

2120
func TestPyTorchJobFailureWithROCm(t *testing.T) {
22-
test := With(t)
23-
rocmBaseImage := GetROCmTrainingImage(test)
24-
runFailedPyTorchJobTest(t, rocmBaseImage)
21+
runFailedPyTorchJobTest(t, GetROCmTrainingImage())
2522
}
2623

2724
func runFailedPyTorchJobTest(t *testing.T, image string) {
@@ -65,7 +62,7 @@ func createFailedPyTorchJob(test Test, namespace string, config corev1.ConfigMap
6562
{
6663
Name: "pytorch",
6764
Image: baseImage,
68-
Command: []string{"python", "-c", "raise Exception('Test failure')"},
65+
Command: []string{"python", "-c", "raise Exception('Test failure')"},
6966
ImagePullPolicy: corev1.PullIfNotPresent,
7067
Resources: corev1.ResourceRequirements{
7168
Requests: corev1.ResourceList{

tests/kfto/core/kfto_training_test.go

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package core
1818

1919
import (
2020
"fmt"
21-
"os"
2221
"testing"
2322

2423
. "github.com/onsi/gomega"
@@ -32,17 +31,11 @@ import (
3231
)
3332

3433
func TestPyTorchJobWithCuda(t *testing.T) {
35-
test := With(t)
36-
cudaBaseImage := GetCudaTrainingImage(test)
37-
gpuLabel := "nvidia.com/gpu"
38-
runKFTOPyTorchJob(t, cudaBaseImage, gpuLabel, 1)
34+
runKFTOPyTorchJob(t, GetCudaTrainingImage(), "nvidia.com/gpu", 1)
3935
}
4036

4137
func TestPyTorchJobWithROCm(t *testing.T) {
42-
test := With(t)
43-
rocmBaseImage := GetROCmTrainingImage(test)
44-
gpuLabel := "amd.com/gpu"
45-
runKFTOPyTorchJob(t, rocmBaseImage, gpuLabel, 1)
38+
runKFTOPyTorchJob(t, GetROCmTrainingImage(), "amd.com/gpu", 1)
4639
}
4740

4841
func runKFTOPyTorchJob(t *testing.T, image string, gpuLabel string, numGpus int) {
@@ -51,16 +44,9 @@ func runKFTOPyTorchJob(t *testing.T, image string, gpuLabel string, numGpus int)
5144
// Create a namespace
5245
namespace := GetOrCreateTestNamespace(test)
5346

54-
// Parse training script
55-
trainingScriptPath := "hf_llm_training.py"
56-
trainingScript, err := os.ReadFile(trainingScriptPath)
57-
if err != nil {
58-
test.T().Fatalf("Error reading training script file: %v", err)
59-
}
60-
6147
// Create a ConfigMap with training script
6248
configData := map[string][]byte{
63-
"hf_llm_training.py": trainingScript,
49+
"hf_llm_training.py": ReadFileExt(test, "hf_llm_training.py"),
6450
}
6551
config := CreateConfigMap(test, namespace, configData)
6652

tests/kfto/core/support.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@ package core
1919
import (
2020
"embed"
2121
"fmt"
22+
"os"
2223
"time"
2324

25+
"github.com/onsi/gomega"
2426
. "github.com/onsi/gomega"
2527
. "github.com/project-codeflare/codeflare-common/support"
2628

@@ -41,6 +43,13 @@ func ReadFile(t Test, fileName string) []byte {
4143
return file
4244
}
4345

46+
func ReadFileExt(t Test, fileName string) []byte {
47+
t.T().Helper()
48+
file, err := os.ReadFile(fileName)
49+
t.Expect(err).NotTo(gomega.HaveOccurred())
50+
return file
51+
}
52+
4453
func PyTorchJob(t Test, namespace, name string) func(g Gomega) *kftov1.PyTorchJob {
4554
return func(g Gomega) *kftov1.PyTorchJob {
4655
job, err := t.Client().Kubeflow().KubeflowV1().PyTorchJobs(namespace).Get(t.Ctx(), name, metav1.GetOptions{})

0 commit comments

Comments
 (0)