Skip to content

Commit 459ebd5

Browse files
committed
Add MIG devices support
Signed-off-by: Gilad Zamoscinski <[email protected]>
1 parent 70c3650 commit 459ebd5

File tree

4 files changed

+94
-25
lines changed

4 files changed

+94
-25
lines changed

main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@ import (
2929
"volcano.sh/k8s-device-plugin/pkg/filewatcher"
3030
"volcano.sh/k8s-device-plugin/pkg/plugin"
3131
"volcano.sh/k8s-device-plugin/pkg/plugin/nvidia"
32+
"github.com/NVIDIA/go-gpuallocator/gpuallocator"
3233
)
3334

3435
func getAllPlugins() []plugin.DevicePlugin {
3536
return []plugin.DevicePlugin{
36-
nvidia.NewNvidiaDevicePlugin(),
37+
nvidia.NewNvidiaDevicePlugin(
38+
nvidia.VolcanoGPUResource,
39+
nvidia.NewGpuDeviceManager(false),
40+
nvidia.VisibleDevice,
41+
gpuallocator.Policy(nil),
42+
pluginapi.DevicePluginPath + "volcano.sock"),
3743
}
3844
}
3945

pkg/plugin/nvidia/mig-strategy.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
* limitations under the License.
1515
*/
1616

17-
package main
17+
package nvidia
1818

1919
import (
2020
"fmt"

pkg/plugin/nvidia/server.go

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import (
2727
"time"
2828

2929
"github.com/NVIDIA/gpu-monitoring-tools/bindings/go/nvml"
30+
"github.com/NVIDIA/go-gpuallocator/gpuallocator"
3031
v1 "k8s.io/api/core/v1"
3132
"k8s.io/klog"
3233

@@ -41,10 +42,12 @@ type NvidiaDevicePlugin struct {
4142
socket string
4243

4344
server *grpc.Server
45+
deviceListEnvvar string
46+
allocatePolicy gpuallocator.Policy
4447
// Physical gpu card
4548
physicalDevices []*Device
4649
health chan *Device
47-
stop chan struct{}
50+
stop chan interface{}
4851

4952
// Virtual devices
5053
virtualDevices []*pluginapi.Device
@@ -54,7 +57,7 @@ type NvidiaDevicePlugin struct {
5457
}
5558

5659
// NewNvidiaDevicePlugin returns an initialized NvidiaDevicePlugin
57-
func NewNvidiaDevicePlugin() *NvidiaDevicePlugin {
60+
func NewNvidiaDevicePlugin(resourceName string, resourceManager ResourceManager, deviceListEnvvar string, allocatePolicy gpuallocator.Policy, socket string) *NvidiaDevicePlugin {
5861
log.Println("Loading NVML")
5962
if err := nvml.Init(); err != nil {
6063
log.Printf("Failed to initialize NVML: %s.", err)
@@ -69,9 +72,11 @@ func NewNvidiaDevicePlugin() *NvidiaDevicePlugin {
6972
}
7073

7174
return &NvidiaDevicePlugin{
72-
ResourceManager: NewGpuDeviceManager(),
73-
resourceName: VolcanoGPUResource,
74-
socket: pluginapi.DevicePluginPath + "volcano.sock",
75+
ResourceManager: resourceManager,
76+
deviceListEnvvar: deviceListEnvvar,
77+
resourceName: resourceName,
78+
socket: socket,
79+
allocatePolicy: allocatePolicy,
7580
kubeInteractor: ki,
7681

7782
// These will be reinitialized every
@@ -89,7 +94,7 @@ func (m *NvidiaDevicePlugin) initialize() {
8994
m.physicalDevices = m.Devices()
9095
m.server = grpc.NewServer([]grpc.ServerOption{}...)
9196
m.health = make(chan *Device)
92-
m.stop = make(chan struct{})
97+
m.stop = make(chan interface{})
9398

9499
m.virtualDevices, m.devicesByIndex = GetDevices()
95100
}
@@ -122,7 +127,7 @@ func (m *NvidiaDevicePlugin) Name() string {
122127
func (m *NvidiaDevicePlugin) Start() error {
123128
m.initialize()
124129
// must be called after initialize
125-
if err := m.kubeInteractor.PatchGPUResourceOnNode(len(m.physicalDevices)); err != nil {
130+
if err := m.kubeInteractor.PatchGPUResourceOnNode(len(m.devicesByIndex)); err != nil {
126131
log.Printf("failed to patch gpu resource: %v", err)
127132
m.cleanup()
128133
return fmt.Errorf("failed to patch gpu resource: %v", err)
@@ -314,7 +319,7 @@ Allocate:
314319
klog.Warningf("Failed to get the gpu id for pod %s/%s", candidatePod.Namespace, candidatePod.Name)
315320
return nil, fmt.Errorf("failed to find gpu id")
316321
}
317-
_, exist := m.GetDeviceNameByIndex(uint(id))
322+
deviceName, exist := m.GetDeviceNameByIndex(uint(id))
318323
if !exist {
319324
klog.Warningf("Failed to find the dev for pod %s/%s because it's not able to find dev with index %d",
320325
candidatePod.Namespace, candidatePod.Name, id)
@@ -325,7 +330,7 @@ Allocate:
325330
reqGPU := len(req.DevicesIDs)
326331
response := pluginapi.ContainerAllocateResponse{
327332
Envs: map[string]string{
328-
VisibleDevice: fmt.Sprintf("%d", id),
333+
VisibleDevice: fmt.Sprintf("%s", deviceName),
329334
AllocatedGPUResource: fmt.Sprintf("%d", reqGPU),
330335
TotalGPUResource: fmt.Sprintf("%d", gpuMemory),
331336
},
@@ -362,3 +367,33 @@ func (m *NvidiaDevicePlugin) dial(unixSocketPath string, timeout time.Duration)
362367

363368
return c, nil
364369
}
370+
371+
// GetPreferredAllocation returns the preferred allocation from the set of devices specified in the request
372+
func (m *NvidiaDevicePlugin) GetPreferredAllocation(ctx context.Context, r *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
373+
response := &pluginapi.PreferredAllocationResponse{}
374+
for _, req := range r.ContainerRequests {
375+
available, err := gpuallocator.NewDevicesFrom(req.AvailableDeviceIDs)
376+
if err != nil {
377+
return nil, fmt.Errorf("Unable to retrieve list of available devices: %v", err)
378+
}
379+
380+
required, err := gpuallocator.NewDevicesFrom(req.MustIncludeDeviceIDs)
381+
if err != nil {
382+
return nil, fmt.Errorf("Unable to retrieve list of required devices: %v", err)
383+
}
384+
385+
allocated := m.allocatePolicy.Allocate(available, required, int(req.AllocationSize))
386+
387+
var deviceIds []string
388+
for _, device := range allocated {
389+
deviceIds = append(deviceIds, device.UUID)
390+
}
391+
392+
resp := &pluginapi.ContainerPreferredAllocationResponse{
393+
DeviceIDs: deviceIds,
394+
}
395+
396+
response.ContainerResponses = append(response.ContainerResponses, resp)
397+
}
398+
return response, nil
399+
}

pkg/plugin/nvidia/utils.go

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ func GenerateVirtualDeviceID(id uint, fakeCounter uint) string {
6363

6464
func SetGPUMemory(raw uint) {
6565
v := raw
66+
// TODO Add cli flag for units
67+
v = uint(math.Floor(float64(raw) / 100.0))
6668
gpuMemory = v
6769
log.Infof("set gpu memory: %d", gpuMemory)
6870
}
@@ -81,23 +83,49 @@ func GetDevices() ([]*pluginapi.Device, map[uint]string) {
8183
for i := uint(0); i < n; i++ {
8284
d, err := nvml.NewDevice(i)
8385
check(err)
84-
var id uint
85-
_, err = fmt.Sscanf(d.Path, "/dev/nvidia%d", &id)
86+
migEnabled, err := d.IsMigEnabled()
8687
check(err)
87-
deviceByIndex[id] = d.UUID
88-
// TODO: Do we assume all cards are of same capacity
89-
if GetGPUMemory() == uint(0) {
90-
SetGPUMemory(uint(*d.Memory))
91-
}
92-
for j := uint(0); j < GetGPUMemory(); j++ {
93-
fakeID := GenerateVirtualDeviceID(id, j)
94-
virtualDevs = append(virtualDevs, &pluginapi.Device{
95-
ID: fakeID,
96-
Health: pluginapi.Healthy,
97-
})
88+
89+
var id uint
90+
// TODO: Support only MigStrategySingle
91+
if migEnabled {
92+
migs, err := d.GetMigDevices()
93+
check(err)
94+
for j, mig := range migs {
95+
// TODO: explain formula (based on device and mig numbers)
96+
id = i*uint(2) + i + uint(j)
97+
deviceByIndex[id] = mig.UUID
98+
if GetGPUMemory() == uint(0) {
99+
SetGPUMemory(uint(*mig.Memory))
100+
}
101+
for j := uint(0); j < GetGPUMemory(); j++ {
102+
fakeID := GenerateVirtualDeviceID(id, j)
103+
virtualDevs = append(virtualDevs, &pluginapi.Device{
104+
ID: fakeID,
105+
Health: pluginapi.Healthy,
106+
})
107+
}
108+
109+
}
110+
111+
} else {
112+
113+
_, err = fmt.Sscanf(d.Path, "/dev/nvidia%d", &id)
114+
check(err)
115+
deviceByIndex[id] = d.UUID
116+
// TODO: Do we assume all cards are of same capacity
117+
if GetGPUMemory() == uint(0) {
118+
SetGPUMemory(uint(*d.Memory))
119+
}
120+
for j := uint(0); j < GetGPUMemory(); j++ {
121+
fakeID := GenerateVirtualDeviceID(id, j)
122+
virtualDevs = append(virtualDevs, &pluginapi.Device{
123+
ID: fakeID,
124+
Health: pluginapi.Healthy,
125+
})
126+
}
98127
}
99128
}
100-
101129
return virtualDevs, deviceByIndex
102130
}
103131

0 commit comments

Comments
 (0)