Skip to content

Commit 0d0558e

Browse files
authored
refactor: use EnzymeRules.@easy_rule in Lux.jl (#1542)
* refactor: use `EnzymeRules.@easy_rule` in Lux.jl * refactor: more migration * fix: import Annotation * fix: bump min Enzyme version * fix: try some fixes * feat: use enzyme rules for safe_reshape_internal
1 parent 226beb3 commit 0d0558e

File tree

11 files changed

+65
-122
lines changed

11 files changed

+65
-122
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.25.0"
4+
version = "1.26.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -87,7 +87,7 @@ ConcreteStructs = "0.2.3"
8787
DiffResults = "1.1"
8888
DispatchDoctor = "0.4.26"
8989
Enzyme = "0.13.81"
90-
EnzymeCore = "0.8.14"
90+
EnzymeCore = "0.8.15"
9191
FastClosures = "0.3.2"
9292
Flux = "0.16.3"
9393
ForwardDiff = "0.10.36, =1"

ext/LuxLossFunctionsExt.jl

Lines changed: 3 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -44,46 +44,9 @@ function CRC.rrule(
4444
return LossFunctionImpl.fused_agg(sum, lfn, x, y), ∇fused_agg
4545
end
4646

47-
# COV_EXCL_START
48-
49-
function EnzymeRules.augmented_primal(
50-
cfg::EnzymeRules.RevConfigWidth{1},
51-
func::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)},
52-
::Type{<:EnzymeCore.Active},
53-
agg_f::EnzymeCore.Const{typeof(sum)},
54-
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
55-
x::EnzymeCore.Annotation{<:AbstractArray},
56-
y::EnzymeCore.Const,
57-
)
58-
primal =
59-
EnzymeRules.needs_primal(cfg) ? func.val(agg_f.val, lfn.val, x.val, y.val) : nothing
60-
61-
cache_x = EnzymeRules.overwritten(cfg)[4] ? copy(x.val) : nothing
62-
cache_y = EnzymeRules.overwritten(cfg)[5] ? copy(y.val) : nothing
63-
64-
return EnzymeRules.AugmentedReturn(primal, nothing, (cache_x, cache_y))
65-
end
66-
67-
function EnzymeRules.reverse(
68-
cfg::EnzymeRules.RevConfigWidth{1},
69-
::EnzymeCore.Const{typeof(LossFunctionImpl.fused_agg)},
70-
dret::EnzymeCore.Active,
71-
(cache_x, cache_y),
72-
agg_f::EnzymeCore.Const{typeof(sum)},
73-
lfn::EnzymeCore.Const{<:LossFunctions.Traits.Loss},
74-
x::EnzymeCore.Annotation{<:AbstractArray},
75-
y::EnzymeCore.Const,
47+
EnzymeRules.@easy_rule(
48+
LossFunctionImpl.fused_agg(fn::typeof(sum), lfn::LossFunctions.Traits.Loss, x, y),
49+
(@Constant, @Constant, LossFunctions.deriv(lfn, x, y) .* Ω, @Constant)
7650
)
77-
EnzymeRules.overwritten(cfg)[4] || (cache_x = x.val)
78-
EnzymeRules.overwritten(cfg)[5] || (cache_y = y.val)
79-
80-
if !(typeof(x) <: EnzymeCore.Const)
81-
@. x.dval = LossFunctions.deriv(lfn.val, cache_x, cache_y) * dret.val
82-
end
83-
84-
return ntuple(Returns(nothing), 4)
85-
end
86-
87-
# COV_EXCL_STOP
8851

8952
end

lib/LuxLib/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
3-
version = "1.12.1"
3+
version = "1.13.0"
44
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
55

66
[deps]
@@ -74,7 +74,7 @@ CUDA = "5.8"
7474
ChainRulesCore = "1.25.1"
7575
DispatchDoctor = "0.4.12"
7676
Enzyme = "0.13.81"
77-
EnzymeCore = "0.8.14"
77+
EnzymeCore = "0.8.15"
7878
FastClosures = "0.3.2"
7979
ForwardDiff = "0.10.36, 1"
8080
Functors = "0.5"

lib/LuxLib/src/impl/Impl.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using StaticArraysCore: StaticVector, SArray
77
using Static: StaticBool, True, False, static
88

99
using ChainRulesCore: ChainRulesCore, NoTangent, HasReverseMode, RuleConfig
10-
using EnzymeCore: EnzymeCore, EnzymeRules
10+
using EnzymeCore: EnzymeCore, EnzymeRules, Annotation
1111
using ForwardDiff: ForwardDiff
1212

1313
using KernelAbstractions: KernelAbstractions, @kernel, @Const, @index

lib/LuxLib/src/impl/batched_mul.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,7 @@ function batched_matmul_cpu!(
226226
batched_matmul_loopvec_impl!(z, x, y)
227227
return nothing
228228
end
229-
if Utils.within_enzyme_autodiff()
230-
# XXX: https://github.com/LuxDL/Lux.jl/issues/1024
231-
fallback_batched_matmul!(z, LoopedArrayOp(), x, y)
232-
else
233-
NNlib.batched_mul!(z, x, y)
234-
end
229+
NNlib.batched_mul!(z, x, y)
235230
return nothing
236231
end
237232

lib/LuxLib/src/impl/batchnorm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ function batchnorm_affine_normalize(
8787
β::Optional{<:AbstractVector},
8888
ϵ,
8989
) where {F,xT,μT,σ²T,N}
90-
x′ = reshape(x, :, size(x, N - 1), size(x, N))
91-
return reshape(
90+
x′ = Utils.safe_reshape(x, :, size(x, N - 1), size(x, N))
91+
return Utils.safe_reshape(
9292
batchnorm_affine_normalize_internal(opmode, act, x′, vec(μ), vec(σ²), γ, β, ϵ),
93-
size(x),
93+
size(x)...,
9494
)
9595
end
9696

lib/LuxLib/src/impl/groupnorm.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@ function groupnorm(
1010
act::F,
1111
ϵ,
1212
) where {F,N,xT}
13-
x′ = reshape(x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N))
13+
x′ = Utils.safe_reshape(
14+
x, size(x)[1:(N - 2)]..., size(x, N - 1) ÷ groups, groups, size(x, N)
15+
)
1416
(μ, σ²), _ = compute_batch_statistics(
1517
x′, nothing, nothing, groupnorm_reduce_dims(x), False(), nothing
1618
)
17-
return reshape(groupnorm_affine_normalize(act, x′, μ, σ², γ, β, ϵ), size(x))
19+
return Utils.safe_reshape(
20+
groupnorm_affine_normalize(act, x′, μ, σ², γ, β, ϵ), size(x)...
21+
)
1822
end
1923

2024
function groupnorm_affine_normalize(
@@ -58,8 +62,8 @@ end
5862
) where {F,N,xT,μT,σ²T}
5963
reshape_calls = if γ != Nothing
6064
quote
61-
γ′ = reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1)
62-
β′ = reshape(β, 1, size(x, N - 2), size(x, N - 1), 1)
65+
γ′ = Utils.safe_reshape(γ, 1, size(x, N - 2), size(x, N - 1), 1)
66+
β′ = Utils.safe_reshape(β, 1, size(x, N - 2), size(x, N - 1), 1)
6367
end
6468
else
6569
quote
@@ -69,13 +73,13 @@ end
6973
end
7074

7175
return quote
72-
x′ = reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N))
73-
μ′ = reshape(μ, 1, 1, size(x, N - 1), size(x, N))
74-
σ²′ = reshape(σ², 1, 1, size(x, N - 1), size(x, N))
76+
x′ = Utils.safe_reshape(x, :, size(x, N - 2), size(x, N - 1), size(x, N))
77+
μ′ = Utils.safe_reshape(μ, 1, 1, size(x, N - 1), size(x, N))
78+
σ²′ = Utils.safe_reshape(σ², 1, 1, size(x, N - 1), size(x, N))
7579
$(reshape_calls)
76-
return reshape(
80+
return Utils.safe_reshape(
7781
groupnorm_affine_normalize_internal(opmode, act, x′, μ′, σ²′, γ′, β′, ϵ),
78-
size(x),
82+
size(x)...,
7983
)
8084
end
8185
end

lib/LuxLib/src/impl/layernorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function expand_layernorm_dims(
4141
) where {xT,γT,βT,N,M}
4242
new_γ_size = (size(γ)..., ntuple(i -> 1, N - M)...)
4343
new_β_size = (size(β)..., ntuple(i -> 1, N - M)...)
44-
return reshape(γ, new_γ_size), reshape(β, new_β_size)
44+
return Utils.safe_reshape(γ, new_γ_size...), Utils.safe_reshape(β, new_β_size...)
4545
end
4646

4747
function expand_layernorm_dims(

lib/LuxLib/src/impl/normalization.jl

Lines changed: 1 addition & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -171,62 +171,8 @@ end
171171

172172
reshape_norm_dims(::Nothing, ::Dims) = nothing
173173
function reshape_norm_dims(x::AbstractArray, dims::Dims)
174-
if Utils.within_enzyme_autodiff()
175-
y = similar(x, get_norm_reshape_dims(dims, length(x)))
176-
reshape_norm_dims!(y, x)
177-
return y
178-
else
179-
return reshape(x, get_norm_reshape_dims(dims, length(x)))
180-
end
181-
end
182-
183-
function reshape_norm_dims!(y::AbstractArray, x::AbstractArray)
184-
copyto!(vec(y), vec(x))
185-
return nothing
186-
end
187-
188-
function CRC.rrule(::typeof(reshape_norm_dims), x::AbstractArray, dims::Dims)
189-
y = reshape_norm_dims(x, dims)
190-
∇reshape_norm_dims = @closure Δ -> begin
191-
∂x = CRC.@thunk reshape(recursive_unthunk(Δ), size(x))
192-
return ∂∅, ∂x, ∂∅
193-
end
194-
return y, ∇reshape_norm_dims
195-
end
196-
197-
# COV_EXCL_START
198-
# reshape_norm_dims is a constant source of runtime activity for Enzyme. Define custom
199-
# rules to avoid this.
200-
function EnzymeRules.augmented_primal(
201-
cfg::EnzymeRules.RevConfigWidth{1},
202-
::EnzymeCore.Const{typeof(reshape_norm_dims)},
203-
::Type{EnzymeCore.Const{Nothing}},
204-
y::EnzymeCore.Annotation{<:AbstractArray},
205-
x::EnzymeCore.Annotation{<:AbstractArray},
206-
)
207-
if EnzymeRules.needs_primal(cfg)
208-
copyto!(vec(y.val), vec(x.val))
209-
end
210-
return EnzymeRules.AugmentedReturn(nothing, nothing, ())
211-
end
212-
213-
function EnzymeRules.reverse(
214-
::EnzymeRules.RevConfigWidth{1},
215-
::EnzymeCore.Const{typeof(reshape_norm_dims)},
216-
::Type{EnzymeCore.Const{Nothing}},
217-
tape,
218-
y::EnzymeCore.Annotation{<:AbstractArray},
219-
x::EnzymeCore.Annotation{<:AbstractArray},
220-
)
221-
if !(typeof(y) <: EnzymeCore.Const)
222-
if !(typeof(x) <: EnzymeCore.Const)
223-
copyto!(vec(x.dval), vec(y.dval))
224-
end
225-
fill!(y.dval, false)
226-
end
227-
return ntuple(Returns(nothing), 2)
174+
return Utils.safe_reshape(x, get_norm_reshape_dims(dims, length(x))...)
228175
end
229-
# COV_EXCL_STOP
230176

231177
@inbounds function get_norm_reshape_dims(sx::NTuple{N,<:Int}, ly::Int) where {N}
232178
if ly == sx[N - 1]

lib/LuxLib/src/utils.jl

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module Utils
22

33
using ChainRulesCore: ChainRulesCore
4-
using EnzymeCore: EnzymeCore, EnzymeRules
4+
using EnzymeCore: EnzymeCore, EnzymeRules, Annotation
55
using FastClosures: @closure
66
using Functors: Functors
77
using ForwardDiff: ForwardDiff
@@ -40,9 +40,43 @@ ofeltype_array(::Type{T}, ::Nothing) where {T} = nothing
4040
contiguous(x::AbstractArray) = x
4141
contiguous(x::SubArray) = copy(x)
4242

43-
safe_reshape(x::AbstractArray, dims...) = reshape(x, dims...)
43+
safe_reshape(x::AbstractArray, dims...) = safe_reshape_internal(x, dims)
4444
safe_reshape(::Nothing, dims...) = nothing
4545

46+
safe_reshape_internal(x::AbstractArray, dims) = reshape(x, dims)
47+
48+
# reshape sometimes causes trouble inside ReshapedArray for Enzyme. This is a workaround
49+
# for that.
50+
struct EnzymeRulesReshapeJacobian{T,N,D,A<:AbstractArray{T,N}} <: AbstractArray{T,2}
51+
x::A
52+
outsize::D
53+
end
54+
55+
Base.size(J::EnzymeRulesReshapeJacobian) = (prod(J.outsize), length(J.x))
56+
57+
function EnzymeRules.multiply_fwd_into(
58+
previous, J::EnzymeRulesReshapeJacobian, x::AbstractArray
59+
)
60+
reshaped_x = reshape(x, J.outsize...)
61+
previous === nothing && return reshaped_x
62+
@. previous += reshaped_x
63+
return previous
64+
end
65+
66+
function EnzymeRules.multiply_rev_into(
67+
previous, J::EnzymeRulesReshapeJacobian, x::AbstractArray
68+
)
69+
reshaped_x = reshape(x, size(J.x)...)
70+
previous === nothing && return reshaped_x
71+
@. previous += reshaped_x
72+
return previous
73+
end
74+
75+
EnzymeRules.@easy_rule(
76+
safe_reshape_internal(x::AbstractArray, dims),
77+
(EnzymeRulesReshapeJacobian(x, size(Ω)), @Constant),
78+
)
79+
4680
remove_tracking(x) = x
4781
remove_tracking(x::AbstractArray) = x
4882
remove_tracking(::Type{T}) where {T} = T

0 commit comments

Comments
 (0)