@@ -188,6 +188,44 @@ function define_ROCCellArray()
188
188
end
189
189
end
190
190
191
+ """
192
+ @define_MtlCellArray
193
+
194
+ Define the following type alias and constructors in the caller module:
195
+
196
+ ********************************************************************************
197
+ MtlCellArray{T<:Cell,N,B,T_elem} <: AbstractArray{T,N} where Cell <: Union{Number, SArray, FieldArray}
198
+
199
+ `N`-dimensional CellArray with cells of type `T`, blocklength `B`, and `T_array` being a `MtlArray` of element type `T_elem`: alias for `CellArray{T,N,B,MtlArray{T_elem,CellArrays._N}}`.
200
+
201
+ --------------------------------------------------------------------------------
202
+
203
+ MtlCellArray{T,B}(undef, dims)
204
+ MtlCellArray{T}(undef, dims)
205
+
206
+ Construct an uninitialized `N`-dimensional `CellArray` containing `Cells` of type `T` which are stored in an array of kind `MtlArray`.
207
+
208
+ See also: [`CellArray`](@ref), [`CPUCellArray`](@ref), [`CuCellArray`](@ref), [`ROCCellArray`](@ref)
209
+ ********************************************************************************
210
+
211
+ !!! note "Avoiding unneeded dependencies"
212
+ The type aliases and constructors for GPU `CellArray`s are provided via macros to avoid unneeded dependencies on the GPU packages in CellArrays.
213
+
214
+ See also: [`@define_CuCellArray`](@ref), [`@define_ROCCellArray`](@ref)
215
+ """
216
+ macro define_MtlCellArray () esc (define_MtlCellArray ()) end
217
+
218
+ function define_MtlCellArray ()
219
+ quote
220
+ const MtlCellArray{T,N,B,T_elem} = CellArrays. CellArray{T,N,B,Metal. MtlArray{T_elem,CellArrays. _N}}
221
+
222
+ MtlCellArray {T,B} (:: UndefInitializer , dims:: NTuple{N,Int} ) where {T<: CellArrays.Cell ,N,B} = (CellArrays. check_T (T); MtlCellArray {T,N,B,CellArrays.eltype(T)} (undef, dims))
223
+ MtlCellArray {T,B} (:: UndefInitializer , dims:: Int... ) where {T<: CellArrays.Cell ,B} = MtlCellArray {T,B} (undef, dims)
224
+ MtlCellArray {T} (:: UndefInitializer , dims:: NTuple{N,Int} ) where {T<: CellArrays.Cell ,N} = MtlCellArray {T,0} (undef, dims)
225
+ MtlCellArray {T} (:: UndefInitializer , dims:: Int... ) where {T<: CellArrays.Cell } = MtlCellArray {T} (undef, dims)
226
+ end
227
+ end
228
+
191
229
192
230
# # AbstractArray methods
193
231
0 commit comments