Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/compute-domain-kubelet-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ func (c Config) DriverPluginPath() string {
}

func main() {
if err := common.MaskNvidiaDriverParams(); err != nil {
fmt.Fprintf(os.Stderr, "Error masking NVIDIA driver params: %v\n", err)
}

if err := newApp().Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
Expand Down
15 changes: 15 additions & 0 deletions cmd/gpu-kubelet-plugin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ func (c Config) DriverPluginPath() string {
}

func main() {
if err := common.MaskNvidiaDriverParams(); err != nil {
fmt.Fprintf(os.Stderr, "Error masking NVIDIA driver params: %v\n", err)
}

if len(os.Args) > 1 && os.Args[1] == "prestart" {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer cancel()

if err := runPrestart(ctx); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
os.Exit(0)
}

if err := newApp().Run(os.Args); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
Expand Down
226 changes: 226 additions & 0 deletions cmd/gpu-kubelet-plugin/prestart.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Copyright (c) 2026 NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package main

import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
"time"

"k8s.io/klog/v2"
)

// Main intent: help users to self-troubleshoot when the GPU driver is not set up
// properly before installing this DRA driver. In that case, the log of the init
// container running the prestart code is meant to yield an actionable error message.
// For now, rely on k8s to implement a high-level retry with back-off.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the comment might be out of date

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rely on k8s to implement a high-level retry with back-off

this bit does seem to be out of date with the forever loop. That's true for the shell script too, i can modify both to reflect that.

do you have other modifications in mind? other bits seems fine to me

func runPrestart(ctx context.Context) error {
waitS := 10 * time.Second
attempt := 0

nvidiaDriverRoot := os.Getenv("NVIDIA_DRIVER_ROOT")
if nvidiaDriverRoot == "" {
// Not set, or set to empty string (not distinguishable).
// Normalize to "/" (treated as such elsewhere).
nvidiaDriverRoot = "/"
}

driverRootParent := "/driver-root-parent"
// filepath.Base removes trailing slash (if existing) and get last path element.
driverRootPath := filepath.Join(driverRootParent, filepath.Base(nvidiaDriverRoot))

// Ensure the /driver-root-parent directory exists
if err := os.MkdirAll(driverRootParent, 0755); err != nil {
klog.Warningf("Failed to create %s: %v", driverRootParent, err)
}

// Create in-container path /driver-root as a symlink. Expectation: link may be
// broken initially (e.g., if the GPU operator isn't deployed yet. The link heals
// once the driver becomes mounted (e.g., once GPU operator provides the driver
// on the host at /run/nvidia/driver).
fmt.Printf("create symlink: /driver-root -> %s\n", driverRootPath)

if err := os.Symlink(driverRootPath, "/driver-root"); err != nil {
klog.Warningf("Failed to create symlink: %v", err)
}

for {
if validateAndExitOnSuccess(ctx, nvidiaDriverRoot, attempt) {
return nil
}

select {
case <-ctx.Done():
// DS pods may get deleted (terminated with SIGTERM) and re-created when the GPU
// Operator driver container creates a mount at /run/nvidia. Make that explicit.
fmt.Printf("%s: received SIGTERM\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"))
return nil
case <-time.After(waitS):
attempt++
}
}
}

func emitCommonErr(nvidiaDriverRoot string) {
fmt.Printf("Check failed. Has the NVIDIA GPU driver been set up? "+
"It is expected to be installed under "+
"NVIDIA_DRIVER_ROOT (currently set to '%s') "+
"in the host filesystem. If that path appears to be unexpected: "+
"review the DRA driver's 'nvidiaDriverRoot' Helm chart variable. "+
"Otherwise, review if the GPU driver has "+
"actually been installed under that path.\n", nvidiaDriverRoot)
}

func validateAndExitOnSuccess(ctx context.Context, nvidiaDriverRoot string, attempt int) bool {
fmt.Printf("%s /driver-root (%s on host): ", time.Now().UTC().Format("2006-01-02T15:04:05Z"), nvidiaDriverRoot)

// Search specific set of directories (not recursively: not required, and
// /driver-root may be a big tree). Limit to first result (multiple results
// are a bit of a pathological state, but continue with validation logic).

// original script does not follow symlink for nvpath but since symlinkm
// can also execute so reuse findFirstFile to avoid new func that's largely
// duplicative.
nvPath := findFirstFile(
"nvidia-smi",
"/driver-root/opt/bin",
"/driver-root/usr/bin",
"/driver-root/usr/sbin",
"/driver-root/bin",
"/driver-root/sbin",
)

nvLibPath := findFirstFile(
"libnvidia-ml.so.1",
"/driver-root/usr/lib64",
"/driver-root/usr/lib/x86_64-linux-gnu",
"/driver-root/usr/lib/aarch64-linux-gnu",
"/driver-root/lib64",
"/driver-root/lib/x86_64-linux-gnu",
"/driver-root/lib/aarch64-linux-gnu",
)

if nvPath == "" {
fmt.Printf("nvidia-smi: not found, ")
} else {
fmt.Printf("nvidia-smi: '%s', ", nvPath)
}

if nvLibPath == "" {
fmt.Printf("libnvidia-ml.so.1: not found, ")
} else {
fmt.Printf("libnvidia-ml.so.1: '%s', ", nvLibPath)
}

// Log top-level entries in /driver-root (this may be valuable debug info).
entries, _ := os.ReadDir("/driver-root")
var entryNames string
for i, e := range entries {
if i > 0 {
entryNames += " "
}
entryNames += e.Name()
}
fmt.Printf("current contents: [%s].\n", entryNames)

if nvPath != "" && nvLibPath != "" {
// Run with clean environment (only LD_PRELOAD; nvidia-smi has only this
// dependency). Emit message before invocation (nvidia-smi may be slow or
// hang).
fmt.Printf("invoke: env -i LD_PRELOAD=%s %s\n", nvLibPath, nvPath)

// override default env to just LD_PRELOAD
cmd := exec.CommandContext(ctx, nvPath)
cmd.Env = []string{"LD_PRELOAD=" + nvLibPath}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr

err := cmd.Run()
// For checking GPU driver health: rely on nvidia-smi's exit code. Rely
// on code 0 signaling that the driver is properly set up. See section
// 'RETURN VALUE' in the nvidia-smi man page for meaning of error codes.
if err == nil {
fmt.Printf("nvidia-smi returned with code 0: success, leave\n")
return true
}

if exitErr, ok := err.(*exec.ExitError); ok {
fmt.Printf("exit code: %d\n", exitErr.ExitCode())
} else {
// nvidia-smi fails to start. e.g. permission denied etc.
fmt.Printf("execution failed: %v, exit code: -1\n", err)
}
}

// Reduce log volume: log hints only every Nth attempt.
if attempt%6 != 0 {
return false
}

// nvidia-smi binaries not found, or execution failed. First, provide generic
// error message. Then, try to provide actionable hints for common problems.
fmt.Println()
emitCommonErr(nvidiaDriverRoot)

// For host-provided driver not at / provide feedback for two special cases.
if nvidiaDriverRoot != "/" {
if len(entries) == 0 {
fmt.Printf("Hint: Directory %s on the host is empty\n", nvidiaDriverRoot)
} else {
// Not empty, but at least one of the binaries not found: this is a
// rather pathological state.
if nvPath == "" || nvLibPath == "" {
fmt.Printf("Hint: Directory %s is not empty but at least one of the binaries wasn't found.\n", nvidiaDriverRoot)
}
}
}

// Common mistake: driver container, but forgot `--set nvidiaDriverRoot`
if nvidiaDriverRoot == "/" {
if _, err := os.Stat("/driver-root/run/nvidia/driver/usr/bin/nvidia-smi"); err == nil {
fmt.Printf("Hint: '/run/nvidia/driver/usr/bin/nvidia-smi' exists on the host, you " +
"may want to re-install the DRA driver Helm chart with " +
"--set nvidiaDriverRoot=/run/nvidia/driver\n")
}
}

if nvidiaDriverRoot == "/run/nvidia/driver" {
fmt.Printf("Hint: NVIDIA_DRIVER_ROOT is set to '/run/nvidia/driver' " +
"which typically means that the NVIDIA GPU Operator " +
"manages the GPU driver. Make sure that the GPU Operator " +
"is deployed and healthy.\n")
}
fmt.Println()

return false
}

// findFirstFile finds the first occurrence of filename in the provided
// directories not recursively.
// It follows symlinks (similar to find -L).
func findFirstFile(filename string, dirs ...string) string {
for _, dir := range dirs {
path := filepath.Join(dir, filename)
if info, err := os.Stat(path); err == nil && !info.IsDir() {
return path
}
}
return ""
}
108 changes: 89 additions & 19 deletions cmd/gpu-kubelet-plugin/vfio-device.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"os"
"os/exec"
"path/filepath"
"strings"
"time"

"k8s.io/klog/v2"
Expand All @@ -31,19 +32,17 @@ import (
)

const (
kernelIommuGroupPath = "/sys/kernel/iommu_groups"
vfioPciModule = "vfio_pci"
vfioPciDriver = "vfio-pci"
nvidiaDriver = "nvidia"
hostRoot = "/host-root"
sysModulesRoot = "/sys/module"
pciDevicesRoot = "/sys/bus/pci/devices"
vfioDevicesRoot = "/dev/vfio"
unbindFromDriverScript = "/usr/bin/unbind_from_driver.sh"
bindToDriverScript = "/usr/bin/bind_to_driver.sh"
driverResetRetries = "5"
gpuFreeCheckInterval = 1 * time.Second
gpuFreeCheckTimeout = 60 * time.Second
kernelIommuGroupPath = "/sys/kernel/iommu_groups"
vfioPciModule = "vfio_pci"
vfioPciDriver = "vfio-pci"
nvidiaDriver = "nvidia"
hostRoot = "/host-root"
sysModulesRoot = "/sys/module"
pciDevicesRoot = "/sys/bus/pci/devices"
vfioDevicesRoot = "/dev/vfio"
driverResetRetries = "5"
gpuFreeCheckInterval = 1 * time.Second
gpuFreeCheckTimeout = 60 * time.Second
)

type VfioPciManager struct {
Expand Down Expand Up @@ -248,20 +247,91 @@ func (vm *VfioPciManager) changeDriver(pciAddress, driver string) error {
return nil
}

func (vm *VfioPciManager) acquireUnbindLock(gpu string) error {
lockRetries := 5
unbindLockFile := filepath.Join("/proc/driver/nvidia/gpus", gpu, "unbindLock")

if _, err := os.Stat(unbindLockFile); err != nil {
// If the lock file doesn't exist, we assume no lock is needed.
return nil
}

for attempt := 1; attempt <= lockRetries; attempt++ {
klog.Infof("[retry %d/%d] Attempting to acquire unbindLock for %s", attempt, lockRetries, gpu)

// Try to write 1 to acquire the lock
err := os.WriteFile(unbindLockFile, []byte("1\n"), 0644)
if err != nil {
klog.Warningf("failed to write to unbindLock file %s: %v", unbindLockFile, err)
}

// Read the lock file to verify
content, err := os.ReadFile(unbindLockFile)
if err == nil {
val := strings.TrimSpace(string(content))
if val == "1" {
klog.Infof("UnbindLock acquired for %s", gpu)
return nil
}
}

time.Sleep(time.Duration(attempt) * time.Second)
}

return fmt.Errorf("cannot obtain unbindLock for %s", gpu)
}

func (vm *VfioPciManager) unbindFromDriver(pciAddress string) error {
out, err := execCommand(unbindFromDriverScript, []string{pciAddress}) //nolint:gosec
driverPath := filepath.Join(pciDevicesRoot, pciAddress, "driver")
if _, err := os.Stat(driverPath); err != nil {
// Not bound to any driver
return nil
}

existingDriver, err := filepath.EvalSymlinks(driverPath)
if err != nil {
klog.Errorf("Attempting to unbind %s from its driver failed; stdout: %s, err: %v", pciAddress, string(out), err)
return fmt.Errorf("failed to resolve driver symlink for %s: %v", pciAddress, err)
}

existingDriverName := filepath.Base(existingDriver)
if existingDriverName == "nvidia" {
if err := vm.acquireUnbindLock(pciAddress); err != nil {
return err
}
}

if err := os.WriteFile(filepath.Join(existingDriver, "unbind"), []byte(pciAddress+"\n"), 0644); err != nil {
klog.Errorf("Attempting to unbind %s from its driver failed; err: %v", pciAddress, err)
return err
}
return nil
}

func (vm *VfioPciManager) bindToDriver(pciAddress, driver string) error {
out, err := execCommand(bindToDriverScript, []string{pciAddress, driver}) //nolint:gosec
if err != nil {
klog.Errorf("Attempting to bind %s to %s driver failed; stdout: %s, err: %v", pciAddress, driver, string(out), err)
return err
driversPath := "/sys/bus/pci/drivers"
driverOverrideFile := filepath.Join(pciDevicesRoot, pciAddress, "driver_override")
bindFile := filepath.Join(driversPath, driver, "bind")

if _, err := os.Stat(driverOverrideFile); err != nil {
klog.Errorf("'%s' file does not exist", driverOverrideFile)
return fmt.Errorf("driver_override file not found: %v", err)
}

if err := os.WriteFile(driverOverrideFile, []byte(driver+"\n"), 0644); err != nil {
klog.Errorf("failed to write '%s' to %s", driver, driverOverrideFile)
return fmt.Errorf("failed to write to driver_override: %v", err)
}

if _, err := os.Stat(bindFile); err != nil {
klog.Errorf("'%s' file does not exist", bindFile)
return fmt.Errorf("bind file not found: %v", err)
}

if err := os.WriteFile(bindFile, []byte(pciAddress+"\n"), 0644); err != nil {
klog.Errorf("failed to write %s to %s; err: %v", pciAddress, bindFile, err)
// attempt to revert driver_override
_ = os.WriteFile(driverOverrideFile, []byte("\n"), 0644)
return fmt.Errorf("failed to write to bind file: %v", err)
}
return nil
}
Expand Down
Loading