@@ -21,6 +21,9 @@ MetalDevice() = MetalDevice{Missing}()
2121struct oneAPIDevice{T<: EltypeAdaptorType } <: AbstractGPUDevice end
2222oneAPIDevice() = oneAPIDevice{Missing}()
2323
24+ struct OpenCLDevice{T<: EltypeAdaptorType } <: AbstractGPUDevice end
25+ OpenCLDevice() = OpenCLDevice{Missing}()
26+
2427struct ReactantDevice{C,D,S,T<: EltypeAdaptorType ,TN} <: AbstractAcceleratorDevice
2528 client:: C
2629 device:: D
@@ -49,6 +52,7 @@ Base.eltype(::CUDADevice{D,T}) where {D,T} = T
4952Base. eltype(:: AMDGPUDevice{D,T} ) where {D,T} = T
5053Base. eltype(:: MetalDevice{T} ) where {T} = T
5154Base. eltype(:: oneAPIDevice{T} ) where {T} = T
55+ Base. eltype(:: OpenCLDevice{T} ) where {T} = T
5256Base. eltype(:: ReactantDevice{C,D,S,T} ) where {C,D,S,T} = T
5357
5458# Helper functions to create devices with specific eltypes
@@ -78,6 +82,12 @@ function with_eltype(::oneAPIDevice, ::Type{T}) where {T<:AbstractFloat}
7882 return oneAPIDevice{T}()
7983end
8084
85+ with_eltype(:: OpenCLDevice , :: Nothing ) = OpenCLDevice{Nothing}()
86+ with_eltype(:: OpenCLDevice , :: Missing ) = OpenCLDevice{Missing}()
87+ function with_eltype(:: OpenCLDevice , :: Type{T} ) where {T<: AbstractFloat }
88+ return OpenCLDevice{T}()
89+ end
90+
8191function with_eltype(dev:: ReactantDevice{C,D,S,<:Any,TN} , :: Missing ) where {C,D,S,TN}
8292 return ReactantDevice{C,D,S,Missing,TN}(dev. client, dev. device, dev. sharding)
8393end
@@ -141,12 +151,13 @@ Checks if the trigger package for the device is loaded. Trigger packages are as
141151 - `AMDGPU.jl` for AMD GPU ROCM Support.
142152 - `Metal.jl` for Apple Metal GPU Support.
143153 - `oneAPI.jl` for Intel oneAPI GPU Support.
154+ - `OpenCL.jl` for OpenCL support.
144155"""
145156loaded(x) = false
146157loaded(:: Union{CPUDevice,Type{<:CPUDevice}} ) = true
147158
148159# Order is important here
149- const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
160+ const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice )
150161
151162const GPU_DEVICE = Ref{Union{Nothing,AbstractDevice}}(nothing )
152163
@@ -204,13 +215,13 @@ Selects GPU device based on the following criteria:
204215 the device.
205216 - `missing` (default): Device specific. For `CUDADevice` this calls `CUDA.cu(x)`,
206217 for `AMDGPUDevice` this calls `AMDGPU.roc(x)`, for `MetalDevice` this calls
207- `Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`.
218+ `Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`, for `OpenCLDevice` this calls `CLArray(x)` .
208219 - `nothing`: Preserves the original element type.
209220 - `Type{<:AbstractFloat}`: Converts floating-point arrays to the specified type.
210221
211222!!! warning
212223
213- `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`
224+ `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`, `OpenCL`
214225 and `CPU` backends, `device_id` is ignored and a warning is printed.
215226
216227!!! warning
@@ -475,6 +486,8 @@ function set_device!(::Type{T}, dev_or_id) where {T<:AbstractDevice}
475486 @warn " Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting."
476487 T === oneAPIDevice &&
477488 @warn " Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."
489+ T === OpenCLDevice &&
490+ @warn " Support for Multi Device OpenCL hasn't been implemented yet. Ignoring the device setting."
478491 T === CPUDevice &&
479492 @warn " Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting."
480493 T === ReactantDevice &&
0 commit comments