|
| 1 | +# Precompilation workload for DiffEqFlux |
| 2 | +# This improves time-to-first-X (TTFX) by precompiling common code paths |
| 3 | + |
| 4 | +using PrecompileTools: @compile_workload, @setup_workload |
| 5 | + |
| 6 | +@setup_workload begin |
| 7 | + # Setup code - imports and minimal test data |
| 8 | + # This code is run during precompilation but the compilation results are discarded |
| 9 | + using Random: MersenneTwister |
| 10 | + using Lux: Chain, Dense |
| 11 | + |
| 12 | + @compile_workload begin |
| 13 | + # These operations will be precompiled |
| 14 | + # Focus on the most common use cases |
| 15 | + |
| 16 | + # Use a fixed RNG for reproducibility |
| 17 | + rng = MersenneTwister(0) |
| 18 | + |
| 19 | + # Create a simple model - this is the most common pattern |
| 20 | + model = Chain(Dense(2, 4, tanh), Dense(4, 2)) |
| 21 | + |
| 22 | + # Create NeuralODE layer - the main entry point |
| 23 | + # Note: We don't run the forward pass because it requires an ODE solver |
| 24 | + # which is not a direct dependency of DiffEqFlux |
| 25 | + tspan = (0.0f0, 1.0f0) |
| 26 | + node = NeuralODE(model, tspan) |
| 27 | + |
| 28 | + # Setup parameters and state - this is called often and benefits from precompilation |
| 29 | + ps, st = Lux.setup(rng, node) |
| 30 | + |
| 31 | + # Precompile StatefulLuxLayer creation (used in forward pass) |
| 32 | + stateful = StatefulLuxLayer{true}(node.model, nothing, st) |
| 33 | + |
| 34 | + # Precompile the dudt function creation pattern |
| 35 | + x0 = Float32[1.0, 0.0] |
| 36 | + dudt_out = stateful(x0, ps) |
| 37 | + |
| 38 | + # Precompile ODEFunction and ODEProblem creation |
| 39 | + dudt(u, p, t) = stateful(u, p) |
| 40 | + ff = ODEFunction{false}(dudt; tgrad = basic_tgrad) |
| 41 | + prob = ODEProblem{false}(ff, x0, node.tspan, ps) |
| 42 | + |
| 43 | + # Precompile FFJORD constructor |
| 44 | + ffjord_model = Chain(Dense(2, 4, tanh), Dense(4, 2)) |
| 45 | + ffjord = FFJORD(ffjord_model, tspan, (2,)) |
| 46 | + |
| 47 | + # Precompile collocation kernel calculations (commonly used) |
| 48 | + tpoints = Float32[0.0, 0.5, 1.0] |
| 49 | + data = Float32[1.0 1.1 1.2; 0.0 0.1 0.2] |
| 50 | + try |
| 51 | + collocate_data(data, tpoints, TriangularKernel()) |
| 52 | + catch |
| 53 | + # May fail with small data, but we still get the compilation |
| 54 | + end |
| 55 | + end |
| 56 | +end |
0 commit comments