-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathinit.jl
59 lines (45 loc) · 1.48 KB
/
init.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
const libcudnn = CuArrays.libcudnn
const Cptr = Ptr{Void}
macro cuda(lib,fun,x...) # give an error if library missing, or if error code!=0
if libcudnn ≠ nothing
fx = Expr(:call, :ccall, ("$fun",libcudnn), :UInt32, x...)
msg = "$lib.$fun error "
err = gensym()
# esc(:(if ($err=$fx) != 0; warn($msg, $err); Base.show_backtrace(STDOUT, backtrace()); end))
esc(:(if ($err=$fx) != 0; error($msg, $err); end; @gs))
else
Expr(:call,:error,"Cannot find lib$lib, please install it and rerun Pkg.build(\"Knet\").")
end
end
macro cuda1(lib,fun,x...) # return -1 if library missing, error code if run
if libcudnn ≠ nothing
fx = Expr(:call, :ccall, ("$fun",libcudnn), :UInt32, x...)
err = gensym()
esc(:($err=$fx; @gs; $err))
else
-1
end
end
const CUDNN_VERSION = Ref{Int}(-1)
function cudnn_version()
if CUDNN_VERSION[] == -1
CUDNN_VERSION[] = Int(ccall((:cudnnGetVersion,libcudnn),Csize_t,()))
end
return CUDNN_VERSION[]
end
const CUDNN_HANDLES = Array{Ptr{Void}}(0)
function cudnn_create_handle()
handleP = Cptr[0]
@cuda(cudnn, cudnnCreate, (Ptr{Cptr},), handleP)
handle = handleP[1]
atexit(()->@cuda(cudnn,cudnnDestroy,(Cptr,), handle))
return handle
end
# TODO: handle multiple GPUs
function cudnnhandle()
if isempty(CUDNN_HANDLES)
handle = cudnn_create_handle()
push!(CUDNN_HANDLES, handle)
end
return CUDNN_HANDLES[1]
end