Skip to content

Commit fa2b856

Browse files
committed
feat: emit batch_norm ops from stablehlo
1 parent d51fed9 commit fa2b856

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

ext/LuxReactantExt/LuxReactantExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Utils.contiguous(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_a
1717

1818
Utils.eltype(::Type{<:TracedRArray{T, N}}) where {T, N} = T
1919
Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T
20+
Utils.eltype(x::Reactant.AnyTracedRArray) = Reactant.unwrapped_eltype(x)
2021

2122
function Utils.promote_to(::Type{T}, x::Number) where {T <: Number}
2223
x isa Reactant.TracedType && return x

lib/LuxLib/Project.toml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1919
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
2020
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2121
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
22-
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2322
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
23+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2424
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2525
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2626
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
@@ -32,23 +32,29 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3232
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
3333
BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
3434
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
35-
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
3635
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
3736
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
37+
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
3838
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
39+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
3940
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
4041
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
4142
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4243
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
4344

45+
[sources]
46+
LuxCore = {path = "../LuxCore"}
47+
MLDataDevices = {path = "../MLDataDevices"}
48+
4449
[extensions]
4550
LuxLibAppleAccelerateExt = "AppleAccelerate"
4651
LuxLibBLISBLASExt = "BLISBLAS"
4752
LuxLibCUDAExt = "CUDA"
48-
LuxLibMKLExt = "MKL"
4953
LuxLibEnzymeExt = "Enzyme"
5054
LuxLibLoopVectorizationExt = "LoopVectorization"
55+
LuxLibMKLExt = "MKL"
5156
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
57+
LuxLibReactantExt = "Reactant"
5258
LuxLibReverseDiffExt = "ReverseDiff"
5359
LuxLibSLEEFPiratesExt = "SLEEFPirates"
5460
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
@@ -79,9 +85,10 @@ MLDataDevices = "1.6"
7985
Markdown = "1.10"
8086
NNlib = "0.9.26"
8187
Octavian = "0.3.28"
82-
Preferences = "1.4.3"
8388
Polyester = "0.7.15"
89+
Preferences = "1.4.3"
8490
Random = "1.10"
91+
Reactant = "0.2.13"
8592
Reexport = "1"
8693
ReverseDiff = "1.15"
8794
SLEEFPirates = "0.6.43"
@@ -91,7 +98,3 @@ Statistics = "1.10"
9198
Tracker = "0.2.36"
9299
cuDNN = "1.3"
93100
julia = "1.10"
94-
95-
[sources]
96-
LuxCore = { path = "../LuxCore" }
97-
MLDataDevices = { path = "../MLDataDevices" }
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
module LuxLibReactantExt
2+
3+
using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray,
4+
AnyTracedRVector, TracedRNumber
5+
using Static: StaticBool, True, False
6+
7+
using LuxLib: LuxLib, Impl, Optional, Utils
8+
9+
# Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however,
10+
# NNlib doesn't have certain ops implemented. In those cases we can emit more optimized
11+
# StableHLO
12+
13+
function Impl.batchnorm(
14+
x::AnyTracedRArray{T},
15+
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
16+
::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
17+
training::StaticBool, act::F, momentum, ϵ
18+
) where {T, F}
19+
x = TracedUtils.materialize_traced_array(x)
20+
21+
γ = if γ === nothing
22+
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
23+
else
24+
TracedUtils.materialize_traced_array(γ)
25+
end
26+
β = if β === nothing
27+
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
28+
else
29+
TracedUtils.materialize_traced_array(β)
30+
end
31+
32+
if training isa True
33+
op = MLIR.Dialects.stablehlo.batch_norm_training(
34+
TracedUtils.get_mlir_data(x),
35+
TracedUtils.get_mlir_data(γ),
36+
TracedUtils.get_mlir_data(β);
37+
epsilon=Float32(ϵ),
38+
feature_index=Int64(ndims(x) - 2)
39+
)
40+
41+
res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x)))
42+
μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1))
43+
σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1))
44+
45+
if=== nothing && rσ² === nothing
46+
return res, nothing, nothing
47+
else
48+
@assert rμ !== nothing && rσ² !== nothing
49+
m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x)))
50+
rμ, rσ² = Impl.update_running_statistics(
51+
rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))
52+
)
53+
return res, rμ, rσ²
54+
end
55+
else
56+
if=== nothing && rσ² === nothing
57+
μ, σ² = Impl.mean_var(
58+
x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false
59+
)
60+
μ = TracedUtils.materialize_traced_array(vec(μ))
61+
σ² = TracedUtils.materialize_traced_array(vec(σ²))
62+
else
63+
@assert rμ !== nothing && rσ² !== nothing
64+
μ = TracedUtils.materialize_traced_array(rμ)
65+
σ² = TracedUtils.materialize_traced_array(rσ²)
66+
end
67+
68+
res = MLIR.IR.result(
69+
MLIR.Dialects.stablehlo.batch_norm_inference(
70+
TracedUtils.get_mlir_data(x),
71+
TracedUtils.get_mlir_data(γ),
72+
TracedUtils.get_mlir_data(β),
73+
TracedUtils.get_mlir_data(μ),
74+
TracedUtils.get_mlir_data(σ²);
75+
epsilon=Float32(ϵ),
76+
feature_index=Int64(ndims(x) - 2)
77+
),
78+
1
79+
)
80+
81+
return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ²
82+
end
83+
end
84+
85+
end

0 commit comments

Comments
 (0)