Skip to content

Commit 1fd86b7

Browse files
committed
Optimize intrinsics to avoid exceptions.
1 parent b4c57dc commit 1fd86b7

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

src/device/intrinsics/math.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,25 +150,25 @@ end
150150

151151

152152
@device_function clz(x::Union{Int32, UInt32}) =
153-
assume(within(UInt32(0), UInt32(32)),
154-
ccall("extern __nv_clz", llvmcall, Int32, (UInt32,), x))
153+
assume(within(0u32, 32u32),
154+
ccall("extern __nv_clz", llvmcall, UInt32, (UInt32,), x))
155155
@device_function clz(x::Union{Int64, UInt64}) =
156-
assume(within(UInt64(0), UInt64(64)),
157-
ccall("extern __nv_clzll", llvmcall, Int32, (UInt64,), x))
156+
assume(within(0u32, 64u32),
157+
ccall("extern __nv_clzll", llvmcall, UInt32, (UInt64,), x))
158158

159159
@device_function ffs(x::Union{Int32, UInt32}) =
160-
assume(within(UInt32(0), UInt32(32)),
161-
ccall("extern __nv_ffs", llvmcall, Int32, (UInt32,), x))
160+
assume(within(0u32, 32u32),
161+
ccall("extern __nv_ffs", llvmcall, UInt32, (UInt32,), x))
162162
@device_function ffs(x::Union{Int64, UInt64}) =
163-
assume(within(UInt64(0), UInt64(64)),
164-
ccall("extern __nv_ffsll", llvmcall, Int32, (UInt64,), x))
163+
assume(within(0u32, 64u32),
164+
ccall("extern __nv_ffsll", llvmcall, UInt32, (UInt64,), x))
165165

166166
@device_function popc(x::Union{Int32, UInt32}) =
167-
assume(within(UInt32(0), UInt32(32)),
168-
ccall("extern __nv_popc", llvmcall, Int32, (UInt32,), x))
167+
assume(within(0u32, 32u32),
168+
ccall("extern __nv_popc", llvmcall, UInt32, (UInt32,), x))
169169
@device_function popc(x::Union{Int64, UInt64}) =
170-
assume(within(UInt64(0), UInt64(64)),
171-
ccall("extern __nv_popcll", llvmcall, Int32, (UInt64,), x))
170+
assume(within(0u32, 64u32),
171+
ccall("extern __nv_popcll", llvmcall, UInt32, (UInt64,), x))
172172

173173
@device_function byte_perm(x::Union{Int32, UInt32}, y::Union{Int32, UInt32}, z::Union{Int32, UInt32}) =
174174
ccall("extern __nv_byte_perm", llvmcall, Int32, (UInt32, UInt32, UInt32), x, y, z)

src/device/intrinsics/warp_shuffle.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# TODO: does not work on sub-word (ie. Int16) or non-word divisible sized types
66

77
# TODO: these functions should dispatch based on the actual warp size
8-
const ws = Int32(32)
8+
const ws = 32u32
99

1010

1111
# core intrinsics
@@ -18,7 +18,7 @@ const ws = Int32(32)
1818
for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
1919
("_down", :down, UInt32(0x1f), src->src),
2020
("_xor", :bfly, UInt32(0x1f), src->src),
21-
("", :idx, UInt32(0x1f), src->:($src-1)))
21+
("", :idx, UInt32(0x1f), src->:($src-(1u32))))
2222
fname = Symbol("shfl$(name)_sync")
2323
@eval export $fname
2424

@@ -28,8 +28,8 @@ for (name, mode, mask, offset) in (("_up", :up, UInt32(0x00), src->src),
2828
@eval begin
2929
@inline $fname(mask, val::$T, src, width=$ws) =
3030
ccall($intrinsic, llvmcall, $T,
31-
(UInt32, $T, UInt32, UInt32),
32-
mask, val, $(offset(:src)), pack(width, $mask))
31+
(UInt32, $T, UInt32, UInt32),
32+
mask, val, $(offset(:src)), pack(width, $mask))
3333
end
3434
end
3535
end

src/device/utils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# helper type for writing Int32 literals
44
# TODO: upstream this
55
struct Literal{T} end
6-
Base.:(*)(x, ::Type{Literal{T}}) where {T} = T(x)
6+
Base.:(*)(x, ::Type{Literal{T}}) where {T} = x%T
77
const i32 = Literal{Int32}
8+
const u32 = Literal{UInt32}
89

910
# local method table for device functions
1011
@static if isdefined(Base.Experimental, Symbol("@overlay"))

0 commit comments

Comments
 (0)