Skip to content

Commit 9bf4ce8

Browse files
authored
fix: reactant precompilation + throw error (#1631)
* fix: reactant precompilation + throw error * fix: rmsnorm
1 parent 5744a4a commit 9bf4ce8

File tree

8 files changed

+63
-44
lines changed

8 files changed

+63
-44
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.29.4"
4+
version = "1.29.5"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -122,7 +122,7 @@ Optimisers = "0.4.6"
122122
PrecompileTools = "1.2.1"
123123
Preferences = "1.4.3"
124124
Random = "1.10"
125-
Reactant = "0.2.200"
125+
Reactant = "0.2.205"
126126
ReactantCore = "0.1.16"
127127
Reexport = "1.2.2"
128128
ReverseDiff = "1.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ NPZ = "0.4.3"
6868
Optimisers = "0.4.6"
6969
Printf = "1.10"
7070
Random = "1.10"
71-
Reactant = "0.2.200"
71+
Reactant = "0.2.203"
7272
StableRNGs = "1"
7373
StaticArrays = "1"
7474
WeightInitializers = "1.3"

ext/ReactantExt/ReactantExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ include("tracing.jl")
5959
include("saved_model.jl")
6060
include("batched_jacobian.jl")
6161

62-
# include("precompile_workloads.jl")
62+
include("precompile_workloads.jl")
6363

6464
end

ext/ReactantExt/precompile_workloads.jl

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,72 @@ end
1717
if Reactant.Reactant_jll.is_available()
1818
@setup_workload begin
1919
orig_backend = Reactant.XLA.default_backend()
20-
Reactant.set_default_backend("cpu") # always precompile on CPU
20+
Reactant.set_default_backend("cpu")
2121

22-
dev = reactant_device(; force=true)
22+
@compile_workload begin
23+
@static if Reactant.precompilation_supported()
24+
dev = reactant_device(; force=true)
2325

24-
# attention model
25-
mha = Lux.MultiHeadAttention(4; nheads=2)
26-
ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev
26+
# attention model
27+
mha = Lux.MultiHeadAttention(4; nheads=2)
28+
ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev
2729

28-
q = rand(Float32, (4, 3, 2)) |> dev
29-
k = rand(Float32, (4, 3, 2)) |> dev
30-
v = rand(Float32, (4, 3, 2)) |> dev
30+
q = ones(Float32, (4, 3, 2)) |> dev
31+
k = ones(Float32, (4, 3, 2)) |> dev
32+
v = ones(Float32, (4, 3, 2)) |> dev
3133

32-
# convolution + dense model
33-
conv_model = Lux.Chain(
34-
Lux.Conv((3, 3), 3 => 32),
35-
Lux.Conv((3, 3), 32 => 64),
36-
Lux.GlobalMaxPool(),
37-
Lux.FlattenLayer(),
38-
Lux.Dense(64 => 10),
39-
)
40-
ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev
34+
try
35+
@compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha))
4136

42-
x = rand(Float32, (28, 28, 3, 2)) |> dev
37+
Lux.Training.single_train_step(
38+
AutoEnzyme(),
39+
PrecompileWorkloads.sumabs2attnloss,
40+
(q, k, v),
41+
Lux.Training.TrainState(
42+
mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)
43+
),
44+
)
45+
catch err
46+
if !(err isa Reactant.ReactantPrecompilationException)
47+
rethrow(err)
48+
end
49+
end
4350

44-
@compile_workload begin
45-
@compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha))
51+
# convolution + dense model
52+
conv_model = Lux.Chain(
53+
Lux.Conv((3, 3), 3 => 32),
54+
Lux.Conv((3, 3), 32 => 64),
55+
Lux.GlobalMaxPool(),
56+
Lux.FlattenLayer(),
57+
Lux.Dense(64 => 10),
58+
)
59+
ps_conv_model, st_conv_model =
60+
Lux.setup(Random.default_rng(), conv_model) |> dev
4661

47-
Lux.Training.single_train_step(
48-
AutoEnzyme(),
49-
PrecompileWorkloads.sumabs2attnloss,
50-
(q, k, v),
51-
Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)),
52-
)
62+
x = ones(Float32, (28, 28, 3, 2)) |> dev
5363

54-
@compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model))
64+
try
65+
@compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model))
5566

56-
Lux.Training.single_train_step(
57-
AutoEnzyme(),
58-
PrecompileWorkloads.sumabs2loss,
59-
x,
60-
Lux.Training.TrainState(
61-
conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0)
62-
),
63-
)
67+
Lux.Training.single_train_step(
68+
AutoEnzyme(),
69+
PrecompileWorkloads.sumabs2loss,
70+
x,
71+
Lux.Training.TrainState(
72+
conv_model,
73+
ps_conv_model,
74+
st_conv_model,
75+
Optimisers.Adam(0.001f0),
76+
),
77+
)
78+
catch err
79+
if !(err isa Reactant.ReactantPrecompilationException)
80+
rethrow(err)
81+
end
82+
end
83+
end
6484
end
6585

66-
Reactant.clear_oc_cache()
6786
Reactant.set_default_backend(orig_backend)
6887
end
6988
end

lib/LuxCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Flux = "0.16.3"
4242
Functors = "0.5"
4343
MLDataDevices = "1.17"
4444
Random = "1.10"
45-
Reactant = "0.2.200"
45+
Reactant = "0.2.203"
4646
ReverseDiff = "1.15"
4747
SciMLPublic = "1.0.0"
4848
Setfield = "1"

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Metal = "1"
6767
OneHotArrays = "0.2.10"
6868
Preferences = "1.4.3"
6969
Random = "1.10"
70-
Reactant = "0.2.200"
70+
Reactant = "0.2.203"
7171
RecursiveArrayTools = "3.8"
7272
ReverseDiff = "1.15"
7373
SciMLPublic = "1.0.0"

src/layers/normalize.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ parameterlength(l::RMSNorm) = has_affine(l) ? prod(l.normalized_shape) : 0
785785
# Don't use `match_eltype` here, since often times the eltypes are intentionally
786786
# different.
787787
ϵ = T(rms.epsilon)
788-
mean_sq = mean(abs2, x; dims=1:length(rms.normalized_shape))
788+
mean_sq = mean(abs2, x; dims=Tuple(1:length(rms.normalized_shape)))
789789

790790
if has_affine(rms)
791791
norm_x = @. (x * LuxOps.rsqrt(mean_sq + ϵ)) * ps.scale

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Preferences = "1.4.3"
8080
PythonCall = "0.9"
8181
Random = "1.10"
8282
ReTestItems = "1.24.0"
83-
Reactant = "0.2.170"
83+
Reactant = "0.2.205"
8484
Reexport = "1.2.2"
8585
ReverseDiff = "1.15.3"
8686
Setfield = "1.1.1"

0 commit comments

Comments
 (0)