Skip to content

Commit af7b79b

Browse files
committed
Use jit-cdi modifier for gated modifiers
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 78713e7 commit af7b79b

File tree

1 file changed

+18
-45
lines changed

1 file changed

+18
-45
lines changed

internal/modifier/gated.go

Lines changed: 18 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ import (
2121

2222
"github.com/NVIDIA/nvidia-container-toolkit/api/config/v1"
2323
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
24-
"github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25-
"github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
2624
"github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2725
)
2826

@@ -41,84 +39,59 @@ func (f *Factory) newFeatureGatedModifier() (oci.SpecModifier, error) {
4139
return nil, nil
4240
}
4341

44-
var discoverers []discover.Discover
45-
46-
if f.image.Getenv("NVIDIA_GDS") == "enabled" {
47-
d, err := discover.NewGDSDiscoverer(f.logger, f.driver)
48-
if err != nil {
49-
return nil, fmt.Errorf("failed to construct discoverer for GDS devices: %w", err)
50-
}
51-
discoverers = append(discoverers, d)
52-
}
53-
54-
if f.image.Getenv("NVIDIA_MOFED") == "enabled" {
55-
d, err := discover.NewMOFEDDiscoverer(f.logger, f.driver)
56-
if err != nil {
57-
return nil, fmt.Errorf("failed to construct discoverer for MOFED devices: %w", err)
58-
}
59-
discoverers = append(discoverers, d)
60-
}
61-
62-
if f.image.Getenv("NVIDIA_NVSWITCH") == "enabled" {
63-
d, err := discover.NewNvSwitchDiscoverer(f.logger, f.driver)
64-
if err != nil {
65-
return nil, fmt.Errorf("failed to construct discoverer for NVSWITCH devices: %w", err)
66-
}
67-
discoverers = append(discoverers, d)
68-
}
69-
70-
if f.image.Getenv("NVIDIA_GDRCOPY") == "enabled" {
71-
d, err := discover.NewGDRCopyDiscoverer(f.logger, f.driver)
42+
var modifers list
43+
if gatedDeviceRequests := withUniqueDevices(gatedDevices(*f.image)).DeviceRequests(); len(gatedDeviceRequests) != 0 {
44+
featureGatedModifier, err := f.newAutomaticCDISpecModifier(gatedDeviceRequests)
7245
if err != nil {
73-
return nil, fmt.Errorf("failed to construct discoverer for GDRCopy devices: %w", err)
46+
return nil, err
7447
}
75-
discoverers = append(discoverers, d)
48+
modifers = append(modifers, featureGatedModifier)
7649
}
7750

7851
// If the feature flag has explicitly been toggled, we don't make any modification.
7952
if !f.cfg.Features.DisableCUDACompatLibHook.IsEnabled() {
80-
cudaCompatDiscoverer, err := getCudaCompatModeDiscoverer(f.logger, f.cfg, f.driver, f.hookCreator)
53+
cudaCompatModifer, err := f.getCudaCompatModeModifier()
8154
if err != nil {
8255
return nil, fmt.Errorf("failed to construct CUDA Compat discoverer: %w", err)
8356
}
84-
discoverers = append(discoverers, cudaCompatDiscoverer)
57+
modifers = append(modifers, cudaCompatModifer)
8558
}
8659

87-
return f.newModifierFromDiscoverer(discover.Merge(discoverers...))
60+
return modifers, nil
8861
}
8962

90-
func getCudaCompatModeDiscoverer(logger logger.Interface, cfg *config.Config, driver *root.Driver, hookCreator discover.HookCreator) (discover.Discover, error) {
63+
func (f *Factory) getCudaCompatModeModifier() (oci.SpecModifier, error) {
9164
// We don't support the enable-cuda-compat hook in CSV mode.
92-
if cfg.NVIDIAContainerRuntimeConfig.Mode == "csv" {
65+
if f.cfg.NVIDIAContainerRuntimeConfig.Mode == "csv" {
9366
return nil, nil
9467
}
9568

9669
// For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook.
97-
if cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook {
70+
if f.cfg.NVIDIAContainerRuntimeConfig.Mode == "legacy" && f.cfg.NVIDIAContainerRuntimeConfig.Modes.Legacy.CUDACompatMode != config.CUDACompatModeHook {
9871
return nil, nil
9972
}
10073

101-
version, err := driver.Version()
74+
version, err := f.driver.Version()
10275
if err != nil {
10376
return nil, fmt.Errorf("failed to get driver version: %w", err)
10477
}
10578

106-
compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(logger, hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version})
79+
compatLibHookDiscoverer := discover.NewCUDACompatHookDiscoverer(f.logger, f.hookCreator, &discover.EnableCUDACompatHookOptions{HostDriverVersion: version})
10780
// For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook.
108-
if cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" {
109-
return compatLibHookDiscoverer, nil
81+
if f.cfg.NVIDIAContainerRuntimeConfig.Mode != "legacy" {
82+
return f.newModifierFromDiscoverer(compatLibHookDiscoverer)
11083
}
11184

11285
// For legacy mode, we also need to inject a hook to update the LDCache
11386
// after we have modifed the configuration.
11487
ldcacheUpdateHookDiscoverer, err := discover.NewLDCacheUpdateHook(
115-
logger,
88+
f.logger,
11689
discover.None{},
117-
hookCreator,
90+
f.hookCreator,
11891
)
11992
if err != nil {
12093
return nil, fmt.Errorf("failed to construct ldcache update discoverer: %w", err)
12194
}
12295

123-
return discover.Merge(compatLibHookDiscoverer, ldcacheUpdateHookDiscoverer), nil
96+
return f.newModifierFromDiscoverer(discover.Merge(compatLibHookDiscoverer, ldcacheUpdateHookDiscoverer))
12497
}

0 commit comments

Comments
 (0)