Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/CI_MLDataDevices.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- windows-latest
group:
- cpu
- opencl
- reactant
uses: ./.github/workflows/CommonCI.yml
with:
Expand All @@ -40,6 +41,7 @@ jobs:
matrix:
group:
- cpu
- opencl
- reactant
uses: ./.github/workflows/CommonCI.yml
with:
Expand Down
5 changes: 4 additions & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.15.3"
version = "1.16.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -29,6 +29,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"
OpenCL = "08131aa3-fb12-5dee-8b74-c09406e224a2"

[extensions]
MLDataDevicesAMDGPUExt = "AMDGPU"
Expand All @@ -49,6 +50,7 @@ MLDataDevicesTrackerExt = "Tracker"
MLDataDevicesZygoteExt = "Zygote"
MLDataDevicescuDNNExt = ["CUDA", "cuDNN"]
MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
MLDataDevicesOpenCLExt = ["GPUArrays", "OpenCL"]

[compat]
AMDGPU = "1, 2"
Expand All @@ -75,3 +77,4 @@ Zygote = "0.7"
cuDNN = "1.3"
julia = "1.10"
oneAPI = "1.5, 2"
OpenCL = "0.10.5"
97 changes: 97 additions & 0 deletions lib/MLDataDevices/ext/MLDataDevicesOpenCLExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
module MLDataDevicesOpenCLExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using OpenCL: OpenCL, cl, CLArray
using MLDataDevices: MLDataDevices, Internal, OpenCLDevice, reset_gpu_device!

const SUPPORTS_FP64 = Dict{cl.Device,Bool}()

function __init__()
reset_gpu_device!()
for dev in vcat(cl.devices.(cl.platforms())...)
SUPPORTS_FP64[dev] = "cl_khr_fp64" in dev.extensions
end
return nothing
end

MLDataDevices.loaded(::Union{OpenCLDevice,Type{<:OpenCLDevice}}) = true
function MLDataDevices.functional(::Union{OpenCLDevice,Type{<:OpenCLDevice}})
return !isempty(cl.platforms()) && !isempty(vcat(cl.devices.(cl.platforms())...))
end
Comment on lines 19 to 21
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no nicer way than a try/catch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the most straightforward way to test for functionality, but the change in the latest commit should work good enough.


# Default RNG
MLDataDevices.default_device_rng(::OpenCLDevice) = GPUArrays.default_rng(CLArray)

# Query Device from Array
Internal.get_device(::CLArray) = OpenCLDevice()

Internal.get_device_type(::CLArray) = OpenCLDevice

# unsafe_free!
function Internal.unsafe_free_internal!(::Type{OpenCLDevice}, x::AbstractArray)
if applicable(OpenCL.unsafe_free!, x)
OpenCL.unsafe_free!(x)
else
@warn "OpenCL.unsafe_free! is not defined for $(typeof(x))." maxlog = 1
end
return nothing
end

# Device Transfer
for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32))
@eval function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray{$(T1)})
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
if !SUPPORTS_FP64[cl.device()]
@warn LazyString(
"Double type is not supported on this device. Using `", $(T2), "` instead."
)
return CLArray{$(T2)}(x)
end
return CLArray(x)
end

@eval function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray{$(T1)})
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
if !SUPPORTS_FP64[cl.device()] && $(T1) <: Union{Float64,ComplexF64}
throw(
ArgumentError(
"FP64 is not supported on this device and eltype=nothing was specified"
),
)
end
return CLArray(x)
end

@eval function Adapt.adapt_storage(
::OpenCLDevice{T}, x::AbstractArray{$(T1)}
) where {T<:AbstractFloat}
MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x
if T === Float64 && !SUPPORTS_FP64[cl.device()]
throw(ArgumentError("FP64 is not supported on this device"))
end
return CLArray{T}(x)
end
end

opencl_array_adapt(::Type{T}, x) where {T} = Internal.array_adapt(CLArray, CLArray, T, x)

function Adapt.adapt_storage(::OpenCLDevice{Missing}, x::AbstractArray)
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
return opencl_array_adapt(Missing, x)
end

function Adapt.adapt_storage(::OpenCLDevice{Nothing}, x::AbstractArray)
MLDataDevices.get_device_type(x) <: OpenCLDevice && return x
return opencl_array_adapt(Nothing, x)
end

function Adapt.adapt_storage(::OpenCLDevice{T}, x::AbstractArray) where {T<:AbstractFloat}
MLDataDevices.get_device_type(x) <: OpenCLDevice && eltype(x) == T && return x
if T === Float64 && !SUPPORTS_FP64[cl.device()]
throw(ArgumentError("FP64 is not supported on this device"))
end
return opencl_array_adapt(T, x)
end

end
2 changes: 1 addition & 1 deletion lib/MLDataDevices/src/MLDataDevices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export gpu_device, cpu_device
export xla_device, reactant_device

export CPUDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice
export CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice
export XLADevice, ReactantDevice
export get_device, get_device_type

Expand Down
9 changes: 6 additions & 3 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ using ..MLDataDevices:
AMDGPUDevice,
MetalDevice,
oneAPIDevice,
OpenCLDevice,
ReactantDevice,
UnknownDevice,
supported_gpu_backends,
GPU_DEVICES,
loaded,
functional

for dev in (CPUDevice, MetalDevice, oneAPIDevice)
for dev in (CPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)
msg = "`device_id` is not applicable for `$dev`."
@eval begin
with_device(::Type{$dev}, ::Nothing) = $dev()
Expand All @@ -30,7 +31,7 @@ for dev in (CPUDevice, MetalDevice, oneAPIDevice)
end
end

for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :OpenCL)
tpkg = name === :CPU ? "" : string(name)
ldev = Symbol(name, :Device)
@eval begin
Expand All @@ -47,6 +48,7 @@ for T in (
AMDGPUDevice{Nothing},
MetalDevice,
oneAPIDevice,
OpenCLDevice,
ReactantDevice,
)
@eval get_device_id(::$(T)) = nothing
Expand Down Expand Up @@ -116,7 +118,8 @@ function get_gpu_device(; force::Bool)
a. `CUDA.jl` and `cuDNN.jl` (or just `LuxCUDA.jl`) for NVIDIA CUDA Support.
b. `AMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support. (Experimental)
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog = 1
d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)
e. `OpenCL.jl` for OpenCL support. (Experimental)""" maxlog = 1
return CPUDevice
end

Expand Down
19 changes: 16 additions & 3 deletions lib/MLDataDevices/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ MetalDevice() = MetalDevice{Missing}()
struct oneAPIDevice{T<:EltypeAdaptorType} <: AbstractGPUDevice end
oneAPIDevice() = oneAPIDevice{Missing}()

struct OpenCLDevice{T<:EltypeAdaptorType} <: AbstractGPUDevice end
OpenCLDevice() = OpenCLDevice{Missing}()

struct ReactantDevice{C,D,S,T<:EltypeAdaptorType,TN} <: AbstractAcceleratorDevice
client::C
device::D
Expand Down Expand Up @@ -49,6 +52,7 @@ Base.eltype(::CUDADevice{D,T}) where {D,T} = T
Base.eltype(::AMDGPUDevice{D,T}) where {D,T} = T
Base.eltype(::MetalDevice{T}) where {T} = T
Base.eltype(::oneAPIDevice{T}) where {T} = T
Base.eltype(::OpenCLDevice{T}) where {T} = T
Base.eltype(::ReactantDevice{C,D,S,T}) where {C,D,S,T} = T

# Helper functions to create devices with specific eltypes
Expand Down Expand Up @@ -78,6 +82,12 @@ function with_eltype(::oneAPIDevice, ::Type{T}) where {T<:AbstractFloat}
return oneAPIDevice{T}()
end

with_eltype(::OpenCLDevice, ::Nothing) = OpenCLDevice{Nothing}()
with_eltype(::OpenCLDevice, ::Missing) = OpenCLDevice{Missing}()
function with_eltype(::OpenCLDevice, ::Type{T}) where {T<:AbstractFloat}
return OpenCLDevice{T}()
end

function with_eltype(dev::ReactantDevice{C,D,S,<:Any,TN}, ::Missing) where {C,D,S,TN}
return ReactantDevice{C,D,S,Missing,TN}(dev.client, dev.device, dev.sharding)
end
Expand Down Expand Up @@ -141,12 +151,13 @@ Checks if the trigger package for the device is loaded. Trigger packages are as
- `AMDGPU.jl` for AMD GPU ROCM Support.
- `Metal.jl` for Apple Metal GPU Support.
- `oneAPI.jl` for Intel oneAPI GPU Support.
- `OpenCL.jl` for OpenCL support.
"""
loaded(x) = false
loaded(::Union{CPUDevice,Type{<:CPUDevice}}) = true

# Order is important here
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice)
const GPU_DEVICES = (CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, OpenCLDevice)

const GPU_DEVICE = Ref{Union{Nothing,AbstractDevice}}(nothing)

Expand Down Expand Up @@ -204,13 +215,13 @@ Selects GPU device based on the following criteria:
the device.
- `missing` (default): Device specific. For `CUDADevice` this calls `CUDA.cu(x)`,
for `AMDGPUDevice` this calls `AMDGPU.roc(x)`, for `MetalDevice` this calls
`Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`.
`Metal.mtl(x)`, for `oneAPIDevice` this calls `oneArray(x)`, for `OpenCLDevice` this calls `CLArray(x)`.
- `nothing`: Preserves the original element type.
- `Type{<:AbstractFloat}`: Converts floating-point arrays to the specified type.

!!! warning

`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`
`device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI`, `OpenCL`
and `CPU` backends, `device_id` is ignored and a warning is printed.

!!! warning
Expand Down Expand Up @@ -475,6 +486,8 @@ function set_device!(::Type{T}, dev_or_id) where {T<:AbstractDevice}
@warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting."
T === oneAPIDevice &&
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting."
T === OpenCLDevice &&
@warn "Support for Multi Device OpenCL hasn't been implemented yet. Ignoring the device setting."
T === CPUDevice &&
@warn "Setting device for `CPUDevice` doesn't make sense. Ignoring the device setting."
T === ReactantDevice &&
Expand Down
3 changes: 3 additions & 0 deletions lib/MLDataDevices/test/eltype_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@
oneapi_old = oneAPIDevice()
@test oneapi_old isa oneAPIDevice{Missing}

opencl_old = OpenCLDevice()
@test opencl_old isa OpenCLDevice{Missing}

reactant_old = ReactantDevice()
@test reactant_old isa ReactantDevice{Missing,Missing,Missing,Missing}
end
Expand Down
18 changes: 15 additions & 3 deletions lib/MLDataDevices/test/iterator_tests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
using MLDataDevices, MLUtils, Test

const BACKEND_GROUP = lowercase(get(ENV, "BACKEND_GROUP", "none"))

if BACKEND_GROUP == "opencl" || BACKEND_GROUP == "all"
using OpenCL, pocl_jll
end

using MLDataDevices, MLUtils, Test

if BACKEND_GROUP == "cuda" || BACKEND_GROUP == "all"
using LuxCUDA
end
Expand All @@ -23,7 +27,15 @@ if BACKEND_GROUP == "oneapi" || BACKEND_GROUP == "all"
using oneAPI
end

DEVICES = [CPUDevice, CUDADevice, AMDGPUDevice, MetalDevice, oneAPIDevice, ReactantDevice]
DEVICES = [
CPUDevice,
CUDADevice,
AMDGPUDevice,
MetalDevice,
oneAPIDevice,
OpenCLDevice,
ReactantDevice,
]

freed_if_can_be_freed(x) = freed_if_can_be_freed(get_device_type(x), x)
freed_if_can_be_freed(::Type{CPUDevice}, x) = true
Expand Down
2 changes: 2 additions & 0 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,12 @@ end
:AMDGPU,
:oneAPI,
:Metal,
:OpenCL,
AMDGPUDevice(),
CUDADevice(),
MetalDevice(),
oneAPIDevice(),
OpenCLDevice(),
)
backend_name = if backend isa Symbol
string(backend)
Expand Down
Loading
Loading