diff --git a/pkg/sentry/devices/nvproxy/nvproxy_test.go b/pkg/sentry/devices/nvproxy/nvproxy_test.go index c339cab7cd..edc1f9b251 100644 --- a/pkg/sentry/devices/nvproxy/nvproxy_test.go +++ b/pkg/sentry/devices/nvproxy/nvproxy_test.go @@ -38,8 +38,11 @@ func TestInit(t *testing.T) { func TestAllSupportedHashesPresent(t *testing.T) { Init() for version, abi := range abis { - if abi.checksum == "" { - t.Errorf("unexpected empty value for driver %q", version.String()) + if abi.checksumX86_64 == "" { + t.Errorf("unexpected empty checksumX86_64 value for driver %q", version.String()) + } + if abi.checksumARM64 == "" { + t.Errorf("unexpected empty checksumARM64 value for driver %q", version.String()) } } } diff --git a/pkg/sentry/devices/nvproxy/version.go b/pkg/sentry/devices/nvproxy/version.go index 60f0c9f259..a0ddac5e49 100644 --- a/pkg/sentry/devices/nvproxy/version.go +++ b/pkg/sentry/devices/nvproxy/version.go @@ -107,8 +107,9 @@ type driverABIStructsFunc func() *driverABIStructs // abiConAndChecksum couples the driver's abiConstructor to the SHA256 checksum of its linux .run // driver installer file from NVIDIA. type abiConAndChecksum struct { - cons driverABIFunc - checksum string + cons driverABIFunc + checksumX86_64 string + checksumARM64 string } // driverABI defines the Nvidia kernel driver ABI proxied at a given version. @@ -157,17 +158,28 @@ type DriverStruct struct { var abis map[DriverVersion]abiConAndChecksum var abisOnce sync.Once -// Note: runfileChecksum is the checksum of the .run file of the driver installer for linux from +type addDriverABIArgs struct { + major, minor, patch int + runfileChecksumX86_64 string + runfileChecksumARM64 string + cons driverABIFunc +} + +// Note: runfileChecksums are the checksum of the .run file of the driver installer for linux from // nvidia. // To add a new version, add in support as normal and add the "addDriverABI" call for your version. // Run `make sudo TARGETS=//tools/gpu:main ARGS="checksum --version={}"` to get checksum. -func addDriverABI(major, minor, patch int, runfileChecksum string, cons driverABIFunc) driverABIFunc { +func addDriverABI(args addDriverABIArgs) driverABIFunc { if abis == nil { abis = make(map[DriverVersion]abiConAndChecksum) } - version := NewDriverVersion(major, minor, patch) - abis[version] = abiConAndChecksum{cons: cons, checksum: runfileChecksum} - return cons + version := NewDriverVersion(args.major, args.minor, args.patch) + abis[version] = abiConAndChecksum{ + cons: args.cons, + checksumX86_64: args.runfileChecksumX86_64, + checksumARM64: args.runfileChecksumARM64, + } + return args.cons } // Init initializes abis global map. @@ -702,10 +714,41 @@ func Init() { // The following exist on the "535" branch. They branched off the main // branch at 535.113.01. - v535_183_01 := addDriverABI(535, 183, 01, "f6707afbdda9407e3cbc2e5128e60bcbcdbf02fae29958c72fafb5d405e8b883", v535_113_01) - v535_183_06 := addDriverABI(535, 183, 06, "c7bb0a0569c5347845479ed4e3e4d885c6ee3b8adf068c3401cdf754d5ba3d3b", v535_183_01) - v535_216_01 := addDriverABI(535, 216, 01, "5ddea1147810012e33967c3181341bcd6624bd3d654c63f845df833b4ece6af7", v535_183_06) - _ = addDriverABI(535, 230, 02, "20cca9118083fcc8083158466e9cb2b616a7922206bcb7296b1fa5cc9af2e0fd", v535_216_01) + v535_183_01 := addDriverABI( + addDriverABIArgs{ + major: 535, + minor: 183, + patch: 01, + runfileChecksumX86_64: "f6707afbdda9407e3cbc2e5128e60bcbcdbf02fae29958c72fafb5d405e8b883", + runfileChecksumARM64: "c9d13b6250d24b76ef87a49b179f234564184a9f6d6414184668958b7f6d21e6", + cons: v535_113_01, + }) + v535_183_06 := addDriverABI( + addDriverABIArgs{ + major: 535, + minor: 183, + patch: 06, + runfileChecksumX86_64: "c7bb0a0569c5347845479ed4e3e4d885c6ee3b8adf068c3401cdf754d5ba3d3b", + runfileChecksumARM64: "af3f72f5e4906805987844636b87ad1132650d05116272824c76dcc3f816d8e9", + cons: v535_183_01, + }) + v535_216_01 := addDriverABI(addDriverABIArgs{ + major: 535, + minor: 216, + patch: 01, + runfileChecksumX86_64: "5ddea1147810012e33967c3181341bcd6624bd3d654c63f845df833b4ece6af7", + runfileChecksumARM64: "4869ae0345b5892b2a50aed566c8226d3e07813d1190aa466feba5e9e21b33b9", + cons: v535_183_06, + }) + + _ = addDriverABI(addDriverABIArgs{ + major: 535, + minor: 230, + patch: 02, + runfileChecksumX86_64: "20cca9118083fcc8083158466e9cb2b616a7922206bcb7296b1fa5cc9af2e0fd", + runfileChecksumARM64: "ea000e6ff481f55e9bfedbea93b739368c635fe4be6156fdad560524ac7f363b", + cons: v535_216_01, + }) // 545.23.06 is an intermediate unqualified version from the main branch. v545_23_06 := func() *driverABI { @@ -774,47 +817,90 @@ func Init() { return abi } - v550_54_14 := addDriverABI(550, 54, 14, "8c497ff1cfc7c310fb875149bc30faa4fd26d2237b2cba6cd2e8b0780157cfe3", func() *driverABI { - abi := v550_40_07() - abi.uvmIoctl[nvgpu.UVM_ALLOC_SEMAPHORE_POOL] = uvmHandler(uvmIoctlSimple[nvgpu.UVM_ALLOC_SEMAPHORE_POOL_PARAMS_V550], compUtil) - abi.uvmIoctl[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION] = uvmHandler(uvmIoctlHasFrontendFD[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550], compUtil) + v550_54_14 := addDriverABI( + addDriverABIArgs{ + major: 550, + minor: 54, + patch: 14, + runfileChecksumX86_64: "8c497ff1cfc7c310fb875149bc30faa4fd26d2237b2cba6cd2e8b0780157cfe3", + runfileChecksumARM64: "b0fae8061633885c24f6b0c047649b46249a3bb44cadffbf658af28f80642c1d", + cons: func() *driverABI { + abi := v550_40_07() + abi.uvmIoctl[nvgpu.UVM_ALLOC_SEMAPHORE_POOL] = uvmHandler(uvmIoctlSimple[nvgpu.UVM_ALLOC_SEMAPHORE_POOL_PARAMS_V550], compUtil) + abi.uvmIoctl[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION] = uvmHandler(uvmIoctlHasFrontendFD[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550], compUtil) - prevStructs := abi.getStructs - abi.getStructs = func() *driverABIStructs { - structs := prevStructs() - structs.uvmStructs[nvgpu.UVM_ALLOC_SEMAPHORE_POOL] = driverStructWithName(nvgpu.UVM_ALLOC_SEMAPHORE_POOL_PARAMS_V550{}, "UVM_ALLOC_SEMAPHORE_POOL_PARAMS") - structs.uvmStructs[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION] = driverStructWithName(nvgpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550{}, "UVM_MAP_EXTERNAL_ALLOCATION_PARAMS") - return structs - } - - return abi - }) + prevStructs := abi.getStructs + abi.getStructs = func() *driverABIStructs { + structs := prevStructs() + structs.uvmStructs[nvgpu.UVM_ALLOC_SEMAPHORE_POOL] = driverStructWithName(nvgpu.UVM_ALLOC_SEMAPHORE_POOL_PARAMS_V550{}, "UVM_ALLOC_SEMAPHORE_POOL_PARAMS") + structs.uvmStructs[nvgpu.UVM_MAP_EXTERNAL_ALLOCATION] = driverStructWithName(nvgpu.UVM_MAP_EXTERNAL_ALLOCATION_PARAMS_V550{}, "UVM_MAP_EXTERNAL_ALLOCATION_PARAMS") + return structs + } + return abi + }, + }, + ) - v550_54_15 := addDriverABI(550, 54, 15, "2e859ae5f912a9a47aaa9b2d40a94a14f6f486b5d3b67c0ddf8b72c1c9650385", v550_54_14) + v550_54_15 := addDriverABI( + addDriverABIArgs{ + major: 550, + minor: 54, + patch: 15, + runfileChecksumX86_64: "2e859ae5f912a9a47aaa9b2d40a94a14f6f486b5d3b67c0ddf8b72c1c9650385", + runfileChecksumARM64: "49072d0c36ed85c7d8046776d34886f9ede9a6e4f46d5c7d533e8a8921d94cc1", + cons: v550_54_14, + }, + ) - v550_90_07 := addDriverABI(550, 90, 07, "51acf579d5a9884f573a1d3f522e7fafa5e7841e22a9cec0b4bbeae31b0b9733", func() *driverABI { - abi := v550_54_15() - abi.controlCmd[nvgpu.NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE] = ctrlHandler(rmControlSimple, compUtil) + v550_90_07 := addDriverABI( + addDriverABIArgs{ + major: 550, + minor: 90, + patch: 07, + runfileChecksumX86_64: "51acf579d5a9884f573a1d3f522e7fafa5e7841e22a9cec0b4bbeae31b0b9733", + runfileChecksumARM64: "b896b76ae465307afc5b269c40bd8ccb279e6ea7d3ecae95534a91ecb1971572", + cons: func() *driverABI { + abi := v550_54_15() + abi.controlCmd[nvgpu.NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE] = ctrlHandler(rmControlSimple, compUtil) - prevStructs := abi.getStructs - abi.getStructs = func() *driverABIStructs { - structs := prevStructs() - structs.controlStructs[nvgpu.NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE] = simpleDriverStruct("NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE_PARAMS") - return structs - } + prevStructs := abi.getStructs + abi.getStructs = func() *driverABIStructs { + structs := prevStructs() + structs.controlStructs[nvgpu.NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE] = simpleDriverStruct("NV_CONF_COMPUTE_CTRL_CMD_GPU_GET_KEY_ROTATION_STATE_PARAMS") + return structs + } - return abi - }) + return abi + }, + }, + ) // This version does not belong on any branch, but it is a child of 550.90.07. - _ = addDriverABI(550, 90, 12, "391883846713b9e700af2ae87f8ac671f5527508ce3f9f60058deb363e05162a", v550_90_07) + _ = addDriverABI( + addDriverABIArgs{ + major: 550, + minor: 90, + patch: 12, + runfileChecksumX86_64: "391883846713b9e700af2ae87f8ac671f5527508ce3f9f60058deb363e05162a", + runfileChecksumARM64: "0c410aff85b772bdb411d749c23e12ef2658f997e3094c41de8a0495a9fab4b4", + cons: v550_90_07, + }, + ) // 550.100 is an intermediate unqualified version from the main branch. v550_100 := v550_90_07 // The following exist on the "550" branch. They branched off the main // branch at 550.100. - _ = addDriverABI(550, 127, 05, "d384f34f5d2a896bd7536d3deb6a6d973d8094a3ad485a1c2ee3bf5192086ae9", v550_100) + _ = addDriverABI(addDriverABIArgs{ + major: 550, + minor: 127, + patch: 05, + runfileChecksumX86_64: "d384f34f5d2a896bd7536d3deb6a6d973d8094a3ad485a1c2ee3bf5192086ae9", + runfileChecksumARM64: "df0b06a89bc37fc8a8e2a152a9ba5a7de1c70636dab0ae62fd6f94e937847816", + cons: v550_100, + }, + ) // 555.42.02 is an intermediate unqualified version. v555_42_02 := func() *driverABI { @@ -844,24 +930,45 @@ func Init() { return abi } - v560_35_03 := addDriverABI(560, 35, 03, "f2932c92fadd43c5b2341be453fc4f73f0ad7185c26bb7a43fbde81ae29f1fe3", v560_28_03) - v565_57_01 := addDriverABI(565, 57, 01, "6eebe94e585e385e8804f5a74152df414887bf819cc21bd95b72acd0fb182c7a", v560_35_03) + v560_35_03 := addDriverABI(addDriverABIArgs{ + major: 560, + minor: 35, + patch: 03, + runfileChecksumX86_64: "f2932c92fadd43c5b2341be453fc4f73f0ad7185c26bb7a43fbde81ae29f1fe3", + runfileChecksumARM64: "b3c64054abd1357a63c5162a337139a2cb3915da96fadbf5a900b6a438df1beb", + cons: v560_28_03, + }) + v565_57_01 := addDriverABI(addDriverABIArgs{ + major: 565, + minor: 57, + patch: 01, + runfileChecksumX86_64: "6eebe94e585e385e8804f5a74152df414887bf819cc21bd95b72acd0fb182c7a", + runfileChecksumARM64: "68355cdec3531b83b7cbebca5bcee6c3e8bd02a5c2636f4656a108525b2f61f1", + cons: v560_35_03, + }) - _ = addDriverABI(570, 86, 15, "87709c19c7401243136bc0ec9e7f147c6803070a11449ae8f0819dee7963f76b", func() *driverABI { - abi := v565_57_01() - abi.allocationClass[nvgpu.TURING_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) - abi.allocationClass[nvgpu.AMPERE_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) - abi.allocationClass[nvgpu.HOPPER_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) + _ = addDriverABI(addDriverABIArgs{ + major: 570, + minor: 86, + patch: 15, + runfileChecksumX86_64: "87709c19c7401243136bc0ec9e7f147c6803070a11449ae8f0819dee7963f76b", + runfileChecksumARM64: "a663f81873bafda8313abb5a09f36c593426bb94a8bcc3f2017c79c95bf32978", + cons: func() *driverABI { + abi := v565_57_01() + abi.allocationClass[nvgpu.TURING_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) + abi.allocationClass[nvgpu.AMPERE_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) + abi.allocationClass[nvgpu.HOPPER_CHANNEL_GPFIFO_A] = allocHandler(rmAllocChannelV570, compUtil) - prevStructs := abi.getStructs - abi.getStructs = func() *driverABIStructs { - structs := prevStructs() - structs.allocationStructs[nvgpu.TURING_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") - structs.allocationStructs[nvgpu.AMPERE_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") - structs.allocationStructs[nvgpu.HOPPER_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") - return structs - } - return abi + prevStructs := abi.getStructs + abi.getStructs = func() *driverABIStructs { + structs := prevStructs() + structs.allocationStructs[nvgpu.TURING_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") + structs.allocationStructs[nvgpu.AMPERE_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") + structs.allocationStructs[nvgpu.HOPPER_CHANNEL_GPFIFO_A] = driverStructWithName(nvgpu.NV_CHANNEL_ALLOC_PARAMS_V570{}, "NV_CHANNEL_ALLOC_PARAMS") + return structs + } + return abi + }, }) }) } @@ -908,9 +1015,9 @@ func newDriverStruct(paramType reflect.Type, name string) DriverStruct { // ForEachSupportDriver calls f on all supported drivers. // Precondition: Init() must have been called. -func ForEachSupportDriver(f func(version DriverVersion, checksum string)) { +func ForEachSupportDriver(f func(version DriverVersion, checksum_X86_64, checksum_ARM64 string)) { for version, abi := range abis { - f(version, abi.checksum) + f(version, abi.checksumX86_64, abi.checksumARM64) } } @@ -941,12 +1048,12 @@ func SupportedDrivers() []DriverVersion { // ExpectedDriverChecksum returns the expected checksum for a given version. // Precondition: Init() must have been called. -func ExpectedDriverChecksum(version DriverVersion) (string, bool) { +func ExpectedDriverChecksum(version DriverVersion) (string, string, bool) { abi, ok := abis[version] if !ok { - return "", false + return "", "", false } - return abi.checksum, true + return abi.checksumX86_64, abi.checksumARM64, true } // SupportedIoctls returns the ioctl numbers that are supported by nvproxy at diff --git a/tools/gpu/drivers/install_driver.go b/tools/gpu/drivers/install_driver.go index 091f29401b..90481d1d1d 100644 --- a/tools/gpu/drivers/install_driver.go +++ b/tools/gpu/drivers/install_driver.go @@ -23,6 +23,7 @@ import ( "net/http" "os" "os/exec" + "runtime" "sort" "strings" @@ -40,13 +41,32 @@ func init() { nvproxy.Init() } +// CPUArchitecture is the CPU architecture of the driver. +type CPUArchitecture string + +const ( + // X86_64 is the default architecture. + X86_64 CPUArchitecture = "x86_64" + // ARM64 is the architecture for ARM based GPUs. + ARM64 CPUArchitecture = "aarch64" +) + +// GetArchitecture returns the CPU architecture of the driver. Right now, we only support X86_64 and +// ARM64. +func GetArchitecture() CPUArchitecture { + if strings.HasPrefix(runtime.GOARCH, "arm") { + return ARM64 + } + return X86_64 +} + // Installer handles the logic to install drivers. type Installer struct { requestedVersion nvproxy.DriverVersion // include functions so they can be mocked in tests. - expectedChecksumFunc func(nvproxy.DriverVersion) (string, bool) + expectedChecksumFunc func(nvproxy.DriverVersion) (string, string, bool) getCurrentDriverFunc func() (nvproxy.DriverVersion, error) - downloadFunc func(context.Context, string) (io.ReadCloser, error) + downloadFunc func(context.Context, string, CPUArchitecture) (io.ReadCloser, error) installFunc func(string) error } @@ -68,15 +88,14 @@ func NewInstaller(requestedVersion string, latest bool) (*Installer, error) { } ret.requestedVersion = d } - return ret, nil } // MaybeInstall installs a driver if 1) no driver is present on the system already or 2) the // driver currently installed does not match the requested version. -func (i *Installer) MaybeInstall(ctx context.Context) error { +func (i *Installer) MaybeInstall(ctx context.Context, arch CPUArchitecture) error { // If we don't support the driver, don't attempt to install it. - if _, ok := i.expectedChecksumFunc(i.requestedVersion); !ok { + if _, _, ok := i.expectedChecksumFunc(i.requestedVersion); !ok { return fmt.Errorf("requested driver %q is not supported", i.requestedVersion) } @@ -98,7 +117,7 @@ func (i *Installer) MaybeInstall(ctx context.Context) error { } log.Infof("Downloading driver: %s", i.requestedVersion) - reader, err := i.downloadFunc(ctx, i.requestedVersion.String()) + reader, err := i.downloadFunc(ctx, i.requestedVersion.String(), arch) if err != nil { return fmt.Errorf("failed to download driver: %w", err) } @@ -108,7 +127,7 @@ func (i *Installer) MaybeInstall(ctx context.Context) error { return fmt.Errorf("failed to open driver file: %w", err) } defer os.Remove(f.Name()) - if err := i.writeAndCheck(f, reader); err != nil { + if err := i.writeAndCheck(f, reader, arch); err != nil { f.Close() return fmt.Errorf("writeAndCheck: %w", err) } @@ -138,7 +157,7 @@ func (i *Installer) uninstallDriver(ctx context.Context, driverVersion string) e return nil } -func (i *Installer) writeAndCheck(f *os.File, reader io.ReadCloser) error { +func (i *Installer) writeAndCheck(f *os.File, reader io.ReadCloser, arch CPUArchitecture) error { checksum := sha256.New() buf := make([]byte, 1024*1024) for { @@ -157,10 +176,15 @@ func (i *Installer) writeAndCheck(f *os.File, reader io.ReadCloser) error { } } gotChecksum := fmt.Sprintf("%x", checksum.Sum(nil)) - wantChecksum, ok := i.expectedChecksumFunc(i.requestedVersion) + wantChecksumX86_64, wantChecksumARM64, ok := i.expectedChecksumFunc(i.requestedVersion) if !ok { return fmt.Errorf("requested driver %q is not supported", i.requestedVersion) } + wantChecksum := wantChecksumX86_64 + if arch == ARM64 { + wantChecksum = wantChecksumARM64 + } + if gotChecksum != wantChecksum { return fmt.Errorf("driver %q checksum mismatch: got %q, want %q", i.requestedVersion, gotChecksum, wantChecksum) } @@ -217,7 +241,7 @@ func ListSupportedDrivers(outfile string) error { } var list []string - nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) { + nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, _, _ string) { list = append(list, version.String()) }) sort.Strings(list) @@ -228,8 +252,8 @@ func ListSupportedDrivers(outfile string) error { } // ChecksumDriver downloads and returns the SHA265 checksum of the driver. -func ChecksumDriver(ctx context.Context, driverVersion string) (string, error) { - f, err := DownloadDriver(ctx, driverVersion) +func ChecksumDriver(ctx context.Context, driverVersion string, arch CPUArchitecture) (string, error) { + f, err := DownloadDriver(ctx, driverVersion, arch) if err != nil { return "", fmt.Errorf("failed to download driver: %w", err) } @@ -248,8 +272,8 @@ func ChecksumDriver(ctx context.Context, driverVersion string) (string, error) { // DownloadDriver downloads the requested driver and returns the binary as a []byte so it can be // checked before written to disk. -func DownloadDriver(ctx context.Context, driverVersion string) (io.ReadCloser, error) { - url := fmt.Sprintf("%s%s/NVIDIA-Linux-x86_64-%s.run", nvidiaBaseURL, driverVersion, driverVersion) +func DownloadDriver(ctx context.Context, driverVersion string, arch CPUArchitecture) (io.ReadCloser, error) { + url := fmt.Sprintf("%s%s/NVIDIA-Linux-%s-%s.run", nvidiaBaseURL, driverVersion, arch, driverVersion) resp, err := http.Get(url) if err != nil { return nil, fmt.Errorf("failed to download driver: %w", err) diff --git a/tools/gpu/drivers/install_driver_test.go b/tools/gpu/drivers/install_driver_test.go index 6f2b6f065f..1f1e09d472 100644 --- a/tools/gpu/drivers/install_driver_test.go +++ b/tools/gpu/drivers/install_driver_test.go @@ -32,19 +32,21 @@ func TestVersionInstalled(t *testing.T) { checksum := fmt.Sprintf("%x", sha256.Sum256(versionContent)) version := nvproxy.NewDriverVersion(1, 2, 3) getFunction := func() (nvproxy.DriverVersion, error) { return version, nil } - downloadFunction := func(context.Context, string) (io.ReadCloser, error) { return nil, fmt.Errorf("should not get here") } + downloadFunction := func(context.Context, string, CPUArchitecture) (io.ReadCloser, error) { + return nil, fmt.Errorf("should not get here") + } installer := &Installer{ requestedVersion: version, - expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { + expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, string, bool) { if v == version { - return checksum, true + return checksum, checksum, true } - return "", false + return "", "", false }, getCurrentDriverFunc: getFunction, downloadFunc: downloadFunction, } - if err := installer.MaybeInstall(ctx); err != nil { + if err := installer.MaybeInstall(ctx, X86_64); err != nil { t.Fatalf("Installation failed: %v", err) } } @@ -55,11 +57,11 @@ func TestVersionNotSupported(t *testing.T) { unsupportedVersion := nvproxy.NewDriverVersion(1, 2, 3) installer := &Installer{ requestedVersion: unsupportedVersion, - expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { - return "", false + expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, string, bool) { + return "", "", false }, } - err := installer.MaybeInstall(ctx) + err := installer.MaybeInstall(ctx, X86_64) if err == nil { t.Fatalf("Installation succeeded, want error") } @@ -72,23 +74,25 @@ func TestVersionNotSupported(t *testing.T) { func TestShaMismatch(t *testing.T) { ctx := context.Background() version := nvproxy.NewDriverVersion(1, 2, 3) + content := []byte("some content") + checksum := fmt.Sprintf("%x", sha256.Sum256(content)) installer := &Installer{ requestedVersion: version, getCurrentDriverFunc: func() (nvproxy.DriverVersion, error) { return nvproxy.DriverVersion{}, nil }, - expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { + expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, string, bool) { if v == version { - return "mismatch", true + return "mismatch", "mismatch", true } - return "", false + return checksum, checksum, false }, - downloadFunc: func(context.Context, string) (io.ReadCloser, error) { + downloadFunc: func(context.Context, string, CPUArchitecture) (io.ReadCloser, error) { reader := bytes.NewReader([]byte("some content")) return io.NopCloser(reader), nil }, } - err := installer.MaybeInstall(ctx) + err := installer.MaybeInstall(ctx, X86_64) if err == nil { t.Fatalf("Installation succeeded, want error") } @@ -100,29 +104,37 @@ func TestShaMismatch(t *testing.T) { // TestDriverInstalls tests the successful installation of a driver. func TestDriverInstalls(t *testing.T) { ctx := context.Background() - content := []byte("some content") - checksum := fmt.Sprintf("%x", sha256.Sum256(content)) + for _, arch := range []CPUArchitecture{X86_64, ARM64} { + t.Run(fmt.Sprintf("%s", arch), func(t *testing.T) { + testDriverInstalls(ctx, t, arch) + }) + } +} + +func testDriverInstalls(ctx context.Context, t *testing.T, arch CPUArchitecture) { version := nvproxy.NewDriverVersion(1, 2, 3) installer := &Installer{ requestedVersion: version, getCurrentDriverFunc: func() (nvproxy.DriverVersion, error) { return nvproxy.DriverVersion{}, nil }, - expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, bool) { + expectedChecksumFunc: func(v nvproxy.DriverVersion) (string, string, bool) { + checksumX86_64 := fmt.Sprintf("%x", sha256.Sum256([]byte(X86_64))) + checksumARM64 := fmt.Sprintf("%x", sha256.Sum256([]byte(ARM64))) if v == version { - return checksum, true + return checksumX86_64, checksumARM64, true } - return "", false + return "garbage", "garbage", false }, - downloadFunc: func(context.Context, string) (io.ReadCloser, error) { - reader := bytes.NewReader(content) + downloadFunc: func(context.Context, string, CPUArchitecture) (io.ReadCloser, error) { + reader := bytes.NewReader([]byte(arch)) return io.NopCloser(reader), nil }, installFunc: func(_ string) error { return nil }, } - if err := installer.MaybeInstall(ctx); err != nil { + if err := installer.MaybeInstall(ctx, arch); err != nil { t.Fatalf("Installation failed: %v", err) } } diff --git a/tools/gpu/main.go b/tools/gpu/main.go index 41a8a6bbf5..72d2a468da 100644 --- a/tools/gpu/main.go +++ b/tools/gpu/main.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "os" + "sync" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/devices/nvproxy" @@ -92,7 +93,7 @@ func main() { log.Warningf("Failed to create installer: %v", err.Error()) os.Exit(1) } - if err := installer.MaybeInstall(ctx); err != nil { + if err := installer.MaybeInstall(ctx, drivers.GetArchitecture()); err != nil { log.Warningf("Failed to install driver: %v", err.Error()) os.Exit(1) } @@ -102,30 +103,46 @@ func main() { os.Exit(1) } - checksum, err := drivers.ChecksumDriver(ctx, *checksumVersion) - if err != nil { - log.Warningf("Failed to compute checksum: %v", err) - os.Exit(1) + for _, arch := range []drivers.CPUArchitecture{drivers.X86_64, drivers.ARM64} { + checksum, err := drivers.ChecksumDriver(ctx, *checksumVersion, arch) + if err != nil { + log.Warningf("Failed to compute checksum on arch %q: %v", arch, err) + os.Exit(1) + } + fmt.Printf("Checksum for driver %q on %s : %q\n", *checksumVersion, arch, checksum) } - fmt.Printf("Checksum: %q\n", checksum) case validateChecksumCmdStr: if err := validateChecksumCmd.Parse(os.Args[2:]); err != nil { log.Warningf("%s failed with: %v", validateChecksumCmdStr, err) os.Exit(1) } - nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, checksum string) { - wantChecksum, err := drivers.ChecksumDriver(ctx, version.String()) - if err != nil { - log.Warningf("error on version %q: %v", version.String(), err) - return - } - if checksum != wantChecksum { - log.Warningf("Checksum mismatch on driver %q got: %q want: %q", version.String(), checksum, wantChecksum) - return + var wg sync.WaitGroup + + nvproxy.ForEachSupportDriver(func(version nvproxy.DriverVersion, x86Checksum, armChecksum string) { + for _, arch := range []drivers.CPUArchitecture{drivers.X86_64, drivers.ARM64} { + wg.Add(1) + go func() { + defer wg.Done() + gotChecksum, err := drivers.ChecksumDriver(ctx, version.String(), arch) + if err != nil { + log.Warningf("error on version %q on arch %q: %v", version.String(), arch, err) + return + } + checksum := x86Checksum + if arch == drivers.ARM64 { + checksum = armChecksum + } + + if checksum != gotChecksum { + log.Warningf("Checksum mismatch on driver %q on arch %q: got: %q want: %q", version.String(), arch, gotChecksum, checksum) + return + } + log.Infof("Checksum matched on driver %q on arch %q.", version.String(), arch) + }() } - log.Infof("Checksum matched on driver %q.", version.String()) }) + wg.Wait() case listCmdStr: if err := listCmd.Parse(os.Args[2:]); err != nil { log.Warningf("%s failed with: %v", listCmdStr, err)