@@ -27,7 +27,7 @@ CRC.@non_differentiable get_batchnorm_statistics(::Any...)
2727function batchnorm(x:: AbstractArray{xT, N} , γ:: Optional{<:AbstractVector} ,
2828 β:: Optional{<:AbstractVector} , rμ:: Optional{<:AbstractVector} ,
2929 rσ²:: Optional{<:AbstractVector} , training:: StaticBool , act:: F ,
30- momentum:: Real , ϵ:: Real ) where {F, xT, N}
30+ momentum, ϵ) where {F, xT, N}
3131 (μ, σ²), (rμ, rσ²) = compute_batch_statistics(
3232 x, reshape_norm_dims(x, rμ), reshape_norm_dims(x, rσ²),
3333 batchnorm_reduce_dims(x), training, momentum)
3737function batchnorm_affine_normalize(
3838 act:: F , x:: AbstractArray{xT, N} , μ:: AbstractArray{μT, N} ,
3939 σ²:: AbstractArray{σ²T, N} , γ:: Optional{<:AbstractVector} ,
40- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT, μT, σ²T, N}
40+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT, μT, σ²T, N}
4141 return batchnorm_affine_normalize(
4242 internal_operation_mode((x, μ, σ², γ, β)), act, x, μ, σ², γ, β, ϵ)
4343end
4444
4545function batchnorm_affine_normalize(
4646 :: GenericBroadcastOp , act:: F , x:: AbstractArray{xT, N} , μ:: AbstractArray{μT, N} ,
4747 σ²:: AbstractArray{σ²T, N} , γ:: Optional{<:AbstractVector} ,
48- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT, μT, σ²T, N}
48+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT, μT, σ²T, N}
4949 return affine_normalize(
5050 act, x, μ, σ², reshape_norm_dims(x, γ), reshape_norm_dims(x, β), ϵ)
5151end
@@ -54,7 +54,7 @@ function batchnorm_affine_normalize(
5454 opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{xT, N} ,
5555 μ:: AbstractArray{μT, N} , σ²:: AbstractArray{σ²T, N} ,
5656 γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} ,
57- ϵ:: Real ) where {F, xT, μT, σ²T, N}
57+ ϵ) where {F, xT, μT, σ²T, N}
5858 x′ = reshape(x, :, size(x, N - 1 ), size(x, N))
5959 return reshape(
6060 batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ),
6464@stable default_mode= " disable" function batchnorm_affine_normalize_internal(
6565 opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{xT, 3} ,
6666 μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
67- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, xT}
67+ β:: Optional{<:AbstractVector} , ϵ) where {F, xT}
6868 y = similar(x,
6969 promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
7070 safe_eltype(γ), safe_eltype(β)))
7575function batchnorm_affine_normalize_internal!(
7676 y:: AbstractArray{yT, 3} , opmode:: LoopedArrayOp , act:: F , x:: AbstractArray{xT, 3} ,
7777 μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
78- β:: Optional{<:AbstractVector} , ϵ:: Real ,
78+ β:: Optional{<:AbstractVector} , ϵ,
7979 γ′:: Optional{<:AbstractVector} = nothing ) where {F, xT, yT}
8080 N = size(y, 2 )
8181 γ′ = γ′ === nothing ?
225225function batchnorm_affine_normalize_internal!(
226226 y:: AbstractArray{yT, 3} , :: GPUBroadcastOp , act:: F , x:: AbstractArray{xT, 3} ,
227227 μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
228- β:: Optional{<:AbstractVector} , ϵ:: Real ,
228+ β:: Optional{<:AbstractVector} , ϵ,
229229 γ′:: Optional{<:AbstractVector} = nothing ) where {F, xT, yT}
230230 backend = KA. get_backend(y)
231231 run_ka_kernel(
@@ -278,7 +278,7 @@ function CRC.rrule(
278278 cfg:: RuleConfig{>:HasReverseMode} , :: typeof (batchnorm_affine_normalize_internal),
279279 opmode:: AbstractInternalArrayOpMode , act:: F , x:: AbstractArray{T, N} ,
280280 μ:: AbstractVector , σ²:: AbstractVector , γ:: Optional{<:AbstractVector} ,
281- β:: Optional{<:AbstractVector} , ϵ:: Real ) where {F, T, N}
281+ β:: Optional{<:AbstractVector} , ϵ) where {F, T, N}
282282 y = similar(x,
283283 promote_type(safe_eltype(x), safe_eltype(μ), safe_eltype(σ²),
284284 safe_eltype(γ), safe_eltype(β)))
304304
305305function ∇batchnorm_affine_normalize(opmode:: LoopedArrayOp , ∂y:: AbstractArray{∂yT, 3} ,
306306 x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector ,
307- γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ:: Real ,
307+ γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ,
308308 γ′:: AbstractVector ) where {∂yT, xT}
309309 ∂x, ∂μ, ∂σ² = similar(x), similar(μ), similar(σ²)
310310 ∂γ = γ === nothing ? nothing : similar(γ)
@@ -322,7 +322,7 @@ function ∇batchnorm_affine_normalize_cpu!(
322322 ∂x:: AbstractArray{∂xT, 3} , ∂μ:: AbstractVector{∂μT} ,
323323 ∂σ²:: AbstractVector{∂σ²T} , :: Nothing , :: Nothing , ∂y:: AbstractArray{∂yT, 3} ,
324324 x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector , :: Nothing ,
325- ϵ:: Real , γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
325+ ϵ, γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂yT, xT}
326326 half = eltype(∂σ²)(0.5 )
327327
328328 fill!(∂μ, 0 )
@@ -361,7 +361,7 @@ function ∇batchnorm_affine_normalize_cpu!(
361361 ∂x:: AbstractArray{∂xT, 3} , ∂μ:: AbstractVector{∂μT} ,
362362 ∂σ²:: AbstractVector{∂σ²T} , ∂γ:: AbstractVector{∂γT} ,
363363 ∂β:: AbstractVector{∂βT} , ∂y:: AbstractArray{∂yT, 3} , x:: AbstractArray{xT, 3} ,
364- μ:: AbstractVector , σ²:: AbstractVector , γ:: AbstractVector , ϵ:: Real ,
364+ μ:: AbstractVector , σ²:: AbstractVector , γ:: AbstractVector , ϵ,
365365 γ′:: AbstractVector ) where {∂xT, ∂μT, ∂σ²T, ∂γT, ∂βT, ∂yT, xT}
366366 half = eltype(∂σ²)(0.5 )
367367
406406function ∇batchnorm_affine_normalize(
407407 opmode:: AbstractInternalArrayOpMode , ∂y:: AbstractArray{∂yT, 3} ,
408408 x:: AbstractArray{xT, 3} , μ:: AbstractVector , σ²:: AbstractVector ,
409- γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ:: Real ,
409+ γ:: Optional{<:AbstractVector} , β:: Optional{<:AbstractVector} , ϵ,
410410 γ′:: AbstractVector ) where {∂yT, xT}
411411 ∂x, ∂σ² = similar(x), similar(σ², size(x))
412412 ∂γ = γ === nothing ? nothing : similar(γ, size(x))
@@ -425,7 +425,7 @@ function ∇batchnorm_affine_normalize!(
425425 ∂x:: AbstractArray{∂xT, 3} , ∂σ²:: AbstractArray{∂σ²T, 3} ,
426426 ∂γ:: Optional{<:AbstractArray{<:Any, 3}} , :: GPUBroadcastOp ,
427427 ∂y:: AbstractArray{∂yT, 3} , x:: AbstractArray{xT, 3} , μ:: AbstractVector ,
428- σ²:: AbstractVector , γ:: Optional{<:AbstractVector} , ϵ:: Real ,
428+ σ²:: AbstractVector , γ:: Optional{<:AbstractVector} , ϵ,
429429 γ′:: AbstractVector ) where {∂xT, ∂σ²T, ∂yT, xT}
430430 backend = KA. get_backend(∂x)
431431 run_ka_kernel(
0 commit comments