Skip to content

Commit 16060b6

Browse files
committed
chore(runtimes): persist runtimes map and expose Runtimes function
Signed-off-by: Tomas Tormo <tomas.tormo@gmail.com>
1 parent ff52bd6 commit 16060b6

2 files changed

Lines changed: 84 additions & 6 deletions

File tree

pkg/runtime/core/core.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,39 @@ import (
2929
// +kubebuilder:rbac:groups=trainer.kubeflow.org,resources=trainingruntimes,verbs=get;list;watch
3030
// +kubebuilder:rbac:groups=trainer.kubeflow.org,resources=clustertrainingruntimes,verbs=get;list;watch
3131

32+
var runtimes map[string]runtime.Runtime
33+
3234
func New(ctx context.Context, client client.Client, indexer client.FieldIndexer, cfg *configapi.Configuration) (map[string]runtime.Runtime, error) {
3335
registry := NewRuntimeRegistry()
34-
runtimes := make(map[string]runtime.Runtime, len(registry))
36+
newRuntimes := make(map[string]runtime.Runtime, len(registry))
3537
for name, registrar := range registry {
3638
for _, dep := range registrar.dependencies {
3739
depRegistrar, depExist := registry[dep]
38-
_, depRegistered := runtimes[dep]
40+
_, depRegistered := newRuntimes[dep]
3941
if depExist && !depRegistered {
4042
r, err := depRegistrar.factory(ctx, client, indexer, cfg)
4143
if err != nil {
4244
return nil, fmt.Errorf("initializing runtime %q on which %q depends: %w", dep, name, err)
4345
}
44-
runtimes[dep] = r
46+
newRuntimes[dep] = r
4547
}
4648
}
47-
if _, ok := runtimes[name]; !ok {
49+
if _, ok := newRuntimes[name]; !ok {
4850
r, err := registrar.factory(ctx, client, indexer, cfg)
4951
if err != nil {
5052
return nil, fmt.Errorf("initializing runtime %q: %w", name, err)
5153
}
52-
runtimes[name] = r
54+
newRuntimes[name] = r
5355
}
5456
}
55-
return runtimes, nil
57+
runtimes = newRuntimes
58+
return newRuntimes, nil
59+
}
60+
61+
func Runtimes() map[string]runtime.Runtime {
62+
runtimesCopy := make(map[string]runtime.Runtime, len(runtimes))
63+
for d, r := range runtimes {
64+
runtimesCopy[d] = r
65+
}
66+
return runtimesCopy
5667
}

pkg/runtime/core/core_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
Copyright 2024 The Kubeflow Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package core
18+
19+
import (
20+
"context"
21+
"testing"
22+
23+
"github.com/google/go-cmp/cmp"
24+
"github.com/google/go-cmp/cmp/cmpopts"
25+
26+
testingutil "github.com/kubeflow/trainer/v2/pkg/util/testing"
27+
)
28+
29+
func TestRuntimes(t *testing.T) {
30+
cases := map[string]struct {
31+
}{
32+
"returns a copy of the runtimes map": {},
33+
}
34+
for name := range cases {
35+
t.Run(name, func(t *testing.T) {
36+
ctx, cancel := context.WithCancel(context.Background())
37+
t.Cleanup(cancel)
38+
clientBuilder := testingutil.NewClientBuilder()
39+
c := clientBuilder.Build()
40+
41+
newRuntimes, err := New(ctx, c, testingutil.AsIndex(clientBuilder), nil)
42+
if err != nil {
43+
t.Fatalf("Failed to initialize runtimes: %v", err)
44+
}
45+
46+
gotRuntimes := Runtimes()
47+
if diff := cmp.Diff(newRuntimes, gotRuntimes, cmpopts.IgnoreUnexported(TrainingRuntime{}, ClusterTrainingRuntime{})); len(diff) != 0 {
48+
t.Errorf("Unexpected difference between new and got runtimes (-want,+got):\n%s", diff)
49+
}
50+
51+
// Verify that modifying the returned map does not affect the persisted runtimes.
52+
gotRuntimes["mutated-key"] = nil
53+
gotRuntimes[TrainingRuntimeGroupKind] = nil
54+
delete(gotRuntimes, ClusterTrainingRuntimeGroupKind)
55+
56+
if _, exists := runtimes["mutated-key"]; exists {
57+
t.Error("Adding a key to Runtimes() return value should not affect the persisted runtimes")
58+
}
59+
if runtimes[TrainingRuntimeGroupKind] == nil {
60+
t.Error("Modifying an existing key in Runtimes() return value should not affect the persisted runtimes")
61+
}
62+
if _, exists := runtimes[ClusterTrainingRuntimeGroupKind]; !exists {
63+
t.Error("Deleting a key from Runtimes() return value should not affect the persisted runtimes")
64+
}
65+
})
66+
}
67+
}

0 commit comments

Comments
 (0)