Skip to content

Commit 35498ed

Browse files
committed
Add MtlCellArrays
1 parent 13ec0e3 commit 35498ed

File tree

4 files changed

+172
-99
lines changed

4 files changed

+172
-99
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1010
[weakdeps]
1111
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
1314

1415
[compat]
1516
Adapt = "3, 4"
1617
AMDGPU = "0.3.7, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1"
1718
CUDA = "3.12, 4, 5"
19+
Metal = "1"
1820
julia = "1.9" # Minimum required Julia version (supporting extensions and weak dependencies)
1921
StaticArrays = "1"
2022

2123
[extras]
2224
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2325

2426
[targets]
25-
test = ["Test", "AMDGPU", "CUDA"]
27+
test = ["Test", "AMDGPU", "CUDA", "Metal"]

src/CellArray.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,44 @@ function define_ROCCellArray()
188188
end
189189
end
190190

191+
"""
192+
@define_MtlCellArray
193+
194+
Define the following type alias and constructors in the caller module:
195+
196+
********************************************************************************
197+
MtlCellArray{T<:Cell,N,B,T_elem} <: AbstractArray{T,N} where Cell <: Union{Number, SArray, FieldArray}
198+
199+
`N`-dimensional CellArray with cells of type `T`, blocklength `B`, and `T_array` being a `MtlArray` of element type `T_elem`: alias for `CellArray{T,N,B,MtlArray{T_elem,CellArrays._N}}`.
200+
201+
--------------------------------------------------------------------------------
202+
203+
MtlCellArray{T,B}(undef, dims)
204+
MtlCellArray{T}(undef, dims)
205+
206+
Construct an uninitialized `N`-dimensional `CellArray` containing `Cells` of type `T` which are stored in an array of kind `MtlArray`.
207+
208+
See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`CuCellArray`](@ref), [`ROCCellArray`](@ref)
209+
********************************************************************************
210+
211+
!!! note "Avoiding unneeded dependencies"
212+
The type aliases and constructors for GPU `CellArray`s are provided via macros to avoid unneeded dependencies on the GPU packages in CellArrays.
213+
214+
See also: [`@define_CuCellArray`](@ref), [`@define_ROCCellArray`](@ref)
215+
"""
216+
macro define_MtlCellArray() esc(define_MtlCellArray()) end
217+
218+
function define_MtlCellArray()
219+
quote
220+
const MtlCellArray{T,N,B,T_elem} = CellArrays.CellArray{T,N,B,Metal.MtlArray{T_elem,CellArrays._N}}
221+
222+
MtlCellArray{T,B}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N,B} = (CellArrays.check_T(T); MtlCellArray{T,N,B,CellArrays.eltype(T)}(undef, dims))
223+
MtlCellArray{T,B}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell,B} = MtlCellArray{T,B}(undef, dims)
224+
MtlCellArray{T}(::UndefInitializer, dims::NTuple{N,Int}) where {T<:CellArrays.Cell,N} = MtlCellArray{T,0}(undef, dims)
225+
MtlCellArray{T}(::UndefInitializer, dims::Int...) where {T<:CellArrays.Cell} = MtlCellArray{T}(undef, dims)
226+
end
227+
end
228+
191229

192230
## AbstractArray methods
193231

src/CellArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ https://github.com/omlins/CellArray.jl
1111
- [`CPUCellArray`](@ref)
1212
- `CuCellArray` (available via [`@define_CuCellArray`](@ref))
1313
- `ROCCellArray` (available via [`@define_ROCCellArray`](@ref))
14+
- `MtlCellArray` (available via [`@define_MtlCellArray`](@ref))
1415
1516
# Functions (additional to standard AbstractArray functionality)
1617
- [`cellsize`](@ref)
@@ -31,5 +32,5 @@ using .Exceptions
3132
include("CellArray.jl")
3233

3334
## Exports (need to be after include of submodules if re-exports from them)
34-
export CellArray, CPUCellArray, @define_CuCellArray, @define_ROCCellArray, cellsize, blocklength, field
35+
export CellArray, CPUCellArray, @define_CuCellArray, @define_ROCCellArray, @define_MtlCellArray, cellsize, blocklength, field
3536
end

0 commit comments

Comments
 (0)