Skip to content

Commit 1991b8c

Browse files
Add shouldSkipUninstall to avoid GPU driver teardown on restart
Signed-off-by: Karthik Vetrivel <[email protected]>
1 parent acfb2a8 commit 1991b8c

File tree

1 file changed

+85
-0
lines changed

1 file changed

+85
-0
lines changed

cmd/driver-manager/main.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ type config struct {
7777
gpuDirectRDMAEnabled bool
7878
useHostMofed bool
7979
kubeconfig string
80+
driverVersion string
81+
forceReinstall bool
8082
}
8183

8284
// ComponentState tracks the deployment state of GPU operator components
@@ -208,6 +210,20 @@ func main() {
208210
EnvVars: []string{"KUBECONFIG"},
209211
Value: "",
210212
},
213+
&cli.StringFlag{
214+
Name: "driver-version",
215+
Usage: "Desired NVIDIA driver version",
216+
Destination: &cfg.driverVersion,
217+
EnvVars: []string{"DRIVER_VERSION"},
218+
Value: "",
219+
},
220+
&cli.BoolFlag{
221+
Name: "force-reinstall",
222+
Usage: "Force driver reinstall regardless of current state",
223+
Destination: &cfg.forceReinstall,
224+
EnvVars: []string{"FORCE_REINSTALL"},
225+
Value: false,
226+
},
211227
}
212228

213229
app.Commands = []*cli.Command{
@@ -271,6 +287,11 @@ func (dm *DriverManager) uninstallDriver() error {
271287
return fmt.Errorf("driver is pre-installed on host")
272288
}
273289

290+
if skip, reason := dm.shouldSkipUninstall(); skip {
291+
dm.log.Infof("Skipping driver uninstall: %s", reason)
292+
return nil
293+
}
294+
274295
// Fetch current component states
275296
if err := dm.fetchCurrentLabels(); err != nil {
276297
return fmt.Errorf("failed to fetch current labels: %w", err)
@@ -629,6 +650,70 @@ func (dm *DriverManager) isDriverLoaded() bool {
629650
return err == nil
630651
}
631652

653+
func (dm *DriverManager) shouldSkipUninstall() (bool, string) {
654+
if dm.config.forceReinstall {
655+
dm.log.Info("Force reinstall is enabled, proceeding with driver uninstall")
656+
return false, ""
657+
}
658+
659+
if !dm.isDriverLoaded() {
660+
return false, ""
661+
}
662+
663+
if dm.config.driverVersion == "" {
664+
return false, ""
665+
}
666+
667+
version, err := dm.detectCurrentDriverVersion()
668+
if err != nil {
669+
dm.log.Warnf("Unable to determine installed driver version: %v", err)
670+
// If driver is loaded but we can't detect version, proceed with reinstall to ensure correct version
671+
dm.log.Info("Cannot verify driver version, proceeding with reinstall to ensure correct version is installed")
672+
return false, ""
673+
}
674+
675+
if version != dm.config.driverVersion {
676+
dm.log.Infof("Installed driver version %s does not match desired %s, proceeding with uninstall", version, dm.config.driverVersion)
677+
return false, ""
678+
}
679+
680+
dm.log.Infof("Installed driver version %s matches desired version, skipping uninstall", version)
681+
return true, "desired version already present"
682+
}
683+
684+
func (dm *DriverManager) detectCurrentDriverVersion() (string, error) {
685+
baseCtx := dm.ctx
686+
if baseCtx == nil {
687+
baseCtx = context.Background()
688+
}
689+
690+
ctx, cancel := context.WithTimeout(baseCtx, 10*time.Second)
691+
defer cancel()
692+
693+
// Try chroot to /run/nvidia/driver for containerized driver
694+
cmd := exec.CommandContext(ctx, "chroot", "/run/nvidia/driver", "modinfo", "-F", "version", "nvidia")
695+
cmd.Env = append(os.Environ(), "LC_ALL=C")
696+
cmdOutput, chrootErr := cmd.Output()
697+
if chrootErr == nil {
698+
version := strings.TrimSpace(string(cmdOutput))
699+
if version != "" {
700+
dm.log.Infof("Driver version detected via chroot: %s", version)
701+
return version, nil
702+
}
703+
}
704+
705+
// Second try to read from /sys/module/nvidia/version if available
706+
if versionData, err := os.ReadFile("/sys/module/nvidia/version"); err == nil {
707+
version := strings.TrimSpace(string(versionData))
708+
if version != "" {
709+
dm.log.Infof("Driver version detected from /sys/module/nvidia/version: %s", version)
710+
return version, nil
711+
}
712+
}
713+
714+
return "", fmt.Errorf("all version detection methods failed: chroot: %v", chrootErr)
715+
}
716+
632717
func (dm *DriverManager) isNouveauLoaded() bool {
633718
_, err := os.Stat("/sys/module/nouveau/refcnt")
634719
return err == nil

0 commit comments

Comments
 (0)