@@ -134,41 +134,40 @@ end
134134parameterlength(l:: BatchNorm ) = ifelse(has_affine(l), l. chs * 2 , 0 )
135135statelength(l:: BatchNorm ) = ifelse(has_track_stats(l), l. chs * 2 , 0 ) + 1
136136
137- function (BN :: BatchNorm )(x :: AbstractArray , ps, st :: NamedTuple )
137+ function apply( :: Type{<:BatchNorm} , bn, x :: AbstractArray )
138138 CRC. ignore_derivatives() do
139- if st . training isa Val{true }
139+ if bn . training isa Val{true }
140140 @argcheck size(x, ndims(x)) != 1 " Batch size for BatchNorm cannot be 1 during training"
141141 end
142142 end
143143
144- x′ = match_eltype(BN, ps, st, x)
145- σ = NNlib. fast_act(BN. activation, x′)
144+ # XXX : restore `match_eltype` support
145+ # x′ = match_eltype(bn, ps, st, x)
146+ x′ = x
147+ σ = NNlib. fast_act(bn. activation, x′)
146148 y, stats = batchnorm(
147149 x′,
148- safe_getproperty(ps , Val(:scale)),
149- safe_getproperty(ps , Val(:bias)),
150- safe_getproperty(st , Val(:running_mean)),
151- safe_getproperty(st , Val(:running_var)),
152- st . training,
150+ safe_getproperty(bn , Val(:scale)),
151+ safe_getproperty(bn , Val(:bias)),
152+ safe_getproperty(bn , Val(:running_mean)),
153+ safe_getproperty(bn , Val(:running_var)),
154+ bn . training,
153155 σ,
154- convert(unwrapped_eltype(x′), BN . momentum),
155- convert(unwrapped_eltype(x′), BN . epsilon),
156+ convert(unwrapped_eltype(x′), bn . momentum),
157+ convert(unwrapped_eltype(x′), bn . epsilon),
156158 )
157- return y, update_batchnorm_state(BN, st, stats)
159+ update_batchnorm_state!(bn, stats)
160+ return y
158161end
159162
160- function update_batchnorm_state(BN:: BatchNorm , st:: NamedTuple , stats)
161- has_track_stats(BN) && return merge(
162- st,
163- (;
164- running_mean= Utils. vec(stats. running_mean),
165- running_var= Utils. vec(stats. running_var),
166- ),
167- )
168- return st
163+ function update_batchnorm_state!(bn, stats)
164+ has_track_stats(bn) || return nothing
165+ bn. running_mean = Utils. vec(stats. running_mean)
166+ bn. running_var = Utils. vec(stats. running_var)
167+ return nothing
169168end
170169
171- CRC. @non_differentiable update_batchnorm_state(:: Any... )
170+ CRC. @non_differentiable update_batchnorm_state! (:: Any... )
172171
173172function Base. show(io:: IO , l:: BatchNorm )
174173 print(io, " BatchNorm($(l. chs) " )
@@ -280,18 +279,20 @@ end
280279
281280parameterlength(l:: GroupNorm ) = has_affine(l) ? (l. chs * 2 ) : 0
282281
283- function (GN:: GroupNorm )(x:: AbstractArray , ps, st:: NamedTuple )
284- x′ = match_eltype(GN, ps, st, x)
285- σ = NNlib. fast_act(GN. activation, x′)
282+ function apply(:: Type{<:GroupNorm} , gn, x:: AbstractArray )
283+ # XXX : restore `match_eltype` support
284+ # x′ = match_eltype(GN, ps, st, x)
285+ x′ = x
286+ σ = NNlib. fast_act(gn. activation, x′)
286287 y = groupnorm(
287288 x′,
288- safe_getproperty(ps , Val(:scale)),
289- safe_getproperty(ps , Val(:bias)),
290- GN . groups,
289+ safe_getproperty(gn , Val(:scale)),
290+ safe_getproperty(gn , Val(:bias)),
291+ gn . groups,
291292 σ,
292- convert(unwrapped_eltype(x′), GN . epsilon),
293+ convert(unwrapped_eltype(x′), gn . epsilon),
293294 )
294- return y, st
295+ return y
295296end
296297
297298function Base. show(io:: IO , l:: GroupNorm )
@@ -434,35 +435,34 @@ end
434435parameterlength(l:: InstanceNorm ) = ifelse(has_affine(l), l. chs * 2 , 0 )
435436statelength(l:: InstanceNorm ) = ifelse(has_track_stats(l), l. chs * 2 , 0 ) + 1
436437
437- function (IN:: InstanceNorm )(x:: AbstractArray , ps, st:: NamedTuple )
438- x′ = match_eltype(IN, ps, st, x)
439- σ = NNlib. fast_act(IN. activation, x′)
438+ function apply(:: Type{<:InstanceNorm} , in, x:: AbstractArray )
439+ # XXX : restore `match_eltype` support
440+ # x′ = match_eltype(IN, ps, st, x)
441+ x′ = x
442+ σ = NNlib. fast_act(in. activation, x′)
440443 y, stats = instancenorm(
441444 x′,
442- safe_getproperty(ps , Val(:scale)),
443- safe_getproperty(ps , Val(:bias)),
444- safe_getproperty(st , Val(:running_mean)),
445- safe_getproperty(st , Val(:running_var)),
446- st . training,
445+ safe_getproperty(in , Val(:scale)),
446+ safe_getproperty(in , Val(:bias)),
447+ safe_getproperty(in , Val(:running_mean)),
448+ safe_getproperty(in , Val(:running_var)),
449+ in . training,
447450 σ,
448- convert(unwrapped_eltype(x′), IN . momentum),
449- convert(unwrapped_eltype(x′), IN . epsilon),
451+ convert(unwrapped_eltype(x′), in . momentum),
452+ convert(unwrapped_eltype(x′), in . epsilon),
450453 )
451- return y, update_instancenorm_state(IN, st, stats)
454+ update_instancenorm_state!(in, stats)
455+ return y
452456end
453457
454- function update_instancenorm_state(IN:: InstanceNorm , st:: NamedTuple , stats)
455- has_track_stats(IN) && return merge(
456- st,
457- (;
458- running_mean= Utils. vec(stats. running_mean),
459- running_var= Utils. vec(stats. running_var),
460- ),
461- )
462- return st
458+ function update_instancenorm_state!(in:: InstanceNorm , stats)
459+ has_track_stats(in) || return nothing
460+ in. running_mean = Utils. vec(stats. running_mean)
461+ in. running_var = Utils. vec(stats. running_var)
462+ return nothing
463463end
464464
465- CRC. @non_differentiable update_instancenorm_state(:: Any... )
465+ CRC. @non_differentiable update_instancenorm_state! (:: Any... )
466466
467467function Base. show(io:: IO , l:: InstanceNorm )
468468 print(io, " InstanceNorm($(l. chs) " )
@@ -554,18 +554,20 @@ function initialparameters(rng::AbstractRNG, ln::LayerNorm)
554554 return (;)
555555end
556556
557- function (l:: LayerNorm )(x:: AbstractArray , ps, st:: NamedTuple )
558- x′ = match_eltype(l, ps, st, x)
557+ function apply(:: Type{<:LayerNorm} , l, x:: AbstractArray )
558+ # XXX : restore `match_eltype` support
559+ # x′ = match_eltype(l, ps, st, x)
560+ x′ = x
559561 σ = NNlib. fast_act(l. activation, x′)
560562 y = layernorm(
561563 x′,
562- safe_getproperty(ps , Val(:scale)),
563- safe_getproperty(ps , Val(:bias)),
564+ safe_getproperty(l , Val(:scale)),
565+ safe_getproperty(l , Val(:bias)),
564566 σ,
565567 l. dims,
566568 convert(unwrapped_eltype(x′), l. epsilon),
567569 )
568- return y, st
570+ return y
569571end
570572
571573function Base. show(io:: IO , l:: LayerNorm )
@@ -781,17 +783,14 @@ parameterlength(l::RMSNorm) = has_affine(l) ? prod(l.normalized_shape) : 0
781783
782784# specialization on `NT` is important here, else we won't be able to infer the
783785# correct eltype of the output.
784- function (rms :: RMSNorm )(x :: AbstractArray{T } , ps, st :: NamedTuple ) where {T}
786+ function apply( :: Type{<:RMSNorm } , rms, x :: AbstractArray{T} ) where {T}
785787 # Don't use `match_eltype` here, since often times the eltypes are intentionally
786788 # different.
787789 ϵ = T(rms. epsilon)
788790 mean_sq = mean(abs2, x; dims= 1 : length(rms. normalized_shape))
789791
790792 if has_affine(rms)
791- norm_x = @. (x * LuxOps. rsqrt(mean_sq + ϵ)) * ps. scale
792- else
793- norm_x = @. x * LuxOps. rsqrt(mean_sq + ϵ)
793+ return @. (x * LuxOps. rsqrt(mean_sq + ϵ)) * ps. scale
794794 end
795-
796- return norm_x, st
795+ return @. x * LuxOps. rsqrt(mean_sq + ϵ)
797796end
0 commit comments