Skip to content

Commit 5699ec1

Browse files
authored
Merge pull request #807 from elezar/add-cdi-imex-channels
Add imex mode to CDI spec generation
2 parents 4fc181a + 8603d60 commit 5699ec1

File tree

12 files changed

+339
-55
lines changed

12 files changed

+339
-55
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,12 @@ func (m command) build() *cli.Command {
104104
Destination: &opts.format,
105105
},
106106
&cli.StringFlag{
107-
Name: "mode",
108-
Aliases: []string{"discovery-mode"},
109-
Usage: "The mode to use when discovering the available entities. One of [auto | nvml | wsl]. If mode is set to 'auto' the mode will be determined based on the system configuration.",
110-
Value: nvcdi.ModeAuto,
107+
Name: "mode",
108+
Aliases: []string{"discovery-mode"},
109+
Usage: "The mode to use when discovering the available entities. " +
110+
"One of [" + strings.Join(nvcdi.AllModes[string](), " | ") + "]. " +
111+
"If mode is set to 'auto' the mode will be determined based on the system configuration.",
112+
Value: string(nvcdi.ModeAuto),
111113
Destination: &opts.mode,
112114
},
113115
&cli.StringFlag{
@@ -184,13 +186,7 @@ func (m command) validateFlags(c *cli.Context, opts *options) error {
184186
}
185187

186188
opts.mode = strings.ToLower(opts.mode)
187-
switch opts.mode {
188-
case nvcdi.ModeAuto:
189-
case nvcdi.ModeCSV:
190-
case nvcdi.ModeNvml:
191-
case nvcdi.ModeWsl:
192-
case nvcdi.ModeManagement:
193-
default:
189+
if !nvcdi.IsValidMode(opts.mode) {
194190
return fmt.Errorf("invalid discovery mode: %v", opts.mode)
195191
}
196192

pkg/nvcdi/api.go

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,6 @@ import (
2424
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
2525
)
2626

27-
const (
28-
// ModeAuto configures the CDI spec generator to automatically detect the system configuration
29-
ModeAuto = "auto"
30-
// ModeNvml configures the CDI spec generator to use the NVML library.
31-
ModeNvml = "nvml"
32-
// ModeWsl configures the CDI spec generator to generate a WSL spec.
33-
ModeWsl = "wsl"
34-
// ModeManagement configures the CDI spec generator to generate a management spec.
35-
ModeManagement = "management"
36-
// ModeGds configures the CDI spec generator to generate a GDS spec.
37-
ModeGds = "gds"
38-
// ModeMofed configures the CDI spec generator to generate a MOFED spec.
39-
ModeMofed = "mofed"
40-
// ModeCSV configures the CDI spec generator to generate a spec based on the contents of CSV
41-
// mountspec files.
42-
ModeCSV = "csv"
43-
)
44-
4527
// Interface defines the API for the nvcdi package
4628
type Interface interface {
4729
GetSpec() (spec.Interface, error)

pkg/nvcdi/lib-imex.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvcdi
18+
19+
import (
20+
"fmt"
21+
"path/filepath"
22+
"strconv"
23+
"strings"
24+
25+
"tags.cncf.io/container-device-interface/pkg/cdi"
26+
"tags.cncf.io/container-device-interface/specs-go"
27+
28+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
29+
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
31+
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
32+
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
33+
)
34+
35+
type imexlib nvcdilib
36+
37+
var _ Interface = (*imexlib)(nil)
38+
39+
const (
40+
classImexChannel = "imex-channel"
41+
)
42+
43+
// GetSpec should not be called for imexlib.
44+
func (l *imexlib) GetSpec() (spec.Interface, error) {
45+
return nil, fmt.Errorf("unexpected call to imexlib.GetSpec()")
46+
}
47+
48+
// GetAllDeviceSpecs returns the device specs for all available devices.
49+
func (l *imexlib) GetAllDeviceSpecs() ([]specs.Device, error) {
50+
channelsDiscoverer := discover.NewCharDeviceDiscoverer(
51+
l.logger,
52+
l.devRoot,
53+
[]string{"/dev/nvidia-caps-imex-channels/channel*"},
54+
)
55+
56+
channels, err := channelsDiscoverer.Devices()
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
var channelIDs []string
62+
for _, channel := range channels {
63+
channelIDs = append(channelIDs, filepath.Base(channel.Path))
64+
}
65+
66+
return l.GetDeviceSpecsByID(channelIDs...)
67+
}
68+
69+
// GetCommonEdits returns an empty set of edits for IMEX devices.
70+
func (l *imexlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
71+
return edits.FromDiscoverer(discover.None{})
72+
}
73+
74+
// GetDeviceSpecsByID returns the CDI device specs for the IMEX channels specified.
75+
func (l *imexlib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
76+
var deviceSpecs []specs.Device
77+
for _, id := range ids {
78+
trimmed := strings.TrimPrefix(id, "channel")
79+
_, err := strconv.ParseUint(trimmed, 10, 64)
80+
if err != nil {
81+
return nil, fmt.Errorf("invalid channel ID %v: %w", id, err)
82+
}
83+
path := "/dev/nvidia-caps-imex-channels/channel" + trimmed
84+
deviceSpec := specs.Device{
85+
Name: trimmed,
86+
ContainerEdits: specs.ContainerEdits{
87+
DeviceNodes: []*specs.DeviceNode{
88+
{
89+
Path: path,
90+
HostPath: filepath.Join(l.devRoot, path),
91+
},
92+
},
93+
},
94+
}
95+
deviceSpecs = append(deviceSpecs, deviceSpec)
96+
}
97+
return deviceSpecs, nil
98+
}
99+
100+
// GetGPUDeviceEdits is unsupported for the imexlib specs
101+
func (l *imexlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
102+
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
103+
}
104+
105+
// GetGPUDeviceSpecs is unsupported for the imexlib specs
106+
func (l *imexlib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
107+
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
108+
}
109+
110+
// GetMIGDeviceEdits is unsupported for the imexlib specs
111+
func (l *imexlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
112+
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
113+
}
114+
115+
// GetMIGDeviceSpecs is unsupported for the imexlib specs
116+
func (l *imexlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
117+
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
118+
}

pkg/nvcdi/lib-imex_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvcdi
18+
19+
import (
20+
"bytes"
21+
"path/filepath"
22+
"strings"
23+
"testing"
24+
25+
testlog "github.com/sirupsen/logrus/hooks/test"
26+
"github.com/stretchr/testify/require"
27+
28+
"github.com/NVIDIA/nvidia-container-toolkit/internal/test"
29+
)
30+
31+
func TestImexMode(t *testing.T) {
32+
t.Setenv("__NVCT_TESTING_DEVICES_ARE_FILES", "true")
33+
34+
logger, _ := testlog.NewNullLogger()
35+
36+
moduleRoot, err := test.GetModuleRoot()
37+
require.NoError(t, err)
38+
hostRoot := filepath.Join(moduleRoot, "testdata", "lookup", "rootfs-1")
39+
40+
expectedSpec := `---
41+
cdiVersion: 0.5.0
42+
containerEdits:
43+
env:
44+
- NVIDIA_VISIBLE_DEVICES=void
45+
devices:
46+
- containerEdits:
47+
deviceNodes:
48+
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel0
49+
path: /dev/nvidia-caps-imex-channels/channel0
50+
name: "0"
51+
- containerEdits:
52+
deviceNodes:
53+
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel1
54+
path: /dev/nvidia-caps-imex-channels/channel1
55+
name: "1"
56+
- containerEdits:
57+
deviceNodes:
58+
- hostPath: {{ .hostRoot }}/dev/nvidia-caps-imex-channels/channel2047
59+
path: /dev/nvidia-caps-imex-channels/channel2047
60+
name: "2047"
61+
kind: nvidia.com/imex-channel
62+
`
63+
expectedSpec = strings.ReplaceAll(expectedSpec, "{{ .hostRoot }}", hostRoot)
64+
65+
lib, err := New(
66+
WithLogger(logger),
67+
WithMode(ModeImex),
68+
WithDriverRoot(hostRoot),
69+
)
70+
require.NoError(t, err)
71+
72+
spec, err := lib.GetSpec()
73+
require.NoError(t, err)
74+
75+
var b bytes.Buffer
76+
77+
_, err = spec.WriteTo(&b)
78+
require.NoError(t, err)
79+
require.Equal(t, expectedSpec, b.String())
80+
}

pkg/nvcdi/lib-nvml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ var _ Interface = (*nvmllib)(nil)
3737

3838
// GetSpec should not be called for nvmllib
3939
func (l *nvmllib) GetSpec() (spec.Interface, error) {
40-
return nil, fmt.Errorf("Unexpected call to nvmllib.GetSpec()")
40+
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
4141
}
4242

4343
// GetAllDeviceSpecs returns the device specs for all available devices.

pkg/nvcdi/lib.go

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ type nvcdilib struct {
4646
logger logger.Interface
4747
nvmllib nvml.Interface
4848
nvsandboxutilslib nvsandboxutils.Interface
49-
mode string
49+
mode Mode
5050
devicelib device.Interface
5151
deviceNamers DeviceNamers
5252
driverRoot string
@@ -161,6 +161,11 @@ func New(opts ...Option) (Interface, error) {
161161
l.class = "mofed"
162162
}
163163
lib = (*mofedlib)(l)
164+
case ModeImex:
165+
if l.class == "" {
166+
l.class = classImexChannel
167+
}
168+
lib = (*imexlib)(l)
164169
default:
165170
return nil, fmt.Errorf("unknown mode %q", l.mode)
166171
}
@@ -206,28 +211,6 @@ func (m *wrapper) GetCommonEdits() (*cdi.ContainerEdits, error) {
206211
return edits, nil
207212
}
208213

209-
// resolveMode resolves the mode for CDI spec generation based on the current system.
210-
func (l *nvcdilib) resolveMode() (rmode string) {
211-
if l.mode != ModeAuto {
212-
return l.mode
213-
}
214-
defer func() {
215-
l.logger.Infof("Auto-detected mode as '%v'", rmode)
216-
}()
217-
218-
platform := l.infolib.ResolvePlatform()
219-
switch platform {
220-
case info.PlatformNVML:
221-
return ModeNvml
222-
case info.PlatformTegra:
223-
return ModeCSV
224-
case info.PlatformWSL:
225-
return ModeWsl
226-
}
227-
l.logger.Warningf("Unsupported platform detected: %v; assuming %v", platform, ModeNvml)
228-
return ModeNvml
229-
}
230-
231214
// getCudaVersion returns the CUDA version of the current system.
232215
func (l *nvcdilib) getCudaVersion() (string, error) {
233216
version, err := l.getCudaVersionNvsandboxutils()

0 commit comments

Comments
 (0)