Skip to content

Commit 719654d

Browse files
authored
feat: add OpenCL support to MLDataDevices (#1590)
* add opencl support * fix * add opencl tests * gemini recommended fixes * minor test fixes * Add 'opencl' to CI matrix groups * address comments * use JuliaFormatter to format the file * format more files * address comments * attempt to fix tests * test fix * new functional and fix tests * fix typo * enable opencl cpu fallback tests
1 parent bf85cee commit 719654d

File tree

12 files changed

+355
-12
lines changed

12 files changed

+355
-12
lines changed

.github/workflows/CI_MLDataDevices.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
- windows-latest
2727
group:
2828
- cpu
29+
- opencl
2930
- reactant
3031
uses: ./.github/workflows/CommonCI.yml
3132
with:
@@ -40,6 +41,7 @@ jobs:
4041
matrix:
4142
group:
4243
- cpu
44+
- opencl
4345
- reactant
4446
uses: ./.github/workflows/CommonCI.yml
4547
with:

lib/MLDataDevices/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.15.3"
4+
version = "1.16.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -29,6 +29,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
2929
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3030
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
3131
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
32+
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"
3233

3334
[extensions]
3435
MLDataDevicesAMDGPUExt = "AMDGPU"
@@ -49,6 +50,7 @@ MLDataDevicesTrackerExt = "Tracker"
4950
MLDataDevicesZygoteExt = "Zygote"
5051
MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
5152
MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
53+
MLDataDevicesOpenCLExt = ["GPUArrays", "OpenCL"]
5254

5355
[compat]
5456
AMDGPU = "1, 2"
@@ -75,3 +77,4 @@ Zygote = "0.7"
7577
cuDNN = "1.3"
7678
julia = "1.10"
7779
oneAPI = "1.5, 2"
80+
OpenCL = "0.10.5"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
module MLDataDevicesOpenCLExt
2+
3+
using Adapt: Adapt
4+
using GPUArrays: GPUArrays
5+
using OpenCL: OpenCL, cl, CLArray
6+
using MLDataDevices: MLDataDevices, Internal, OpenCLDevice, reset_gpu_device!
7+
8+
const SUPPORTS_FP64 = Dict{cl.Device,Bool}()
9+
10+
function __init__()
11+
reset_gpu_device!()
12+
for dev in vcat(cl.devices.(cl.platforms())...)
13+
SUPPORTS_FP64[dev] = "cl_khr_fp64" in dev.extensions
14+
end
15+
return nothing
16+
end
17+
18+
MLDataDevices.loaded(::Union{OpenCLDevice,Type{<:OpenCLDevice}}) = true
19+
function MLDataDevices.functional(::Union{OpenCLDevice,Type{<:OpenCLDevice}})
20+
return !isempty(cl.platforms()) && !isempty(vcat(cl.devices.(cl.platforms())...))
21+
end
22+
23+
# Default RNG
24+
MLDataDevices.default_device_rng(::OpenCLDevice) = GPUArrays.default_rng(CLArray)
25+
26+
# Query Device from Array
27+
Internal.get_device(::CLArray) = OpenCLDevice()
28+
29+
Internal.get_device_type(::CLArray) = OpenCLDevice
30+
31+
# unsafe_free!
32+
function Internal.unsafe_free_internal!(::Type{OpenCLDevice}, x::AbstractArray)
33+
if applicable(OpenCL.unsafe_free!, x)
34+
OpenCL.unsafe_free!(x)
35+
else
36+
@warn "OpenCL.unsafe_free! is not defined for $(typeof(x))." maxlog = 1
37+
end
38+
return nothing
39+
end
40+
41+
# Device Transfer
42+
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
43+
@eval function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray{$(T1)})
44+
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
45+
if !SUPPORTS_FP64[cl.device()]
46+
@warn LazyString(
47+
"Double type is not supported on this device. Using `", $(T2), "` instead."
48+
)
49+
return CLArray{$(T2)}(x)
50+
end
51+
return CLArray(x)
52+
end
53+
54+
@eval function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray{$(T1)})
55+
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
56+
if !SUPPORTS_FP64[cl.device()] && $(T1) <: Union{Float64,ComplexF64}
57+
throw(
58+
ArgumentError(
59+
"FP64 is not supported on this device and eltype=nothing was specified"
60+
),
61+
)
62+
end
63+
return CLArray(x)
64+
end
65+
66+
@eval function Adapt.adapt_storage(
67+
::OpenCLDevice{T}, x::AbstractArray{$(T1)}
68+
) where {T<:AbstractFloat}
69+
MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x
70+
if T === Float64 && !SUPPORTS_FP64[cl.device()]
71+
throw(ArgumentError("FP64 is not supported on this device"))
72+
end
73+
return CLArray{T}(x)
74+
end
75+
end
76+
77+
opencl_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CLArray, CLArray, T, x)
78+
79+
function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray)
80+
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
81+
return opencl_array_adapt(Missing, x)
82+
end
83+
84+
function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray)
85+
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
86+
return opencl_array_adapt(Nothing, x)
87+
end
88+
89+
function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T<:AbstractFloat}
90+
MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x
91+
if T === Float64 && !SUPPORTS_FP64[cl.device()]
92+
throw(ArgumentError("FP64 is not supported on this device"))
93+
end
94+
return opencl_array_adapt(T, x)
95+
end
96+
97+
end

lib/MLDataDevices/src/MLDataDevices.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ export gpu_device, cpu_device
2121
export xla_device, reactant_device
2222

2323
export CPUDevice
24-
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
24+
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice
2525
export XLADevice, ReactantDevice
2626
export get_device, get_device_type
2727

lib/MLDataDevices/src/internal.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@ using ..MLDataDevices:
1212
AMDGPUDevice,
1313
MetalDevice,
1414
oneAPIDevice,
15+
OpenCLDevice,
1516
ReactantDevice,
1617
UnknownDevice,
1718
supported_gpu_backends,
1819
GPU_DEVICES,
1920
loaded,
2021
functional
2122

22-
for dev in (CPUDevice, MetalDevice, oneAPIDevice)
23+
for dev in (CPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)
2324
msg = "`device_id` is not applicable for `$dev`."
2425
@eval begin
2526
with_device(::Type{$dev}, ::Nothing) = $dev()
@@ -30,7 +31,7 @@ for dev in (CPUDevice, MetalDevice, oneAPIDevice)
3031
end
3132
end
3233

33-
for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
34+
for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
3435
tpkg = name === :CPU ? "" : string(name)
3536
ldev = Symbol(name, :Device)
3637
@eval begin
@@ -47,6 +48,7 @@ for T in (
4748
AMDGPUDevice{Nothing},
4849
MetalDevice,
4950
oneAPIDevice,
51+
OpenCLDevice,
5052
ReactantDevice,
5153
)
5254
@eval get_device_id(::$(T)) = nothing
@@ -116,7 +118,8 @@ function get_gpu_device(; force::Bool)
116118
a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support.
117119
b. `AMDGPU.jl` for AMD GPU ROCM Support.
118120
c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
119-
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog = 1
121+
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)
122+
e. `OpenCL.jl` for OpenCL support. (Experimental)""" maxlog = 1
120123
return CPUDevice
121124
end
122125

lib/MLDataDevices/src/public.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ MetalDevice() = MetalDevice{Missing}()
2121
struct oneAPIDevice{T<:EltypeAdaptorType} <: AbstractGPUDevice end
2222
oneAPIDevice() = oneAPIDevice{Missing}()
2323

24+
struct OpenCLDevice{T<:EltypeAdaptorType} <: AbstractGPUDevice end
25+
OpenCLDevice() = OpenCLDevice{Missing}()
26+
2427
struct ReactantDevice{C,D,S,T<:EltypeAdaptorType,TN} <: AbstractAcceleratorDevice
2528
client::C
2629
device::D
@@ -49,6 +52,7 @@ Base.eltype(::CUDADevice{D,T}) where {D,T} = T
4952
Base.eltype(::AMDGPUDevice{D,T}) where {D,T} = T
5053
Base.eltype(::MetalDevice{T}) where {T} = T
5154
Base.eltype(::oneAPIDevice{T}) where {T} = T
55+
Base.eltype(::OpenCLDevice{T}) where {T} = T
5256
Base.eltype(::ReactantDevice{C,D,S,T}) where {C,D,S,T} = T
5357

5458
# Helper functions to create devices with specific eltypes
@@ -78,6 +82,12 @@ function with_eltype(::oneAPIDevice, ::Type{T}) where {T<:AbstractFloat}
7882
return oneAPIDevice{T}()
7983
end
8084

85+
with_eltype(::OpenCLDevice, ::Nothing) = OpenCLDevice{Nothing}()
86+
with_eltype(::OpenCLDevice, ::Missing) = OpenCLDevice{Missing}()
87+
function with_eltype(::OpenCLDevice, ::Type{T}) where {T<:AbstractFloat}
88+
return OpenCLDevice{T}()
89+
end
90+
8191
function with_eltype(dev::ReactantDevice{C,D,S,<:Any,TN}, ::Missing) where {C,D,S,TN}
8292
return ReactantDevice{C,D,S,Missing,TN}(dev.client, dev.device, dev.sharding)
8393
end
@@ -141,12 +151,13 @@ Checks if the trigger package for the device is loaded. Trigger packages are as
141151
- `AMDGPU.jl` for AMD GPU ROCM Support.
142152
- `Metal.jl` for Apple Metal GPU Support.
143153
- `oneAPI.jl` for Intel oneAPI GPU Support.
154+
- `OpenCL.jl` for OpenCL support.
144155
"""
145156
loaded(x) = false
146157
loaded(::Union{CPUDevice,Type{<:CPUDevice}}) = true
147158

148159
# Order is important here
149-
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
160+
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)
150161

151162
const GPU_DEVICE = Ref{Union{Nothing,AbstractDevice}}(nothing)
152163

@@ -204,13 +215,13 @@ Selects GPU device based on the following criteria:
204215
the device.
205216
- `missing` (default): Device specific. For `CUDADevice` this calls `CUDA.cu(x)`,
206217
for `AMDGPUDevice` this calls `AMDGPU.roc(x)`, for `MetalDevice` this calls
207-
`Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`.
218+
`Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`, for `OpenCLDevice` this calls `CLArray(x)`.
208219
- `nothing`: Preserves the original element type.
209220
- `Type{<:AbstractFloat}`: Converts floating-point arrays to the specified type.
210221
211222
!!! warning
212223
213-
`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`
224+
`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`, `OpenCL`
214225
and `CPU` backends, `device_id` is ignored and a warning is printed.
215226
216227
!!! warning
@@ -475,6 +486,8 @@ function set_device!(::Type{T}, dev_or_id) where {T<:AbstractDevice}
475486
@warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting."
476487
T === oneAPIDevice &&
477488
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."
489+
T === OpenCLDevice &&
490+
@warn "Support for Multi Device OpenCL hasn't been implemented yet. Ignoring the device setting."
478491
T === CPUDevice &&
479492
@warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting."
480493
T === ReactantDevice &&

lib/MLDataDevices/test/eltype_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@
123123
oneapi_old = oneAPIDevice()
124124
@test oneapi_old isa oneAPIDevice{Missing}
125125

126+
opencl_old = OpenCLDevice()
127+
@test opencl_old isa OpenCLDevice{Missing}
128+
126129
reactant_old = ReactantDevice()
127130
@test reactant_old isa ReactantDevice{Missing,Missing,Missing,Missing}
128131
end

lib/MLDataDevices/test/iterator_tests.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
using MLDataDevices, MLUtils, Test
2-
31
const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))
42

3+
if BACKEND_GROUP == "opencl" || BACKEND_GROUP == "all"
4+
using OpenCL, pocl_jll
5+
end
6+
7+
using MLDataDevices, MLUtils, Test
8+
59
if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all"
610
using LuxCUDA
711
end
@@ -23,7 +27,15 @@ if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
2327
using oneAPI
2428
end
2529

26-
DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice]
30+
DEVICES = [
31+
CPUDevice,
32+
CUDADevice,
33+
AMDGPUDevice,
34+
MetalDevice,
35+
oneAPIDevice,
36+
OpenCLDevice,
37+
ReactantDevice,
38+
]
2739

2840
freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
2941
freed_if_can_be_freed(::Type{CPUDevice}, x) = true

lib/MLDataDevices/test/misc_tests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,12 @@ end
137137
:AMDGPU,
138138
:oneAPI,
139139
:Metal,
140+
:OpenCL,
140141
AMDGPUDevice(),
141142
CUDADevice(),
142143
MetalDevice(),
143144
oneAPIDevice(),
145+
OpenCLDevice(),
144146
)
145147
backend_name = if backend isa Symbol
146148
string(backend)

0 commit comments

Comments
 (0)