Skip to content

Commit 84a3f93

Browse files
elezarArangoGutierrez
authored andcommitted
refactor(health): extract nvmlHealthProvider and add tests
Refactor the health monitoring implementation: - Extract nvmlHealthProvider struct to encapsulate health check state - Add registerDeviceEvents() and runEventMonitor() methods - Rename xids field to xidsDisabled for clarity - Migrate from stop channel to context.Context for cancellation Add comprehensive tests: - TestCheckHealth validates XID event handling with mocks - TestRegisterDeviceEventsNotSupported ensures old GPUs returning ERROR_NOT_SUPPORTED are not incorrectly marked unhealthy Signed-off-by: Carlos Eduardo Arango Gutierrez <eduardoa@nvidia.com>
1 parent 59a16eb commit 84a3f93

File tree

2 files changed

+210
-23
lines changed

2 files changed

+210
-23
lines changed

internal/rm/health.go

Lines changed: 61 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package rm
1818

1919
import (
20+
"context"
2021
"fmt"
2122
"os"
2223
"strconv"
@@ -40,8 +41,19 @@ const (
4041
envEnableHealthChecks = "DP_ENABLE_HEALTHCHECKS"
4142
)
4243

44+
type nvmlHealthProvider struct {
45+
nvmllib nvml.Interface
46+
devices Devices
47+
parentToDeviceMap map[string]*Device
48+
deviceIDToGiMap map[string]uint32
49+
deviceIDToCiMap map[string]uint32
50+
51+
xidsDisabled disabledXIDs
52+
unhealthy chan<- *Device
53+
}
54+
4355
// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices
44-
func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devices, unhealthy chan<- *Device) error {
56+
func (r *nvmlResourceManager) checkHealth(ctx context.Context, devices Devices, unhealthy chan<- *Device) error {
4557
xids := getDisabledHealthCheckXids()
4658
if xids.IsAllDisabled() {
4759
return nil
@@ -71,13 +83,14 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
7183
_ = eventSet.Free()
7284
}()
7385

86+
// Construct the device maps.
87+
// TODO: This should be factored out. The main issue is marking the devices
88+
// unhealthy as part of this loop.
7489
parentToDeviceMap := make(map[string]*Device)
7590
deviceIDToGiMap := make(map[string]uint32)
7691
deviceIDToCiMap := make(map[string]uint32)
77-
78-
eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError)
7992
for _, d := range devices {
80-
uuid, gi, ci, err := r.getDevicePlacement(d)
93+
uuid, gi, ci, err := (&withDevicePlacements{r.nvml}).getDevicePlacement(d)
8194
if err != nil {
8295
klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err)
8396
unhealthy <- d
@@ -86,18 +99,35 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
8699
deviceIDToGiMap[d.ID] = gi
87100
deviceIDToCiMap[d.ID] = ci
88101
parentToDeviceMap[uuid] = d
102+
}
103+
104+
p := nvmlHealthProvider{
105+
nvmllib: r.nvml,
106+
devices: devices,
107+
unhealthy: unhealthy,
108+
parentToDeviceMap: parentToDeviceMap,
109+
deviceIDToGiMap: deviceIDToGiMap,
110+
deviceIDToCiMap: deviceIDToCiMap,
111+
xidsDisabled: xids,
112+
}
113+
p.registerDeviceEvents(eventSet)
114+
115+
return p.runEventMonitor(ctx, eventSet)
116+
}
89117

90-
gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid)
118+
func (r *nvmlHealthProvider) registerDeviceEvents(eventSet nvml.EventSet) {
119+
eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError)
120+
for uuid, d := range r.parentToDeviceMap {
121+
gpu, ret := r.nvmllib.DeviceGetHandleByUUID(uuid)
91122
if ret != nvml.SUCCESS {
92123
klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret)
93-
unhealthy <- d
124+
r.unhealthy <- d
94125
continue
95126
}
96-
97127
supportedEvents, ret := gpu.GetSupportedEventTypes()
98128
if ret != nvml.SUCCESS {
99129
klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret)
100-
unhealthy <- d
130+
r.unhealthy <- d
101131
continue
102132
}
103133

@@ -107,14 +137,16 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
107137
klog.Warningf("Device %v is too old to support healthchecking.", d.ID)
108138
case ret != nvml.SUCCESS:
109139
klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret)
110-
unhealthy <- d
140+
r.unhealthy <- d
111141
}
112142
}
143+
}
113144

145+
func (r *nvmlHealthProvider) runEventMonitor(ctx context.Context, eventSet nvml.EventSet) error {
114146
for {
115147
select {
116-
case <-stop:
117-
return nil
148+
case <-ctx.Done():
149+
return ctx.Err()
118150
default:
119151
}
120152

@@ -124,18 +156,20 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
124156
}
125157
if ret != nvml.SUCCESS {
126158
klog.Infof("Error waiting for event: %v; Marking all devices as unhealthy", ret)
127-
for _, d := range devices {
128-
unhealthy <- d
159+
for _, d := range r.devices {
160+
r.unhealthy <- d
129161
}
130162
continue
131163
}
132164

165+
// TODO: We create an event mask for other event types but don't handle
166+
// them here.
133167
if e.EventType != nvml.EventTypeXidCriticalError {
134168
klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e)
135169
continue
136170
}
137171

138-
if xids.IsDisabled(e.EventData) {
172+
if r.xidsDisabled.IsDisabled(e.EventData) {
139173
klog.Infof("Skipping event %+v", e)
140174
continue
141175
}
@@ -145,29 +179,29 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
145179
if ret != nvml.SUCCESS {
146180
// If we cannot reliably determine the device UUID, we mark all devices as unhealthy.
147181
klog.Infof("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy.", e, ret)
148-
for _, d := range devices {
149-
unhealthy <- d
182+
for _, d := range r.devices {
183+
r.unhealthy <- d
150184
}
151185
continue
152186
}
153187

154-
d, exists := parentToDeviceMap[eventUUID]
188+
d, exists := r.parentToDeviceMap[eventUUID]
155189
if !exists {
156190
klog.Infof("Ignoring event for unexpected device: %v", eventUUID)
157191
continue
158192
}
159193

160194
if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF {
161-
gi := deviceIDToGiMap[d.ID]
162-
ci := deviceIDToCiMap[d.ID]
195+
gi := r.deviceIDToGiMap[d.ID]
196+
ci := r.deviceIDToCiMap[d.ID]
163197
if gi != e.GpuInstanceId || ci != e.ComputeInstanceId {
164198
continue
165199
}
166200
klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci)
167201
}
168202

169203
klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID)
170-
unhealthy <- d
204+
r.unhealthy <- d
171205
}
172206
}
173207

@@ -276,25 +310,29 @@ func newHealthCheckXIDs(xids ...string) disabledXIDs {
276310
return output
277311
}
278312

313+
type withDevicePlacements struct {
314+
nvml.Interface
315+
}
316+
279317
// getDevicePlacement returns the placement of the specified device.
280318
// For a MIG device the placement is defined by the 3-tuple <parent UUID, GI, CI>
281319
// For a full device the returned 3-tuple is the device's uuid and 0xFFFFFFFF for the other two elements.
282-
func (r *nvmlResourceManager) getDevicePlacement(d *Device) (string, uint32, uint32, error) {
320+
func (r *withDevicePlacements) getDevicePlacement(d *Device) (string, uint32, uint32, error) {
283321
if !d.IsMigDevice() {
284322
return d.GetUUID(), 0xFFFFFFFF, 0xFFFFFFFF, nil
285323
}
286324
return r.getMigDeviceParts(d)
287325
}
288326

289327
// getMigDeviceParts returns the parent GI and CI ids of the MIG device.
290-
func (r *nvmlResourceManager) getMigDeviceParts(d *Device) (string, uint32, uint32, error) {
328+
func (r *withDevicePlacements) getMigDeviceParts(d *Device) (string, uint32, uint32, error) {
291329
if !d.IsMigDevice() {
292330
return "", 0, 0, fmt.Errorf("cannot get GI and CI of full device")
293331
}
294332

295333
uuid := d.GetUUID()
296334
// For older driver versions, the call to DeviceGetHandleByUUID will fail for MIG devices.
297-
mig, ret := r.nvml.DeviceGetHandleByUUID(uuid)
335+
mig, ret := r.DeviceGetHandleByUUID(uuid)
298336
if ret == nvml.SUCCESS {
299337
parentHandle, ret := mig.GetDeviceHandleFromMigDeviceHandle()
300338
if ret != nvml.SUCCESS {

internal/rm/health_test.go

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717
package rm
1818

1919
import (
20+
"context"
2021
"fmt"
2122
"strings"
23+
"sync"
2224
"testing"
2325

26+
"github.com/NVIDIA/go-nvml/pkg/nvml"
27+
"github.com/NVIDIA/go-nvml/pkg/nvml/mock"
28+
"github.com/NVIDIA/go-nvml/pkg/nvml/mock/dgxa100"
2429
"github.com/stretchr/testify/require"
2530
)
2631

@@ -221,3 +226,147 @@ func TestGetDisabledHealthCheckXids(t *testing.T) {
221226
})
222227
}
223228
}
229+
230+
func TestCheckHealth(t *testing.T) {
231+
ctx, cancel := context.WithCancel(context.Background())
232+
unhealthy := make(chan *Device)
233+
234+
server := dgxa100.New()
235+
236+
deviceMock := server.Devices[0].(*dgxa100.Device)
237+
deviceMock.GetSupportedEventTypesFunc = func() (uint64, nvml.Return) {
238+
return nvml.EventTypeXidCriticalError, nvml.SUCCESS
239+
}
240+
// TODO: Should this be more dynamic?
241+
deviceMock.RegisterEventsFunc = func(v uint64, eventSet nvml.EventSet) nvml.Return {
242+
return nvml.SUCCESS
243+
}
244+
245+
var count int
246+
eventData := []nvml.EventData{
247+
{
248+
EventData: 109,
249+
EventType: nvml.EventTypeXidCriticalError,
250+
Device: server.Devices[0],
251+
},
252+
{
253+
EventData: 48,
254+
EventType: nvml.EventTypeXidCriticalError,
255+
Device: server.Devices[0],
256+
},
257+
}
258+
259+
server.EventSetCreateFunc = func() (nvml.EventSet, nvml.Return) {
260+
es := &mock.EventSet{
261+
WaitFunc: func(v uint32) (nvml.EventData, nvml.Return) {
262+
ed := eventData[count%len(eventData)]
263+
count++
264+
if count == len(eventData) {
265+
// Cancel the context to signal the health checker to terminate
266+
// after the predefined events have been triggered.
267+
cancel()
268+
}
269+
return ed, nvml.SUCCESS
270+
},
271+
FreeFunc: func() nvml.Return {
272+
return nvml.SUCCESS
273+
},
274+
}
275+
return es, nvml.SUCCESS
276+
}
277+
278+
r := &nvmlResourceManager{
279+
nvml: server,
280+
}
281+
282+
var unhealthyDevices []*Device
283+
var wg sync.WaitGroup
284+
wg.Add(1)
285+
286+
go func() {
287+
defer wg.Done()
288+
for d := range unhealthy {
289+
unhealthyDevices = append(unhealthyDevices, d)
290+
}
291+
}()
292+
293+
var expectedDevices []*Device
294+
295+
devices := make(Devices)
296+
for i, d := range server.Devices {
297+
device, err := BuildDevice(newNvmlGPUDevice(i, d))
298+
require.NoError(t, err)
299+
devices[device.GetUUID()] = device
300+
expectedDevices = append(expectedDevices, device)
301+
// TODO: We only expect a single unhealthy event for the first device.
302+
break
303+
}
304+
305+
err := r.checkHealth(ctx, devices, unhealthy)
306+
require.ErrorIs(t, err, context.Canceled)
307+
308+
// Close the channel and wait for the goroutine to finish collecting unhealthy devices
309+
close(unhealthy)
310+
wg.Wait()
311+
312+
require.EqualValues(t, expectedDevices, unhealthyDevices)
313+
}
314+
315+
func TestRegisterDeviceEventsNotSupported(t *testing.T) {
316+
// This test verifies that devices which return ERROR_NOT_SUPPORTED from
317+
// RegisterEvents are NOT marked as unhealthy. This is the expected behavior
318+
// for old GPUs that don't support event registration.
319+
// See: commit 8cd14472a "Fix healthchecking on old devices"
320+
321+
unhealthy := make(chan *Device, 10)
322+
323+
server := dgxa100.New()
324+
325+
deviceMock := server.Devices[0].(*dgxa100.Device)
326+
deviceMock.GetSupportedEventTypesFunc = func() (uint64, nvml.Return) {
327+
return nvml.EventTypeXidCriticalError, nvml.SUCCESS
328+
}
329+
// Simulate an old device that doesn't support event registration
330+
deviceMock.RegisterEventsFunc = func(v uint64, eventSet nvml.EventSet) nvml.Return {
331+
return nvml.ERROR_NOT_SUPPORTED
332+
}
333+
334+
eventSet := &mock.EventSet{
335+
FreeFunc: func() nvml.Return {
336+
return nvml.SUCCESS
337+
},
338+
}
339+
340+
devices := make(Devices)
341+
parentToDeviceMap := make(map[string]*Device)
342+
343+
for i, d := range server.Devices {
344+
device, err := BuildDevice(newNvmlGPUDevice(i, d))
345+
require.NoError(t, err)
346+
devices[device.GetUUID()] = device
347+
parentToDeviceMap[device.GetUUID()] = device
348+
break
349+
}
350+
351+
p := nvmlHealthProvider{
352+
nvmllib: server,
353+
devices: devices,
354+
unhealthy: unhealthy,
355+
parentToDeviceMap: parentToDeviceMap,
356+
deviceIDToGiMap: make(map[string]uint32),
357+
deviceIDToCiMap: make(map[string]uint32),
358+
}
359+
360+
p.registerDeviceEvents(eventSet)
361+
362+
// Close the channel so we can drain it
363+
close(unhealthy)
364+
365+
// Verify that NO devices were marked as unhealthy
366+
var unhealthyDevices []*Device
367+
for d := range unhealthy {
368+
unhealthyDevices = append(unhealthyDevices, d)
369+
}
370+
371+
require.Empty(t, unhealthyDevices, "Devices returning ERROR_NOT_SUPPORTED should NOT be marked as unhealthy")
372+
}

0 commit comments

Comments
 (0)