Skip to content

Commit 517859c

Browse files
cdesiniotiselezar
authored andcommitted
Add 'vfio' mode to pkg/nvcdi for generating CDI specs for NVIDIA passthrough GPUs
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent 4091718 commit 517859c

File tree

5 files changed

+267
-0
lines changed

5 files changed

+267
-0
lines changed

pkg/nvcdi/lib-vfio.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/**
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
**/
17+
18+
package nvcdi
19+
20+
import (
21+
"fmt"
22+
"path/filepath"
23+
"strconv"
24+
25+
"tags.cncf.io/container-device-interface/pkg/cdi"
26+
"tags.cncf.io/container-device-interface/specs-go"
27+
)
28+
29+
type vfiolib nvcdilib
30+
31+
type vfioDevice struct {
32+
index int
33+
group int
34+
devRoot string
35+
}
36+
37+
var _ deviceSpecGeneratorFactory = (*vfiolib)(nil)
38+
39+
func (l *vfiolib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
40+
vfioDevices, err := l.getVfioDevices(ids...)
41+
if err != nil {
42+
return nil, err
43+
}
44+
var deviceSpecGenerators DeviceSpecGenerators
45+
for _, vfioDevice := range vfioDevices {
46+
deviceSpecGenerators = append(deviceSpecGenerators, vfioDevice)
47+
}
48+
49+
return deviceSpecGenerators, nil
50+
}
51+
52+
// GetDeviceSpecs returns the CDI device specs for a vfio device.
53+
func (l *vfioDevice) GetDeviceSpecs() ([]specs.Device, error) {
54+
path := fmt.Sprintf("/dev/vfio/%d", l.group)
55+
deviceSpec := specs.Device{
56+
Name: fmt.Sprintf("%d", l.index),
57+
ContainerEdits: specs.ContainerEdits{
58+
DeviceNodes: []*specs.DeviceNode{
59+
{
60+
Path: path,
61+
HostPath: filepath.Join(l.devRoot, path),
62+
},
63+
},
64+
},
65+
}
66+
return []specs.Device{deviceSpec}, nil
67+
}
68+
69+
// GetCommonEdits returns common edits for ALL devices.
70+
// Note, currently there are no common edits.
71+
func (l *vfiolib) GetCommonEdits() (*cdi.ContainerEdits, error) {
72+
e := cdi.ContainerEdits{
73+
ContainerEdits: &specs.ContainerEdits{
74+
DeviceNodes: []*specs.DeviceNode{
75+
{
76+
Path: "/dev/vfio/vfio",
77+
HostPath: filepath.Join(l.devRoot, "/dev/vfio/vfio"),
78+
},
79+
},
80+
},
81+
}
82+
return &e, nil
83+
}
84+
85+
func (l *vfiolib) getVfioDevices(ids ...string) ([]*vfioDevice, error) {
86+
var vfioDevices []*vfioDevice
87+
for _, id := range ids {
88+
if id == "all" {
89+
return l.getAllVfioDevices()
90+
}
91+
index, err := strconv.ParseInt(id, 10, 32)
92+
if err != nil {
93+
return nil, fmt.Errorf("invalid channel ID %v: %w", id, err)
94+
}
95+
i := int(index)
96+
dev, err := l.nvpcilib.GetGPUByIndex(i)
97+
if err != nil {
98+
return nil, fmt.Errorf("failed to get device: %w", err)
99+
}
100+
vfioDevices = append(vfioDevices, &vfioDevice{index: i, group: dev.IommuGroup, devRoot: l.devRoot})
101+
}
102+
103+
return vfioDevices, nil
104+
}
105+
106+
func (l *vfiolib) getAllVfioDevices() ([]*vfioDevice, error) {
107+
devices, err := l.nvpcilib.GetGPUs()
108+
if err != nil {
109+
return nil, fmt.Errorf("failed getting NVIDIA GPUs: %v", err)
110+
}
111+
112+
var vfioDevices []*vfioDevice
113+
for i, dev := range devices {
114+
if dev.Driver != "vfio-pci" {
115+
continue
116+
}
117+
l.logger.Debugf("Found NVIDIA device: address=%s, driver=%s, iommu_group=%d, deviceId=%x",
118+
dev.Address, dev.Driver, dev.IommuGroup, dev.Device)
119+
vfioDevices = append(vfioDevices, &vfioDevice{index: i, group: dev.IommuGroup, devRoot: l.devRoot})
120+
}
121+
return vfioDevices, nil
122+
}

pkg/nvcdi/lib-vfio_test.go

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/**
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
**/
17+
18+
package nvcdi
19+
20+
import (
21+
"bytes"
22+
"testing"
23+
24+
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func TestModeVfio(t *testing.T) {
29+
testCases := []struct {
30+
description string
31+
pcilib *nvpci.InterfaceMock
32+
ids []string
33+
expectedError error
34+
expectedSpec string
35+
}{
36+
{
37+
description: "get all specs single device",
38+
pcilib: &nvpci.InterfaceMock{
39+
GetGPUsFunc: func() ([]*nvpci.NvidiaPCIDevice, error) {
40+
devices := []*nvpci.NvidiaPCIDevice{
41+
{
42+
Driver: "vfio-pci",
43+
IommuGroup: 5,
44+
},
45+
}
46+
return devices, nil
47+
},
48+
},
49+
expectedSpec: `---
50+
cdiVersion: 0.5.0
51+
kind: nvidia.com/pgpu
52+
devices:
53+
- name: "0"
54+
containerEdits:
55+
deviceNodes:
56+
- path: /dev/vfio/5
57+
hostPath: /dev/vfio/5
58+
containerEdits:
59+
env:
60+
- NVIDIA_VISIBLE_DEVICES=void
61+
deviceNodes:
62+
- path: /dev/vfio/vfio
63+
hostPath: /dev/vfio/vfio
64+
`,
65+
},
66+
{
67+
description: "get single device spec by index",
68+
pcilib: &nvpci.InterfaceMock{
69+
GetGPUByIndexFunc: func(n int) (*nvpci.NvidiaPCIDevice, error) {
70+
devices := []*nvpci.NvidiaPCIDevice{
71+
{
72+
Driver: "vfio-pci",
73+
IommuGroup: 45,
74+
},
75+
{
76+
Driver: "vfio-pci",
77+
IommuGroup: 5,
78+
},
79+
}
80+
return devices[n], nil
81+
},
82+
},
83+
ids: []string{"1"},
84+
expectedSpec: `---
85+
cdiVersion: 0.5.0
86+
kind: nvidia.com/pgpu
87+
devices:
88+
- name: "1"
89+
containerEdits:
90+
deviceNodes:
91+
- path: /dev/vfio/5
92+
hostPath: /dev/vfio/5
93+
containerEdits:
94+
env:
95+
- NVIDIA_VISIBLE_DEVICES=void
96+
deviceNodes:
97+
- path: /dev/vfio/vfio
98+
hostPath: /dev/vfio/vfio
99+
`,
100+
},
101+
}
102+
103+
for _, tc := range testCases {
104+
t.Run(tc.description, func(t *testing.T) {
105+
lib, err := New(
106+
WithMode(ModeVfio),
107+
WithPCILib(tc.pcilib),
108+
)
109+
require.NoError(t, err)
110+
111+
spec, err := lib.GetSpec(tc.ids...)
112+
require.EqualValues(t, tc.expectedError, err)
113+
114+
var output bytes.Buffer
115+
116+
_, err = spec.WriteTo(&output)
117+
require.NoError(t, err)
118+
119+
require.Equal(t, tc.expectedSpec, output.String())
120+
})
121+
}
122+
123+
}

pkg/nvcdi/lib.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121

2222
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2323
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
24+
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
2425
"github.com/NVIDIA/go-nvml/pkg/nvml"
2526

2627
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
@@ -54,6 +55,8 @@ type nvcdilib struct {
5455
driver *root.Driver
5556
infolib info.Interface
5657

58+
nvpcilib nvpci.Interface
59+
5760
mergedDeviceOptions []transform.MergedDeviceOption
5861

5962
featureFlags map[FeatureFlag]bool
@@ -140,6 +143,14 @@ func New(opts ...Option) (Interface, error) {
140143
l.class = classImexChannel
141144
}
142145
factory = (*imexlib)(l)
146+
case ModeVfio:
147+
if l.class == "" {
148+
l.class = "pgpu"
149+
}
150+
if l.nvpcilib == nil {
151+
l.nvpcilib = nvpci.New()
152+
}
153+
factory = (*vfiolib)(l)
143154
default:
144155
return nil, fmt.Errorf("unknown mode %q", l.mode)
145156
}

pkg/nvcdi/mode.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ const (
4646
ModeImex = Mode("imex")
4747
// ModeNvswitch configures the CDI spec generator to generate a spec for the available nvswitch devices.
4848
ModeNvswitch = Mode("nvswitch")
49+
// ModeVfio configures the CDI spec generator to generate a VFIO spec.
50+
ModeVfio = Mode("vfio")
4951
)
5052

5153
type modeConstraint interface {
@@ -72,6 +74,7 @@ func getModes() modes {
7274
ModeMofed,
7375
ModeNvml,
7476
ModeNvswitch,
77+
ModeVfio,
7578
ModeWsl,
7679
}
7780
lookup := make(map[Mode]bool)

pkg/nvcdi/options.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package nvcdi
1919
import (
2020
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2121
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
22+
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
2223
"github.com/NVIDIA/go-nvml/pkg/nvml"
2324

2425
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
@@ -43,6 +44,13 @@ func WithInfoLib(infolib info.Interface) Option {
4344
}
4445
}
4546

47+
// WithPCILib sets the PCI library to be used for CDI spec generation.
48+
func WithPCILib(pcilib nvpci.Interface) Option {
49+
return func(l *nvcdilib) {
50+
l.nvpcilib = pcilib
51+
}
52+
}
53+
4654
// WithDeviceNamers sets the device namer for the library
4755
func WithDeviceNamers(namers ...DeviceNamer) Option {
4856
return func(l *nvcdilib) {

0 commit comments

Comments
 (0)