@@ -6,31 +6,38 @@ export CuDeviceArray, CuDeviceVector, CuDeviceMatrix, ldg
6
6
# # construction
7
7
8
8
"""
9
- CuDeviceArray{T,N,A}(ptr, dims, [maxsize])
9
+ CuDeviceArray{T,N,A,I }(ptr, dims, [maxsize])
10
10
11
11
Construct an `N`-dimensional dense CUDA device array with element type `T` wrapping a
12
- pointer, where `N` is determined from the length of `dims` and `T` is determined from the
13
- type of `ptr`. `dims` may be a single scalar, or a tuple of integers corresponding to the
14
- lengths in each dimension). If the rank `N` is supplied explicitly as in `Array{T,N}(dims)`,
15
- then it must match the length of `dims`. The same applies to the element type `T`, which
16
- should match the type of the pointer `ptr`.
12
+ pointer `ptr` in address space `A`. `dims` should be a tuple of `N` integers corresponding
13
+ to the lengths in each dimension. `maxsize` is the maximum number of bytes that can be
14
+ stored in the array, and is determined automatically if not specified. `I` is the integer
15
+ type used to store the size of the array, and is determined automatically if not specified.
17
16
"""
18
17
CuDeviceArray
19
18
20
- # NOTE: we can't support the typical `tuple or series of integer` style construction,
21
- # because we're currently requiring a trailing pointer argument.
22
-
23
- struct CuDeviceArray{T,N,A} <: DenseArray{T,N}
19
+ struct CuDeviceArray{T,N,A,I} <: DenseArray{T,N}
24
20
ptr:: LLVMPtr{T,A}
25
- maxsize:: Int
26
-
27
- dims:: Dims{N}
28
- len:: Int
21
+ maxsize:: I
22
+
23
+ dims:: NTuple{N,I}
24
+ len:: I
25
+
26
+ # determine an index type based on the size of the array.
27
+ # this is type unstable, so only use this constructor from the host side.
28
+ function CuDeviceArray {T,N,A} (ptr:: LLVMPtr{T,A} , dims:: Tuple ,
29
+ maxsize:: Integer = prod (dims)* sizeof (T)) where {T,A,N}
30
+ if maxsize <= typemax (Int32)
31
+ CuDeviceArray {T,N,A,Int32} (ptr, dims, maxsize)
32
+ else
33
+ CuDeviceArray {T,N,A,Int64} (ptr, dims, maxsize)
34
+ end
35
+ end
29
36
30
- # inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
31
- CuDeviceArray {T,N,A} (ptr:: LLVMPtr{T,A} , dims:: Tuple ,
32
- maxsize:: Int = prod (dims)* sizeof (T)) where {T,A,N} =
33
- new (ptr, maxsize, dims, prod (dims))
37
+ # fully typed, for use in device code
38
+ CuDeviceArray {T,N,A,I } (ptr:: LLVMPtr{T,A} , dims:: Tuple ,
39
+ maxsize:: Integer = prod (dims)* sizeof (T)) where {T,A,N,I } =
40
+ new {T,N,A,I} (ptr, convert (I, maxsize), map (I, dims), convert (I, prod (dims) ))
34
41
end
35
42
36
43
const CuDeviceVector = CuDeviceArray{T,1 ,A} where {T,A}
@@ -224,18 +231,18 @@ Base.show(io::IO, mime::MIME"text/plain", a::CuDeviceArray) = show(io, a)
224
231
end
225
232
end
226
233
227
- function Base. reinterpret (:: Type{T} , a:: CuDeviceArray{S,N,A} ) where {T,S,N,A}
234
+ function Base. reinterpret (:: Type{T} , a:: CuDeviceArray{S,N,A,I } ) where {T,S,N,A,I }
228
235
err = GPUArrays. _reinterpret_exception (T, a)
229
236
err === nothing || throw (err)
230
237
231
238
if sizeof (T) == sizeof (S) # fast case
232
- return CuDeviceArray {T,N,A} (reinterpret (LLVMPtr{T,A}, a. ptr), size (a), a. maxsize)
239
+ return CuDeviceArray {T,N,A,I } (reinterpret (LLVMPtr{T,A}, a. ptr), size (a), a. maxsize)
233
240
end
234
241
235
242
isize = size (a)
236
243
size1 = div (isize[1 ]* sizeof (S), sizeof (T))
237
244
osize = tuple (size1, Base. tail (isize)... )
238
- return CuDeviceArray {T,N,A} (reinterpret (LLVMPtr{T,A}, a. ptr), osize, a. maxsize)
245
+ return CuDeviceArray {T,N,A,I } (reinterpret (LLVMPtr{T,A}, a. ptr), osize, a. maxsize)
239
246
end
240
247
241
248
@@ -252,7 +259,7 @@ function Base.reshape(a::CuDeviceArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
252
259
end
253
260
254
261
# create a derived device array (reinterpreted or reshaped) that's still a CuDeviceArray
255
- @inline function _derived_array (a:: CuDeviceArray{<:Any,<:Any,A} , :: Type{T} ,
262
+ @inline function _derived_array (a:: CuDeviceArray{<:Any,<:Any,A,I } , :: Type{T} ,
256
263
osize:: Dims{N} ) where {T, N, A}
257
- return CuDeviceArray {T,N,A} (a. ptr, osize, a. maxsize)
264
+ return CuDeviceArray {T,N,A,I } (a. ptr, osize, a. maxsize)
258
265
end
0 commit comments