Skip to content

Commit 9b1d55d

Browse files
committed
Add an index typevar to CuDeviceArray.
1 parent c97bc77 commit 9b1d55d

File tree

6 files changed

+40
-31
lines changed

6 files changed

+40
-31
lines changed

Diff for: src/device/array.jl

+30-23
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,38 @@ export CuDeviceArray, CuDeviceVector, CuDeviceMatrix, ldg
66
## construction
77

88
"""
9-
CuDeviceArray{T,N,A}(ptr, dims, [maxsize])
9+
CuDeviceArray{T,N,A,I}(ptr, dims, [maxsize])
1010
1111
Construct an `N`-dimensional dense CUDA device array with element type `T` wrapping a
12-
pointer, where `N` is determined from the length of `dims` and `T` is determined from the
13-
type of `ptr`. `dims` may be a single scalar, or a tuple of integers corresponding to the
14-
lengths in each dimension). If the rank `N` is supplied explicitly as in `Array{T,N}(dims)`,
15-
then it must match the length of `dims`. The same applies to the element type `T`, which
16-
should match the type of the pointer `ptr`.
12+
pointer `ptr` in address space `A`. `dims` should be a tuple of `N` integers corresponding
13+
to the lengths in each dimension. `maxsize` is the maximum number of bytes that can be
14+
stored in the array, and is determined automatically if not specified. `I` is the integer
15+
type used to store the size of the array, and is determined automatically if not specified.
1716
"""
1817
CuDeviceArray
1918

20-
# NOTE: we can't support the typical `tuple or series of integer` style construction,
21-
# because we're currently requiring a trailing pointer argument.
22-
23-
struct CuDeviceArray{T,N,A} <: DenseArray{T,N}
19+
struct CuDeviceArray{T,N,A,I} <: DenseArray{T,N}
2420
ptr::LLVMPtr{T,A}
25-
maxsize::Int
26-
27-
dims::Dims{N}
28-
len::Int
21+
maxsize::I
22+
23+
dims::NTuple{N,I}
24+
len::I
25+
26+
# determine an index type based on the size of the array.
27+
# this is type unstable, so only use this constructor from the host side.
28+
function CuDeviceArray{T,N,A}(ptr::LLVMPtr{T,A}, dims::Tuple,
29+
maxsize::Integer=prod(dims)*sizeof(T)) where {T,A,N}
30+
if maxsize <= typemax(Int32)
31+
CuDeviceArray{T,N,A,Int32}(ptr, dims, maxsize)
32+
else
33+
CuDeviceArray{T,N,A,Int64}(ptr, dims, maxsize)
34+
end
35+
end
2936

30-
# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
31-
CuDeviceArray{T,N,A}(ptr::LLVMPtr{T,A}, dims::Tuple,
32-
maxsize::Int=prod(dims)*sizeof(T)) where {T,A,N} =
33-
new(ptr, maxsize, dims, prod(dims))
37+
# fully typed, for use in device code
38+
CuDeviceArray{T,N,A,I}(ptr::LLVMPtr{T,A}, dims::Tuple,
39+
maxsize::Integer=prod(dims)*sizeof(T)) where {T,A,N,I} =
40+
new{T,N,A,I}(ptr, convert(I, maxsize), map(I, dims), convert(I, prod(dims)))
3441
end
3542

3643
const CuDeviceVector = CuDeviceArray{T,1,A} where {T,A}
@@ -224,18 +231,18 @@ Base.show(io::IO, mime::MIME"text/plain", a::CuDeviceArray) = show(io, a)
224231
end
225232
end
226233

227-
function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A}) where {T,S,N,A}
234+
function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A,I}) where {T,S,N,A,I}
228235
err = GPUArrays._reinterpret_exception(T, a)
229236
err === nothing || throw(err)
230237

231238
if sizeof(T) == sizeof(S) # fast case
232-
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
239+
return CuDeviceArray{T,N,A,I}(reinterpret(LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
233240
end
234241

235242
isize = size(a)
236243
size1 = div(isize[1]*sizeof(S), sizeof(T))
237244
osize = tuple(size1, Base.tail(isize)...)
238-
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
245+
return CuDeviceArray{T,N,A,I}(reinterpret(LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
239246
end
240247

241248

@@ -252,7 +259,7 @@ function Base.reshape(a::CuDeviceArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M
252259
end
253260

254261
# create a derived device array (reinterpreted or reshaped) that's still a CuDeviceArray
255-
@inline function _derived_array(a::CuDeviceArray{<:Any,<:Any,A}, ::Type{T},
262+
@inline function _derived_array(a::CuDeviceArray{<:Any,<:Any,A,I}, ::Type{T},
256263
osize::Dims{N}) where {T, N, A}
257-
return CuDeviceArray{T,N,A}(a.ptr, osize, a.maxsize)
264+
return CuDeviceArray{T,N,A,I}(a.ptr, osize, a.maxsize)
258265
end

Diff for: src/device/intrinsics/memory_shared.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ generator function will be called dynamically.
1616
# NOTE: this relies on const-prop to forward the literal length to the generator.
1717
# maybe we should include the size in the type, like StaticArrays does?
1818
ptr = emit_shmem(T, Val(len))
19-
CuDeviceArray{T,N,AS.Shared}(ptr, dims)
19+
# XXX: 4GB ought to be enough shared memory for anybody
20+
CuDeviceArray{T,N,AS.Shared,Int32}(ptr, dims)
2021
end
2122
CuStaticSharedArray(::Type{T}, len::Integer) where {T} = CuStaticSharedArray(T, (len,))
2223

@@ -53,7 +54,8 @@ shared memory; in the case of a homogeneous multi-part buffer it is preferred to
5354
end
5455
end
5556
ptr = emit_shmem(T) + offset
56-
CuDeviceArray{T,N,AS.Shared}(ptr, dims)
57+
# XXX: 4GB ought to be enough shared memory for anybody
58+
CuDeviceArray{T,N,AS.Shared,Int32}(ptr, dims)
5759
end
5860
Base.@propagate_inbounds CuDynamicSharedArray(::Type{T}, len::Integer, offset) where {T} =
5961
CuDynamicSharedArray(T, (len,), offset)

Diff for: src/device/random.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import RandomNumbers
2222
}
2323
attributes #0 = { alwaysinline }
2424
""", "entry"), LLVMPtr{UInt32, AS.Shared}, Tuple{})
25-
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
25+
CuDeviceArray{UInt32,1,AS.Shared,Int32}(ptr, (32,))
2626
end
2727

2828
# array with per-warp counters, incremented when generating numbers
@@ -36,7 +36,7 @@ end
3636
}
3737
attributes #0 = { alwaysinline }
3838
""", "entry"), LLVMPtr{UInt32, AS.Shared}, Tuple{})
39-
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
39+
CuDeviceArray{UInt32,1,AS.Shared,Int32}(ptr, (32,))
4040
end
4141

4242
# initialization function, called automatically at the start of each kernel because
@@ -192,7 +192,7 @@ end
192192
for var in [:ki, :wi, :fi, :ke, :we, :fe]
193193
val = getfield(Random, var)
194194
gpu_var = Symbol("gpu_$var")
195-
arr_typ = :(CuDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant})
195+
arr_typ = :(CuDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant,Int32})
196196
@eval @inline @generated function $gpu_var()
197197
ptr = emit_constant_array($(QuoteNode(var)), $val)
198198
Expr(:call, $arr_typ, ptr, $(size(val)))

Diff for: test/core/codegen.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ end
153153
return
154154
end
155155

156-
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{2,CuDeviceArray{Float32,1,AS.Global}}))
156+
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{2,CuDeviceArray{Float32,1,AS.Global,Int32}}))
157157
@test !occursin(".local", asm)
158158
end
159159

Diff for: test/core/device/intrinsics/math.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ using SpecialFunctions
143143
@inbounds b[], c[] = @fastmath sincos(a[])
144144
return
145145
end
146-
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{3,CuDeviceArray{Float32,1,AS.Global}}))
146+
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{3,CuDeviceArray{Float32,1,AS.Global,Int32}}))
147147
@assert contains(asm, "sin.approx.f32")
148148
@assert contains(asm, "cos.approx.f32")
149149
@assert !contains(asm, "__nv") # from libdevice

Diff for: test/core/device/intrinsics/wmma.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ end
344344
return
345345
end
346346

347-
ptx = sprint(io -> CUDA.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDA.AS.Global},)))
347+
ptx = sprint(io -> CUDA.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDA.AS.Global,Int32},)))
348348

349349
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
350350
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.global.f32", ptx)

0 commit comments

Comments
 (0)