Skip to content

Commit 8d39975

Browse files
Merge pull request #997 from ChrisRackauckas-Claude/precompile-improvements-20251230-111636
Add PrecompileTools workload to improve startup time
2 parents 6f3f600 + 2725701 commit 8d39975

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1313
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1414
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
1515
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
16+
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1819
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
@@ -58,6 +59,7 @@ Optimization = "4"
5859
OptimizationOptimJL = "0.4"
5960
OptimizationOptimisers = "0.3"
6061
OrdinaryDiffEq = "6.76.0"
62+
PrecompileTools = "1.3.3"
6163
Printf = "1.10"
6264
Random = "1.10"
6365
ReTestItems = "1.25.1"

src/DiffEqFlux.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,7 @@ export BacksolveAdjoint, QuadratureAdjoint, GaussAdjoint, InterpolatingAdjoint,
5555
AdjointLSS, NILSS, NILSAS
5656
export TrackerVJP, ZygoteVJP, EnzymeVJP, ReverseDiffVJP
5757

58+
# Precompilation workload - must be at the end
59+
include("precompilation.jl")
60+
5861
end

src/precompilation.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Precompilation workload for DiffEqFlux
2+
# This improves time-to-first-X (TTFX) by precompiling common code paths
3+
4+
using PrecompileTools: @compile_workload, @setup_workload
5+
6+
@setup_workload begin
7+
# Setup code - imports and minimal test data
8+
# This code is run during precompilation but the compilation results are discarded
9+
using Random: MersenneTwister
10+
using Lux: Chain, Dense
11+
12+
@compile_workload begin
13+
# These operations will be precompiled
14+
# Focus on the most common use cases
15+
16+
# Use a fixed RNG for reproducibility
17+
rng = MersenneTwister(0)
18+
19+
# Create a simple model - this is the most common pattern
20+
model = Chain(Dense(2, 4, tanh), Dense(4, 2))
21+
22+
# Create NeuralODE layer - the main entry point
23+
# Note: We don't run the forward pass because it requires an ODE solver
24+
# which is not a direct dependency of DiffEqFlux
25+
tspan = (0.0f0, 1.0f0)
26+
node = NeuralODE(model, tspan)
27+
28+
# Setup parameters and state - this is called often and benefits from precompilation
29+
ps, st = Lux.setup(rng, node)
30+
31+
# Precompile StatefulLuxLayer creation (used in forward pass)
32+
stateful = StatefulLuxLayer{true}(node.model, nothing, st)
33+
34+
# Precompile the dudt function creation pattern
35+
x0 = Float32[1.0, 0.0]
36+
dudt_out = stateful(x0, ps)
37+
38+
# Precompile ODEFunction and ODEProblem creation
39+
dudt(u, p, t) = stateful(u, p)
40+
ff = ODEFunction{false}(dudt; tgrad = basic_tgrad)
41+
prob = ODEProblem{false}(ff, x0, node.tspan, ps)
42+
43+
# Precompile FFJORD constructor
44+
ffjord_model = Chain(Dense(2, 4, tanh), Dense(4, 2))
45+
ffjord = FFJORD(ffjord_model, tspan, (2,))
46+
47+
# Precompile collocation kernel calculations (commonly used)
48+
tpoints = Float32[0.0, 0.5, 1.0]
49+
data = Float32[1.0 1.1 1.2; 0.0 0.1 0.2]
50+
try
51+
collocate_data(data, tpoints, TriangularKernel())
52+
catch
53+
# May fail with small data, but we still get the compilation
54+
end
55+
end
56+
end

0 commit comments

Comments
 (0)