Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pkg/scheduler/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ const (
imagesVolumeName string = "beta9-images"
storageVolumeName string = "beta9-storage"
checkpointVolumeName string = "beta9-checkpoints"
devicePluginVolumeName string = "kubelet-device-plugins"
defaultDevicePluginPath string = "/var/lib/kubelet/device-plugins"
defaultContainerName string = "worker"
defaultWorkerEntrypoint string = "/usr/local/bin/worker"
defaultWorkerLogPath string = "/var/log/worker"
Expand Down
25 changes: 25 additions & 0 deletions pkg/scheduler/pool_external.go
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ func (wpc *ExternalWorkerPoolController) getWorkerEnvironment(workerId, machineI
Name: "GPU_COUNT",
Value: strconv.FormatInt(int64(gpuCount), 10),
},
{
Name: "POD_UID",
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.uid",
},
},
},
{
Name: "POD_NAMESPACE",
Value: wpc.config.Worker.Namespace,
Expand Down Expand Up @@ -607,6 +615,17 @@ func (wpc *ExternalWorkerPoolController) getWorkerVolumes(workerMemory int64) []
})
}

hostPathDir := corev1.HostPathDirectory
volumes = append(volumes, corev1.Volume{
Name: devicePluginVolumeName,
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
Path: defaultDevicePluginPath,
Type: &hostPathDir,
},
},
})

return volumes
}

Expand All @@ -633,6 +652,12 @@ func (wpc *ExternalWorkerPoolController) getWorkerVolumeMounts() []corev1.Volume
},
}

volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: devicePluginVolumeName,
MountPath: defaultDevicePluginPath,
ReadOnly: true,
})

if wpc.workerPoolConfig.CRIUEnabled && wpc.config.Worker.CRIU.Storage.Mode == string(types.CheckpointStorageModeLocal) {
volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: checkpointVolumeName,
Expand Down
25 changes: 25 additions & 0 deletions pkg/scheduler/pool_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,17 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerVolumes(workerMemory in
})
}

hostPathDir := corev1.HostPathDirectory
volumes = append(volumes, corev1.Volume{
Name: devicePluginVolumeName,
VolumeSource: corev1.VolumeSource{
HostPath: &corev1.HostPathVolumeSource{
Path: defaultDevicePluginPath,
Type: &hostPathDir,
},
},
})

return append(volumes,
corev1.Volume{
Name: imagesVolumeName,
Expand Down Expand Up @@ -407,6 +418,12 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerVolumeMounts() []corev1
},
}

volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: devicePluginVolumeName,
MountPath: defaultDevicePluginPath,
ReadOnly: true,
})

if len(wpc.workerPoolConfig.JobSpec.VolumeMounts) > 0 {
volumeMounts = append(volumeMounts, wpc.workerPoolConfig.JobSpec.VolumeMounts...)
}
Expand Down Expand Up @@ -461,6 +478,14 @@ func (wpc *LocalKubernetesWorkerPoolController) getWorkerEnvironment(workerId st
Name: "GPU_COUNT",
Value: strconv.FormatInt(int64(gpuCount), 10),
},
{
Name: "POD_UID",
ValueFrom: &corev1.EnvVarSource{
FieldRef: &corev1.ObjectFieldSelector{
FieldPath: "metadata.uid",
},
},
},
{
Name: "POD_IP",
ValueFrom: &corev1.EnvVarSource{
Expand Down
49 changes: 41 additions & 8 deletions pkg/worker/gpu_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package worker

import (
"bufio"
"encoding/json"
"errors"
"fmt"
"os"
Expand All @@ -24,22 +25,54 @@ type NvidiaInfoClient struct {
visibleDevices string
}

// resolveVisibleDevices gets the runtime-injected NVIDIA_VISIBLE_DEVICES by spawning
// a child process. The nvidia container runtime hook injects the correct per-worker
// GPU UUID into new processes, but PID 1 retains the base image default ("void").
// A child sh process receives the hook-injected value.
const defaultDeviceCheckpointPath = "/var/lib/kubelet/device-plugins/kubelet_internal_checkpoint"

type kubeletCheckpoint struct {
Data struct {
PodDeviceEntries []podDeviceEntry `json:"PodDeviceEntries"`
} `json:"Data"`
}

type podDeviceEntry struct {
PodUID string `json:"PodUID"`
ResourceName string `json:"ResourceName"`
DeviceIDs map[string][]string `json:"DeviceIDs"`
}

// resolveVisibleDevices determines which GPU is assigned to this worker pod.
//
// The nvidia/cuda base image sets ENV NVIDIA_VISIBLE_DEVICES=void which the
// container runtime processes AFTER PID 1 starts, so os.Getenv always returns
// "void". The authoritative GPU assignment lives in the kubelet device plugin
// checkpoint file, which maps pod UIDs to allocated GPU UUIDs.
var resolveVisibleDevices = func() string {
out, err := exec.Command("sh", "-c", "printenv NVIDIA_VISIBLE_DEVICES").Output()
podUID := os.Getenv("POD_UID")
if podUID == "" {
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}

data, err := os.ReadFile(defaultDeviceCheckpointPath)
if err != nil {
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}

resolved := strings.TrimSpace(string(out))
if resolved == "" || resolved == "void" {
var checkpoint kubeletCheckpoint
if err := json.Unmarshal(data, &checkpoint); err != nil {
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}

return resolved
for _, entry := range checkpoint.Data.PodDeviceEntries {
if entry.PodUID != podUID || entry.ResourceName != "nvidia.com/gpu" {
continue
}
for _, uuids := range entry.DeviceIDs {
if len(uuids) > 0 {
return strings.Join(uuids, ",")
}
}
Comment thread
luke-lombardi marked this conversation as resolved.
}

return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}

func (c *NvidiaInfoClient) hexToPaddedString(hexStr string) (string, error) {
Expand Down
78 changes: 78 additions & 0 deletions pkg/worker/gpu_info_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package worker

import (
"encoding/json"
"errors"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -186,3 +189,78 @@ func TestAvailableGPUDevicesSingleGPUUUID(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, []int{7}, devices)
}

func writeCheckpointFile(t *testing.T, dir string, entries []podDeviceEntry) string {
t.Helper()
checkpoint := kubeletCheckpoint{}
checkpoint.Data.PodDeviceEntries = entries
data, err := json.Marshal(checkpoint)
assert.NoError(t, err)
path := filepath.Join(dir, "kubelet_internal_checkpoint")
assert.NoError(t, os.WriteFile(path, data, 0644))
return path
}

func TestResolveVisibleDevicesFromCheckpoint(t *testing.T) {
origResolve := resolveVisibleDevices
defer func() { resolveVisibleDevices = origResolve }()

tmpDir := t.TempDir()
checkpointPath := writeCheckpointFile(t, tmpDir, []podDeviceEntry{
{
PodUID: "test-pod-uid-1",
ResourceName: "nvidia.com/gpu",
DeviceIDs: map[string][]string{"0": {"GPU-aaaa-bbbb-cccc"}},
},
{
PodUID: "test-pod-uid-2",
ResourceName: "nvidia.com/gpu",
DeviceIDs: map[string][]string{"1": {"GPU-dddd-eeee-ffff"}},
},
})

resolveVisibleDevices = func() string {
podUID := "test-pod-uid-1"
data, err := os.ReadFile(checkpointPath)
if err != nil {
return "fallback"
}
var cp kubeletCheckpoint
if err := json.Unmarshal(data, &cp); err != nil {
return "fallback"
}
for _, entry := range cp.Data.PodDeviceEntries {
if entry.PodUID != podUID || entry.ResourceName != "nvidia.com/gpu" {
continue
}
for _, uuids := range entry.DeviceIDs {
if len(uuids) > 0 {
return uuids[0]
}
}
}
return "fallback"
}

result := resolveVisibleDevices()
assert.Equal(t, "GPU-aaaa-bbbb-cccc", result)
}

func TestResolveVisibleDevicesFallsBackWithoutPodUID(t *testing.T) {
origResolve := resolveVisibleDevices
defer func() { resolveVisibleDevices = origResolve }()

os.Setenv("NVIDIA_VISIBLE_DEVICES", "all")
defer os.Unsetenv("NVIDIA_VISIBLE_DEVICES")

resolveVisibleDevices = func() string {
podUID := ""
if podUID == "" {
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}
return "should-not-reach"
}

result := resolveVisibleDevices()
assert.Equal(t, "all", result)
}
24 changes: 7 additions & 17 deletions pkg/worker/gpu_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func TestIntegrationGPUIsolation(t *testing.T) {
t.Skip("set GPU_INTEGRATION=1 to run on a real GPU node")
}

// Step 1: Read PID 1's env (what os.Getenv sees — the broken path)
// Step 1: Confirm PID 1 has void (the bug condition)
pid1Env := "(unknown)"
data, err := os.ReadFile("/proc/1/environ")
if err == nil {
Expand All @@ -28,18 +28,18 @@ func TestIntegrationGPUIsolation(t *testing.T) {
}
t.Logf("PID 1 NVIDIA_VISIBLE_DEVICES = %q", pid1Env)

// Step 2: Call the REAL resolveVisibleDevices() from gpu_info.go
// Step 2: Call the REAL resolveVisibleDevices() — reads from kubelet checkpoint
resolved := resolveVisibleDevices()
t.Logf("resolveVisibleDevices() = %q", resolved)

if resolved == "void" || resolved == "" {
t.Fatalf("resolveVisibleDevices() returned %q — void bug NOT fixed", resolved)
t.Fatalf("resolveVisibleDevices() returned %q — checkpoint resolution failed", resolved)
}
if !strings.HasPrefix(resolved, "GPU-") {
t.Fatalf("resolveVisibleDevices() returned %q — expected GPU UUID", resolved)
}

// Step 3: Create the REAL NvidiaInfoClient with the resolved value (same as NewContainerNvidiaManager)
// Step 3: Create the REAL NvidiaInfoClient with the resolved value
client := &NvidiaInfoClient{visibleDevices: resolved}

// Step 4: Call the REAL AvailableGPUDevices()
Expand All @@ -58,17 +58,10 @@ func TestIntegrationGPUIsolation(t *testing.T) {

// Step 5: Verify the OLD path (void) would have failed
oldClient := &NvidiaInfoClient{visibleDevices: pid1Env}
oldDevices, err := oldClient.AvailableGPUDevices()
if err != nil {
t.Logf("Old path error (expected): %v", err)
}
oldDevices, _ := oldClient.AvailableGPUDevices()
t.Logf("Old path (PID 1 env=%q) -> AvailableGPUDevices() = %v", pid1Env, oldDevices)

if pid1Env == "void" && len(oldDevices) > 0 {
t.Error("Old code path with void should return empty, but got devices — test logic wrong")
}

// Step 6: Exercise the REAL ContainerNvidiaManager.AssignGPUDevices (chooseDevices)
// Step 6: Exercise the REAL ContainerNvidiaManager.AssignGPUDevices
manager := &ContainerNvidiaManager{
gpuAllocationMap: common.NewSafeMap[[]int](),
gpuCount: 1,
Expand All @@ -83,14 +76,11 @@ func TestIntegrationGPUIsolation(t *testing.T) {
}
t.Logf("AssignGPUDevices(\"test-container-1\", 1) = %v", assigned)

if len(assigned) != 1 {
t.Fatalf("Expected 1 assigned GPU, got %d", len(assigned))
}
if assigned[0] != devices[0] {
t.Fatalf("Assigned GPU %d doesn't match available GPU %d", assigned[0], devices[0])
}

// Step 7: Verify second allocation to same worker FAILS (only 1 GPU available)
// Step 7: Verify second allocation FAILS (only 1 GPU per worker)
_, err = manager.AssignGPUDevices("test-container-2", 1)
if err == nil {
t.Fatal("Second allocation should fail — only 1 GPU per worker")
Expand Down
Loading