Skip to content

Commit 4a84232

Browse files
refactor(plugin): use context for health check lifecycle
Replace the stop channel with context.Context cancellation and sync.WaitGroup for health check goroutine lifecycle management: - Initialize health channel and healthCtx before Serve() to eliminate race where kubelet could call ListAndWatch before fields are set - Use sync.WaitGroup to ensure health goroutine completes before channel close in Stop() - Add nil guards in Stop() to handle partial Start() failure safely - Add structured error handling for health check completion (success, canceled, error cases) - Add debug logging in ListAndWatch for context cancellation and channel closure - Fix updateResponseForMPS receiver from value to pointer Signed-off-by: Carlos Eduardo Arango Gutierrez <carangog@redhat.com> Co-authored-by: Evan Lezar <elezar@nvidia.com>
1 parent 4109c98 commit 4a84232

File tree

1 file changed

+38
-28
lines changed

1 file changed

+38
-28
lines changed

internal/plugin/server.go

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"path"
2626
"path/filepath"
2727
"strings"
28+
"sync"
2829
"time"
2930

3031
cdiapi "tags.cncf.io/container-device-interface/pkg/cdi"
@@ -62,7 +63,12 @@ type nvidiaDevicePlugin struct {
6263
socket string
6364
server *grpc.Server
6465
health chan *rm.Device
65-
stop chan interface{}
66+
67+
// healthCtx and healthCancel control the health check goroutine lifecycle.
68+
healthCtx context.Context
69+
healthCancel context.CancelFunc
70+
// healthWg is used to wait for the health check goroutine to complete during cleanup.
71+
healthWg sync.WaitGroup
6672

6773
imexChannels imex.Channels
6874

@@ -90,11 +96,6 @@ func (o *options) devicePluginForResource(ctx context.Context, resourceManager r
9096
mps: mpsOptions,
9197

9298
socket: getPluginSocketPath(resourceManager.Resource()),
93-
// These will be reinitialized every
94-
// time the plugin server is restarted.
95-
server: nil,
96-
health: nil,
97-
stop: nil,
9899
}
99100
return &plugin, nil
100101
}
@@ -106,19 +107,6 @@ func getPluginSocketPath(resource spec.ResourceName) string {
106107
return filepath.Join(pluginapi.DevicePluginPath, pluginName) + ".sock"
107108
}
108109

109-
func (plugin *nvidiaDevicePlugin) initialize() {
110-
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
111-
plugin.health = make(chan *rm.Device)
112-
plugin.stop = make(chan interface{})
113-
}
114-
115-
func (plugin *nvidiaDevicePlugin) cleanup() {
116-
close(plugin.stop)
117-
plugin.server = nil
118-
plugin.health = nil
119-
plugin.stop = nil
120-
}
121-
122110
// Devices returns the full set of devices associated with the plugin.
123111
func (plugin *nvidiaDevicePlugin) Devices() rm.Devices {
124112
return plugin.rm.Devices()
@@ -127,16 +115,16 @@ func (plugin *nvidiaDevicePlugin) Devices() rm.Devices {
127115
// Start starts the gRPC server, registers the device plugin with the Kubelet,
128116
// and starts the device healthchecks.
129117
func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
130-
plugin.initialize()
131-
132118
if err := plugin.mps.waitForDaemon(); err != nil {
133119
return fmt.Errorf("error waiting for MPS daemon: %w", err)
134120
}
135121

122+
plugin.health = make(chan *rm.Device)
123+
plugin.healthCtx, plugin.healthCancel = context.WithCancel(plugin.ctx)
124+
136125
err := plugin.Serve()
137126
if err != nil {
138127
klog.Errorf("Could not start device plugin for '%s': %s", plugin.rm.Resource(), err)
139-
plugin.cleanup()
140128
return err
141129
}
142130
klog.Infof("Starting to serve '%s' on %s", plugin.rm.Resource(), plugin.socket)
@@ -148,10 +136,17 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
148136
}
149137
klog.Infof("Registered device plugin for '%s' with Kubelet", plugin.rm.Resource())
150138

139+
plugin.healthWg.Add(1)
151140
go func() {
141+
defer plugin.healthWg.Done()
152142
// TODO: add MPS health check
153-
err := plugin.rm.CheckHealth(plugin.stop, plugin.health)
154-
if err != nil {
143+
err := plugin.rm.CheckHealth(plugin.healthCtx, plugin.health)
144+
switch {
145+
case err == nil:
146+
klog.Infof("Health check completed successfully for '%s'", plugin.rm.Resource())
147+
case errors.Is(err, context.Canceled):
148+
klog.V(4).Infof("Health check canceled for '%s' (plugin shutdown)", plugin.rm.Resource())
149+
default:
155150
klog.Errorf("Failed to start health check: %v; continuing with health checks disabled", err)
156151
}
157152
}()
@@ -164,12 +159,21 @@ func (plugin *nvidiaDevicePlugin) Stop() error {
164159
if plugin == nil || plugin.server == nil {
165160
return nil
166161
}
162+
// Stop health checks if they were started.
163+
if plugin.healthCancel != nil {
164+
plugin.healthCancel()
165+
plugin.healthWg.Wait()
166+
}
167+
if plugin.health != nil {
168+
close(plugin.health)
169+
}
170+
167171
klog.Infof("Stopping to serve '%s' on %s", plugin.rm.Resource(), plugin.socket)
168172
plugin.server.Stop()
173+
plugin.server = nil
169174
if err := os.Remove(plugin.socket); err != nil && !os.IsNotExist(err) {
170175
return err
171176
}
172-
plugin.cleanup()
173177
return nil
174178
}
175179

@@ -181,6 +185,7 @@ func (plugin *nvidiaDevicePlugin) Serve() error {
181185
return err
182186
}
183187

188+
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
184189
pluginapi.RegisterDevicePluginServer(plugin.server, plugin)
185190

186191
go func() {
@@ -271,9 +276,14 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
271276

272277
for {
273278
select {
274-
case <-plugin.stop:
279+
case <-plugin.healthCtx.Done():
280+
klog.V(4).Infof("Stopping health checks for '%s'", plugin.rm.Resource())
275281
return nil
276-
case d := <-plugin.health:
282+
case d, ok := <-plugin.health:
283+
if !ok {
284+
klog.V(4).Infof("Health channel closed for '%s'", plugin.rm.Resource())
285+
return nil
286+
}
277287
// FIXME: there is no way to recover from the Unhealthy state.
278288
d.Health = pluginapi.Unhealthy
279289
klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID)
@@ -368,7 +378,7 @@ func (plugin *nvidiaDevicePlugin) getAllocateResponse(requestIds []string) (*plu
368378
// updateResponseForMPS ensures that the ContainerAllocate response contains the information required to use MPS.
369379
// This includes per-resource pipe and log directories as well as a global daemon-specific shm
370380
// and assumes that an MPS control daemon has already been started.
371-
func (plugin nvidiaDevicePlugin) updateResponseForMPS(response *pluginapi.ContainerAllocateResponse) {
381+
func (plugin *nvidiaDevicePlugin) updateResponseForMPS(response *pluginapi.ContainerAllocateResponse) {
372382
plugin.mps.updateReponse(response)
373383
}
374384

0 commit comments

Comments
 (0)