Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions pkg/worker/gpu_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package worker

import (
"errors"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -175,33 +174,6 @@ func TestAvailableGPUDevicesReturnsQueryErrors(t *testing.T) {
assert.Nil(t, devices)
}

func TestResolveVisibleDevicesUsesChildProcess(t *testing.T) {
origResolve := resolveVisibleDevices
defer func() { resolveVisibleDevices = origResolve }()

resolveVisibleDevices = func() string {
return "GPU-04612b44-abcd-1234-5678-aabbccddeeff"
}

result := resolveVisibleDevices()
assert.Equal(t, "GPU-04612b44-abcd-1234-5678-aabbccddeeff", result)
}

func TestResolveVisibleDevicesFallsBackToEnv(t *testing.T) {
origResolve := resolveVisibleDevices
defer func() { resolveVisibleDevices = origResolve }()

os.Setenv("NVIDIA_VISIBLE_DEVICES", "GPU-fallback-uuid")
defer os.Unsetenv("NVIDIA_VISIBLE_DEVICES")

resolveVisibleDevices = func() string {
return os.Getenv("NVIDIA_VISIBLE_DEVICES")
}

result := resolveVisibleDevices()
assert.Equal(t, "GPU-fallback-uuid", result)
}

func TestAvailableGPUDevicesSingleGPUUUID(t *testing.T) {
cleanup := withMockDevices(eightGPUOutput, true)
defer cleanup()
Expand Down
102 changes: 102 additions & 0 deletions pkg/worker/gpu_integration_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package worker

import (
"fmt"
"os"
"strings"
"testing"

common "github.com/beam-cloud/beta9/pkg/common"
"gvisor.dev/gvisor/pkg/sync"
)

func TestIntegrationGPUIsolation(t *testing.T) {
if os.Getenv("GPU_INTEGRATION") != "1" {
t.Skip("set GPU_INTEGRATION=1 to run on a real GPU node")
}

// Step 1: Read PID 1's env (what os.Getenv sees — the broken path)
pid1Env := "(unknown)"
data, err := os.ReadFile("/proc/1/environ")
if err == nil {
for _, entry := range strings.Split(string(data), "\x00") {
if strings.HasPrefix(entry, "NVIDIA_VISIBLE_DEVICES=") {
pid1Env = strings.TrimPrefix(entry, "NVIDIA_VISIBLE_DEVICES=")
break
}
}
}
t.Logf("PID 1 NVIDIA_VISIBLE_DEVICES = %q", pid1Env)

// Step 2: Call the REAL resolveVisibleDevices() from gpu_info.go
resolved := resolveVisibleDevices()
t.Logf("resolveVisibleDevices() = %q", resolved)

if resolved == "void" || resolved == "" {
t.Fatalf("resolveVisibleDevices() returned %q — void bug NOT fixed", resolved)
}
if !strings.HasPrefix(resolved, "GPU-") {
t.Fatalf("resolveVisibleDevices() returned %q — expected GPU UUID", resolved)
}

// Step 3: Create the REAL NvidiaInfoClient with the resolved value (same as NewContainerNvidiaManager)
client := &NvidiaInfoClient{visibleDevices: resolved}

// Step 4: Call the REAL AvailableGPUDevices()
devices, err := client.AvailableGPUDevices()
if err != nil {
t.Fatalf("AvailableGPUDevices() error: %v", err)
}
t.Logf("AvailableGPUDevices() = %v", devices)

if len(devices) == 0 {
t.Fatal("AvailableGPUDevices() returned empty — GPU not visible")
}
if len(devices) != 1 {
t.Fatalf("AvailableGPUDevices() returned %d GPUs, expected exactly 1 for per-worker isolation", len(devices))
}

// Step 5: Verify the OLD path (void) would have failed
oldClient := &NvidiaInfoClient{visibleDevices: pid1Env}
oldDevices, err := oldClient.AvailableGPUDevices()
if err != nil {
t.Logf("Old path error (expected): %v", err)
}
t.Logf("Old path (PID 1 env=%q) -> AvailableGPUDevices() = %v", pid1Env, oldDevices)

if pid1Env == "void" && len(oldDevices) > 0 {
t.Error("Old code path with void should return empty, but got devices — test logic wrong")
}

// Step 6: Exercise the REAL ContainerNvidiaManager.AssignGPUDevices (chooseDevices)
manager := &ContainerNvidiaManager{
gpuAllocationMap: common.NewSafeMap[[]int](),
gpuCount: 1,
mu: sync.Mutex{},
infoClient: client,
resolvedVisibleDevices: resolved,
}

assigned, err := manager.AssignGPUDevices("test-container-1", 1)
if err != nil {
t.Fatalf("AssignGPUDevices() failed: %v", err)
}
t.Logf("AssignGPUDevices(\"test-container-1\", 1) = %v", assigned)

if len(assigned) != 1 {
t.Fatalf("Expected 1 assigned GPU, got %d", len(assigned))
}
if assigned[0] != devices[0] {
t.Fatalf("Assigned GPU %d doesn't match available GPU %d", assigned[0], devices[0])
}

// Step 7: Verify second allocation to same worker FAILS (only 1 GPU available)
_, err = manager.AssignGPUDevices("test-container-2", 1)
if err == nil {
t.Fatal("Second allocation should fail — only 1 GPU per worker")
}
t.Logf("Second allocation correctly failed: %v", err)

fmt.Printf("\nRESULT: resolved=%s gpu_index=%d old_path_would_fail=%v PASS\n",
resolved, assigned[0], pid1Env == "void" && len(oldDevices) == 0)
}
24 changes: 13 additions & 11 deletions pkg/worker/nvidia.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@ type GPUManager interface {
}

type ContainerNvidiaManager struct {
gpuAllocationMap *common.SafeMap[[]int]
gpuCount uint32
mu sync.Mutex
statFunc func(path string, stat *syscall.Stat_t) (err error)
infoClient GPUInfoClient
gpuAllocationMap *common.SafeMap[[]int]
gpuCount uint32
mu sync.Mutex
statFunc func(path string, stat *syscall.Stat_t) (err error)
infoClient GPUInfoClient
resolvedVisibleDevices string
}

func NewContainerNvidiaManager(gpuCount uint32) GPUManager {
Expand All @@ -51,11 +52,12 @@ func NewContainerNvidiaManager(gpuCount uint32) GPUManager {
log.Info().Str("resolved_visible_devices", visibleDevices).Msg("resolved NVIDIA_VISIBLE_DEVICES for GPU filtering")

return &ContainerNvidiaManager{
gpuAllocationMap: common.NewSafeMap[[]int](),
gpuCount: gpuCount,
mu: sync.Mutex{},
statFunc: syscall.Stat,
infoClient: &NvidiaInfoClient{visibleDevices: visibleDevices},
gpuAllocationMap: common.NewSafeMap[[]int](),
gpuCount: gpuCount,
mu: sync.Mutex{},
statFunc: syscall.Stat,
infoClient: &NvidiaInfoClient{visibleDevices: visibleDevices},
resolvedVisibleDevices: visibleDevices,
}
}

Expand Down Expand Up @@ -114,7 +116,7 @@ func (c *ContainerNvidiaManager) chooseDevices(containerId string, requestedGpuC
// Check if we managed to allocate the requested number of GPUs
if len(allocableDevices) < int(requestedGpuCount) {
return nil, fmt.Errorf("not enough GPUs available: requested=%d, allocable=%d, visible=%d, configured=%d, already_allocated=%d, NVIDIA_VISIBLE_DEVICES=%q",
requestedGpuCount, len(allocableDevices), len(availableDevices), c.gpuCount, len(currentAllocations), c.infoClient.(*NvidiaInfoClient).visibleDevices)
requestedGpuCount, len(allocableDevices), len(availableDevices), c.gpuCount, len(currentAllocations), c.resolvedVisibleDevices)
}

// Allocate the requested number of GPUs
Expand Down
Loading