1717package rm
1818
1919import (
20+ "context"
2021 "fmt"
2122 "os"
2223 "strconv"
@@ -40,8 +41,19 @@ const (
4041 envEnableHealthChecks = "DP_ENABLE_HEALTHCHECKS"
4142)
4243
44+ type nvmlHealthProvider struct {
45+ nvmllib nvml.Interface
46+ devices Devices
47+ parentToDeviceMap map [string ]* Device
48+ deviceIDToGiMap map [string ]uint32
49+ deviceIDToCiMap map [string ]uint32
50+
51+ xidsDisabled disabledXIDs
52+ unhealthy chan <- * Device
53+ }
54+
4355// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices
44- func (r * nvmlResourceManager ) checkHealth (stop <- chan interface {} , devices Devices , unhealthy chan <- * Device ) error {
56+ func (r * nvmlResourceManager ) checkHealth (ctx context. Context , devices Devices , unhealthy chan <- * Device ) error {
4557 xids := getDisabledHealthCheckXids ()
4658 if xids .IsAllDisabled () {
4759 return nil
@@ -71,13 +83,14 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
7183 _ = eventSet .Free ()
7284 }()
7385
86+ // Construct the device maps.
87+ // TODO: This should be factored out. The main issue is marking the devices
88+ // unhealthy as part of this loop.
7489 parentToDeviceMap := make (map [string ]* Device )
7590 deviceIDToGiMap := make (map [string ]uint32 )
7691 deviceIDToCiMap := make (map [string ]uint32 )
77-
78- eventMask := uint64 (nvml .EventTypeXidCriticalError | nvml .EventTypeDoubleBitEccError | nvml .EventTypeSingleBitEccError )
7992 for _ , d := range devices {
80- uuid , gi , ci , err := r .getDevicePlacement (d )
93+ uuid , gi , ci , err := ( & withDevicePlacements { r . nvml }) .getDevicePlacement (d )
8194 if err != nil {
8295 klog .Warningf ("Could not determine device placement for %v: %v; Marking it unhealthy." , d .ID , err )
8396 unhealthy <- d
@@ -86,18 +99,35 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
8699 deviceIDToGiMap [d .ID ] = gi
87100 deviceIDToCiMap [d .ID ] = ci
88101 parentToDeviceMap [uuid ] = d
102+ }
103+
104+ p := nvmlHealthProvider {
105+ nvmllib : r .nvml ,
106+ devices : devices ,
107+ unhealthy : unhealthy ,
108+ parentToDeviceMap : parentToDeviceMap ,
109+ deviceIDToGiMap : deviceIDToGiMap ,
110+ deviceIDToCiMap : deviceIDToCiMap ,
111+ xidsDisabled : xids ,
112+ }
113+ p .registerDeviceEvents (eventSet )
114+
115+ return p .runEventMonitor (ctx , eventSet )
116+ }
89117
90- gpu , ret := r .nvml .DeviceGetHandleByUUID (uuid )
118+ func (r * nvmlHealthProvider ) registerDeviceEvents (eventSet nvml.EventSet ) {
119+ eventMask := uint64 (nvml .EventTypeXidCriticalError | nvml .EventTypeDoubleBitEccError | nvml .EventTypeSingleBitEccError )
120+ for uuid , d := range r .parentToDeviceMap {
121+ gpu , ret := r .nvmllib .DeviceGetHandleByUUID (uuid )
91122 if ret != nvml .SUCCESS {
92123 klog .Infof ("unable to get device handle from UUID: %v; marking it as unhealthy" , ret )
93- unhealthy <- d
124+ r . unhealthy <- d
94125 continue
95126 }
96-
97127 supportedEvents , ret := gpu .GetSupportedEventTypes ()
98128 if ret != nvml .SUCCESS {
99129 klog .Infof ("unable to determine the supported events for %v: %v; marking it as unhealthy" , d .ID , ret )
100- unhealthy <- d
130+ r . unhealthy <- d
101131 continue
102132 }
103133
@@ -107,14 +137,16 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
107137 klog .Warningf ("Device %v is too old to support healthchecking." , d .ID )
108138 case ret != nvml .SUCCESS :
109139 klog .Infof ("Marking device %v as unhealthy: %v" , d .ID , ret )
110- unhealthy <- d
140+ r . unhealthy <- d
111141 }
112142 }
143+ }
113144
145+ func (r * nvmlHealthProvider ) runEventMonitor (ctx context.Context , eventSet nvml.EventSet ) error {
114146 for {
115147 select {
116- case <- stop :
117- return nil
148+ case <- ctx . Done () :
149+ return ctx . Err ()
118150 default :
119151 }
120152
@@ -124,18 +156,20 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
124156 }
125157 if ret != nvml .SUCCESS {
126158 klog .Infof ("Error waiting for event: %v; Marking all devices as unhealthy" , ret )
127- for _ , d := range devices {
128- unhealthy <- d
159+ for _ , d := range r . devices {
160+ r . unhealthy <- d
129161 }
130162 continue
131163 }
132164
165+ // TODO: We create an event mask for other event types but don't handle
166+ // them here.
133167 if e .EventType != nvml .EventTypeXidCriticalError {
134168 klog .Infof ("Skipping non-nvmlEventTypeXidCriticalError event: %+v" , e )
135169 continue
136170 }
137171
138- if xids .IsDisabled (e .EventData ) {
172+ if r . xidsDisabled .IsDisabled (e .EventData ) {
139173 klog .Infof ("Skipping event %+v" , e )
140174 continue
141175 }
@@ -145,29 +179,29 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic
145179 if ret != nvml .SUCCESS {
146180 // If we cannot reliably determine the device UUID, we mark all devices as unhealthy.
147181 klog .Infof ("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy." , e , ret )
148- for _ , d := range devices {
149- unhealthy <- d
182+ for _ , d := range r . devices {
183+ r . unhealthy <- d
150184 }
151185 continue
152186 }
153187
154- d , exists := parentToDeviceMap [eventUUID ]
188+ d , exists := r . parentToDeviceMap [eventUUID ]
155189 if ! exists {
156190 klog .Infof ("Ignoring event for unexpected device: %v" , eventUUID )
157191 continue
158192 }
159193
160194 if d .IsMigDevice () && e .GpuInstanceId != 0xFFFFFFFF && e .ComputeInstanceId != 0xFFFFFFFF {
161- gi := deviceIDToGiMap [d .ID ]
162- ci := deviceIDToCiMap [d .ID ]
195+ gi := r . deviceIDToGiMap [d .ID ]
196+ ci := r . deviceIDToCiMap [d .ID ]
163197 if gi != e .GpuInstanceId || ci != e .ComputeInstanceId {
164198 continue
165199 }
166200 klog .Infof ("Event for mig device %v (gi=%v, ci=%v)" , d .ID , gi , ci )
167201 }
168202
169203 klog .Infof ("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy." , e .EventData , d .ID )
170- unhealthy <- d
204+ r . unhealthy <- d
171205 }
172206}
173207
@@ -276,25 +310,29 @@ func newHealthCheckXIDs(xids ...string) disabledXIDs {
276310 return output
277311}
278312
313+ type withDevicePlacements struct {
314+ nvml.Interface
315+ }
316+
279317// getDevicePlacement returns the placement of the specified device.
280318// For a MIG device the placement is defined by the 3-tuple <parent UUID, GI, CI>
281319// For a full device the returned 3-tuple is the device's uuid and 0xFFFFFFFF for the other two elements.
282- func (r * nvmlResourceManager ) getDevicePlacement (d * Device ) (string , uint32 , uint32 , error ) {
320+ func (r * withDevicePlacements ) getDevicePlacement (d * Device ) (string , uint32 , uint32 , error ) {
283321 if ! d .IsMigDevice () {
284322 return d .GetUUID (), 0xFFFFFFFF , 0xFFFFFFFF , nil
285323 }
286324 return r .getMigDeviceParts (d )
287325}
288326
289327// getMigDeviceParts returns the parent GI and CI ids of the MIG device.
290- func (r * nvmlResourceManager ) getMigDeviceParts (d * Device ) (string , uint32 , uint32 , error ) {
328+ func (r * withDevicePlacements ) getMigDeviceParts (d * Device ) (string , uint32 , uint32 , error ) {
291329 if ! d .IsMigDevice () {
292330 return "" , 0 , 0 , fmt .Errorf ("cannot get GI and CI of full device" )
293331 }
294332
295333 uuid := d .GetUUID ()
296334 // For older driver versions, the call to DeviceGetHandleByUUID will fail for MIG devices.
297- mig , ret := r .nvml . DeviceGetHandleByUUID (uuid )
335+ mig , ret := r .DeviceGetHandleByUUID (uuid )
298336 if ret == nvml .SUCCESS {
299337 parentHandle , ret := mig .GetDeviceHandleFromMigDeviceHandle ()
300338 if ret != nvml .SUCCESS {
0 commit comments