Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/MLDataDevices/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLDataDevices"
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.15.2"
version = "1.15.3"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
51 changes: 28 additions & 23 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ end
function to_rarray_internal end

# Utility function to facilitate data transfer
# For AbstractFloat and Complex{<:AbstractFloat} arrays, we provide specialized methods to avoid
# ambiguity with the general fallback and to enable efficient type conversion when needed.
function array_adapt(
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:AbstractFloat}
) where {F,aType}
Expand All @@ -296,16 +298,6 @@ function array_adapt(
) where {F,aType}
return f(x)
end
function array_adapt(
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:Number}
) where {F,aType}
return f(x)
end
function array_adapt(
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:AbstractChar}
) where {F,aType}
return f(x)
end

function array_adapt(
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{<:AbstractFloat}
Expand All @@ -317,17 +309,8 @@ function array_adapt(
) where {F,aType}
return aType(x)
end
function array_adapt(
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{<:Number}
) where {F,aType}
return aType(x)
end
function array_adapt(
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{<:AbstractChar}
) where {F,aType}
return aType(x)
end

# For specific type parameters, we do type conversion
function array_adapt(
::F, ::Type{aType}, ::Type{T}, x::AbstractArray{<:AbstractFloat}
) where {F,aType,T}
Expand All @@ -338,14 +321,36 @@ function array_adapt(
) where {F,aType,T}
return aType{Complex{T}}(x)
end

# Fallback for all other isbits types (e.g., Int32, Char, or custom immutable structs)
function array_adapt(
::F, ::Type{aType}, ::Type{T}, x::AbstractArray{<:Number}
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{T}
) where {F,aType,T}
return aType(x)
isbitstype(T) || error(
"Cannot move array with element type `$(T)` to device. Element type must be an \
`isbits` type (e.g., Number, Char, or custom immutable struct with only \
`isbits` fields)."
)
return f(x)
end
Comment on lines 326 to 335
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot Number/AbstractChar are isbitstype, we can remove the old dispatches

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the redundant dispatches for Number and AbstractChar types. They're now handled by the fallback methods since they're isbits types. Kept only the specialized methods for AbstractFloat and Complex{<:AbstractFloat} for type conversion. (commit 0ae7ae4)


function array_adapt(
::F, ::Type{aType}, ::Type{T}, x::AbstractArray{<:AbstractChar}
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{T}
) where {F,aType,T}
isbitstype(T) || error(
"Cannot move array with element type `$(T)` to device. Element type must be an \
`isbits` type (e.g., Number, Char, or custom immutable struct with only \
`isbits` fields)."
)
return aType(x)
end

function array_adapt(::F, ::Type{aType}, ::Type{E}, x::AbstractArray{T}) where {F,aType,E,T}
isbitstype(T) || error(
"Cannot move array with element type `$(T)` to device. Element type must be an \
`isbits` type (e.g., Number, Char, or custom immutable struct with only \
`isbits` fields)."
)
return aType(x)
end

Expand Down
23 changes: 23 additions & 0 deletions lib/MLDataDevices/test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,29 @@ using FillArrays, Zygote # Extensions
y = device(x)
@test x === y
end

@testset "IsBits Types" begin
# Test that custom isbits types can be transferred to GPU
struct SimpleBitsAMD
field::Int32
end

isbits_array = [SimpleBitsAMD(1), SimpleBitsAMD(2), SimpleBitsAMD(3)]
isbits_array_xpu = device(isbits_array)

if MLDataDevices.functional(AMDGPUDevice)
@test isbits_array_xpu isa ROCArray{SimpleBitsAMD}
@test Array(isbits_array_xpu) == isbits_array

# Test transfer back to CPU
isbits_array_cpu = cpu_device()(isbits_array_xpu)
@test isbits_array_cpu isa Array{SimpleBitsAMD}
@test isbits_array_cpu == isbits_array
else
@test isbits_array_xpu isa Array{SimpleBitsAMD}
@test isbits_array_xpu == isbits_array
end
end
end

@testset "Functions" begin
Expand Down
23 changes: 23 additions & 0 deletions lib/MLDataDevices/test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,29 @@ using FillArrays, Zygote # Extensions
@test char_array_xpu == char_array
end
end

@testset "IsBits Types" begin
# Test that custom isbits types can be transferred to GPU
struct SimpleBits
field::Int32
end

isbits_array = [SimpleBits(1), SimpleBits(2), SimpleBits(3)]
isbits_array_xpu = device(isbits_array)

if MLDataDevices.functional(CUDADevice)
@test isbits_array_xpu isa CuArray{SimpleBits}
@test Array(isbits_array_xpu) == isbits_array

# Test transfer back to CPU
isbits_array_cpu = cpu_device()(isbits_array_xpu)
@test isbits_array_cpu isa Array{SimpleBits}
@test isbits_array_cpu == isbits_array
else
@test isbits_array_xpu isa Array{SimpleBits}
@test isbits_array_xpu == isbits_array
end
end
end

@testset "Functions" begin
Expand Down
23 changes: 23 additions & 0 deletions lib/MLDataDevices/test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,29 @@ using FillArrays, Zygote # Extensions
y = device(x)
@test x === y
end

@testset "IsBits Types" begin
# Test that custom isbits types can be transferred to GPU
struct SimpleBitsMetal
field::Int32
end

isbits_array = [SimpleBitsMetal(1), SimpleBitsMetal(2), SimpleBitsMetal(3)]
isbits_array_xpu = device(isbits_array)

if MLDataDevices.functional(MetalDevice)
@test isbits_array_xpu isa MtlArray{SimpleBitsMetal}
@test Array(isbits_array_xpu) == isbits_array

# Test transfer back to CPU
isbits_array_cpu = cpu_device()(isbits_array_xpu)
@test isbits_array_cpu isa Array{SimpleBitsMetal}
@test isbits_array_cpu == isbits_array
else
@test isbits_array_xpu isa Array{SimpleBitsMetal}
@test isbits_array_xpu == isbits_array
end
end
end

@testset "Functions" begin
Expand Down
23 changes: 23 additions & 0 deletions lib/MLDataDevices/test/oneapi_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,29 @@ using FillArrays, Zygote # Extensions
y = device(x)
@test x === y
end

@testset "IsBits Types" begin
# Test that custom isbits types can be transferred to GPU
struct SimpleBitsOneAPI
field::Int32
end

isbits_array = [SimpleBitsOneAPI(1), SimpleBitsOneAPI(2), SimpleBitsOneAPI(3)]
isbits_array_xpu = device(isbits_array)

if MLDataDevices.functional(oneAPIDevice)
@test isbits_array_xpu isa oneArray{SimpleBitsOneAPI}
@test Array(isbits_array_xpu) == isbits_array

# Test transfer back to CPU
isbits_array_cpu = cpu_device()(isbits_array_xpu)
@test isbits_array_cpu isa Array{SimpleBitsOneAPI}
@test isbits_array_cpu == isbits_array
else
@test isbits_array_xpu isa Array{SimpleBitsOneAPI}
@test isbits_array_xpu == isbits_array
end
end
end

@testset "Functions" begin
Expand Down
Loading