-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathbatchnorm.jl
140 lines (118 loc) · 4.86 KB
/
batchnorm.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# WARNING: I still have doubts this works correctly, use with caution
mutable struct BND
ptr::Ptr{Void}
end
Base.unsafe_convert(::Type{Cptr}, bnd::BND) = bnd.ptr
function BND(xtd::TD, mode::Cuint)
d = Ref{Cptr}(0)
@cuda(cudnn, cudnnCreateTensorDescriptor, (Ptr{Cptr},), d)
@cuda(cudnn, cudnnDeriveBNTensorDescriptor,
(Ref{Void}, Cptr, Cuint),
d[], xtd, mode)
return BND(d[])
end
function batchnorm_param_size(x::CuArray, mode::UInt32)
xsz = size(x)
if mode == CUDNN_BATCHNORM_PER_ACTIVATION
psz = (1, xsz[3], xsz[2], xsz[1]) # 1xCxHxW
elseif mode == CUDNN_BATCHNORM_SPATIAL
psz = (1, xsz[3], 1, 1) # 1xCx1x1
else
error("Mode $mode is not supported")
end
return psz
end
mutable struct BatchNormState{T}
bnScale::CuArray{T,4}
bnBias::CuArray{T,4}
resultRunningMean::CuArray{T,4}
resultRunningVariance::CuArray{T,4}
resultSaveMean::CuArray{T,4}
resultSaveInvVariance::CuArray{T,4}
end
function BatchNormState(x::CuArray{T,4}, mode=CUDNN_BATCHNORM_SPATIAL) where T
psz = batchnorm_param_size(x, mode)
bnScale = CuArray(randn(T, psz))
bnBias = CuArray(randn(T, psz))
resultRunningMean = CuArray(zeros(T, psz))
resultRunningVariance = CuArray(zeros(T, psz))
resultSaveMean = CuArray(zeros(T, psz))
resultSaveInvVariance = CuArray(zeros(T, psz))
return BatchNormState(bnScale, bnBias, resultRunningMean, resultRunningVariance,
resultSaveMean, resultSaveInvVariance)
end
function batchnorm_train!(y::CuArray{T,4}, x::CuArray{T,4}, s::BatchNormState;
handle=cudnnhandle(), alpha=1, beta=0,
exponentialAverageFactor=T(1), mode=CUDNN_BATCHNORM_SPATIAL,
epsilon=CUDNN_BN_MIN_EPSILON) where T
xtd = TD(x)
ytd = TD(y)
bnScaleBiasMeanVarDesc = BND(xtd, mode)
@cuda(cudnn, cudnnBatchNormalizationForwardTraining,
(Cptr, UInt32, Cptr, Cptr, Cptr, Cptr, Cptr, Cptr,
Cptr, Cptr, Cptr, Cdouble,
Cptr, Cptr, Cdouble,
Cptr, Cptr),
handle, mode, Ref(T(alpha)), Ref(T(beta)), xtd, x, ytd, y,
bnScaleBiasMeanVarDesc, s.bnScale, s.bnBias, exponentialAverageFactor,
s.resultRunningMean, s.resultRunningVariance, epsilon,
s.resultSaveMean, s.resultSaveInvVariance)
end
function batchnorm_train(x::CuArray{T,4}, s::BatchNormState; opts...) where T
y = similar(x)
batchnorm_train!(y, x, s; opts...)
return y
end
function batchnorm_infer!(y::CuArray{T,4}, x::CuArray{T,4}, s::BatchNormState;
handle=cudnnhandle(), alpha=1, beta=0,
exponentialAverageFactor=T(1), mode=CUDNN_BATCHNORM_SPATIAL,
epsilon=CUDNN_BN_MIN_EPSILON) where T
xtd = TD(x)
ytd = TD(y)
xsz = size(x)
bnScaleBiasMeanVarDesc = BND(xtd, mode)
estimatedMean = s.resultRunningMean
estimatedVariance = s.resultRunningVariance
@cuda(cudnn, cudnnBatchNormalizationForwardInference,
(Cptr, UInt32, Cptr, Cptr, Cptr, Cptr, Cptr, Cptr,
Cptr, Cptr, Cptr,
Cptr, Cptr, Cdouble),
handle, mode, Ref(T(alpha)), Ref(T(beta)), xtd, x, ytd, y,
bnScaleBiasMeanVarDesc, s.bnScale, s.bnBias,
estimatedMean, estimatedVariance, epsilon)
end
function batchnorm_infer(x::CuArray{T,4}, s::BatchNormState; opts...) where T
y = similar(x)
batchnorm_infer!(y, x, s; opts...)
return y
end
function batchnorm_grad!(dx::CuArray{T,4}, x::CuArray{T,4}, dy::CuArray{T,4}, s::BatchNormState;
handle=cudnnhandle(), alpha_data=1, beta_data=0,
alpha_param=1, beta_param=0,
exponentialAverageFactor=T(1), mode=CUDNN_BATCHNORM_SPATIAL,
epsilon=CUDNN_BN_MIN_EPSILON) where T
xtd = TD(x)
dytd = TD(dy)
dxtd = TD(dx)
xsz = size(x)
bnScaleBiasMeanVarDesc = BND(xtd, mode)
# should we update bnScale & bnBias manually or cuDNN does it automatically?
resultBnScaleDiff = similar(s.bnScale)
resultBnBiasDiff = similar(s.bnScale)
savedMean = s.resultSaveMean
savedInvVariance = s.resultSaveInvVariance
@cuda(cudnn, cudnnBatchNormalizationBackward,
(Cptr, UInt32, Cptr, Cptr, Cptr, Cptr,
Cptr, Cptr, Cptr, Cptr, Cptr, Cptr,
Cptr, Cptr, Cptr, Cptr,
Cdouble, Cptr, Cptr),
handle, mode, Ref(T(alpha_data)), Ref(T(beta_data)), Ref(T(alpha_param)), Ref(T(beta_param)),
xtd, x, dytd, dy, dxtd, dx,
bnScaleBiasMeanVarDesc, s.bnScale, resultBnScaleDiff, resultBnBiasDiff,
epsilon, savedMean, savedInvVariance)
end
function batchnorm_grad(x::CuArray{T,4}, dy::CuArray{T,4}, s::BatchNormState; opts...) where T
dx = similar(x)
batchnorm_grad!(dx, x, dy, s; opts...)
return dx
end