Skip to content

Commit d036417

Browse files
[rayapp] Test templates against different ray versions (#453)
Adding --ray-version flag to rayapp test Tested against model-composition-recsys template <img width="2904" height="1580" alt="image" src="https://github.com/user-attachments/assets/10a9411a-5b1d-434e-ab27-b5c56753fbeb" /> <img width="1452" height="790" alt="image" src="https://github.com/user-attachments/assets/0237e8c6-4c7f-4976-8d56-9f323e180426" /> --------- Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com> Signed-off-by: Elliot Barnwell <elliot.barnwell@anyscale.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent d3197ad commit d036417

File tree

9 files changed

+337
-19
lines changed

9 files changed

+337
-19
lines changed

rayapp/rayapp/main.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ func main() {
2222

2323
testFlags := flag.NewFlagSet("test", flag.ExitOnError)
2424
testBuildFile := testFlags.String("build", "BUILD.yaml", "build file")
25+
testRayVersion := testFlags.String("ray-version", "", "ray version to test against")
2526

2627
probeFlags := flag.NewFlagSet("probe", flag.ExitOnError)
2728
probeBuildFile := probeFlags.String("build", "BUILD.yaml", "build file")
@@ -49,11 +50,11 @@ func main() {
4950
log.Fatal("test requires <template-name> or 'all'")
5051
}
5152
if args[0] == "all" {
52-
if err := rayapp.RunAllTemplateTests(*testBuildFile); err != nil {
53+
if err := rayapp.RunAllTemplateTests(*testBuildFile, *testRayVersion); err != nil {
5354
log.Fatal(err)
5455
}
5556
} else {
56-
if err := rayapp.RunTemplateTest(args[0], *testBuildFile); err != nil {
57+
if err := rayapp.RunTemplateTest(args[0], *testBuildFile, *testRayVersion); err != nil {
5758
log.Fatal(err)
5859
}
5960
}
@@ -88,5 +89,6 @@ func printUsage() {
8889
fmt.Println(" --build string Build file (default \"BUILD.yaml\")")
8990
fmt.Println()
9091
fmt.Println("Test flags (test):")
91-
fmt.Println(" --build string Build file (default \"BUILD.yaml\")")
92+
fmt.Println(" --build string Build file (default \"BUILD.yaml\")")
93+
fmt.Println(" --ray-version string ray version to test against")
9294
}

rayapp/template.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ var validTemplateName = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)
5656

5757
var buildIDVersionFindRe = regexp.MustCompile(`(\d)(\d{2})(\d+)`)
5858

59+
// isRayImageURI returns true if the image URI refers to a ray image
60+
// (e.g. "anyscale/ray:2.44.0", "anyscale/ray-llm:2.44.0-py312").
61+
func isRayImageURI(imageURI string) bool {
62+
parts := strings.SplitN(imageURI, "/", 2)
63+
if len(parts) < 2 {
64+
return false
65+
}
66+
imageName := strings.SplitN(parts[1], ":", 2)[0]
67+
return imageName == "ray" || strings.HasPrefix(imageName, "ray-")
68+
}
69+
5970
// buildIDToImageName maps build ID slugified image-type (after "anyscale") to image name for URI.
6071
var buildIDToImageName = map[string]string{
6172
"ray": "ray",
@@ -120,6 +131,9 @@ func convertImageURIToBuildID(imageURI string) (buildID string, err error) {
120131

121132
var imageURIVersionRe = regexp.MustCompile(`(\d)\.(\d{2})\.(\d+)`)
122133

134+
// validRayVersionRe matches a full ray version string like "2.44.0".
135+
var validRayVersionRe = regexp.MustCompile(`^\d\.\d{2}\.\d+$`)
136+
123137
// extractRayVersionFromImageURI returns the ray version from the image URI.
124138
func extractRayVersionFromImageURI(imageURI string) (rayVersion string, err error) {
125139
matches := imageURIVersionRe.FindStringSubmatch(imageURI)
@@ -134,6 +148,25 @@ func extractRayVersionFromImageURI(imageURI string) (rayVersion string, err erro
134148
return fmt.Sprintf("%s.%s.%s", major, minor, patch), nil
135149
}
136150

151+
// overrideClusterEnvRayVersion returns a new ClusterEnv with its ray version replaced.
152+
// It converts the original cluster env to an image URI, swaps the version portion, and
153+
// returns a ClusterEnv using image_uri with the new version.
154+
func overrideClusterEnvRayVersion(env *ClusterEnv, newVersion string) (*ClusterEnv, error) {
155+
if env.BYOD != nil {
156+
return env, nil
157+
}
158+
159+
imageURI, _, err := getImageURIAndRayVersionFromClusterEnv(env)
160+
if err != nil {
161+
return nil, fmt.Errorf("resolve cluster env: %w", err)
162+
}
163+
164+
newImageURI := imageURIVersionRe.ReplaceAllStringFunc(
165+
imageURI, func(string) string { return newVersion },
166+
)
167+
return &ClusterEnv{ImageURI: newImageURI}, nil
168+
}
169+
137170
// getImageURIAndRayVersionFromClusterEnv returns image URI and ray version from cluster env.
138171
// It supports BYOD (docker_image + ray_version) or BuildID/ImageURI.
139172
func getImageURIAndRayVersionFromClusterEnv(env *ClusterEnv) (string, string, error) {

rayapp/template_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,28 @@ func TestConvertBuildIDToImageURI(t *testing.T) {
721721
}
722722
}
723723

724+
func TestIsRayImageURI(t *testing.T) {
725+
tests := []struct {
726+
imageURI string
727+
want bool
728+
}{
729+
{"anyscale/ray:2.44.0-py311", true},
730+
{"anyscale/ray-llm:2.44.1-py312-cu128", true},
731+
{"anyscale/ray-ml:2.44.0-py311", true},
732+
{"other/ray:2.44.0", true},
733+
{"anyscale/notray:2.44.0", false},
734+
{"anyscale/myimage:latest", false},
735+
{"noregistry", false},
736+
}
737+
for _, tt := range tests {
738+
t.Run(tt.imageURI, func(t *testing.T) {
739+
if got := isRayImageURI(tt.imageURI); got != tt.want {
740+
t.Errorf("isRayImageURI(%q) = %v, want %v", tt.imageURI, got, tt.want)
741+
}
742+
})
743+
}
744+
}
745+
724746
func TestConvertImageURIToBuildID(t *testing.T) {
725747
tests := []struct {
726748
name string
@@ -760,3 +782,57 @@ func TestConvertImageURIToBuildID(t *testing.T) {
760782
})
761783
}
762784
}
785+
786+
func TestOverrideClusterEnvRayVersion(t *testing.T) {
787+
tests := []struct {
788+
name string
789+
env *ClusterEnv
790+
newVersion string
791+
wantImageURI string
792+
wantSameEnv bool
793+
}{
794+
{
795+
name: "override build_id",
796+
env: &ClusterEnv{BuildID: "anyscaleray2370-py311"},
797+
newVersion: "2.44.0",
798+
wantImageURI: "anyscale/ray:2.44.0-py311",
799+
},
800+
{
801+
name: "override image_uri",
802+
env: &ClusterEnv{ImageURI: "anyscale/ray:2.37.0-py311"},
803+
newVersion: "2.44.0",
804+
wantImageURI: "anyscale/ray:2.44.0-py311",
805+
},
806+
{
807+
name: "BYOD returned unchanged",
808+
env: &ClusterEnv{
809+
BYOD: &ClusterEnvBYOD{
810+
DockerImage: "cr.ray.io/ray:2.37.0-py311",
811+
RayVersion: "2.37.0",
812+
},
813+
},
814+
newVersion: "2.44.0",
815+
wantSameEnv: true,
816+
},
817+
}
818+
for _, tt := range tests {
819+
t.Run(tt.name, func(t *testing.T) {
820+
got, err := overrideClusterEnvRayVersion(tt.env, tt.newVersion)
821+
if err != nil {
822+
t.Fatalf("unexpected error: %v", err)
823+
}
824+
if tt.wantSameEnv {
825+
if got != tt.env {
826+
t.Errorf("expected same env pointer back, got a different one")
827+
}
828+
} else {
829+
if got.ImageURI != tt.wantImageURI {
830+
t.Errorf("ImageURI = %q, want %q", got.ImageURI, tt.wantImageURI)
831+
}
832+
if got.BuildID != "" {
833+
t.Errorf("BuildID should be empty, got %q", got.BuildID)
834+
}
835+
}
836+
})
837+
}
838+
}

rayapp/template_test_runner.go

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,17 @@ func probe(tmplName, buildFile string, cli *AnyscaleCLI, api *anyscaleAPI) error
9292
return nil
9393
}
9494

95-
func RunAllTemplateTests(buildFile string) error {
95+
func RunAllTemplateTests(buildFile, rayVersion string) error {
9696
cli := NewAnyscaleCLI()
9797
host, token := os.Getenv("ANYSCALE_HOST"), os.Getenv("ANYSCALE_CLI_TOKEN")
9898
api, err := newAnyscaleAPI(host, token)
9999
if err != nil {
100100
return fmt.Errorf("new anyscale api failed: %w", err)
101101
}
102-
return runTemplateTestsWithFilter(buildFile, nil, cli, api)
102+
return runTemplateTestsWithFilter(buildFile, nil, rayVersion, cli, api)
103103
}
104104

105-
func RunTemplateTest(tmplName, buildFile string) error {
105+
func RunTemplateTest(tmplName, buildFile, rayVersion string) error {
106106
cli := NewAnyscaleCLI()
107107
host, token := os.Getenv("ANYSCALE_HOST"), os.Getenv("ANYSCALE_CLI_TOKEN")
108108
api, err := newAnyscaleAPI(host, token)
@@ -111,15 +111,23 @@ func RunTemplateTest(tmplName, buildFile string) error {
111111
}
112112
return runTemplateTestsWithFilter(buildFile, func(tmpl *Template) bool {
113113
return tmpl.Name == tmplName
114-
}, cli, api)
114+
}, rayVersion, cli, api)
115115
}
116116

117117
func runTemplateTestsWithFilter(
118118
buildFile string,
119119
filter func(tmpl *Template) bool,
120+
rayVersion string,
120121
cli *AnyscaleCLI,
121122
api *anyscaleAPI,
122123
) error {
124+
if rayVersion != "" && !validRayVersionRe.MatchString(rayVersion) {
125+
return fmt.Errorf(
126+
"invalid ray version %q: must match X.YY.Z (e.g. 2.44.0)",
127+
rayVersion,
128+
)
129+
}
130+
123131
tmpls, err := readTemplates(buildFile)
124132
if err != nil {
125133
return fmt.Errorf("read templates failed: %w", err)
@@ -138,6 +146,39 @@ func runTemplateTestsWithFilter(
138146
skippedNoTest++
139147
continue
140148
}
149+
if rayVersion != "" {
150+
if t.ClusterEnv == nil {
151+
log.Printf(
152+
"Template %s has no cluster_env, "+
153+
"skipping ray version override",
154+
t.Name,
155+
)
156+
continue
157+
}
158+
hasBuildID := strings.TrimSpace(t.ClusterEnv.BuildID) != ""
159+
hasImageURI := strings.TrimSpace(t.ClusterEnv.ImageURI) != ""
160+
if !hasBuildID && !hasImageURI {
161+
log.Printf(
162+
"Template %s has no build_id or image_uri, "+
163+
"skipping ray version override",
164+
t.Name,
165+
)
166+
continue
167+
}
168+
if hasImageURI && !isRayImageURI(t.ClusterEnv.ImageURI) {
169+
log.Printf(
170+
"Template %s image_uri %q is not a ray image, "+
171+
"skipping ray version override",
172+
t.Name, t.ClusterEnv.ImageURI,
173+
)
174+
continue
175+
}
176+
env, err := overrideClusterEnvRayVersion(t.ClusterEnv, rayVersion)
177+
if err != nil {
178+
return fmt.Errorf("override ray version for %q: %w", t.Name, err)
179+
}
180+
t.ClusterEnv = env
181+
}
141182
filteredTmpls = append(filteredTmpls, t)
142183
}
143184
if len(filteredTmpls) == 0 {

0 commit comments

Comments
 (0)