Skip to content

Commit bcd7bb5

Browse files
committed
remove bash requirements from final image
rewrite bash scripts with golang Signed-off-by: Léiyì Zhang <leiyiz@google.com>
1 parent 8ad5c2b commit bcd7bb5

File tree

7 files changed

+397
-78
lines changed

7 files changed

+397
-78
lines changed

cmd/compute-domain-kubelet-plugin/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ func (c Config) DriverPluginPath() string {
6969
}
7070

7171
func main() {
72+
if err := common.MaskNvidiaDriverParams(); err != nil {
73+
fmt.Fprintf(os.Stderr, "Error masking NVIDIA driver params: %v\n", err)
74+
}
75+
7276
if err := newApp().Run(os.Args); err != nil {
7377
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
7478
os.Exit(1)

cmd/gpu-kubelet-plugin/main.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ func (c Config) DriverPluginPath() string {
6969
}
7070

7171
func main() {
72+
if err := common.MaskNvidiaDriverParams(); err != nil {
73+
fmt.Fprintf(os.Stderr, "Error masking NVIDIA driver params: %v\n", err)
74+
}
75+
76+
if len(os.Args) > 1 && os.Args[1] == "prestart-init" {
77+
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGTERM)
78+
defer cancel()
79+
80+
if err := runPrestartInit(ctx); err != nil {
81+
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
82+
os.Exit(1)
83+
}
84+
os.Exit(0)
85+
}
86+
7287
if err := newApp().Run(os.Args); err != nil {
7388
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
7489
os.Exit(1)

cmd/gpu-kubelet-plugin/prestart.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright (c) 2026 NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package main
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"os"
23+
"os/exec"
24+
"path/filepath"
25+
"time"
26+
27+
"k8s.io/klog/v2"
28+
)
29+
30+
// Main intent: help users to self-troubleshoot when the GPU driver is not set up
31+
// properly before installing this DRA driver. In that case, the log of the init
32+
// container running this script is meant to yield an actionable error message.
33+
// For now, rely on k8s to implement a high-level retry with back-off.
34+
func runPrestartInit(ctx context.Context) error {
35+
// Design goal: long-running init container that retries at constant frequency,
36+
// and leaves only upon success (with code 0).
37+
waitS := 10 * time.Second
38+
attempt := 0
39+
40+
nvidiaDriverRoot := os.Getenv("NVIDIA_DRIVER_ROOT")
41+
if nvidiaDriverRoot == "" {
42+
// Not set, or set to empty string (not distinguishable).
43+
// Normalize to "/" (treated as such elsewhere).
44+
nvidiaDriverRoot = "/"
45+
}
46+
47+
driverRootParent := "/driver-root-parent"
48+
// Remove trailing slash (if existing) and get last path element.
49+
driverRootPath := filepath.Join(driverRootParent, filepath.Base(nvidiaDriverRoot))
50+
51+
// Ensure the /driver-root-parent directory exists
52+
if err := os.MkdirAll(driverRootParent, 0755); err != nil {
53+
klog.Warningf("Failed to create %s: %v", driverRootParent, err)
54+
}
55+
56+
// Create in-container path /driver-root as a symlink. Expectation: link may be
57+
// broken initially (e.g., if the GPU operator isn't deployed yet. The link heals
58+
// once the driver becomes mounted (e.g., once GPU operator provides the driver
59+
// on the host at /run/nvidia/driver).
60+
fmt.Printf("create symlink: /driver-root -> %s\n", driverRootPath)
61+
_ = os.Remove("/driver-root")
62+
if err := os.Symlink(driverRootPath, "/driver-root"); err != nil {
63+
klog.Warningf("Failed to create symlink: %v", err)
64+
}
65+
66+
for {
67+
if validateAndExitOnSuccess(ctx, nvidiaDriverRoot, attempt) {
68+
return nil
69+
}
70+
71+
select {
72+
case <-ctx.Done():
73+
// DS pods may get deleted (terminated with SIGTERM) and re-created when the GPU
74+
// Operator driver container creates a mount at /run/nvidia. Make that explicit.
75+
fmt.Printf("%s: received SIGTERM\n", time.Now().UTC().Format("2006-01-02T15:04:05.000Z"))
76+
return nil
77+
case <-time.After(waitS):
78+
attempt++
79+
}
80+
}
81+
}
82+
83+
func emitCommonErr(nvidiaDriverRoot string) {
84+
fmt.Printf("Check failed. Has the NVIDIA GPU driver been set up? "+
85+
"It is expected to be installed under "+
86+
"NVIDIA_DRIVER_ROOT (currently set to '%s') "+
87+
"in the host filesystem. If that path appears to be unexpected: "+
88+
"review the DRA driver's 'nvidiaDriverRoot' Helm chart variable. "+
89+
"Otherwise, review if the GPU driver has "+
90+
"actually been installed under that path.\n", nvidiaDriverRoot)
91+
}
92+
93+
func validateAndExitOnSuccess(ctx context.Context, nvidiaDriverRoot string, attempt int) bool {
94+
fmt.Printf("%s /driver-root (%s on host): ", time.Now().UTC().Format("2006-01-02T15:04:05Z"), nvidiaDriverRoot)
95+
96+
// Search specific set of directories (not recursively: not required, and
97+
// /driver-root may be a big tree). Limit to first result (multiple results
98+
// are a bit of a pathological state, but continue with validation logic).
99+
nvPath := findFirstFile(
100+
"nvidia-smi",
101+
"/driver-root/opt/bin",
102+
"/driver-root/usr/bin",
103+
"/driver-root/usr/sbin",
104+
"/driver-root/bin",
105+
"/driver-root/sbin",
106+
)
107+
108+
// Follow symlinks (-L), because `libnvidia-ml.so.1` is typically a link.
109+
nvLibPath := findFirstFile(
110+
"libnvidia-ml.so.1",
111+
"/driver-root/usr/lib64",
112+
"/driver-root/usr/lib/x86_64-linux-gnu",
113+
"/driver-root/usr/lib/aarch64-linux-gnu",
114+
"/driver-root/lib64",
115+
"/driver-root/lib/x86_64-linux-gnu",
116+
"/driver-root/lib/aarch64-linux-gnu",
117+
)
118+
119+
if nvPath == "" {
120+
fmt.Printf("nvidia-smi: not found, ")
121+
} else {
122+
fmt.Printf("nvidia-smi: '%s', ", nvPath)
123+
}
124+
125+
if nvLibPath == "" {
126+
fmt.Printf("libnvidia-ml.so.1: not found, ")
127+
} else {
128+
fmt.Printf("libnvidia-ml.so.1: '%s', ", nvLibPath)
129+
}
130+
131+
// Log top-level entries in /driver-root (this may be valuable debug info).
132+
entries, _ := os.ReadDir("/driver-root")
133+
var entryNames string
134+
for i, e := range entries {
135+
if i > 0 {
136+
entryNames += " "
137+
}
138+
entryNames += e.Name()
139+
}
140+
fmt.Printf("current contents: [%s].\n", entryNames)
141+
142+
if nvPath != "" && nvLibPath != "" {
143+
// Run with clean environment (only LD_PRELOAD; nvidia-smi has only this
144+
// dependency). Emit message before invocation (nvidia-smi may be slow or
145+
// hang).
146+
fmt.Printf("invoke: env -i LD_PRELOAD=%s %s\n", nvLibPath, nvPath)
147+
148+
cmd := exec.CommandContext(ctx, nvPath)
149+
cmd.Env = []string{"LD_PRELOAD=" + nvLibPath}
150+
cmd.Stdout = os.Stdout
151+
cmd.Stderr = os.Stderr
152+
153+
err := cmd.Run()
154+
// For checking GPU driver health: rely on nvidia-smi's exit code. Rely
155+
// on code 0 signaling that the driver is properly set up. See section
156+
// 'RETURN VALUE' in the nvidia-smi man page for meaning of error codes.
157+
if err == nil {
158+
fmt.Printf("nvidia-smi returned with code 0: success, leave\n")
159+
return true
160+
}
161+
162+
if exitErr, ok := err.(*exec.ExitError); ok {
163+
fmt.Printf("exit code: %d\n", exitErr.ExitCode())
164+
} else {
165+
fmt.Printf("execution failed: %v, exit code: -1\n", err)
166+
}
167+
}
168+
169+
// Reduce log volume: log hints only every Nth attempt.
170+
if attempt%6 != 0 {
171+
return false
172+
}
173+
174+
// nvidia-smi binaries not found, or execution failed. First, provide generic
175+
// error message. Then, try to provide actionable hints for common problems.
176+
fmt.Println()
177+
emitCommonErr(nvidiaDriverRoot)
178+
179+
// For host-provided driver not at / provide feedback for two special cases.
180+
if nvidiaDriverRoot != "/" {
181+
if len(entries) == 0 {
182+
fmt.Printf("Hint: Directory %s on the host is empty\n", nvidiaDriverRoot)
183+
} else {
184+
// Not empty, but at least one of the binaries not found: this is a
185+
// rather pathological state.
186+
if nvPath == "" || nvLibPath == "" {
187+
fmt.Printf("Hint: Directory %s is not empty but at least one of the binaries wasn't found.\n", nvidiaDriverRoot)
188+
}
189+
}
190+
}
191+
192+
// Common mistake: driver container, but forgot `--set nvidiaDriverRoot`
193+
if nvidiaDriverRoot == "/" {
194+
if _, err := os.Stat("/driver-root/run/nvidia/driver/usr/bin/nvidia-smi"); err == nil {
195+
fmt.Printf("Hint: '/run/nvidia/driver/usr/bin/nvidia-smi' exists on the host, you " +
196+
"may want to re-install the DRA driver Helm chart with " +
197+
"--set nvidiaDriverRoot=/run/nvidia/driver\n")
198+
}
199+
}
200+
201+
if nvidiaDriverRoot == "/run/nvidia/driver" {
202+
fmt.Printf("Hint: NVIDIA_DRIVER_ROOT is set to '/run/nvidia/driver' " +
203+
"which typically means that the NVIDIA GPU Operator " +
204+
"manages the GPU driver. Make sure that the GPU Operator " +
205+
"is deployed and healthy.\n")
206+
}
207+
fmt.Println()
208+
209+
return false
210+
}
211+
212+
func findFirstFile(filename string, dirs ...string) string {
213+
for _, dir := range dirs {
214+
path := filepath.Join(dir, filename)
215+
if info, err := os.Stat(path); err == nil && !info.IsDir() {
216+
return path
217+
}
218+
}
219+
return ""
220+
}

cmd/gpu-kubelet-plugin/vfio-device.go

Lines changed: 90 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"os"
2424
"os/exec"
2525
"path/filepath"
26+
"strings"
2627
"time"
2728

2829
"k8s.io/klog/v2"
@@ -31,19 +32,17 @@ import (
3132
)
3233

3334
const (
34-
kernelIommuGroupPath = "/sys/kernel/iommu_groups"
35-
vfioPciModule = "vfio_pci"
36-
vfioPciDriver = "vfio-pci"
37-
nvidiaDriver = "nvidia"
38-
hostRoot = "/host-root"
39-
sysModulesRoot = "/sys/module"
40-
pciDevicesRoot = "/sys/bus/pci/devices"
41-
vfioDevicesRoot = "/dev/vfio"
42-
unbindFromDriverScript = "/usr/bin/unbind_from_driver.sh"
43-
bindToDriverScript = "/usr/bin/bind_to_driver.sh"
44-
driverResetRetries = "5"
45-
gpuFreeCheckInterval = 1 * time.Second
46-
gpuFreeCheckTimeout = 60 * time.Second
35+
kernelIommuGroupPath = "/sys/kernel/iommu_groups"
36+
vfioPciModule = "vfio_pci"
37+
vfioPciDriver = "vfio-pci"
38+
nvidiaDriver = "nvidia"
39+
hostRoot = "/host-root"
40+
sysModulesRoot = "/sys/module"
41+
pciDevicesRoot = "/sys/bus/pci/devices"
42+
vfioDevicesRoot = "/dev/vfio"
43+
driverResetRetries = "5"
44+
gpuFreeCheckInterval = 1 * time.Second
45+
gpuFreeCheckTimeout = 60 * time.Second
4746
)
4847

4948
type VfioPciManager struct {
@@ -248,20 +247,92 @@ func (vm *VfioPciManager) changeDriver(pciAddress, driver string) error {
248247
return nil
249248
}
250249

250+
func (vm *VfioPciManager) acquireUnbindLock(gpu string) error {
251+
lockRetries := 5
252+
unbindLockFile := filepath.Join("/proc/driver/nvidia/gpus", gpu, "unbindLock")
253+
254+
if _, err := os.Stat(unbindLockFile); err != nil {
255+
// If the lock file doesn't exist, we assume no lock is needed.
256+
return nil
257+
}
258+
259+
for attempt := 1; attempt <= lockRetries; attempt++ {
260+
klog.Infof("[retry %d/%d] Attempting to acquire unbindLock for %s", attempt, lockRetries, gpu)
261+
262+
// Try to write 1 to acquire the lock
263+
err := os.WriteFile(unbindLockFile, []byte("1\n"), 0200)
264+
if err != nil {
265+
klog.Warningf("failed to write to unbindLock file %s: %v", unbindLockFile, err)
266+
}
267+
268+
// Read the lock file to verify
269+
content, err := os.ReadFile(unbindLockFile)
270+
if err == nil {
271+
val := strings.TrimSpace(string(content))
272+
if val == "1" {
273+
klog.Infof("UnbindLock acquired for %s", gpu)
274+
return nil
275+
}
276+
}
277+
278+
time.Sleep(time.Duration(attempt) * time.Second)
279+
}
280+
281+
return fmt.Errorf("cannot obtain unbindLock for %s", gpu)
282+
}
283+
251284
func (vm *VfioPciManager) unbindFromDriver(pciAddress string) error {
252-
out, err := execCommand(unbindFromDriverScript, []string{pciAddress}) //nolint:gosec
285+
driverPath := filepath.Join(pciDevicesRoot, pciAddress, "driver")
286+
if _, err := os.Stat(driverPath); err != nil {
287+
// Not bound to any driver
288+
return nil
289+
}
290+
291+
existingDriver, err := filepath.EvalSymlinks(driverPath)
253292
if err != nil {
254-
klog.Errorf("Attempting to unbind %s from its driver failed; stdout: %s, err: %v", pciAddress, string(out), err)
293+
return fmt.Errorf("failed to resolve driver symlink for %s: %v", pciAddress, err)
294+
}
295+
296+
existingDriverName := filepath.Base(existingDriver)
297+
if existingDriverName == "nvidia" {
298+
if err := vm.acquireUnbindLock(pciAddress); err != nil {
299+
return err
300+
}
301+
}
302+
303+
unbindFile := filepath.Join(existingDriver, "unbind")
304+
if err := os.WriteFile(unbindFile, []byte(pciAddress+"\n"), 0200); err != nil {
305+
klog.Errorf("Attempting to unbind %s from its driver failed; err: %v", pciAddress, err)
255306
return err
256307
}
257308
return nil
258309
}
259310

260311
func (vm *VfioPciManager) bindToDriver(pciAddress, driver string) error {
261-
out, err := execCommand(bindToDriverScript, []string{pciAddress, driver}) //nolint:gosec
262-
if err != nil {
263-
klog.Errorf("Attempting to bind %s to %s driver failed; stdout: %s, err: %v", pciAddress, driver, string(out), err)
264-
return err
312+
driversPath := "/sys/bus/pci/drivers"
313+
driverOverrideFile := filepath.Join(pciDevicesRoot, pciAddress, "driver_override")
314+
bindFile := filepath.Join(driversPath, driver, "bind")
315+
316+
if _, err := os.Stat(driverOverrideFile); err != nil {
317+
klog.Errorf("'%s' file does not exist", driverOverrideFile)
318+
return fmt.Errorf("driver_override file not found: %v", err)
319+
}
320+
321+
if err := os.WriteFile(driverOverrideFile, []byte(driver+"\n"), 0200); err != nil {
322+
klog.Errorf("failed to write '%s' to %s", driver, driverOverrideFile)
323+
return fmt.Errorf("failed to write to driver_override: %v", err)
324+
}
325+
326+
if _, err := os.Stat(bindFile); err != nil {
327+
klog.Errorf("'%s' file does not exist", bindFile)
328+
return fmt.Errorf("bind file not found: %v", err)
329+
}
330+
331+
if err := os.WriteFile(bindFile, []byte(pciAddress+"\n"), 0200); err != nil {
332+
klog.Errorf("Attempting to bind %s to %s driver failed; err: %v", pciAddress, driver, err)
333+
// attempt to revert driver_override
334+
_ = os.WriteFile(driverOverrideFile, []byte("\n"), 0200)
335+
return fmt.Errorf("failed to write to bind file: %v", err)
265336
}
266337
return nil
267338
}

0 commit comments

Comments
 (0)