|
1 | 1 | module LuxTestUtils |
2 | 2 |
|
| 3 | +using Adapt: Adapt, adapt |
3 | 4 | using ArrayInterface: ArrayInterface |
4 | 5 | using ComponentArrays: ComponentArray, getdata, getaxes |
5 | 6 | using DispatchDoctor: allow_unstable |
6 | | -using Functors: Functors |
| 7 | +using Functors: Functors, fmap |
7 | 8 | using MLDataDevices: cpu_device, gpu_device, get_device, get_device_type, AbstractGPUDevice |
8 | 9 | using Optimisers: Optimisers |
9 | 10 | using Pkg: PackageSpec |
@@ -33,25 +34,41 @@ using Zygote: Zygote |
33 | 34 | const CRC = ChainRulesCore |
34 | 35 | const FD = FiniteDiff |
35 | 36 |
|
| 37 | +const JET_TESTING_ENABLED = Ref{Bool}(false) |
| 38 | +const ENZYME_TESTING_ENABLED = Ref{Bool}(false) |
| 39 | +const ZYGOTE_TESTING_ENABLED = Ref{Bool}(false) |
| 40 | + |
36 | 41 | # Check if JET will work |
37 | 42 | try |
38 | 43 | using JET: JET, JETTestFailure, get_reports, report_call, report_opt |
39 | | - # XXX: In 1.11, JET leads to stack overflows |
40 | | - global JET_TESTING_ENABLED = v"1.10-" ≤ VERSION < v"1.11-" |
| 44 | + JET_TESTING_ENABLED[] = true |
41 | 45 | catch err |
42 | 46 | @error "`JET.jl` did not successfully precompile on $(VERSION). All `@jet` tests will \ |
43 | 47 | be skipped." maxlog = 1 err = err |
44 | | - global JET_TESTING_ENABLED = false |
| 48 | + JET_TESTING_ENABLED[] = false |
45 | 49 | end |
46 | 50 |
|
47 | | -# Check if Enzyme will work |
48 | | -try |
49 | | - using Enzyme: Enzyme |
50 | | - __ftest(x) = x |
51 | | - Enzyme.autodiff(Enzyme.Reverse, __ftest, Enzyme.Active, Enzyme.Active(2.0)) |
52 | | - global ENZYME_TESTING_ENABLED = Sys.islinux() |
53 | | -catch err |
54 | | - global ENZYME_TESTING_ENABLED = false |
| 51 | +# Check if Enzyme will work (only on non-prerelease versions) |
| 52 | +@static if isempty(VERSION.prerelease) |
| 53 | + try |
| 54 | + using Enzyme: Enzyme |
| 55 | + Enzyme.gradient(Enzyme.Reverse, Base.Fix1(sum, abs2), ones(Float32, 10)) |
| 56 | + ENZYME_TESTING_ENABLED[] = Sys.islinux() |
| 57 | + catch err |
| 58 | + @error "`Enzyme.jl` did not successfully differentiate a simple function or \ |
| 59 | + failed to load on $(VERSION). All Enzyme tests will be \ |
| 60 | + skipped." maxlog = 1 err = err |
| 61 | + ENZYME_TESTING_ENABLED[] = false |
| 62 | + end |
| 63 | +end |
| 64 | + |
| 65 | +function __init__() |
| 66 | + ZYGOTE_TESTING_ENABLED[] = VERSION < v"1.12-" |
| 67 | + |
| 68 | + if JET_TESTING_ENABLED[] |
| 69 | + # JET doesn't work nicely on 1.11 |
| 70 | + JET_TESTING_ENABLED[] = VERSION < v"1.11-" || VERSION ≥ v"1.12-" |
| 71 | + end |
55 | 72 | end |
56 | 73 |
|
57 | 74 | include("package_install.jl") |
|
0 commit comments