diff --git a/cmd/main.go b/cmd/main.go index eb260fde..a1c77fa7 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -35,5 +35,6 @@ import ( func main() { device_plugin.PGPUAlias = os.Getenv("P_GPU_ALIAS") + device_plugin.NVSwitchAlias = os.Getenv("NVSWITCH_ALIAS") device_plugin.InitiateDevicePlugin() } diff --git a/go.sum b/go.sum index 011a060a..84e298c6 100644 --- a/go.sum +++ b/go.sum @@ -293,10 +293,28 @@ github.com/onsi/ginkgo/v2 v2.22.2 h1:/3X8Panh8/WwhU/3Ssa6rCKqPLuAkVY2I0RoyDLySlU github.com/onsi/ginkgo/v2 v2.22.2/go.mod h1:oeMosUL+8LtarXBHu/c0bx2D/K9zyQ6uX3cTyztHwsk= github.com/onsi/gomega v1.36.2 h1:koNYke6TVk6ZmnyHrCXba/T/MoLBXFjeC1PtvYgw0A8= github.com/onsi/gomega v1.36.2/go.mod h1:DdwyADRjrc825LhMEkD76cHR5+pUnjhUN8GlHlRPHzY= +github.com/opencontainers/runtime-spec v1.3.0 h1:YZupQUdctfhpZy3TM39nN9Ika5CBWT5diQ8ibYCRkxg= +github.com/opencontainers/runtime-spec v1.3.0/go.mod h1:jwyrGlmzljRJv/Fgzds9SsS/C5hL+LL3ko9hs6T5lQ0= +github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 h1:tAKu3NkKWZYpqBSOJKwTxT1wIGueiF7gcmcNgr5pNTY= +github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116/go.mod h1:DKDEfzxvRkoQ6n9TGhxQgg2IM1lY4aM0eaQP4e3oElw= +github.com/opencontainers/selinux v1.10.0 h1:rAiKF8hTcgLI3w0DHm6i0ylVVcOrlgR1kK99DRLDhyU= +github.com/opencontainers/selinux v1.10.0/go.mod h1:2i0OySw99QjzBBQByd1Gr9gSjvuho1lHsJxIJ3gGbJI= +github.com/otiai10/copy v1.2.0/go.mod h1:rrF5dJ5F0t/EWSYODDu4j9/vEeYHMkc8jt0zJChqQWw= +github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= +github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= +github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= +github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= +github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= +github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/pelletier/go-toml v1.9.5 h1:4yBQzkHv+7BHq2PQUZF3Mx0IYxG7LsP222s7Agd3ve8= +github.com/pelletier/go-toml v1.9.5/go.mod h1:u1nR/EPcESfeI/szUZKdtJ0xRNbUoANCkoOuaOx1Y+c= +github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= +github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/polyfloyd/go-errorlint v1.7.1 h1:RyLVXIbosq1gBdk/pChWA8zWYLsq9UEw7a1L5TVMCnA= github.com/polyfloyd/go-errorlint v1.7.1/go.mod h1:aXjNb1x2TNhoLsk26iv1yl7a+zTnXPhwEMtEXukiLR8= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= diff --git a/pkg/device_plugin/cdi.go b/pkg/device_plugin/cdi.go index 9c8b21a4..2633dc0d 100644 --- a/pkg/device_plugin/cdi.go +++ b/pkg/device_plugin/cdi.go @@ -68,7 +68,11 @@ func GenerateCDISpec() error { // Generate NVSwitch CDI spec if we have NVSwitch devices if len(nvSwitchDeviceIDs) > 0 { - if err := generateCDISpecForClass(cdiNVSwitchClass, true); err != nil { + nvSwitchClass := cdiNVSwitchClass + if NVSwitchAlias != "" { + nvSwitchClass = NVSwitchAlias + } + if err := generateCDISpecForClass(nvSwitchClass, true); err != nil { return fmt.Errorf("failed to generate NVSwitch CDI spec: %w", err) } } diff --git a/pkg/device_plugin/device_plugin.go b/pkg/device_plugin/device_plugin.go index 16f6df1c..4313cae4 100644 --- a/pkg/device_plugin/device_plugin.go +++ b/pkg/device_plugin/device_plugin.go @@ -64,6 +64,7 @@ var nvpciLib nvpci.Interface var startDevicePlugin = startDevicePluginFunc var stop = make(chan struct{}) var PGPUAlias string +var NVSwitchAlias string func InitiateDevicePlugin() { // Initialize nvpci library if not already set (allows injection for testing) @@ -99,14 +100,15 @@ func createDevicePlugins() { }) } - // Determine device name - NVSwitches always use their actual name, - // GPUs can use PGPUAlias if set + // Determine device name - use alias if set, otherwise use actual device name var deviceName string if isNVSwitchDeviceID(deviceID) { - // NVSwitches always use their actual device name - deviceName = getDeviceNameForID(deviceID) + if NVSwitchAlias != "" { + deviceName = NVSwitchAlias + } else { + deviceName = getDeviceNameForID(deviceID) + } } else if PGPUAlias != "" { - // GPUs can use the alias deviceName = PGPUAlias } else { deviceName = getDeviceNameForID(deviceID) diff --git a/pkg/device_plugin/device_plugin_test.go b/pkg/device_plugin/device_plugin_test.go index b20f65f1..c012ded5 100644 --- a/pkg/device_plugin/device_plugin_test.go +++ b/pkg/device_plugin/device_plugin_test.go @@ -286,5 +286,26 @@ var _ = Describe("Device Plugin", func() { result := getDeviceNameForID("abcd") Expect(result).To(Equal("")) }) + + It("formats device names with special characters through formatDeviceName", func() { + iommuMap = map[string][]NvidiaPCIDevice{ + "1": { + { + Address: "0000:01:00.0", + DeviceID: 0x2330, + DeviceName: "NVIDIA H100 PCIe [Hopper]", + IommuGroup: 1, + }, + }, + } + result := getDeviceNameForID("2330") + Expect(result).To(Equal("NVIDIA_H100_PCIE_HOPPER")) + }) + + It("returns empty string when iommuMap is empty", func() { + iommuMap = map[string][]NvidiaPCIDevice{} + result := getDeviceNameForID("1b80") + Expect(result).To(Equal("")) + }) }) }) diff --git a/vendor/modules.txt b/vendor/modules.txt index fa60a4a0..4591e53d 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -643,6 +643,14 @@ github.com/onsi/gomega/matchers/support/goraph/edge github.com/onsi/gomega/matchers/support/goraph/node github.com/onsi/gomega/matchers/support/goraph/util github.com/onsi/gomega/types +# github.com/opencontainers/runtime-spec v1.3.0 +## explicit +github.com/opencontainers/runtime-spec/specs-go +# github.com/opencontainers/runtime-tools v0.9.1-0.20251114084447-edf4cb3d2116 +## explicit; go 1.21 +github.com/opencontainers/runtime-tools/generate +github.com/opencontainers/runtime-tools/generate/seccomp +github.com/opencontainers/runtime-tools/validate/capabilities # github.com/pelletier/go-toml v1.9.5 ## explicit; go 1.12 github.com/pelletier/go-toml