Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions cmd/driver-manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ type config struct {
gpuDirectRDMAEnabled bool
useHostMofed bool
kubeconfig string
driverVersion string
forceReinstall bool
}

// ComponentState tracks the deployment state of GPU operator components
Expand Down Expand Up @@ -208,6 +210,20 @@ func main() {
EnvVars: []string{"KUBECONFIG"},
Value: "",
},
&cli.StringFlag{
Name: "driver-version",
Usage: "Desired NVIDIA driver version",
Destination: &cfg.driverVersion,
EnvVars: []string{"DRIVER_VERSION"},
Value: "",
},
&cli.BoolFlag{
Name: "force-reinstall",
Usage: "Force driver reinstall regardless of current state",
Destination: &cfg.forceReinstall,
EnvVars: []string{"FORCE_REINSTALL"},
Value: false,
},
}

app.Commands = []*cli.Command{
Expand Down Expand Up @@ -288,6 +304,14 @@ func (dm *DriverManager) uninstallDriver() error {
return fmt.Errorf("failed to evict GPU operator components: %w", err)
}

if skip, reason := dm.shouldSkipUninstall(); skip {
dm.log.Infof("Skipping driver uninstall: %s", reason)
if err := dm.rescheduleGPUOperatorComponents(); err != nil {
dm.log.Warnf("Failed to reschedule GPU operator components: %v", err)
}
return nil
}

drainOpts := kube.DrainOptions{
Force: dm.config.drainUseForce,
DeleteEmptyDirData: dm.config.drainDeleteEmptyDirData,
Expand Down Expand Up @@ -629,6 +653,70 @@ func (dm *DriverManager) isDriverLoaded() bool {
return err == nil
}

func (dm *DriverManager) shouldSkipUninstall() (bool, string) {
if dm.config.forceReinstall {
dm.log.Info("Force reinstall is enabled, proceeding with driver uninstall")
return false, ""
}

if !dm.isDriverLoaded() {
return false, ""
}

if dm.config.driverVersion == "" {
return false, "Driver version environment variable is not set"
}

version, err := dm.detectCurrentDriverVersion()
if err != nil {
dm.log.Warnf("Unable to determine installed driver version: %v", err)
// If driver is loaded but we can't detect version, proceed with reinstall to ensure correct version
dm.log.Info("Cannot verify driver version, proceeding with reinstall to ensure correct version is installed")
return false, ""
}

if version != dm.config.driverVersion {
dm.log.Infof("Installed driver version %s does not match desired %s, proceeding with uninstall", version, dm.config.driverVersion)
return false, ""
}

dm.log.Infof("Installed driver version %s matches desired version, skipping uninstall", version)
return true, "desired version already present"
}

func (dm *DriverManager) detectCurrentDriverVersion() (string, error) {
baseCtx := dm.ctx
if baseCtx == nil {
baseCtx = context.Background()
}

ctx, cancel := context.WithTimeout(baseCtx, 10*time.Second)
defer cancel()

// Try chroot to /run/nvidia/driver for containerized driver
cmd := exec.CommandContext(ctx, "chroot", "/run/nvidia/driver", "modinfo", "-F", "version", "nvidia")
cmd.Env = append(os.Environ(), "LC_ALL=C")
cmdOutput, chrootErr := cmd.Output()
if chrootErr == nil {
version := strings.TrimSpace(string(cmdOutput))
if version != "" {
dm.log.Infof("Driver version detected via chroot: %s", version)
return version, nil
}
}

// Second try to read from /sys/module/nvidia/version if available
if versionData, err := os.ReadFile("/sys/module/nvidia/version"); err == nil {
version := strings.TrimSpace(string(versionData))
if version != "" {
dm.log.Infof("Driver version detected from /sys/module/nvidia/version: %s", version)
return version, nil
}
}

return "", fmt.Errorf("all version detection methods failed: chroot: %v", chrootErr)
}

func (dm *DriverManager) isNouveauLoaded() bool {
_, err := os.Stat("/sys/module/nouveau/refcnt")
return err == nil
Expand Down