Skip to content

Commit 70ac1e2

Browse files
authored
Merge pull request #801 from elezar/fix-legacy-nvidia-imex-channels
Fix NVIDIA_IMEX_CHANNELS handling on legacy images
2 parents 1467f3f + f774cee commit 70ac1e2

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

internal/config/image/cuda_image.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,11 @@ func (i CUDA) CDIDevicesFromMounts() []string {
292292

293293
// ImexChannelsFromEnvVar returns the list of IMEX channels requested for the image.
294294
func (i CUDA) ImexChannelsFromEnvVar() []string {
295-
return i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List()
295+
imexChannels := i.DevicesFromEnvvars(EnvVarNvidiaImexChannels).List()
296+
if len(imexChannels) == 1 && imexChannels[0] == "all" {
297+
return nil
298+
}
299+
return imexChannels
296300
}
297301

298302
// ImexChannelsFromMounts returns the list of IMEX channels requested for the image.

internal/config/image/cuda_image_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,37 @@ func TestGetVisibleDevicesFromMounts(t *testing.T) {
203203
}
204204
}
205205

206+
func TestImexChannelsFromEnvVar(t *testing.T) {
207+
testCases := []struct {
208+
description string
209+
env []string
210+
expected []string
211+
}{
212+
{
213+
description: "no imex channels specified",
214+
},
215+
{
216+
description: "imex channel specified",
217+
env: []string{
218+
"NVIDIA_IMEX_CHANNELS=3,4",
219+
},
220+
expected: []string{"3", "4"},
221+
},
222+
}
223+
224+
for _, tc := range testCases {
225+
for id, baseEnvvars := range map[string][]string{"": nil, "legacy": {"CUDA_VERSION=1.2.3"}} {
226+
t.Run(tc.description+id, func(t *testing.T) {
227+
i, err := NewCUDAImageFromEnv(append(baseEnvvars, tc.env...))
228+
require.NoError(t, err)
229+
230+
channels := i.ImexChannelsFromEnvVar()
231+
require.EqualValues(t, tc.expected, channels)
232+
})
233+
}
234+
}
235+
}
236+
206237
func makeTestMounts(paths ...string) []specs.Mount {
207238
var mounts []specs.Mount
208239
for _, path := range paths {

0 commit comments

Comments
 (0)