diff --git a/cmd/vfio-manage/bind.go b/cmd/vfio-manage/bind.go index aaf1ec4b..2ab1e54c 100644 --- a/cmd/vfio-manage/bind.go +++ b/cmd/vfio-manage/bind.go @@ -34,9 +34,10 @@ type bindCommand struct { } type bindOptions struct { - all bool - deviceID string - hostRoot string + all bool + deviceID string + hostRoot string + bindNVSwitches bool } // newBindCommand constructs a bind command with the specified logger @@ -78,6 +79,12 @@ func (m bindCommand) build() *cli.Command { Value: "/", Usage: "Path to the host's root filesystem. This is used when loading the vfio-pci module.", }, + &cli.BoolFlag{ + Name: "bind-nvswitches", + Destination: &m.options.bindNVSwitches, + EnvVars: []string{"BIND_NVSWITCHES"}, + Usage: "Also bind NVSwitches to vfio-pci (default: false)", + }, }, } @@ -115,6 +122,14 @@ func (m bindCommand) bindAll() error { return fmt.Errorf("failed to get NVIDIA GPUs: %w", err) } + if m.options.bindNVSwitches { + nvswitches, err := m.nvpciLib.GetNVSwitches() + if err != nil { + return fmt.Errorf("failed to get NVIDIA NVSwitches: %w", err) + } + devices = append(devices, nvswitches...) + } + for _, dev := range devices { m.logger.Infof("Binding device %s", dev.Address) // (cdesiniotis) ideally this should be replaced by a call to nvdev.BindToVFIODriver() @@ -128,12 +143,22 @@ func (m bindCommand) bindAll() error { func (m bindCommand) bindDevice() error { device := m.options.deviceID + // Note: Despite its name, GetGPUByPciBusID returns any NVIDIA PCI device + // (GPU, NVSwitch, etc.) at the specified address, not just GPUs. nvdev, err := m.nvpciLib.GetGPUByPciBusID(device) if err != nil { - return fmt.Errorf("failed to get NVIDIA GPU device: %w", err) + return fmt.Errorf("failed to get NVIDIA device: %w", err) + } + if nvdev == nil { + m.logger.Infof("Device %s is not an NVIDIA device", device) + return nil + } + if nvdev.IsNVSwitch() && !m.options.bindNVSwitches { + m.logger.Infof("Skipping NVSwitch %s (BIND_NVSWITCHES not set)", device) + return nil } - if nvdev == nil || !nvdev.IsGPU() { - m.logger.Infof("Device %s is not a GPU", device) + if !nvdev.IsGPU() && !nvdev.IsNVSwitch() { + m.logger.Infof("Device %s is not an NVIDIA GPU or NVSwitch", device) return nil } diff --git a/cmd/vfio-manage/unbind.go b/cmd/vfio-manage/unbind.go index 8eeadc50..22ee8c22 100644 --- a/cmd/vfio-manage/unbind.go +++ b/cmd/vfio-manage/unbind.go @@ -34,8 +34,9 @@ type unbindCommand struct { } type unbindOptions struct { - all bool - deviceID string + all bool + deviceID string + unbindNVSwitches bool } // newUnbindCommand constructs an unbind command with the specified logger @@ -71,7 +72,13 @@ func (m unbindCommand) build() *cli.Command { Name: "device-id", Aliases: []string{"d"}, Destination: &m.options.deviceID, - Usage: "Specific device ID to bind (e.g., 0000:01:00.0)", + Usage: "Specific device ID to unbind (e.g., 0000:01:00.0)", + }, + &cli.BoolFlag{ + Name: "unbind-nvswitches", + Destination: &m.options.unbindNVSwitches, + EnvVars: []string{"BIND_NVSWITCHES"}, + Usage: "Also unbind NVSwitches from their driver (default: false)", }, }, } @@ -105,6 +112,14 @@ func (m unbindCommand) unbindAll() error { return fmt.Errorf("failed to get NVIDIA GPUs: %w", err) } + if m.options.unbindNVSwitches { + nvswitches, err := m.nvpciLib.GetNVSwitches() + if err != nil { + return fmt.Errorf("failed to get NVIDIA NVSwitches: %w", err) + } + devices = append(devices, nvswitches...) + } + for _, dev := range devices { m.logger.Infof("Unbinding device %s", dev.Address) // (cdesiniotis) ideally this should be replaced by a call to nvdev.UnbindFromDriver() @@ -117,12 +132,22 @@ func (m unbindCommand) unbindAll() error { func (m unbindCommand) unbindDevice() error { device := m.options.deviceID + // Note: Despite its name, GetGPUByPciBusID returns any NVIDIA PCI device + // (GPU, NVSwitch, etc.) at the specified address, not just GPUs. nvdev, err := m.nvpciLib.GetGPUByPciBusID(device) if err != nil { - return fmt.Errorf("failed to get NVIDIA GPU device: %w", err) + return fmt.Errorf("failed to get NVIDIA device: %w", err) + } + if nvdev == nil { + m.logger.Infof("Device %s is not an NVIDIA device", device) + return nil + } + if nvdev.IsNVSwitch() && !m.options.unbindNVSwitches { + m.logger.Infof("Skipping NVSwitch %s (BIND_NVSWITCHES not set)", device) + return nil } - if nvdev == nil || !nvdev.IsGPU() { - m.logger.Infof("Device %s is not a GPU", device) + if !nvdev.IsGPU() && !nvdev.IsNVSwitch() { + m.logger.Infof("Device %s is not an NVIDIA GPU or NVSwitch", device) return nil } diff --git a/scripts/vfio-manage b/scripts/vfio-manage index 01f0912d..ca36f3ff 100755 --- a/scripts/vfio-manage +++ b/scripts/vfio-manage @@ -45,14 +45,29 @@ unbind_from_other_driver() { } is_nvidia_gpu_device() { - gpu=$1 - # make sure device class is for NVIDIA GPU - device_class_file=$(readlink -f "/sys/bus/pci/devices/$gpu/class") + dev=$1 + # make sure device class is for NVIDIA GPU (3D controller or VGA compatible) + device_class_file=$(readlink -f "/sys/bus/pci/devices/$dev/class") device_class=$(cat "$device_class_file") [ "$device_class" = "0x030200" ] || [ "$device_class" = "0x030000" ] || return 1 return 0 } +is_nvidia_nvswitch_device() { + dev=$1 + # make sure device class is for NVIDIA NVSwitch (bridge device) + device_class_file=$(readlink -f "/sys/bus/pci/devices/$dev/class") + device_class=$(cat "$device_class_file") + [ "$device_class" = "0x068000" ] || return 1 + return 0 +} + +is_nvidia_device() { + dev=$1 + # check if device is either a GPU or NVSwitch + is_nvidia_gpu_device "$dev" || is_nvidia_nvswitch_device "$dev" +} + is_bound_to_vfio() { gpu=$1 @@ -71,19 +86,21 @@ is_bound_to_vfio() { } unbind_device() { - gpu=$1 + dev=$1 - if ! is_nvidia_gpu_device "$gpu"; then + if ! is_nvidia_device "$dev"; then return 0 fi - echo "unbinding device $gpu" - unbind_from_driver "$gpu" - #for graphics mode, we need to unbind the auxiliary device as well - aux_dev=$(get_graphics_aux_dev "$gpu") - if [ "$aux_dev" != "NONE" ]; then - echo "gpu $gpu is in graphics mode aux_dev $aux_dev" - unbind_from_driver "$aux_dev" + echo "unbinding device $dev" + unbind_from_driver "$dev" + # for graphics mode GPUs, we need to unbind the auxiliary device as well + if is_nvidia_gpu_device "$dev"; then + aux_dev=$(get_graphics_aux_dev "$dev") + if [ "$aux_dev" != "NONE" ]; then + echo "gpu $dev is in graphics mode aux_dev $aux_dev" + unbind_from_driver "$aux_dev" + fi fi } @@ -136,19 +153,21 @@ get_graphics_aux_dev() { } bind_device() { - gpu=$1 + dev=$1 - if ! is_nvidia_gpu_device "$gpu"; then - echo "device $gpu is not a gpu!" + if ! is_nvidia_device "$dev"; then + echo "device $dev is not an NVIDIA GPU or NVSwitch!" return 0 fi - bind_pci_device "$gpu" - #for graphics mode, we need to bind the auxiliary device as well - aux_dev=$(get_graphics_aux_dev "$gpu") - if [ "$aux_dev" != "NONE" ]; then - echo "gpu $gpu is in graphics mode aux_dev $aux_dev" - bind_pci_device "$aux_dev" + bind_pci_device "$dev" + # for graphics mode GPUs, we need to bind the auxiliary device as well + if is_nvidia_gpu_device "$dev"; then + aux_dev=$(get_graphics_aux_dev "$dev") + if [ "$aux_dev" != "NONE" ]; then + echo "gpu $dev is in graphics mode aux_dev $aux_dev" + bind_pci_device "$aux_dev" + fi fi }