-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathsoftmax.jl
33 lines (28 loc) · 1.29 KB
/
softmax.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
function softmax4d!(y::CuArray{T}, x::CuArray{T};
handle=cudnnhandle(),
algorithm=CUDNN_SOFTMAX_ACCURATE, # or CUDNN_SOFTMAX_FAST
mode=CUDNN_SOFTMAX_MODE_INSTANCE, # or CUDNN_SOFTMAX_MODE_CHANNEL
alpha=1.0, beta=0.0) where T
@cuda(cudnn, cudnnSoftmaxForward,
(Cptr, Cuint, Cuint, Ptr{T}, Cptr, Ptr{T}, Ptr{T}, Cptr, Ptr{T}),
handle, algorithm, mode, Ref(T(alpha)), TD(x), x, Ref(T(beta)), TD(y), y)
return y
end
softmax4d(x::CuArray{T}) where T = softmax4d!(similar(x), x)
function softmax4d_grad!(dx::CuArray{T}, y::CuArray{T}, dy::CuArray{T};
handle=cudnnhandle(),
algorithm=CUDNN_SOFTMAX_ACCURATE, # or CUDNN_SOFTMAX_FAST
mode=CUDNN_SOFTMAX_MODE_INSTANCE, # or CUDNN_SOFTMAX_MODE_CHANNEL
alpha=1.0, beta=0.0) where T
@cuda(cudnn, cudnnSoftmaxBackward,
(Cptr, Cuint, Cuint,
Ptr{T}, Cptr, Ptr{T},
Cptr, Ptr{T},
Ptr{T}, Cptr, Ptr{T}),
handle, algorithm, mode,
Ref(T(alpha)), TD(y), y,
TD(dy), dy,
Ref(T(beta)), TD(dx), dx)
return dx
end
softmax4d_grad(y::CuArray{T}, dy::CuArray{T}) where T = softmax4d_grad!(similar(dy), y, dy)