Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions cmd/gpu-kubelet-plugin/device_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ import (
drametrics "sigs.k8s.io/nvidia-dra-driver-gpu/pkg/metrics"
)

var (
deviceEnumerationInterval = 5 * time.Second
deviceEnumerationTimeout = 5 * time.Minute
)

type deviceEnumerator interface {
enumerateAllPossibleDevices() (*PerGPUAllocatableDevices, error)
}

type OpaqueDeviceConfig struct {
Requests []string
Config runtime.Object
Expand Down Expand Up @@ -71,6 +80,32 @@ type DeviceState struct {
cplock *flock.Flock
}

// enumerateDevicesWithRetry calls enumerateAllPossibleDevices in a loop until
// at least one device is found, the context is cancelled, or the timeout elapses.
// Retries should prevent a ResourceSlice without any device in the spec.
func enumerateDevicesWithRetry(ctx context.Context, nvdevlib deviceEnumerator) (*PerGPUAllocatableDevices, error) {
deadline := time.Now().Add(deviceEnumerationTimeout)
for {
perGPUAllocatable, err := nvdevlib.enumerateAllPossibleDevices()
if err != nil {
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
}
if len(perGPUAllocatable.allocatablesMap) > 0 {
return perGPUAllocatable, nil
}
if time.Now().After(deadline) {
klog.Warningf("No GPU devices found after %v, proceeding with empty device set", deviceEnumerationTimeout)
return perGPUAllocatable, nil
}
klog.Infof("No GPU devices found yet (driver may still be initializing), retrying in %v...", deviceEnumerationInterval)
select {
case <-ctx.Done():
return nil, fmt.Errorf("context cancelled while waiting for GPU devices: %w", ctx.Err())
case <-time.After(deviceEnumerationInterval):
}
}
}

func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
containerDriverRoot := root(config.flags.containerDriverRoot)
devRoot := containerDriverRoot.getDevRoot()
Expand All @@ -81,9 +116,9 @@ func NewDeviceState(ctx context.Context, config *Config) (*DeviceState, error) {
return nil, fmt.Errorf("failed to create device library: %w", err)
}

perGPUAllocatable, err := nvdevlib.enumerateAllPossibleDevices()
perGPUAllocatable, err := enumerateDevicesWithRetry(ctx, nvdevlib)
if err != nil {
return nil, fmt.Errorf("error enumerating all possible devices: %w", err)
return nil, err
}

hostDriverRoot := config.flags.hostDriverRoot
Expand Down
143 changes: 143 additions & 0 deletions cmd/gpu-kubelet-plugin/device_state_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*
Copyright The Kubernetes Authors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package main

import (
"context"
"errors"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

// fakeEnumerator implements deviceEnumerator for tests. Each call to
// enumerateAllPossibleDevices consumes the next result in the slice; once
// exhausted, the last result is repeated.
type fakeEnumerator struct {
results []enumerateResult
calls int
}

type enumerateResult struct {
devices *PerGPUAllocatableDevices
err error
}

func (f *fakeEnumerator) enumerateAllPossibleDevices() (*PerGPUAllocatableDevices, error) {
idx := f.calls
if idx >= len(f.results) {
idx = len(f.results) - 1
}
f.calls++
return f.results[idx].devices, f.results[idx].err
}

func emptyDevices() *PerGPUAllocatableDevices {
return &PerGPUAllocatableDevices{allocatablesMap: map[PCIBusID]AllocatableDevices{}}
}

func oneDevice() *PerGPUAllocatableDevices {
return &PerGPUAllocatableDevices{
allocatablesMap: map[PCIBusID]AllocatableDevices{
"0000:00:04.0": {"gpu-0": {Gpu: &GpuInfo{UUID: "GPU-fake-uuid"}}},
},
}
}

func TestEnumerateDevicesWithRetry(t *testing.T) {
tests := map[string]struct {
enumerator *fakeEnumerator
ctxFn func() (context.Context, context.CancelFunc)
timeout time.Duration
interval time.Duration
wantErr error
wantDeviceCount int
wantMinCalls int
}{
"devices found on first call": {
enumerator: &fakeEnumerator{results: []enumerateResult{{devices: oneDevice()}}},
wantDeviceCount: 1,
wantMinCalls: 1,
},
"devices found after retries": {
enumerator: &fakeEnumerator{results: []enumerateResult{
{devices: emptyDevices()},
{devices: emptyDevices()},
{devices: oneDevice()},
}},
interval: 1 * time.Millisecond,
wantDeviceCount: 1,
wantMinCalls: 3,
},
"error propagated immediately without retry": {
enumerator: &fakeEnumerator{results: []enumerateResult{{err: errors.New("nvml failed")}}},
wantErr: errors.New("nvml failed"),
wantMinCalls: 1,
},
"context cancelled returns context error": {
enumerator: &fakeEnumerator{results: []enumerateResult{{devices: emptyDevices()}}},
ctxFn: func() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx, cancel
},
interval: 1 * time.Millisecond,
wantErr: context.Canceled,
},
"timeout returns empty device set without error": {
enumerator: &fakeEnumerator{results: []enumerateResult{{devices: emptyDevices()}}},
timeout: 1 * time.Millisecond,
interval: 1 * time.Millisecond,
wantDeviceCount: 0,
},
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
// Override package-level timing variables for fast tests.
if tc.interval != 0 {
old := deviceEnumerationInterval
deviceEnumerationInterval = tc.interval
defer func() { deviceEnumerationInterval = old }()
}
if tc.timeout != 0 {
old := deviceEnumerationTimeout
deviceEnumerationTimeout = tc.timeout
defer func() { deviceEnumerationTimeout = old }()
}

ctx := context.Background()
if tc.ctxFn != nil {
var cancel context.CancelFunc
ctx, cancel = tc.ctxFn()
defer cancel()
}

got, err := enumerateDevicesWithRetry(ctx, tc.enumerator)

if tc.wantErr != nil {
require.ErrorContains(t, err, tc.wantErr.Error())
return
}
require.NoError(t, err)
assert.Len(t, got.allocatablesMap, tc.wantDeviceCount)
assert.GreaterOrEqual(t, tc.enumerator.calls, tc.wantMinCalls)
})
}
}
Loading