Skip to content

Commit 78713e7

Browse files
committed
refactor: Use automatic CDI modifier for CSV
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 6b6914f commit 78713e7

File tree

2 files changed

+23
-35
lines changed

2 files changed

+23
-35
lines changed

internal/modifier/cdi.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2828
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2929
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
3031
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3132
)
3233

@@ -180,6 +181,14 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
180181
nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes)
181182
}
182183

184+
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
185+
if err != nil {
186+
f.logger.Warningf("Failed to get the list of CSV files: %v", err)
187+
}
188+
if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
189+
csvFiles = csv.BaseFilesOnly(csvFiles)
190+
}
191+
183192
cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...)
184193
f.logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers)
185194
var modifiers oci.SpecModifiers
@@ -194,6 +203,7 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
194203
nvcdi.WithMode(mode),
195204
nvcdi.WithFeatureFlags(nvcdiFeatureFlags...),
196205
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
206+
nvcdi.WithCSVFiles(csvFiles),
197207
)
198208
if err != nil {
199209
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)

internal/modifier/csv.go

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,14 @@ import (
2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25-
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
27-
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
2826
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
29-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3027
)
3128

3229
// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
3330
// The modifications are defined by CSV MountSpecs.
3431
func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
35-
devices := f.image.VisibleDevices()
32+
devices := withUniqueDevices(csvDevices(*f.image)).DeviceRequests()
3633
if len(devices) == 0 {
3734
f.logger.Infof("No modification required; no devices requested")
3835
return nil, nil
@@ -43,37 +40,7 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
4340
return nil, fmt.Errorf("requirements not met: %v", err)
4441
}
4542

46-
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
47-
if err != nil {
48-
return nil, fmt.Errorf("failed to get list of CSV files: %v", err)
49-
}
50-
51-
if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
52-
csvFiles = csv.BaseFilesOnly(csvFiles)
53-
}
54-
55-
cdilib, err := nvcdi.New(
56-
nvcdi.WithLogger(f.logger),
57-
nvcdi.WithDriverRoot(f.driver.Root),
58-
nvcdi.WithDevRoot(f.driver.DevRoot),
59-
nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path),
60-
nvcdi.WithMode(nvcdi.ModeCSV),
61-
nvcdi.WithCSVFiles(csvFiles),
62-
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
63-
)
64-
if err != nil {
65-
return nil, fmt.Errorf("failed to construct CDI library: %v", err)
66-
}
67-
68-
spec, err := cdilib.GetSpec(devices...)
69-
if err != nil {
70-
return nil, fmt.Errorf("failed to get CDI spec: %v", err)
71-
}
72-
73-
return cdi.New(
74-
cdi.WithLogger(f.logger),
75-
cdi.WithSpec(spec.Raw()),
76-
)
43+
return f.newAutomaticCDISpecModifier(devices)
7744
}
7845

7946
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
@@ -107,3 +74,14 @@ func checkRequirements(logger logger.Interface, image *image.CUDA) error {
10774

10875
return r.Assert()
10976
}
77+
78+
type csvDevices image.CUDA
79+
80+
func (d csvDevices) DeviceRequests() []string {
81+
var devices []string
82+
i := (image.CUDA)(d)
83+
for _, deviceID := range i.VisibleDevices() {
84+
devices = append(devices, "mode=csv,id="+deviceID)
85+
}
86+
return devices
87+
}

0 commit comments

Comments
 (0)