Skip to content

Commit fd043d8

Browse files
authored
Merge pull request #155 from NVIDIA/refactor-vfio-code
Move internal/nvpci to internal/nvpassthrough
2 parents 32de8a3 + c3374ca commit fd043d8

File tree

5 files changed

+166
-168
lines changed

5 files changed

+166
-168
lines changed

cmd/vfio-manage/bind.go

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ package main
2121
import (
2222
"fmt"
2323

24+
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
2425
"github.com/sirupsen/logrus"
2526
"github.com/urfave/cli/v2"
2627

27-
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
28+
"github.com/NVIDIA/k8s-driver-manager/internal/nvpassthrough"
2829
)
2930

3031
type bindCommand struct {
31-
logger *logrus.Logger
32-
nvpciLib nvpci.Interface
33-
options bindOptions
32+
logger *logrus.Logger
33+
nvpci nvpci.Interface
34+
nvpassthrough nvpassthrough.Interface
35+
options bindOptions
3436
}
3537

3638
type bindOptions struct {
@@ -104,9 +106,13 @@ func (m bindCommand) validateFlags() error {
104106
}
105107

106108
func (m bindCommand) run() error {
107-
m.nvpciLib = nvpci.New(
109+
m.nvpci = nvpci.New(
108110
nvpci.WithLogger(m.logger),
109-
nvpci.WithHostRoot(m.options.hostRoot),
111+
)
112+
113+
m.nvpassthrough = nvpassthrough.New(
114+
nvpassthrough.WithLogger(m.logger),
115+
nvpassthrough.WithHostRoot(m.options.hostRoot),
110116
)
111117

112118
if m.options.deviceID != "" {
@@ -117,13 +123,13 @@ func (m bindCommand) run() error {
117123
}
118124

119125
func (m bindCommand) bindAll() error {
120-
devices, err := m.nvpciLib.GetGPUs()
126+
devices, err := m.nvpci.GetGPUs()
121127
if err != nil {
122128
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
123129
}
124130

125131
if m.options.bindNVSwitches {
126-
nvswitches, err := m.nvpciLib.GetNVSwitches()
132+
nvswitches, err := m.nvpci.GetNVSwitches()
127133
if err != nil {
128134
return fmt.Errorf("failed to get NVIDIA NVSwitches: %w", err)
129135
}
@@ -132,8 +138,7 @@ func (m bindCommand) bindAll() error {
132138

133139
for _, dev := range devices {
134140
m.logger.Infof("Binding device %s", dev.Address)
135-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
136-
if err := m.nvpciLib.BindToVFIODriver(dev); err != nil {
141+
if err := m.nvpassthrough.BindToVFIODriver(dev); err != nil {
137142
m.logger.Warnf("Failed to bind device %s: %v", dev.Address, err)
138143
}
139144
}
@@ -145,7 +150,7 @@ func (m bindCommand) bindDevice() error {
145150
device := m.options.deviceID
146151
// Note: Despite its name, GetGPUByPciBusID returns any NVIDIA PCI device
147152
// (GPU, NVSwitch, etc.) at the specified address, not just GPUs.
148-
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
153+
nvdev, err := m.nvpci.GetGPUByPciBusID(device)
149154
if err != nil {
150155
return fmt.Errorf("failed to get NVIDIA device: %w", err)
151156
}
@@ -164,8 +169,7 @@ func (m bindCommand) bindDevice() error {
164169

165170
m.logger.Infof("Binding device %s", device)
166171

167-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver()
168-
if err := m.nvpciLib.BindToVFIODriver(nvdev); err != nil {
172+
if err := m.nvpassthrough.BindToVFIODriver(nvdev); err != nil {
169173
return fmt.Errorf("failed to bind device %s to vfio driver: %w", device, err)
170174
}
171175

cmd/vfio-manage/unbind.go

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ package main
2121
import (
2222
"fmt"
2323

24+
"github.com/NVIDIA/go-nvlib/pkg/nvpci"
2425
"github.com/sirupsen/logrus"
2526
"github.com/urfave/cli/v2"
2627

27-
"github.com/NVIDIA/k8s-driver-manager/internal/nvpci"
28+
"github.com/NVIDIA/k8s-driver-manager/internal/nvpassthrough"
2829
)
2930

3031
type unbindCommand struct {
31-
logger *logrus.Logger
32-
nvpciLib nvpci.Interface
33-
options unbindOptions
32+
logger *logrus.Logger
33+
nvpci nvpci.Interface
34+
nvpassthrough nvpassthrough.Interface
35+
options unbindOptions
3436
}
3537

3638
type unbindOptions struct {
@@ -43,9 +45,12 @@ type unbindOptions struct {
4345
func newUnbindCommand(logger *logrus.Logger) *cli.Command {
4446
c := unbindCommand{
4547
logger: logger,
46-
nvpciLib: nvpci.New(
48+
nvpci: nvpci.New(
4749
nvpci.WithLogger(logger),
4850
),
51+
nvpassthrough: nvpassthrough.New(
52+
nvpassthrough.WithLogger(logger),
53+
),
4954
}
5055
return c.build()
5156
}
@@ -107,13 +112,13 @@ func (m unbindCommand) run() error {
107112
}
108113

109114
func (m unbindCommand) unbindAll() error {
110-
devices, err := m.nvpciLib.GetGPUs()
115+
devices, err := m.nvpci.GetGPUs()
111116
if err != nil {
112117
return fmt.Errorf("failed to get NVIDIA GPUs: %w", err)
113118
}
114119

115120
if m.options.unbindNVSwitches {
116-
nvswitches, err := m.nvpciLib.GetNVSwitches()
121+
nvswitches, err := m.nvpci.GetNVSwitches()
117122
if err != nil {
118123
return fmt.Errorf("failed to get NVIDIA NVSwitches: %w", err)
119124
}
@@ -122,8 +127,7 @@ func (m unbindCommand) unbindAll() error {
122127

123128
for _, dev := range devices {
124129
m.logger.Infof("Unbinding device %s", dev.Address)
125-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
126-
if err := m.nvpciLib.UnbindFromDriver(dev); err != nil {
130+
if err := m.nvpassthrough.UnbindFromDriver(dev); err != nil {
127131
m.logger.Warnf("Failed to unbind device %s: %v", dev.Address, err)
128132
}
129133
}
@@ -134,7 +138,7 @@ func (m unbindCommand) unbindDevice() error {
134138
device := m.options.deviceID
135139
// Note: Despite its name, GetGPUByPciBusID returns any NVIDIA PCI device
136140
// (GPU, NVSwitch, etc.) at the specified address, not just GPUs.
137-
nvdev, err := m.nvpciLib.GetGPUByPciBusID(device)
141+
nvdev, err := m.nvpci.GetGPUByPciBusID(device)
138142
if err != nil {
139143
return fmt.Errorf("failed to get NVIDIA device: %w", err)
140144
}
@@ -153,8 +157,7 @@ func (m unbindCommand) unbindDevice() error {
153157

154158
m.logger.Infof("Unbinding device %s", device)
155159

156-
// (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver()
157-
if err := m.nvpciLib.UnbindFromDriver(nvdev); err != nil {
160+
if err := m.nvpassthrough.UnbindFromDriver(nvdev); err != nil {
158161
return fmt.Errorf("failed to unbind device %s from driver: %w", device, err)
159162
}
160163

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
* Copyright (c) NVIDIA CORPORATION. All rights reserved.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package nvpci
17+
package nvpassthrough
1818

1919
import (
2020
"fmt"
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
1-
package nvpci
1+
/*
2+
* Copyright (c) NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package nvpassthrough
218

319
import (
420
"testing"

0 commit comments

Comments
 (0)