@@ -219,6 +219,8 @@ const (
219219 wslNvidiaSMIPath = "/usr/lib/wsl/lib/nvidia-smi"
220220 // shell indicates what shell to use when invoking commands in a subprocess
221221 shell = "sh"
222+ // path where host is mounted
223+ hostMountPath = "/host"
222224)
223225
224226func main () {
@@ -608,6 +610,18 @@ func runCommand(command string, args []string, silent bool) error {
608610 return cmd .Run ()
609611}
610612
613+ func getHostCommandPath (command string ) (string , error ) {
614+ args := []string {hostMountPath , "/bin/sh" , "-l" , "-c" , fmt .Sprintf ("which %s" , command )}
615+ cmd := exec .Command ("chroot" , args ... )
616+
617+ path , err := cmd .Output ()
618+ if err != nil {
619+ return "" , err
620+ }
621+
622+ return string (path ), nil
623+ }
624+
611625func runCommandWithWait (command string , args []string , sleepSeconds int , silent bool ) error {
612626 for {
613627 cmd := exec .Command (command , args ... )
@@ -698,20 +712,26 @@ func isDriverManagedByOperator(ctx context.Context) (bool, error) {
698712
699713func validateHostDriver (silent bool ) error {
700714 log .Info ("Attempting to validate a pre-installed driver on the host" )
701- if fileInfo , err := os .Lstat (filepath .Join ("/host" , wslNvidiaSMIPath )); err == nil && fileInfo .Size () != 0 {
715+ if fileInfo , err := os .Lstat (filepath .Join (hostMountPath , wslNvidiaSMIPath )); err == nil && fileInfo .Size () != 0 {
702716 log .Infof ("WSL2 system detected, assuming driver is pre-installed" )
703717 disableDevCharSymlinkCreation = true
704718 return nil
705719 }
706- fileInfo , err := os .Lstat ("/host/usr/bin/nvidia-smi" )
720+
721+ nvidiaSMIPath , err := getHostCommandPath ("nvidia-smi" )
707722 if err != nil {
708- return fmt .Errorf ("no 'nvidia-smi' file present on the host: %w" , err )
723+ return fmt .Errorf ("no 'nvidia-smi' executable present on the host $PATH: %w" , err )
724+ }
725+
726+ fileInfo , err := os .Lstat (nvidiaSMIPath )
727+ if err != nil {
728+ return fmt .Errorf ("failed to stat 'nvidia-smi' path on the host: %w" , err )
709729 }
710730 if fileInfo .Size () == 0 {
711731 return fmt .Errorf ("empty 'nvidia-smi' file found on the host" )
712732 }
713733 command := "chroot"
714- args := []string {"/host" , "nvidia-smi" }
734+ args := []string {hostMountPath , nvidiaSMIPath }
715735
716736 return runCommand (command , args , silent )
717737}
@@ -770,7 +790,7 @@ func (d *Driver) runValidation(silent bool) (driverInfo, error) {
770790 err := validateHostDriver (silent )
771791 if err == nil {
772792 log .Info ("Detected a pre-installed driver on the host" )
773- return getDriverInfo (true , hostRootFlag , hostRootFlag , "/host" ), nil
793+ return getDriverInfo (true , hostRootFlag , hostRootFlag , hostMountPath ), nil
774794 }
775795
776796 err = validateDriverContainer (silent , d .ctx )
@@ -848,7 +868,7 @@ func createDevCharSymlinks(driverInfo driverInfo, disableDevCharSymlinkCreation
848868 // either '/host' or '/driver-root', both paths would exist in the validation container.
849869 driverRootCtrPath := driverInstallDirCtrPathFlag
850870 if driverInfo .isHostDriver {
851- driverRootCtrPath = "/host"
871+ driverRootCtrPath = hostMountPath
852872 }
853873
854874 // We now create the symlinks in /dev/char.
@@ -1560,8 +1580,9 @@ func (v *VGPUManager) runValidation(silent bool) (hostDriver bool, err error) {
15601580 args := []string {"/run/nvidia/driver" , "nvidia-smi" }
15611581
15621582 // check if driver is pre-installed on the host and use host path for validation
1563- if _ , err := os .Lstat ("/host/usr/bin/nvidia-smi" ); err == nil {
1564- args = []string {"/host" , "nvidia-smi" }
1583+ nvidiaSMIPath , err := getHostCommandPath ("nvidia-smi" )
1584+ if err == nil {
1585+ args = []string {hostMountPath , nvidiaSMIPath }
15651586 hostDriver = true
15661587 }
15671588
0 commit comments