Skip to content

Commit 144e287

Browse files
committed
test: streamline installing packages in tests
1 parent 3b79c3c commit 144e287

File tree

15 files changed

+130
-159
lines changed

15 files changed

+130
-159
lines changed

.buildkite/testing_weightinitializers.yml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ steps:
4646
# matrix:
4747
# setup:
4848
# julia:
49-
# - "1.10"
5049
# - "1.12"
5150

5251
- group: ":julia: (WeightInitializers) Metal GPU"
@@ -72,7 +71,6 @@ steps:
7271
matrix:
7372
setup:
7473
julia:
75-
- "1.10"
7674
- "1.12"
7775

7876
- group: ":julia: (WeightInitializers) oneAPI GPU"
@@ -98,7 +96,6 @@ steps:
9896
matrix:
9997
setup:
10098
julia:
101-
- "1.10"
10299
- "1.12"
103100

104101
env:

lib/LuxLib/test/runtests.jl

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,18 @@ end
1919
const PARSED_TEST_ARGS = parse_test_args()
2020

2121
const BACKEND_GROUP = lowercase(get(PARSED_TEST_ARGS, "BACKEND_GROUP", "all"))
22-
const EXTRA_PKGS = PackageSpec[]
23-
const EXTRA_DEV_PKGS = PackageSpec[]
2422

2523
const LUXLIB_BLAS_BACKEND = lowercase(
2624
get(PARSED_TEST_ARGS, "LUXLIB_BLAS_BACKEND", "default")
2725
)
2826
@assert LUXLIB_BLAS_BACKEND in ("default", "appleaccelerate", "blis", "mkl")
2927
@info "Running tests with BLAS backend: $(LUXLIB_BLAS_BACKEND)"
3028

31-
if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda")
32-
if isdir(joinpath(@__DIR__, "../../LuxCUDA"))
33-
@info "Using local LuxCUDA"
34-
push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA")))
35-
else
36-
push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA"))
37-
end
38-
end
39-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&
40-
push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU"))
41-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") &&
42-
push!(EXTRA_PKGS, PackageSpec(; name="oneAPI"))
43-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") &&
44-
push!(EXTRA_PKGS, PackageSpec(; name="Metal"))
29+
const EXTRA_PKGS = LuxTestUtils.packages_to_install(BACKEND_GROUP)
4530

46-
if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS)
47-
@info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS
31+
if !isempty(EXTRA_PKGS)
32+
@info "Installing Extra Packages for testing" EXTRA_PKGS
4833
isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS)
49-
isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS)
5034
Base.retry_load_extensions()
5135
Pkg.instantiate()
5236
end

lib/LuxLib/test/shared_testsetup.jl

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,38 +30,34 @@ end
3030

3131
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "All"))
3232

33-
if BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda"
33+
if LuxTestUtils.test_cuda(BACKEND_GROUP)
3434
using LuxCUDA
3535
end
3636

37-
if BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu"
37+
if LuxTestUtils.test_amdgpu(BACKEND_GROUP)
3838
using AMDGPU
3939
end
4040

41-
if BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi"
41+
if LuxTestUtils.test_oneapi(BACKEND_GROUP)
4242
using oneAPI
4343
end
4444

45-
if BACKEND_GROUP == "all" || BACKEND_GROUP == "metal"
45+
if LuxTestUtils.test_metal(BACKEND_GROUP)
4646
using Metal
4747
end
4848

4949
cpu_testing() = BACKEND_GROUP == "all" || BACKEND_GROUP == "cpu"
5050
function cuda_testing()
51-
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda") &&
52-
MLDataDevices.functional(CUDADevice)
51+
return LuxTestUtils.test_cuda(BACKEND_GROUP) && MLDataDevices.functional(CUDADevice)
5352
end
5453
function amdgpu_testing()
55-
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&
56-
MLDataDevices.functional(AMDGPUDevice)
54+
return LuxTestUtils.test_amdgpu(BACKEND_GROUP) && MLDataDevices.functional(AMDGPUDevice)
5755
end
5856
function oneapi_testing()
59-
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") &&
60-
MLDataDevices.functional(oneAPIDevice)
57+
return LuxTestUtils.test_oneapi(BACKEND_GROUP) && MLDataDevices.functional(oneAPIDevice)
6158
end
6259
function metal_testing()
63-
return (BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") &&
64-
MLDataDevices.functional(MetalDevice)
60+
return LuxTestUtils.test_metal(BACKEND_GROUP) && MLDataDevices.functional(MetalDevice)
6561
end
6662

6763
const MODES = begin

lib/LuxTestUtils/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxTestUtils"
22
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
3+
version = "2.1.0"
34
authors = ["Avik Pal <avikpal@mit.edu>"]
4-
version = "2.0.1"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -16,6 +16,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1616
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1717
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
1818
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
19+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1920
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2122

@@ -35,6 +36,7 @@ Functors = "0.5"
3536
JET = "0.9.6, 0.10"
3637
MLDataDevices = "1.6.10"
3738
Optimisers = "0.3.4, 0.4"
39+
Pkg = "1.10"
3840
Test = "1.10"
3941
Zygote = "0.7"
4042
julia = "1.10"

lib/LuxTestUtils/src/LuxTestUtils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using DispatchDoctor: allow_unstable
66
using Functors: Functors
77
using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice
88
using Optimisers: Optimisers
9+
using Pkg: PackageSpec
910
using Test:
1011
Test,
1112
Error,
@@ -53,6 +54,7 @@ catch err
5354
global ENZYME_TESTING_ENABLED = false
5455
end
5556

57+
include("package_install.jl")
5658
include("test_softfail.jl")
5759
include("autodiff.jl")
5860
include("jet.jl")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
function has_cuda()
2+
try
3+
return run(`nvidia-smi -L`) === nothing ? false : true
4+
catch
5+
return false
6+
end
7+
end
8+
9+
function has_amdgpu()
10+
try
11+
out = read(`rocminfo`, String)
12+
return occursin("GPU", out)
13+
catch
14+
return false
15+
end
16+
end
17+
18+
has_metal() = Sys.isapple() && Sys.KERNEL === :Darwin
19+
20+
function has_oneapi()
21+
try
22+
return isdir("/dev/dri") && any(occursin("render", f) for f in readdir("/dev/dri"))
23+
catch
24+
return false
25+
end
26+
end
27+
28+
has_opencl() = true
29+
30+
for backend_group in ("cuda", "amdgpu", "metal", "oneapi", "opencl")
31+
fnanme = Symbol(:test_, backend_group)
32+
has_fnname = Symbol(:has_, backend_group)
33+
@eval function $(fnanme)(backend_group::String="all")
34+
backend_group == $(QuoteNode(backend_group)) && return true
35+
backend_group != "all" && return false
36+
$(has_fnname)() && return true
37+
return false
38+
end
39+
end
40+
41+
function packages_to_install(backend_group::String="all")
42+
backend_group = lowercase(backend_group)
43+
@assert backend_group in
44+
("all", "cpu", "cuda", "amdgpu", "metal", "oneapi", "opencl", "reactant")
45+
46+
pkgs = PackageSpec[]
47+
if test_cuda(backend_group)
48+
push!(pkgs, PackageSpec(; name="CUDA"))
49+
push!(pkgs, PackageSpec(; name="cuDNN"))
50+
push!(pkgs, PackageSpec(; name="LuxCUDA"))
51+
end
52+
test_amdgpu(backend_group) && push!(pkgs, PackageSpec(; name="AMDGPU"))
53+
test_metal(backend_group) && push!(pkgs, PackageSpec(; name="Metal", version="1.9"))
54+
test_oneapi(backend_group) && push!(pkgs, PackageSpec(; name="oneAPI"))
55+
if test_opencl(backend_group)
56+
push!(pkgs, PackageSpec(; name="OpenCL"))
57+
push!(pkgs, PackageSpec(; name="pocl_jll"))
58+
end
59+
return pkgs
60+
end

lib/MLDataDevices/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
11+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
1112
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1213
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1314
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -20,6 +21,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2021
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2122
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2223

24+
[sources]
25+
LuxTestUtils = {path = "../../LuxTestUtils"}
26+
2327
[compat]
2428
Adapt = "4"
2529
Aqua = "0.8.4"

lib/MLDataDevices/test/iterator_tests.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))
2-
3-
if BACKEND_GROUP == "opencl" || BACKEND_GROUP == "all"
4-
using OpenCL, pocl_jll
5-
end
1+
using MLDataDevices, MLUtils, Test, LuxTestUtils
62

7-
using MLDataDevices, MLUtils, Test
3+
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))
84

9-
if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all"
5+
if LuxTestUtils.test_cuda(BACKEND_GROUP)
106
using LuxCUDA
117
end
128

@@ -15,18 +11,22 @@ if BACKEND_GROUP != "cuda"
1511
using Reactant
1612
end
1713

18-
if BACKEND_GROUP == "amdgpu" || BACKEND_GROUP == "all"
14+
if LuxTestUtils.test_amdgpu(BACKEND_GROUP)
1915
using AMDGPU
2016
end
2117

22-
if BACKEND_GROUP == "metal" || BACKEND_GROUP == "all"
18+
if LuxTestUtils.test_metal(BACKEND_GROUP)
2319
using Metal
2420
end
2521

26-
if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
22+
if LuxTestUtils.test_oneapi(BACKEND_GROUP)
2723
using oneAPI
2824
end
2925

26+
if LuxTestUtils.test_opencl(BACKEND_GROUP)
27+
using OpenCL, pocl_jll
28+
end
29+
3030
DEVICES = [
3131
CPUDevice,
3232
CUDADevice,

lib/MLDataDevices/test/runtests.jl

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Pkg: Pkg, PackageSpec
22
using Test
3+
using LuxTestUtils
34

45
function parse_test_args()
56
test_args_from_env = @isdefined(TEST_ARGS) ? TEST_ARGS : ARGS
@@ -18,32 +19,11 @@ const PARSED_TEST_ARGS = parse_test_args()
1819

1920
const BACKEND_GROUP = lowercase(get(PARSED_TEST_ARGS, "BACKEND_GROUP", "none"))
2021

21-
const EXTRA_PKGS = PackageSpec[]
22-
const EXTRA_DEV_PKGS = PackageSpec[]
22+
const EXTRA_PKGS = LuxTestUtils.packages_to_install(BACKEND_GROUP)
2323

24-
if (BACKEND_GROUP == "all" || BACKEND_GROUP == "cuda")
25-
if isdir(joinpath(@__DIR__, "../../LuxCUDA"))
26-
@info "Using local LuxCUDA"
27-
push!(EXTRA_DEV_PKGS, PackageSpec(; path=joinpath(@__DIR__, "../../LuxCUDA")))
28-
else
29-
push!(EXTRA_PKGS, PackageSpec(; name="LuxCUDA"))
30-
end
31-
end
32-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "amdgpu") &&
33-
push!(EXTRA_PKGS, PackageSpec(; name="AMDGPU"))
34-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "oneapi") &&
35-
push!(EXTRA_PKGS, PackageSpec(; name="oneAPI"))
36-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "opencl") && begin
37-
push!(EXTRA_PKGS, PackageSpec(; name="OpenCL"))
38-
push!(EXTRA_PKGS, PackageSpec(; name="pocl_jll"))
39-
end
40-
(BACKEND_GROUP == "all" || BACKEND_GROUP == "metal") &&
41-
push!(EXTRA_PKGS, PackageSpec(; name="Metal"))
42-
43-
if !isempty(EXTRA_PKGS) || !isempty(EXTRA_DEV_PKGS)
44-
@info "Installing Extra Packages for testing" EXTRA_PKGS EXTRA_DEV_PKGS
24+
if !isempty(EXTRA_PKGS)
25+
@info "Installing Extra Packages for testing" EXTRA_PKGS
4526
isempty(EXTRA_PKGS) || Pkg.add(EXTRA_PKGS)
46-
isempty(EXTRA_DEV_PKGS) || Pkg.develop(EXTRA_DEV_PKGS)
4727
Base.retry_load_extensions()
4828
Pkg.instantiate()
4929
end

lib/WeightInitializers/test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
77
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
88
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
10+
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
1011
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -15,6 +16,9 @@ StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1516
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1718

19+
[sources]
20+
LuxTestUtils = {path = "../../LuxTestUtils"}
21+
1822
[compat]
1923
Aqua = "0.8.7"
2024
CPUSummary = "0.2.6"

0 commit comments

Comments
 (0)