Skip to content

Commit 722b541

Browse files
authored
Merge pull request #1667 from elezar/unify-csv-and-jit-cdi
Use automatic CDI spec generation to generate CDI specs for other modifiers
2 parents 88da3d2 + af7b79b commit 722b541

File tree

9 files changed

+214
-106
lines changed

9 files changed

+214
-106
lines changed

internal/devices/devices.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,11 @@ var assertCharDeviceStub = func(path string) error {
5353
}
5454
return nil
5555
}
56+
57+
func IsOverrideApplied() bool {
58+
return isOverrideAppliedStub()
59+
}
60+
61+
var isOverrideAppliedStub = func() bool {
62+
return false
63+
}

internal/devices/devices_mock.go

Lines changed: 39 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/devices/devices_tests.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
type Interface interface {
3030
DeviceFromPath(string, string) (*Device, error)
3131
AssertCharDevice(string) error
32+
IsOverrideApplied() bool
3233
}
3334

3435
type testDefaults struct{}
@@ -47,6 +48,7 @@ func SetInterfaceForTests(m Interface) func() {
4748
funcs := []func(){
4849
SetDeviceFromPathForTest(m.DeviceFromPath),
4950
SetAssertCharDeviceForTest(m.AssertCharDevice),
51+
SetIsOverrideAppliedForTest(m.IsOverrideApplied),
5052
}
5153
return func() {
5254
for _, f := range funcs {
@@ -71,6 +73,14 @@ func SetAssertCharDeviceForTest(testFunc func(string) error) func() {
7173
}
7274
}
7375

76+
func SetIsOverrideAppliedForTest(testFunc func() bool) func() {
77+
current := isOverrideAppliedStub
78+
isOverrideAppliedStub = testFunc
79+
return func() {
80+
isOverrideAppliedStub = current
81+
}
82+
}
83+
7484
type testDevice struct {
7585
Device
7686
}
@@ -115,3 +125,7 @@ func (t *testDefaults) AssertCharDevice(path string) error {
115125

116126
return nil
117127
}
128+
129+
func (t *testDefaults) IsOverrideApplied() bool {
130+
return true
131+
}

internal/modifier/cdi.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
2828
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2929
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
30+
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
3031
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3132
)
3233

@@ -180,6 +181,14 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
180181
nvcdiFeatureFlags = append(nvcdiFeatureFlags, nvcdi.FeatureNoAdditionalGIDsForDeviceNodes)
181182
}
182183

184+
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
185+
if err != nil {
186+
f.logger.Warningf("Failed to get the list of CSV files: %v", err)
187+
}
188+
if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
189+
csvFiles = csv.BaseFilesOnly(csvFiles)
190+
}
191+
183192
cdiModeIdentifiers := cdiModeIdentfiersFromDevices(devices...)
184193
f.logger.Debugf("Per-mode identifiers: %v", cdiModeIdentifiers)
185194
var modifiers oci.SpecModifiers
@@ -194,6 +203,7 @@ func (f *Factory) newAutomaticCDISpecModifier(devices []string) (oci.SpecModifie
194203
nvcdi.WithMode(mode),
195204
nvcdi.WithFeatureFlags(nvcdiFeatureFlags...),
196205
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
206+
nvcdi.WithCSVFiles(csvFiles),
197207
)
198208
if err != nil {
199209
return nil, fmt.Errorf("failed to construct CDI library for mode %q: %w", mode, err)

internal/modifier/cdi/spec.go

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,29 +20,51 @@ import (
2020
"fmt"
2121

2222
"github.com/opencontainers/runtime-spec/specs-go"
23-
"tags.cncf.io/container-device-interface/pkg/cdi"
23+
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
24+
cdi "tags.cncf.io/container-device-interface/specs-go"
2425

26+
"github.com/NVIDIA/nvidia-container-toolkit/internal/devices"
2527
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2628
)
2729

2830
// fromCDISpec represents the modifications performed from a raw CDI spec.
2931
type fromCDISpec struct {
30-
cdiSpec *cdi.Spec
32+
cdiSpec *cdiapi.Spec
3133
}
3234

3335
var _ oci.SpecModifier = (*fromCDISpec)(nil)
3436

3537
// Modify applies the mofiications defined by the raw CDI spec to the incomming OCI spec.
3638
func (m fromCDISpec) Modify(spec *specs.Spec) error {
3739
for _, device := range m.cdiSpec.Devices {
38-
device := device
39-
cdiDevice := cdi.Device{
40+
device := m.enrichDevice(device)
41+
cdiDevice := cdiapi.Device{
4042
Device: &device,
4143
}
4244
if err := cdiDevice.ApplyEdits(spec); err != nil {
43-
return fmt.Errorf("failed to apply edits for device %q: %v", cdiDevice.GetQualifiedName(), err)
45+
return fmt.Errorf("failed to apply edits for device %q: %v", m.cdiSpec.Kind+"="+device.Name, err)
4446
}
4547
}
4648

4749
return m.cdiSpec.ApplyEdits(spec)
4850
}
51+
52+
func (m fromCDISpec) enrichDevice(device cdi.Device) cdi.Device {
53+
if !devices.IsOverrideApplied() {
54+
return device
55+
}
56+
// For testing we need to override the device node information to ensure
57+
// that we don't trigger the CDI modification that requires the device node
58+
// to exist and be a character device.
59+
// The following condition is used to determine whether a failure to get
60+
// the info is fatal:
61+
// hasMinimalSpecification := d.Type != "" && (d.Major != 0 || d.Type == fifoDevice)
62+
for i, dn := range device.ContainerEdits.DeviceNodes {
63+
dn.Type = "c"
64+
if dn.Major == 0 {
65+
dn.Major = 99
66+
}
67+
device.ContainerEdits.DeviceNodes[i] = dn
68+
}
69+
return device
70+
}

internal/modifier/csv.go

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,14 @@ import (
2222
"github.com/NVIDIA/nvidia-container-toolkit/internal/config/image"
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/cuda"
2424
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25-
"github.com/NVIDIA/nvidia-container-toolkit/internal/modifier/cdi"
2625
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
27-
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra/csv"
2826
"github.com/NVIDIA/nvidia-container-toolkit/internal/requirements"
29-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi"
3027
)
3128

3229
// newCSVModifier creates a modifier that applies modications to an OCI spec if required by the runtime wrapper.
3330
// The modifications are defined by CSV MountSpecs.
3431
func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
35-
devices := f.image.VisibleDevices()
32+
devices := withUniqueDevices(csvDevices(*f.image)).DeviceRequests()
3633
if len(devices) == 0 {
3734
f.logger.Infof("No modification required; no devices requested")
3835
return nil, nil
@@ -43,37 +40,7 @@ func (f *Factory) newCSVModifier() (oci.SpecModifier, error) {
4340
return nil, fmt.Errorf("requirements not met: %v", err)
4441
}
4542

46-
csvFiles, err := csv.GetFileList(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.MountSpecPath)
47-
if err != nil {
48-
return nil, fmt.Errorf("failed to get list of CSV files: %v", err)
49-
}
50-
51-
if f.image.Getenv(image.EnvVarNvidiaRequireJetpack) != "csv-mounts=all" {
52-
csvFiles = csv.BaseFilesOnly(csvFiles)
53-
}
54-
55-
cdilib, err := nvcdi.New(
56-
nvcdi.WithLogger(f.logger),
57-
nvcdi.WithDriverRoot(f.driver.Root),
58-
nvcdi.WithDevRoot(f.driver.DevRoot),
59-
nvcdi.WithNVIDIACDIHookPath(f.cfg.NVIDIACTKConfig.Path),
60-
nvcdi.WithMode(nvcdi.ModeCSV),
61-
nvcdi.WithCSVFiles(csvFiles),
62-
nvcdi.WithCSVCompatContainerRoot(f.cfg.NVIDIAContainerRuntimeConfig.Modes.CSV.CompatContainerRoot),
63-
)
64-
if err != nil {
65-
return nil, fmt.Errorf("failed to construct CDI library: %v", err)
66-
}
67-
68-
spec, err := cdilib.GetSpec(devices...)
69-
if err != nil {
70-
return nil, fmt.Errorf("failed to get CDI spec: %v", err)
71-
}
72-
73-
return cdi.New(
74-
cdi.WithLogger(f.logger),
75-
cdi.WithSpec(spec.Raw()),
76-
)
43+
return f.newAutomaticCDISpecModifier(devices)
7744
}
7845

7946
func checkRequirements(logger logger.Interface, image *image.CUDA) error {
@@ -107,3 +74,14 @@ func checkRequirements(logger logger.Interface, image *image.CUDA) error {
10774

10875
return r.Assert()
10976
}
77+
78+
type csvDevices image.CUDA
79+
80+
func (d csvDevices) DeviceRequests() []string {
81+
var devices []string
82+
i := (image.CUDA)(d)
83+
for _, deviceID := range i.VisibleDevices() {
84+
devices = append(devices, "mode=csv,id="+deviceID)
85+
}
86+
return devices
87+
}

0 commit comments

Comments
 (0)