|
16 | 16 |
|
17 | 17 | if Reactant.Reactant_jll.is_available() |
18 | 18 | @setup_workload begin |
19 | | - orig_backend = Reactant.XLA.default_backend() |
20 | | - Reactant.set_default_backend("cpu") # always precompile on CPU |
| 19 | + try |
| 20 | + orig_backend = Reactant.XLA.default_backend() |
| 21 | + Reactant.set_default_backend("cpu") # always precompile on CPU |
21 | 22 |
|
22 | | - dev = reactant_device(; force=true) |
| 23 | + dev = reactant_device(; force=true) |
23 | 24 |
|
24 | | - # attention model |
25 | | - mha = Lux.MultiHeadAttention(4; nheads=2) |
26 | | - ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev |
| 25 | + # attention model |
| 26 | + mha = Lux.MultiHeadAttention(4; nheads=2) |
| 27 | + ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev |
27 | 28 |
|
28 | | - q = rand(Float32, (4, 3, 2)) |> dev |
29 | | - k = rand(Float32, (4, 3, 2)) |> dev |
30 | | - v = rand(Float32, (4, 3, 2)) |> dev |
| 29 | + q = rand(Float32, (4, 3, 2)) |> dev |
| 30 | + k = rand(Float32, (4, 3, 2)) |> dev |
| 31 | + v = rand(Float32, (4, 3, 2)) |> dev |
31 | 32 |
|
32 | | - # convolution + dense model |
33 | | - conv_model = Lux.Chain( |
34 | | - Lux.Conv((3, 3), 3 => 32), |
35 | | - Lux.Conv((3, 3), 32 => 64), |
36 | | - Lux.GlobalMaxPool(), |
37 | | - Lux.FlattenLayer(), |
38 | | - Lux.Dense(64 => 10), |
39 | | - ) |
40 | | - ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev |
| 33 | + # convolution + dense model |
| 34 | + conv_model = Lux.Chain( |
| 35 | + Lux.Conv((3, 3), 3 => 32), |
| 36 | + Lux.Conv((3, 3), 32 => 64), |
| 37 | + Lux.GlobalMaxPool(), |
| 38 | + Lux.FlattenLayer(), |
| 39 | + Lux.Dense(64 => 10), |
| 40 | + ) |
| 41 | + ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev |
41 | 42 |
|
42 | | - x = rand(Float32, (28, 28, 3, 2)) |> dev |
| 43 | + x = rand(Float32, (28, 28, 3, 2)) |> dev |
43 | 44 |
|
44 | | - @compile_workload begin |
45 | | - @compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha)) |
| 45 | + @compile_workload begin |
| 46 | + @compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha)) |
46 | 47 |
|
47 | | - Lux.Training.single_train_step( |
48 | | - AutoEnzyme(), |
49 | | - PrecompileWorkloads.sumabs2attnloss, |
50 | | - (q, k, v), |
51 | | - Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)), |
52 | | - ) |
| 48 | + Lux.Training.single_train_step( |
| 49 | + AutoEnzyme(), |
| 50 | + PrecompileWorkloads.sumabs2attnloss, |
| 51 | + (q, k, v), |
| 52 | + Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)), |
| 53 | + ) |
53 | 54 |
|
54 | | - @compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model)) |
| 55 | + @compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model)) |
55 | 56 |
|
56 | | - Lux.Training.single_train_step( |
57 | | - AutoEnzyme(), |
58 | | - PrecompileWorkloads.sumabs2loss, |
59 | | - x, |
60 | | - Lux.Training.TrainState( |
61 | | - conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0) |
62 | | - ), |
63 | | - ) |
64 | | - end |
| 57 | + Lux.Training.single_train_step( |
| 58 | + AutoEnzyme(), |
| 59 | + PrecompileWorkloads.sumabs2loss, |
| 60 | + x, |
| 61 | + Lux.Training.TrainState( |
| 62 | + conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0) |
| 63 | + ), |
| 64 | + ) |
| 65 | + end |
65 | 66 |
|
66 | | - Reactant.clear_oc_cache() |
67 | | - Reactant.set_default_backend(orig_backend) |
| 67 | + Reactant.clear_oc_cache() |
| 68 | + Reactant.set_default_backend(orig_backend) |
| 69 | + catch err |
| 70 | + if !(err isa Reactant.ReactantPrecompilationException) |
| 71 | + rethrow(err) |
| 72 | + end |
| 73 | + end |
68 | 74 | end |
69 | 75 | end |
0 commit comments