Skip to content

Commit aeccbd3

Browse files
committed
fix: reactant precompilation + throw error
1 parent 5744a4a commit aeccbd3

File tree

7 files changed

+51
-45
lines changed

7 files changed

+51
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ Optimisers = "0.4.6"
122122
PrecompileTools = "1.2.1"
123123
Preferences = "1.4.3"
124124
Random = "1.10"
125-
Reactant = "0.2.200"
125+
Reactant = "0.2.203"
126126
ReactantCore = "0.1.16"
127127
Reexport = "1.2.2"
128128
ReverseDiff = "1.15"

docs/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ NPZ = "0.4.3"
6868
Optimisers = "0.4.6"
6969
Printf = "1.10"
7070
Random = "1.10"
71-
Reactant = "0.2.200"
71+
Reactant = "0.2.203"
7272
StableRNGs = "1"
7373
StaticArrays = "1"
7474
WeightInitializers = "1.3"

ext/ReactantExt/ReactantExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,6 @@ include("tracing.jl")
5959
include("saved_model.jl")
6060
include("batched_jacobian.jl")
6161

62-
# include("precompile_workloads.jl")
62+
include("precompile_workloads.jl")
6363

6464
end

ext/ReactantExt/precompile_workloads.jl

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,54 +16,60 @@ end
1616

1717
if Reactant.Reactant_jll.is_available()
1818
@setup_workload begin
19-
orig_backend = Reactant.XLA.default_backend()
20-
Reactant.set_default_backend("cpu") # always precompile on CPU
19+
try
20+
orig_backend = Reactant.XLA.default_backend()
21+
Reactant.set_default_backend("cpu") # always precompile on CPU
2122

22-
dev = reactant_device(; force=true)
23+
dev = reactant_device(; force=true)
2324

24-
# attention model
25-
mha = Lux.MultiHeadAttention(4; nheads=2)
26-
ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev
25+
# attention model
26+
mha = Lux.MultiHeadAttention(4; nheads=2)
27+
ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev
2728

28-
q = rand(Float32, (4, 3, 2)) |> dev
29-
k = rand(Float32, (4, 3, 2)) |> dev
30-
v = rand(Float32, (4, 3, 2)) |> dev
29+
q = rand(Float32, (4, 3, 2)) |> dev
30+
k = rand(Float32, (4, 3, 2)) |> dev
31+
v = rand(Float32, (4, 3, 2)) |> dev
3132

32-
# convolution + dense model
33-
conv_model = Lux.Chain(
34-
Lux.Conv((3, 3), 3 => 32),
35-
Lux.Conv((3, 3), 32 => 64),
36-
Lux.GlobalMaxPool(),
37-
Lux.FlattenLayer(),
38-
Lux.Dense(64 => 10),
39-
)
40-
ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev
33+
# convolution + dense model
34+
conv_model = Lux.Chain(
35+
Lux.Conv((3, 3), 3 => 32),
36+
Lux.Conv((3, 3), 32 => 64),
37+
Lux.GlobalMaxPool(),
38+
Lux.FlattenLayer(),
39+
Lux.Dense(64 => 10),
40+
)
41+
ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev
4142

42-
x = rand(Float32, (28, 28, 3, 2)) |> dev
43+
x = rand(Float32, (28, 28, 3, 2)) |> dev
4344

44-
@compile_workload begin
45-
@compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha))
45+
@compile_workload begin
46+
@compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha))
4647

47-
Lux.Training.single_train_step(
48-
AutoEnzyme(),
49-
PrecompileWorkloads.sumabs2attnloss,
50-
(q, k, v),
51-
Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)),
52-
)
48+
Lux.Training.single_train_step(
49+
AutoEnzyme(),
50+
PrecompileWorkloads.sumabs2attnloss,
51+
(q, k, v),
52+
Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)),
53+
)
5354

54-
@compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model))
55+
@compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model))
5556

56-
Lux.Training.single_train_step(
57-
AutoEnzyme(),
58-
PrecompileWorkloads.sumabs2loss,
59-
x,
60-
Lux.Training.TrainState(
61-
conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0)
62-
),
63-
)
64-
end
57+
Lux.Training.single_train_step(
58+
AutoEnzyme(),
59+
PrecompileWorkloads.sumabs2loss,
60+
x,
61+
Lux.Training.TrainState(
62+
conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0)
63+
),
64+
)
65+
end
6566

66-
Reactant.clear_oc_cache()
67-
Reactant.set_default_backend(orig_backend)
67+
Reactant.clear_oc_cache()
68+
Reactant.set_default_backend(orig_backend)
69+
catch err
70+
if !(err isa Reactant.ReactantPrecompilationException)
71+
rethrow(err)
72+
end
73+
end
6874
end
6975
end

lib/LuxCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ Flux = "0.16.3"
4242
Functors = "0.5"
4343
MLDataDevices = "1.17"
4444
Random = "1.10"
45-
Reactant = "0.2.200"
45+
Reactant = "0.2.203"
4646
ReverseDiff = "1.15"
4747
SciMLPublic = "1.0.0"
4848
Setfield = "1"

lib/MLDataDevices/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ Metal = "1"
6767
OneHotArrays = "0.2.10"
6868
Preferences = "1.4.3"
6969
Random = "1.10"
70-
Reactant = "0.2.200"
70+
Reactant = "0.2.203"
7171
RecursiveArrayTools = "3.8"
7272
ReverseDiff = "1.15"
7373
SciMLPublic = "1.0.0"

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ Preferences = "1.4.3"
8080
PythonCall = "0.9"
8181
Random = "1.10"
8282
ReTestItems = "1.24.0"
83-
Reactant = "0.2.170"
83+
Reactant = "0.2.203"
8484
Reexport = "1.2.2"
8585
ReverseDiff = "1.15.3"
8686
Setfield = "1.1.1"

0 commit comments

Comments
 (0)