@@ -112,6 +112,20 @@ function test_worker(name)
112112end
113113
114114# code to run in each test's sandbox module before running the test
115+
116+ setup_code = quote
117+ using Metal, Adapt, ObjectiveC, ObjectiveC. Foundation, BFloat16s
118+
119+ import GPUArrays
120+ include($ gpuarrays_testsuite)
121+
122+ const eltypes = [Int16, Int32, Int64,
123+ Complex{Int16}, Complex{Int32}, Complex{Int64},
124+ Float16, Float32,
125+ ComplexF16, ComplexF32]
126+ TestSuite. supported_eltypes(:: Type{<:MtlArray} ) = eltypes
127+ end
128+
115129init_code = quote
116130 using Metal, Adapt, ObjectiveC, ObjectiveC. Foundation, BFloat16s
117131
@@ -121,16 +135,12 @@ init_code = quote
121135
122136 const capturing = parse(Int, get(ENV , " METAL_CAPTURE_ENABLED" , " 0" )) > 0
123137
124- import GPUArrays
125- include($ gpuarrays_testsuite)
138+ # import GPUArrays
139+ # include($gpuarrays_testsuite)
140+ # testf(f, xs...; kwargs...) = TestSuite.compare(f, MtlArray, xs...; kwargs...)
141+ import Main: TestSuite
126142 testf(f, xs... ; kwargs... ) = TestSuite. compare(f, MtlArray, xs... ; kwargs... )
127143
128- const eltypes = [Int16, Int32, Int64,
129- Complex{Int16}, Complex{Int32}, Complex{Int64},
130- Float16, Float32,
131- ComplexF16, ComplexF32]
132- TestSuite. supported_eltypes(:: Type{<:MtlArray} ) = eltypes
133-
134144 # NOTE: based on test/pkg.jl::capture_stdout, but doesn't discard exceptions
135145 macro grab_output(ex)
136146 quote
@@ -165,4 +175,5 @@ init_code = quote
165175 end
166176end
167177
168- runtests(Metal, args; testsuite, init_code, test_worker)
178+ runtests(Metal, args; testsuite, setup_code, init_code, test_worker)
179+ # runtests(Metal, args; testsuite, init_code, test_worker)
0 commit comments