@@ -21,8 +21,6 @@ import (
2121
2222 "github.com/NVIDIA/nvidia-container-toolkit/api/config/v1"
2323 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
24- "github.com/NVIDIA/nvidia-container-toolkit/internal/logger"
25- "github.com/NVIDIA/nvidia-container-toolkit/internal/lookup/root"
2624 "github.com/NVIDIA/nvidia-container-toolkit/internal/oci"
2725)
2826
@@ -41,84 +39,59 @@ func (f *Factory) newFeatureGatedModifier() (oci.SpecModifier, error) {
4139 return nil , nil
4240 }
4341
44- var discoverers []discover.Discover
45-
46- if f .image .Getenv ("NVIDIA_GDS" ) == "enabled" {
47- d , err := discover .NewGDSDiscoverer (f .logger , f .driver )
48- if err != nil {
49- return nil , fmt .Errorf ("failed to construct discoverer for GDS devices: %w" , err )
50- }
51- discoverers = append (discoverers , d )
52- }
53-
54- if f .image .Getenv ("NVIDIA_MOFED" ) == "enabled" {
55- d , err := discover .NewMOFEDDiscoverer (f .logger , f .driver )
56- if err != nil {
57- return nil , fmt .Errorf ("failed to construct discoverer for MOFED devices: %w" , err )
58- }
59- discoverers = append (discoverers , d )
60- }
61-
62- if f .image .Getenv ("NVIDIA_NVSWITCH" ) == "enabled" {
63- d , err := discover .NewNvSwitchDiscoverer (f .logger , f .driver )
64- if err != nil {
65- return nil , fmt .Errorf ("failed to construct discoverer for NVSWITCH devices: %w" , err )
66- }
67- discoverers = append (discoverers , d )
68- }
69-
70- if f .image .Getenv ("NVIDIA_GDRCOPY" ) == "enabled" {
71- d , err := discover .NewGDRCopyDiscoverer (f .logger , f .driver )
42+ var modifers list
43+ if gatedDeviceRequests := withUniqueDevices (gatedDevices (* f .image )).DeviceRequests (); len (gatedDeviceRequests ) != 0 {
44+ featureGatedModifier , err := f .newAutomaticCDISpecModifier (gatedDeviceRequests )
7245 if err != nil {
73- return nil , fmt . Errorf ( "failed to construct discoverer for GDRCopy devices: %w" , err )
46+ return nil , err
7447 }
75- discoverers = append (discoverers , d )
48+ modifers = append (modifers , featureGatedModifier )
7649 }
7750
7851 // If the feature flag has explicitly been toggled, we don't make any modification.
7952 if ! f .cfg .Features .DisableCUDACompatLibHook .IsEnabled () {
80- cudaCompatDiscoverer , err := getCudaCompatModeDiscoverer ( f . logger , f . cfg , f . driver , f . hookCreator )
53+ cudaCompatModifer , err := f . getCudaCompatModeModifier ( )
8154 if err != nil {
8255 return nil , fmt .Errorf ("failed to construct CUDA Compat discoverer: %w" , err )
8356 }
84- discoverers = append (discoverers , cudaCompatDiscoverer )
57+ modifers = append (modifers , cudaCompatModifer )
8558 }
8659
87- return f . newModifierFromDiscoverer ( discover . Merge ( discoverers ... ))
60+ return modifers , nil
8861}
8962
90- func getCudaCompatModeDiscoverer ( logger logger. Interface , cfg * config. Config , driver * root. Driver , hookCreator discover. HookCreator ) (discover. Discover , error ) {
63+ func ( f * Factory ) getCudaCompatModeModifier ( ) (oci. SpecModifier , error ) {
9164 // We don't support the enable-cuda-compat hook in CSV mode.
92- if cfg .NVIDIAContainerRuntimeConfig .Mode == "csv" {
65+ if f . cfg .NVIDIAContainerRuntimeConfig .Mode == "csv" {
9366 return nil , nil
9467 }
9568
9669 // For legacy mode, we only include the enable-cuda-compat hook if cuda-compat-mode is set to hook.
97- if cfg .NVIDIAContainerRuntimeConfig .Mode == "legacy" && cfg .NVIDIAContainerRuntimeConfig .Modes .Legacy .CUDACompatMode != config .CUDACompatModeHook {
70+ if f . cfg .NVIDIAContainerRuntimeConfig .Mode == "legacy" && f . cfg .NVIDIAContainerRuntimeConfig .Modes .Legacy .CUDACompatMode != config .CUDACompatModeHook {
9871 return nil , nil
9972 }
10073
101- version , err := driver .Version ()
74+ version , err := f . driver .Version ()
10275 if err != nil {
10376 return nil , fmt .Errorf ("failed to get driver version: %w" , err )
10477 }
10578
106- compatLibHookDiscoverer := discover .NewCUDACompatHookDiscoverer (logger , hookCreator , & discover.EnableCUDACompatHookOptions {HostDriverVersion : version })
79+ compatLibHookDiscoverer := discover .NewCUDACompatHookDiscoverer (f . logger , f . hookCreator , & discover.EnableCUDACompatHookOptions {HostDriverVersion : version })
10780 // For non-legacy modes we return the hook as is. These modes *should* already include the update-ldcache hook.
108- if cfg .NVIDIAContainerRuntimeConfig .Mode != "legacy" {
109- return compatLibHookDiscoverer , nil
81+ if f . cfg .NVIDIAContainerRuntimeConfig .Mode != "legacy" {
82+ return f . newModifierFromDiscoverer ( compatLibHookDiscoverer )
11083 }
11184
11285 // For legacy mode, we also need to inject a hook to update the LDCache
11386 // after we have modifed the configuration.
11487 ldcacheUpdateHookDiscoverer , err := discover .NewLDCacheUpdateHook (
115- logger ,
88+ f . logger ,
11689 discover.None {},
117- hookCreator ,
90+ f . hookCreator ,
11891 )
11992 if err != nil {
12093 return nil , fmt .Errorf ("failed to construct ldcache update discoverer: %w" , err )
12194 }
12295
123- return discover .Merge (compatLibHookDiscoverer , ldcacheUpdateHookDiscoverer ), nil
96+ return f . newModifierFromDiscoverer ( discover .Merge (compatLibHookDiscoverer , ldcacheUpdateHookDiscoverer ))
12497}
0 commit comments