Skip to content

Commit a1c648a

Browse files
committed
feat: rework more layers
1 parent 055a6d9 commit a1c648a

File tree

4 files changed

+91
-99
lines changed

4 files changed

+91
-99
lines changed

src/extended_ops.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ for (op, field) in (
266266
:track_stats => :track_stats,
267267
:train_state => :train_state,
268268
)
269-
@eval function $(Symbol(:has_, op))(l::AbstractLuxLayer)
269+
@eval function $(Symbol(:has_, op))(l)
270270
res = known(safe_getproperty(l, Val($(Meta.quot(field)))))
271271
return ifelse(res === nothing, false, res)
272272
end

src/layers/embedding.jl

Lines changed: 24 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -57,32 +57,29 @@ end
5757

5858
outputsize(e::Embedding, _, ::AbstractRNG) = (e.out_dims,)
5959

60-
function (e::Embedding)(x::Union{Number,AbstractVector}, ps, st::NamedTuple)
60+
function apply(::Type{<:Embedding}, e, x::Union{Number,AbstractVector})
6161
@argcheck Utils.eltype(x) <: Integer
62-
return ps.weight[:, x], st
62+
return ps.weight[:, x]
6363
end
64-
function (e::Embedding)(x::AbstractArray, ps, st::NamedTuple)
64+
function apply(T::Type{<:Embedding}, e, x::AbstractArray)
6565
@argcheck Utils.eltype(x) <: Integer
66-
y, stₙ = e(Utils.vec(x), ps, st)
67-
return reshape(y, :, size(x)...), stₙ
66+
y = apply(T, e, Utils.vec(x))
67+
return reshape(y, :, size(x)...)
6868
end
69-
function (e::Embedding)(x::NTuple{N,T}, ps, st::NamedTuple) where {N,T}
69+
function apply(::Type{<:Embedding}, e, x::T...) where {T}
7070
@argcheck Utils.eltype(T) <: Integer
71-
return ps.weight[:, x...], st
71+
return ps.weight[:, x...]
7272
end
73-
function (e::Embedding)(x::NTuple{N,<:AbstractVector{T}}, ps, st::NamedTuple) where {N,T}
73+
74+
function apply(::Type{<:Embedding}, e, x::AbstractVector{T}...) where {T}
7475
@argcheck Utils.eltype(T) <: Integer
7576
@argcheck allequal(size, x) DimensionMismatch("Input vectors must have the same shape")
76-
return NNlib.gather(ps.weight, x...), st
77+
return NNlib.gather(ps.weight, x...)
7778
end
78-
function (e::Embedding)(x::NTuple{N,<:AbstractArray{T}}, ps, st::NamedTuple) where {N,T}
79-
@argcheck Utils.eltype(T) <: Integer
79+
function apply(T::Type{<:Embedding}, e, x::AbstractArray...)
8080
@argcheck allequal(size, x) DimensionMismatch("Input arrays must have the same shape")
81-
y, stₙ = e(vec.(x), ps, st)
82-
return reshape(y, :, size(first(x))...), stₙ
83-
end
84-
function (e::Embedding)(::Tuple{}, _, ::NamedTuple)
85-
throw(ArgumentError("Input tuple must contain at least one element"))
81+
y = apply(T, e, vec.(x)...)
82+
return reshape(y, :, size(first(x))...)
8683
end
8784

8885
@doc doc"""
@@ -145,10 +142,11 @@ function initialstates(::AbstractRNG, spe::SinusoidalPositionalEmbedding{T}) whe
145142
return (; sigmas)
146143
end
147144

148-
function (spe::SinusoidalPositionalEmbedding)(x::AbstractArray, ps, st::NamedTuple)
149-
y = reshape(match_eltype(spe, ps, st, x), 1, size(x)...) .* st.sigmas
150-
z = vcat(sin.(y), cos.(y)) .* spe.scale
151-
return z, st
145+
function apply(::Type{<:SinusoidalPositionalEmbedding}, spe, x::AbstractArray)
146+
# XXX: restore `match_eltype` support
147+
# x′ = match_eltype(spe, x)
148+
y = reshape(x, 1, size(x)...) .* spe.sigmas
149+
return vcat(sin.(y), cos.(y)) .* spe.scale
152150
end
153151

154152
"""
@@ -220,16 +218,14 @@ function initialstates(::AbstractRNG, rope::RotaryPositionalEmbedding)
220218
)
221219
end
222220

223-
function (rope::RotaryPositionalEmbedding)(
224-
x::AbstractArray{T,4}, ps, st::NamedTuple
225-
) where {T}
226-
y = apply_rotary_embedding(x, st.cos_cache, st.sin_cache; seq_dim=3)
227-
return y, st
221+
function apply(::Type{<:RotaryPositionalEmbedding}, rope, x::AbstractArray{T,4}) where {T}
222+
return apply_rotary_embedding(x, rope.cos_cache, rope.sin_cache; seq_dim=3)
228223
end
229224

230-
function (rope::RotaryPositionalEmbedding)((x, input_pos)::Tuple, ps, st::NamedTuple)
231-
y = apply_rotary_embedding(x, input_pos, st.cos_cache, st.sin_cache; seq_dim=3)
232-
return y, st
225+
function apply(
226+
::Type{<:RotaryPositionalEmbedding}, rope, x::AbstractArray{T,4}, input_pos
227+
) where {T}
228+
return apply_rotary_embedding(x, input_pos, rope.cos_cache, rope.sin_cache; seq_dim=3)
233229
end
234230

235231
## Functional variants since Qwen3 like models tend to share the same rotary embedding

src/layers/extension.jl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,11 @@ function Base.show(io::IO, ::MIME"text/plain", s::SimpleChainsLayer)
7676
return PrettyPrinting.print_wrapper_model(io, "SimpleChainsLayer", s.lux_layer)
7777
end
7878

79-
function (sc::SimpleChainsLayer)(x, ps, st)
80-
y = match_eltype(sc, ps, st, x)
81-
return (
82-
to_array(
83-
sc.to_array,
84-
apply_simple_chain(sc.layer, y, ps.params, MLDataDevices.get_device(x)),
85-
),
86-
st,
79+
function apply(::Type{<:SimpleChainsLayer}, sc, x)
80+
# XXX: restore `match_eltype` support
81+
# x′ = match_eltype(sc, x)
82+
return to_array(
83+
sc.to_array, apply_simple_chain(sc.layer, x, sc.params, MLDataDevices.get_device(x))
8784
)
8885
end
8986

src/layers/normalize.jl

Lines changed: 61 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -134,41 +134,40 @@ end
134134
parameterlength(l::BatchNorm) = ifelse(has_affine(l), l.chs * 2, 0)
135135
statelength(l::BatchNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1
136136

137-
function (BN::BatchNorm)(x::AbstractArray, ps, st::NamedTuple)
137+
function apply(::Type{<:BatchNorm}, bn, x::AbstractArray)
138138
CRC.ignore_derivatives() do
139-
if st.training isa Val{true}
139+
if bn.training isa Val{true}
140140
@argcheck size(x, ndims(x)) != 1 "Batch size for BatchNorm cannot be 1 during training"
141141
end
142142
end
143143

144-
x′ = match_eltype(BN, ps, st, x)
145-
σ = NNlib.fast_act(BN.activation, x′)
144+
# XXX: restore `match_eltype` support
145+
# x′ = match_eltype(bn, ps, st, x)
146+
x′ = x
147+
σ = NNlib.fast_act(bn.activation, x′)
146148
y, stats = batchnorm(
147149
x′,
148-
safe_getproperty(ps, Val(:scale)),
149-
safe_getproperty(ps, Val(:bias)),
150-
safe_getproperty(st, Val(:running_mean)),
151-
safe_getproperty(st, Val(:running_var)),
152-
st.training,
150+
safe_getproperty(bn, Val(:scale)),
151+
safe_getproperty(bn, Val(:bias)),
152+
safe_getproperty(bn, Val(:running_mean)),
153+
safe_getproperty(bn, Val(:running_var)),
154+
bn.training,
153155
σ,
154-
convert(unwrapped_eltype(x′), BN.momentum),
155-
convert(unwrapped_eltype(x′), BN.epsilon),
156+
convert(unwrapped_eltype(x′), bn.momentum),
157+
convert(unwrapped_eltype(x′), bn.epsilon),
156158
)
157-
return y, update_batchnorm_state(BN, st, stats)
159+
update_batchnorm_state!(bn, stats)
160+
return y
158161
end
159162

160-
function update_batchnorm_state(BN::BatchNorm, st::NamedTuple, stats)
161-
has_track_stats(BN) && return merge(
162-
st,
163-
(;
164-
running_mean=Utils.vec(stats.running_mean),
165-
running_var=Utils.vec(stats.running_var),
166-
),
167-
)
168-
return st
163+
function update_batchnorm_state!(bn, stats)
164+
has_track_stats(bn) || return nothing
165+
bn.running_mean = Utils.vec(stats.running_mean)
166+
bn.running_var = Utils.vec(stats.running_var)
167+
return nothing
169168
end
170169

171-
CRC.@non_differentiable update_batchnorm_state(::Any...)
170+
CRC.@non_differentiable update_batchnorm_state!(::Any...)
172171

173172
function Base.show(io::IO, l::BatchNorm)
174173
print(io, "BatchNorm($(l.chs)")
@@ -280,18 +279,20 @@ end
280279

281280
parameterlength(l::GroupNorm) = has_affine(l) ? (l.chs * 2) : 0
282281

283-
function (GN::GroupNorm)(x::AbstractArray, ps, st::NamedTuple)
284-
x′ = match_eltype(GN, ps, st, x)
285-
σ = NNlib.fast_act(GN.activation, x′)
282+
function apply(::Type{<:GroupNorm}, gn, x::AbstractArray)
283+
# XXX: restore `match_eltype` support
284+
# x′ = match_eltype(GN, ps, st, x)
285+
x′ = x
286+
σ = NNlib.fast_act(gn.activation, x′)
286287
y = groupnorm(
287288
x′,
288-
safe_getproperty(ps, Val(:scale)),
289-
safe_getproperty(ps, Val(:bias)),
290-
GN.groups,
289+
safe_getproperty(gn, Val(:scale)),
290+
safe_getproperty(gn, Val(:bias)),
291+
gn.groups,
291292
σ,
292-
convert(unwrapped_eltype(x′), GN.epsilon),
293+
convert(unwrapped_eltype(x′), gn.epsilon),
293294
)
294-
return y, st
295+
return y
295296
end
296297

297298
function Base.show(io::IO, l::GroupNorm)
@@ -434,35 +435,34 @@ end
434435
parameterlength(l::InstanceNorm) = ifelse(has_affine(l), l.chs * 2, 0)
435436
statelength(l::InstanceNorm) = ifelse(has_track_stats(l), l.chs * 2, 0) + 1
436437

437-
function (IN::InstanceNorm)(x::AbstractArray, ps, st::NamedTuple)
438-
x′ = match_eltype(IN, ps, st, x)
439-
σ = NNlib.fast_act(IN.activation, x′)
438+
function apply(::Type{<:InstanceNorm}, in, x::AbstractArray)
439+
# XXX: restore `match_eltype` support
440+
# x′ = match_eltype(IN, ps, st, x)
441+
x′ = x
442+
σ = NNlib.fast_act(in.activation, x′)
440443
y, stats = instancenorm(
441444
x′,
442-
safe_getproperty(ps, Val(:scale)),
443-
safe_getproperty(ps, Val(:bias)),
444-
safe_getproperty(st, Val(:running_mean)),
445-
safe_getproperty(st, Val(:running_var)),
446-
st.training,
445+
safe_getproperty(in, Val(:scale)),
446+
safe_getproperty(in, Val(:bias)),
447+
safe_getproperty(in, Val(:running_mean)),
448+
safe_getproperty(in, Val(:running_var)),
449+
in.training,
447450
σ,
448-
convert(unwrapped_eltype(x′), IN.momentum),
449-
convert(unwrapped_eltype(x′), IN.epsilon),
451+
convert(unwrapped_eltype(x′), in.momentum),
452+
convert(unwrapped_eltype(x′), in.epsilon),
450453
)
451-
return y, update_instancenorm_state(IN, st, stats)
454+
update_instancenorm_state!(in, stats)
455+
return y
452456
end
453457

454-
function update_instancenorm_state(IN::InstanceNorm, st::NamedTuple, stats)
455-
has_track_stats(IN) && return merge(
456-
st,
457-
(;
458-
running_mean=Utils.vec(stats.running_mean),
459-
running_var=Utils.vec(stats.running_var),
460-
),
461-
)
462-
return st
458+
function update_instancenorm_state!(in::InstanceNorm, stats)
459+
has_track_stats(in) || return nothing
460+
in.running_mean = Utils.vec(stats.running_mean)
461+
in.running_var = Utils.vec(stats.running_var)
462+
return nothing
463463
end
464464

465-
CRC.@non_differentiable update_instancenorm_state(::Any...)
465+
CRC.@non_differentiable update_instancenorm_state!(::Any...)
466466

467467
function Base.show(io::IO, l::InstanceNorm)
468468
print(io, "InstanceNorm($(l.chs)")
@@ -554,18 +554,20 @@ function initialparameters(rng::AbstractRNG, ln::LayerNorm)
554554
return (;)
555555
end
556556

557-
function (l::LayerNorm)(x::AbstractArray, ps, st::NamedTuple)
558-
x′ = match_eltype(l, ps, st, x)
557+
function apply(::Type{<:LayerNorm}, l, x::AbstractArray)
558+
# XXX: restore `match_eltype` support
559+
# x′ = match_eltype(l, ps, st, x)
560+
x′ = x
559561
σ = NNlib.fast_act(l.activation, x′)
560562
y = layernorm(
561563
x′,
562-
safe_getproperty(ps, Val(:scale)),
563-
safe_getproperty(ps, Val(:bias)),
564+
safe_getproperty(l, Val(:scale)),
565+
safe_getproperty(l, Val(:bias)),
564566
σ,
565567
l.dims,
566568
convert(unwrapped_eltype(x′), l.epsilon),
567569
)
568-
return y, st
570+
return y
569571
end
570572

571573
function Base.show(io::IO, l::LayerNorm)
@@ -781,17 +783,14 @@ parameterlength(l::RMSNorm) = has_affine(l) ? prod(l.normalized_shape) : 0
781783

782784
# specialization on `NT` is important here, else we won't be able to infer the
783785
# correct eltype of the output.
784-
function (rms::RMSNorm)(x::AbstractArray{T}, ps, st::NamedTuple) where {T}
786+
function apply(::Type{<:RMSNorm}, rms, x::AbstractArray{T}) where {T}
785787
# Don't use `match_eltype` here, since often times the eltypes are intentionally
786788
# different.
787789
ϵ = T(rms.epsilon)
788790
mean_sq = mean(abs2, x; dims=1:length(rms.normalized_shape))
789791

790792
if has_affine(rms)
791-
norm_x = @. (x * LuxOps.rsqrt(mean_sq + ϵ)) * ps.scale
792-
else
793-
norm_x = @. x * LuxOps.rsqrt(mean_sq + ϵ)
793+
return @. (x * LuxOps.rsqrt(mean_sq + ϵ)) * ps.scale
794794
end
795-
796-
return norm_x, st
795+
return @. x * LuxOps.rsqrt(mean_sq + ϵ)
797796
end

0 commit comments

Comments
 (0)