Skip to content

Commit 7515925

Browse files
committed
[vfio-manage] load vfio-pci module before binding GPUs to it
Signed-off-by: Christopher Desiniotis <cdesiniotis@nvidia.com>
1 parent 0d30fa9 commit 7515925

File tree

4 files changed

+67
-32
lines changed

4 files changed

+67
-32
lines changed

cmd/vfio-manage/bind.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,72 +30,80 @@ import (
3030
type bindCommand struct {
3131
logger *logrus.Logger
3232
nvpciLib nvpci.Interface
33+
options bindOptions
3334
}
3435

3536
type bindOptions struct {
3637
all bool
3738
deviceID string
39+
hostRoot string
3840
}
3941

4042
// newBindCommand constructs a bind command with the specified logger
4143
func newBindCommand(logger *logrus.Logger) *cli.Command {
4244
c := bindCommand{
4345
logger: logger,
44-
nvpciLib: nvpci.New(
45-
nvpci.WithLogger(logger),
46-
),
4746
}
4847
return c.build()
4948
}
5049

5150
// build the bind command
5251
func (m bindCommand) build() *cli.Command {
53-
cfg := bindOptions{}
54-
55-
// Create the 'bind' command
5652
c := cli.Command{
5753
Name: "bind",
5854
Usage: "Bind device(s) to vfio-pci driver",
5955
Before: func(c *cli.Context) error {
60-
return m.validateFlags(&cfg)
56+
return m.validateFlags()
6157
},
6258
Action: func(c *cli.Context) error {
63-
return m.run(&cfg)
59+
return m.run()
6460
},
6561
Flags: []cli.Flag{
6662
&cli.BoolFlag{
6763
Name: "all",
6864
Aliases: []string{"a"},
69-
Destination: &cfg.all,
65+
Destination: &m.options.all,
7066
Usage: "Bind all NVIDIA devices to vfio-pci",
7167
},
7268
&cli.StringFlag{
7369
Name: "device-id",
7470
Aliases: []string{"d"},
75-
Destination: &cfg.deviceID,
71+
Destination: &m.options.deviceID,
7672
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
7773
},
74+
&cli.StringFlag{
75+
Name: "host-root",
76+
Destination: &m.options.hostRoot,
77+
EnvVars: []string{"HOST_ROOT"},
78+
Value: "/",
79+
Usage: "Path to the host's root filesystem. This is used when loading the vfio-pci module.",
80+
},
7881
},
7982
}
8083

8184
return &c
8285
}
8386

84-
func (m bindCommand) validateFlags(cfg *bindOptions) error {
85-
if !cfg.all && cfg.deviceID == "" {
87+
func (m bindCommand) validateFlags() error {
88+
if !m.options.all && m.options.deviceID == "" {
8689
return fmt.Errorf("either --all or --device-id must be specified")
8790
}
8891

89-
if cfg.all && cfg.deviceID != "" {
92+
if m.options.all && m.options.deviceID != "" {
9093
return fmt.Errorf("cannot specify both --all and --device-id")
9194
}
9295

9396
return nil
9497
}
9598

96-
func (m bindCommand) run(cfg *bindOptions) error {
97-
if cfg.deviceID != "" {
98-
return m.bindDevice(cfg.deviceID)
99+
func (m bindCommand) run() error {
100+
m.nvpciLib = nvpci.New(
101+
nvpci.WithLogger(m.logger),
102+
nvpci.WithHostRoot(m.options.hostRoot),
103+
)
104+
105+
if m.options.deviceID != "" {
106+
return m.bindDevice()
99107
}
100108

101109
return m.bindAll()
@@ -118,7 +126,8 @@ func (m bindCommand) bindAll() error {
118126
return nil
119127
}
120128

121-
func (m bindCommand) bindDevice(device string) error {
129+
func (m bindCommand) bindDevice() error {
130+
device := m.options.deviceID
122131
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
123132
if err != nil {
124133
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)

cmd/vfio-manage/unbind.go

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
type unbindCommand struct {
3131
logger *logrus.Logger
3232
nvpciLib nvpci.Interface
33+
options unbindOptions
3334
}
3435

3536
type unbindOptions struct {
@@ -50,29 +51,26 @@ func newUnbindCommand(logger *logrus.Logger) *cli.Command {
5051

5152
// build the unbind command
5253
func (m unbindCommand) build() *cli.Command {
53-
cfg := unbindOptions{}
54-
55-
// Create the 'unbind' command
5654
c := cli.Command{
5755
Name: "unbind",
5856
Usage: "Unbind device(s) from their current driver",
5957
Before: func(c *cli.Context) error {
60-
return m.validateFlags(&cfg)
58+
return m.validateFlags()
6159
},
6260
Action: func(c *cli.Context) error {
63-
return m.run(&cfg)
61+
return m.run()
6462
},
6563
Flags: []cli.Flag{
6664
&cli.BoolFlag{
6765
Name: "all",
6866
Aliases: []string{"a"},
69-
Destination: &cfg.all,
67+
Destination: &m.options.all,
7068
Usage: "Bind all NVIDIA devices to vfio-pci",
7169
},
7270
&cli.StringFlag{
7371
Name: "device-id",
7472
Aliases: []string{"d"},
75-
Destination: &cfg.deviceID,
73+
Destination: &m.options.deviceID,
7674
Usage: "Specific device ID to bind (e.g., 0000:01:00.0)",
7775
},
7876
},
@@ -81,21 +79,21 @@ func (m unbindCommand) build() *cli.Command {
8179
return &c
8280
}
8381

84-
func (m unbindCommand) validateFlags(cfg *unbindOptions) error {
85-
if !cfg.all && cfg.deviceID == "" {
82+
func (m unbindCommand) validateFlags() error {
83+
if !m.options.all && m.options.deviceID == "" {
8684
return fmt.Errorf("either --all or --device-id must be specified")
8785
}
8886

89-
if cfg.all && cfg.deviceID != "" {
87+
if m.options.all && m.options.deviceID != "" {
9088
return fmt.Errorf("cannot specify both --all and --device-id")
9189
}
9290

9391
return nil
9492
}
9593

96-
func (m unbindCommand) run(cfg *unbindOptions) error {
97-
if cfg.deviceID != "" {
98-
return m.unbindDevice(cfg.deviceID)
94+
func (m unbindCommand) run() error {
95+
if m.options.deviceID != "" {
96+
return m.unbindDevice()
9997
}
10098

10199
return m.unbindAll()
@@ -117,7 +115,8 @@ func (m unbindCommand) unbindAll() error {
117115
return nil
118116
}
119117

120-
func (m unbindCommand) unbindDevice(device string) error {
118+
func (m unbindCommand) unbindDevice() error {
119+
device := m.options.deviceID
121120
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
122121
if err != nil {
123122
return fmt.Errorf("failed to get NVIDIA GPU device: %w", err)

internal/linuxutils/kmod.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"bufio"
2121
"fmt"
2222
"os"
23+
"os/exec"
2324
"path/filepath"
2425
"strconv"
2526
"strings"
@@ -44,6 +45,9 @@ func NewKernelModules(log *logrus.Logger, options ...func(modules *KernelModules
4445
for _, option := range options {
4546
option(km)
4647
}
48+
if km.root == "" {
49+
km.root = "/"
50+
}
4751
return km
4852
}
4953

@@ -105,3 +109,8 @@ func (km *KernelModules) List(searchKey string) error {
105109
}
106110
return nil
107111
}
112+
113+
func (km *KernelModules) Load(module string) error {
114+
cmd := exec.Command("chroot", km.root, "modprobe", module)
115+
return cmd.Run()
116+
}

internal/nvpci/nvpci.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424

2525
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
2626
"github.com/sirupsen/logrus"
27+
28+
"github.com/NVIDIA/k8s-driver-manager/internal/linuxutils"
2729
)
2830

2931
const (
@@ -43,7 +45,8 @@ type Interface interface {
4345

4446
type nvpciWrapper struct {
4547
nvpci.Interface
46-
logger *logrus.Logger
48+
logger *logrus.Logger
49+
hostRoot string
4750
}
4851

4952
type nvidiaPCIDevice struct {
@@ -64,6 +67,9 @@ func New(opts ...Option) Interface {
6467
if n.logger == nil {
6568
n.logger = logrus.New()
6669
}
70+
if n.hostRoot == "" {
71+
n.hostRoot = "/"
72+
}
6773

6874
// (cdesiniotis) Create an identical logger for the underlying nvpci library,
6975
// with the exception being the log level. Currently, the nvpci library
@@ -92,6 +98,13 @@ func WithLogger(logger *logrus.Logger) Option {
9298
}
9399
}
94100

101+
// WithHostRoot provides an Option to set the path to the host root filesystem
102+
func WithHostRoot(hostRoot string) Option {
103+
return func(w *nvpciWrapper) {
104+
w.hostRoot = hostRoot
105+
}
106+
}
107+
95108
// (cdesiniotis) ideally this method would be attached to the nvcpi.NvidiaPCIDevice struct
96109
// which removes the need for this wrapper
97110
func (w *nvpciWrapper) BindToVFIODriver(dev *nvpci.NvidiaPCIDevice) error {
@@ -112,6 +125,11 @@ func (w *nvpciWrapper) bindToVFIODriver(device *nvidiaPCIDevice) error {
112125
return fmt.Errorf("failed to find best vfio variant driver: %w", err)
113126
}
114127

128+
km := linuxutils.NewKernelModules(w.logger, linuxutils.WithRoot(w.hostRoot))
129+
if err := km.Load(vfioDriverName); err != nil {
130+
return fmt.Errorf("failed to load %q driver: %w", vfioDriverName, err)
131+
}
132+
115133
// (cdesiniotis) Module names in the modules.alias file will only ever contain
116134
// underscores characters and not dashes -- this aligns with how the linux kernel
117135
// stores module names internally. This can sometimes differ from the name of the

0 commit comments

Comments
 (0)