Skip to content

Commit 254bc7e

Browse files
committed
Add an index typevar to CuDeviceArray.
1 parent 594a8b6 commit 254bc7e

File tree

6 files changed

+48
-39
lines changed

6 files changed

+48
-39
lines changed

Diff for: src/device/array.jl

+31-24
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,38 @@ export CuDeviceArray, CuDeviceVector, CuDeviceMatrix, ldg
66
## construction
77

88
"""
9-
CuDeviceArray{T,N,A}(ptr, dims, [maxsize])
9+
CuDeviceArray{T,N,A,I}(ptr, dims, [maxsize])
1010
1111
Construct an `N`-dimensional dense CUDA device array with element type `T` wrapping a
12-
pointer, where `N` is determined from the length of `dims` and `T` is determined from the
13-
type of `ptr`. `dims` may be a single scalar, or a tuple of integers corresponding to the
14-
lengths in each dimension). If the rank `N` is supplied explicitly as in `Array{T,N}(dims)`,
15-
then it must match the length of `dims`. The same applies to the element type `T`, which
16-
should match the type of the pointer `ptr`.
12+
pointer `ptr` in address space `A`. `dims` should be a tuple of `N` integers corresponding
13+
to the lengths in each dimension. `maxsize` is the maximum number of bytes that can be
14+
stored in the array, and is determined automatically if not specified. `I` is the integer
15+
type used to store the size of the array, and is determined automatically if not specified.
1716
"""
1817
CuDeviceArray
1918

20-
# NOTE: we can't support the typical `tuple or series of integer` style construction,
21-
# because we're currently requiring a trailing pointer argument.
22-
23-
struct CuDeviceArray{T,N,A} <: DenseArray{T,N}
19+
struct CuDeviceArray{T,N,A,I} <: DenseArray{T,N}
2420
ptr::LLVMPtr{T,A}
25-
maxsize::Int
26-
27-
dims::Dims{N}
28-
len::Int
21+
maxsize::I
22+
23+
dims::NTuple{N,I}
24+
len::I
25+
26+
# determine an index type based on the size of the array.
27+
# this is type unstable, so only use this constructor from the host side.
28+
function CuDeviceArray{T,N,A}(ptr::LLVMPtr{T,A}, dims::NTuple{N,<:Integer},
29+
maxsize::Integer=prod(dims)*sizeof(T)) where {T,A,N}
30+
if maxsize <= typemax(Int32)
31+
CuDeviceArray{T,N,A,Int32}(ptr, dims, maxsize)
32+
else
33+
CuDeviceArray{T,N,A,Int64}(ptr, dims, maxsize)
34+
end
35+
end
2936

30-
# inner constructors, fully parameterized, exact types (ie. Int not <:Integer)
31-
CuDeviceArray{T,N,A}(ptr::LLVMPtr{T,A}, dims::Tuple,
32-
maxsize::Int=prod(dims)*sizeof(T)) where {T,A,N} =
33-
new(ptr, maxsize, dims, prod(dims))
37+
# fully typed, for use in device code
38+
CuDeviceArray{T,N,A,I}(ptr::LLVMPtr{T,A}, dims::NTuple{N,<:Integer},
39+
maxsize::Integer=prod(dims)*sizeof(T)) where {T,A,N,I} =
40+
new{T,N,A,I}(ptr, convert(I, maxsize), map(I, dims), convert(I, prod(dims)))
3441
end
3542

3643
const CuDeviceVector = CuDeviceArray{T,1,A} where {T,A}
@@ -224,18 +231,18 @@ Base.show(io::IO, mime::MIME"text/plain", a::CuDeviceArray) = show(io, a)
224231
end
225232
end
226233

227-
function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A}) where {T,S,N,A}
234+
function Base.reinterpret(::Type{T}, a::CuDeviceArray{S,N,A,I}) where {T,S,N,A,I}
228235
err = _reinterpret_exception(T, a)
229236
err === nothing || throw(err)
230237

231238
if sizeof(T) == sizeof(S) # fast case
232-
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
239+
return CuDeviceArray{T,N,A,I}(reinterpret(LLVMPtr{T,A}, a.ptr), size(a), a.maxsize)
233240
end
234241

235242
isize = size(a)
236243
size1 = div(isize[1]*sizeof(S), sizeof(T))
237244
osize = tuple(size1, Base.tail(isize)...)
238-
return CuDeviceArray{T,N,A}(reinterpret(LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
245+
return CuDeviceArray{T,N,A,I}(reinterpret(LLVMPtr{T,A}, a.ptr), osize, a.maxsize)
239246
end
240247

241248

@@ -252,7 +259,7 @@ function Base.reshape(a::CuDeviceArray{T,M}, dims::NTuple{N,Int}) where {T,N,M}
252259
end
253260

254261
# create a derived device array (reinterpreted or reshaped) that's still a CuDeviceArray
255-
@inline function _derived_array(::Type{T}, N::Int, a::CuDeviceArray{T,M,A},
256-
osize::Dims) where {T, M, A}
257-
return CuDeviceArray{T,N,A}(a.ptr, osize, a.maxsize)
262+
@inline function _derived_array(::Type{T}, N::Int, a::CuDeviceArray{T,M,A,I},
263+
osize::Dims) where {T, M, A, I}
264+
return CuDeviceArray{T,N,A,I}(a.ptr, osize, a.maxsize)
258265
end

Diff for: src/device/intrinsics/memory_shared.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ generator function will be called dynamically.
1616
# NOTE: this relies on const-prop to forward the literal length to the generator.
1717
# maybe we should include the size in the type, like StaticArrays does?
1818
ptr = emit_shmem(T, Val(len))
19-
CuDeviceArray{T,N,AS.Shared}(ptr, dims)
19+
# XXX: 4GB ought to be enough shared memory for anybody
20+
CuDeviceArray{T,N,AS.Shared,Int32}(ptr, dims)
2021
end
2122
CuStaticSharedArray(::Type{T}, len::Integer) where {T} = CuStaticSharedArray(T, (len,))
2223

@@ -53,7 +54,8 @@ shared memory; in the case of a homogeneous multi-part buffer it is preferred to
5354
end
5455
end
5556
ptr = emit_shmem(T) + offset
56-
CuDeviceArray{T,N,AS.Shared}(ptr, dims)
57+
# XXX: 4GB ought to be enough shared memory for anybody
58+
CuDeviceArray{T,N,AS.Shared,Int32}(ptr, dims)
5759
end
5860
Base.@propagate_inbounds CuDynamicSharedArray(::Type{T}, len::Integer, offset) where {T} =
5961
CuDynamicSharedArray(T, (len,), offset)

Diff for: src/device/random.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import RandomNumbers
1717
}
1818
attributes #0 = { alwaysinline }
1919
""", "entry"), LLVMPtr{UInt32, AS.Shared}, Tuple{})
20-
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
20+
CuDeviceArray{UInt32,1,AS.Shared,Int32}(ptr, (32,))
2121
end
2222

2323
# shared memory with per-warp counters, incremented when generating numbers
@@ -31,7 +31,7 @@ end
3131
}
3232
attributes #0 = { alwaysinline }
3333
""", "entry"), LLVMPtr{UInt32, AS.Shared}, Tuple{})
34-
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
34+
CuDeviceArray{UInt32,1,AS.Shared,Int32}(ptr, (32,))
3535
end
3636

3737
@device_override Random.make_seed() = clock(UInt32)
@@ -190,7 +190,7 @@ end
190190
for var in [:ki, :wi, :fi, :ke, :we, :fe]
191191
val = getfield(Random, var)
192192
gpu_var = Symbol("gpu_$var")
193-
arr_typ = :(CuDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant})
193+
arr_typ = :(CuDeviceArray{$(eltype(val)),$(ndims(val)),AS.Constant,Int32})
194194
@eval @inline @generated function $gpu_var()
195195
ptr = emit_constant_array($(QuoteNode(var)), $val)
196196
Expr(:call, $arr_typ, ptr, $(size(val)))

Diff for: test/codegen.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ end
153153
return
154154
end
155155

156-
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{2,CuDeviceArray{Float32,1,AS.Global}}))
156+
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{2,CuDeviceArray{Float32,1,AS.Global,Int32}}))
157157
@test !occursin(".local", asm)
158158
end
159159

Diff for: test/device/intrinsics/math.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ using SpecialFunctions
143143
@inbounds b[], c[] = @fastmath sincos(a[])
144144
return
145145
end
146-
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{3,CuDeviceArray{Float32,1,AS.Global}}))
146+
asm = sprint(io->CUDA.code_ptx(io, kernel, NTuple{3,CuDeviceArray{Float32,1,AS.Global,Int32}}))
147147
@assert contains(asm, "sin.approx.f32")
148148
@assert contains(asm, "cos.approx.f32")
149149
@assert !contains(asm, "__nv") # from libdevice

Diff for: test/device/intrinsics/wmma.jl

+8-8
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ map_ptx_to_jl_frag = Dict(
77
"s32" => Int32(42),
88
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
99
"f32" => Float32(42)
10-
)
10+
)
1111
# Return specific matrix shape given operation configuration
1212
function get_array_shape(mat, mnk, layout)
1313
if !(mat in ["a","b","c","d"])
@@ -46,13 +46,13 @@ end
4646
# Type-dependent variables
4747
array_ty = CUDA.WMMA.map_ptx_to_jl_array[elem_type]
4848
expected = map_ptx_to_jl_frag[elem_type]
49-
49+
5050
# Address-space dependent variables
5151
do_shared_test = (addr_space == "_shared")
5252

5353
# Get the function name
5454
func = Symbol("llvm_wmma_load_$(mat)_$(layout)_$(shape)$(addr_space)_stride_$(elem_type)")
55-
55+
5656
input_shape = get_array_shape(mat, mnk, layout)
5757
input = array_ty(42) * ones(array_ty, input_shape)
5858
input_dev = CuArray(input)
@@ -96,7 +96,7 @@ end
9696
elem_type in ops[3],
9797
addr_space in ["", "_global", "_shared"],
9898
stride in ["stride"]
99-
99+
100100
# Skip all but d matrices
101101
if mat != "d"
102102
continue
@@ -171,7 +171,7 @@ end
171171
# Int/subint mma functions are distinguished by the a/b element type
172172
mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
173173
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
174-
mma_func = getfield(Main, mma_sym)
174+
mma_func = getfield(Main, mma_sym)
175175
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
176176

177177
a_shape = get_array_shape("a", mnk, a_layout)
@@ -207,7 +207,7 @@ end
207207
# Alter test depending on a/b element Type
208208
if ab_ty == Float16
209209
@test new_a * new_b + c Array(d_dev) rtol=Base.rtoldefault(Float16)
210-
else # Cast a and b to prevent UInt8 rollover of resultant data
210+
else # Cast a and b to prevent UInt8 rollover of resultant data
211211
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
212212
end
213213
end
@@ -322,7 +322,7 @@ end
322322
return
323323
end
324324

325-
ptx = sprint(io -> CUDA.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDA.AS.Global},)))
325+
ptx = sprint(io -> CUDA.code_ptx(io, kernel, (CuDeviceArray{Float32,1,CUDA.AS.Global,Int32},)))
326326

327327
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
328328
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.global.f32", ptx)
@@ -344,4 +344,4 @@ end
344344
@test !occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.f32", ptx)
345345
@test occursin(r"wmma.store.d.sync(.aligned)?.col.m16n16k16.shared.f32", ptx)
346346
end
347-
end
347+
end

0 commit comments

Comments
 (0)