|
17 | 17 | if Reactant.Reactant_jll.is_available() |
18 | 18 | @setup_workload begin |
19 | 19 | orig_backend = Reactant.XLA.default_backend() |
20 | | - Reactant.set_default_backend("cpu") # always precompile on CPU |
| 20 | + Reactant.set_default_backend("cpu") |
21 | 21 |
|
22 | | - dev = reactant_device(; force=true) |
| 22 | + @compile_workload begin |
| 23 | + @static if Reactant.precompilation_supported() |
| 24 | + dev = reactant_device(; force=true) |
23 | 25 |
|
24 | | - # attention model |
25 | | - mha = Lux.MultiHeadAttention(4; nheads=2) |
26 | | - ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev |
| 26 | + # attention model |
| 27 | + mha = Lux.MultiHeadAttention(4; nheads=2) |
| 28 | + ps_mha, st_mha = Lux.setup(Random.default_rng(), mha) |> dev |
27 | 29 |
|
28 | | - q = rand(Float32, (4, 3, 2)) |> dev |
29 | | - k = rand(Float32, (4, 3, 2)) |> dev |
30 | | - v = rand(Float32, (4, 3, 2)) |> dev |
| 30 | + q = ones(Float32, (4, 3, 2)) |> dev |
| 31 | + k = ones(Float32, (4, 3, 2)) |> dev |
| 32 | + v = ones(Float32, (4, 3, 2)) |> dev |
31 | 33 |
|
32 | | - # convolution + dense model |
33 | | - conv_model = Lux.Chain( |
34 | | - Lux.Conv((3, 3), 3 => 32), |
35 | | - Lux.Conv((3, 3), 32 => 64), |
36 | | - Lux.GlobalMaxPool(), |
37 | | - Lux.FlattenLayer(), |
38 | | - Lux.Dense(64 => 10), |
39 | | - ) |
40 | | - ps_conv_model, st_conv_model = Lux.setup(Random.default_rng(), conv_model) |> dev |
| 34 | + try |
| 35 | + @compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha)) |
41 | 36 |
|
42 | | - x = rand(Float32, (28, 28, 3, 2)) |> dev |
| 37 | + Lux.Training.single_train_step( |
| 38 | + AutoEnzyme(), |
| 39 | + PrecompileWorkloads.sumabs2attnloss, |
| 40 | + (q, k, v), |
| 41 | + Lux.Training.TrainState( |
| 42 | + mha, ps_mha, st_mha, Optimisers.Adam(0.001f0) |
| 43 | + ), |
| 44 | + ) |
| 45 | + catch err |
| 46 | + if !(err isa Reactant.ReactantPrecompilationException) |
| 47 | + rethrow(err) |
| 48 | + end |
| 49 | + end |
43 | 50 |
|
44 | | - @compile_workload begin |
45 | | - @compile mha((q, k, v), ps_mha, LuxCore.testmode(st_mha)) |
| 51 | + # convolution + dense model |
| 52 | + conv_model = Lux.Chain( |
| 53 | + Lux.Conv((3, 3), 3 => 32), |
| 54 | + Lux.Conv((3, 3), 32 => 64), |
| 55 | + Lux.GlobalMaxPool(), |
| 56 | + Lux.FlattenLayer(), |
| 57 | + Lux.Dense(64 => 10), |
| 58 | + ) |
| 59 | + ps_conv_model, st_conv_model = |
| 60 | + Lux.setup(Random.default_rng(), conv_model) |> dev |
46 | 61 |
|
47 | | - Lux.Training.single_train_step( |
48 | | - AutoEnzyme(), |
49 | | - PrecompileWorkloads.sumabs2attnloss, |
50 | | - (q, k, v), |
51 | | - Lux.Training.TrainState(mha, ps_mha, st_mha, Optimisers.Adam(0.001f0)), |
52 | | - ) |
| 62 | + x = ones(Float32, (28, 28, 3, 2)) |> dev |
53 | 63 |
|
54 | | - @compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model)) |
| 64 | + try |
| 65 | + @compile conv_model(x, ps_conv_model, LuxCore.testmode(st_conv_model)) |
55 | 66 |
|
56 | | - Lux.Training.single_train_step( |
57 | | - AutoEnzyme(), |
58 | | - PrecompileWorkloads.sumabs2loss, |
59 | | - x, |
60 | | - Lux.Training.TrainState( |
61 | | - conv_model, ps_conv_model, st_conv_model, Optimisers.Adam(0.001f0) |
62 | | - ), |
63 | | - ) |
| 67 | + Lux.Training.single_train_step( |
| 68 | + AutoEnzyme(), |
| 69 | + PrecompileWorkloads.sumabs2loss, |
| 70 | + x, |
| 71 | + Lux.Training.TrainState( |
| 72 | + conv_model, |
| 73 | + ps_conv_model, |
| 74 | + st_conv_model, |
| 75 | + Optimisers.Adam(0.001f0), |
| 76 | + ), |
| 77 | + ) |
| 78 | + catch err |
| 79 | + if !(err isa Reactant.ReactantPrecompilationException) |
| 80 | + rethrow(err) |
| 81 | + end |
| 82 | + end |
| 83 | + end |
64 | 84 | end |
65 | 85 |
|
66 | | - Reactant.clear_oc_cache() |
67 | 86 | Reactant.set_default_backend(orig_backend) |
68 | 87 | end |
69 | 88 | end |
0 commit comments