Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lib/MLDataDevices/src/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ function array_adapt(
) where {F,aType}
return f(x)
end
function array_adapt(
f::F, ::Type{aType}, ::Type{Missing}, x::AbstractArray{<:AbstractChar}
) where {F,aType}
return f(x)
end

function array_adapt(
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{<:AbstractFloat}
Expand All @@ -317,6 +322,11 @@ function array_adapt(
) where {F,aType}
return aType(x)
end
function array_adapt(
::F, ::Type{aType}, ::Type{Nothing}, x::AbstractArray{<:AbstractChar}
) where {F,aType}
return aType(x)
end

function array_adapt(
::F, ::Type{aType}, ::Type{T}, x::AbstractArray{<:AbstractFloat}
Expand All @@ -333,5 +343,10 @@ function array_adapt(
) where {F,aType,T}
return aType(x)
end
function array_adapt(
::F, ::Type{aType}, ::Type{T}, x::AbstractArray{<:AbstractChar}
) where {F,aType,T}
return aType(x)
end

end
19 changes: 19 additions & 0 deletions lib/MLDataDevices/test/cuda_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ using FillArrays, Zygote # Extensions
y = device(x)
@test x === y
end

@testset "Character Arrays" begin
# Test that character arrays can be transferred to GPU
char_array = ['a', 'b', 'c']
char_array_xpu = device(char_array)

if MLDataDevices.functional(CUDADevice)
@test char_array_xpu isa CuArray{Char}
@test Array(char_array_xpu) == char_array

# Test transfer back to CPU
char_array_cpu = cpu_device()(char_array_xpu)
@test char_array_cpu isa Array{Char}
@test char_array_cpu == char_array
else
@test char_array_xpu isa Array{Char}
@test char_array_xpu == char_array
end
end
end

@testset "Functions" begin
Expand Down
15 changes: 15 additions & 0 deletions lib/MLDataDevices/test/misc_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,18 @@ end
@test get_device(x) == CPUDevice()
@test get_device_type(x) == CPUDevice
end

@testset "Character Arrays" begin
# Test CPU device with character arrays
cdev = cpu_device()
char_array = ['a', 'b', 'c', 'd']
char_array_cpu = cdev(char_array)
@test char_array_cpu isa Array{Char}
@test char_array_cpu == char_array
@test get_device(char_array_cpu) isa CPUDevice

# Test GPU device with character arrays
gdev = gpu_device()
char_array_gpu = gdev(char_array)
@test get_device(char_array_gpu) isa parameterless_type(typeof(gdev))
end
Loading