|
1 | | -using Test |
2 | | -using FunctionProperties |
3 | | -using ComponentArrays, Random |
4 | | - |
5 | | -const GROUP = get(ENV, "GROUP", "All") |
6 | | - |
7 | | -if GROUP == "QA" |
8 | | - using Pkg |
9 | | - Pkg.activate(joinpath(@__DIR__, "qa")) |
10 | | - Pkg.instantiate() |
11 | | - include(joinpath(@__DIR__, "qa", "qa.jl")) |
12 | | -end |
13 | | - |
14 | | -if GROUP in ("All", "Core") |
15 | | - |
16 | | - @test hasbranching(1, 2) do x, y |
17 | | - (x < 0 ? -x : x) + exp(y) |
18 | | - end |
19 | | - |
20 | | - @test !hasbranching(1, 2) do x, y |
21 | | - ifelse(x < 0, -x, x) + exp(y) |
22 | | - end |
23 | | - |
24 | | - # Test overloading via is_leaf |
25 | | - |
26 | | - f_branch() = true ? 1 : 0 |
27 | | - @test FunctionProperties.hasbranching(f_branch) |
28 | | - FunctionProperties.is_leaf(::typeof(f_branch)) = true |
29 | | - @test !FunctionProperties.hasbranching(f_branch) |
30 | | - |
31 | | - # Test simple mutating functions |
32 | | - function f(dx, x) |
33 | | - return @inbounds dx[1] = x[1] |
34 | | - end |
35 | | - x = zeros(1) |
36 | | - dx = zeros(1) |
37 | | - @test !FunctionProperties.hasbranching(f, dx, x) |
38 | | - |
39 | | - # Test broadcast |
40 | | - function f(x) |
41 | | - return cos.(x .+ x .* x) |
42 | | - end |
43 | | - x = [1.0] |
44 | | - @test !FunctionProperties.hasbranching(f, x) |
45 | | - |
46 | | - # Neural networks |
47 | | - # |
48 | | - # The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997): |
49 | | - # `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape. |
50 | | - # The forward pass is expressed here as explicit affine transforms plus broadcast activations, which |
51 | | - # is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer: |
52 | | - # modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain |
53 | | - # genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan |
54 | | - # cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46). |
55 | | - rng = Random.default_rng() |
56 | | - W = randn(rng, Float32, 1, 1) |
57 | | - b = randn(rng, Float32, 1) |
58 | | - p = ComponentArray(; weight = W, bias = b) |
59 | | - t = [0.0] |
60 | | - |
61 | | - function f(x, ps) |
62 | | - return ps.weight * x |
63 | | - end |
64 | | - @test !FunctionProperties.hasbranching(f, t, p) |
65 | | - |
66 | | - function f(x, ps) |
67 | | - return x .+ x |
68 | | - end |
69 | | - @test !FunctionProperties.hasbranching(f, t, p) |
70 | | - |
71 | | - # Affine transform followed by a broadcast activation (the original `apply_activation` intent). |
72 | | - function f2(x, ps) |
73 | | - return identity.(ps.weight * x .+ vec(ps.bias)) |
74 | | - end |
75 | | - @test !FunctionProperties.hasbranching(f2, t, p) |
76 | | - |
77 | | - # A multi-layer perceptron forward pass built from broadcast `tanh` activations. |
78 | | - rng = Random.default_rng() |
79 | | - tspan = (0.0f0, 8.0f0) |
80 | | - W1 = randn(rng, Float32, 32, 2) |
81 | | - b1 = randn(rng, Float32, 32) |
82 | | - W2 = randn(rng, Float32, 32, 32) |
83 | | - b2 = randn(rng, Float32, 32) |
84 | | - W3 = randn(rng, Float32, 1, 32) |
85 | | - b3 = randn(rng, Float32, 1) |
86 | | - p = ComponentArray(; W1, b1, W2, b2, W3, b3) |
87 | | - θ, ax = getdata(p), getaxes(p) |
88 | | - |
89 | | - ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3 |
90 | | - |
91 | | - function dxdt_(dx, x, p, t) |
92 | | - x1, x2 = x |
93 | | - dx[1] = x[2] + first(ann(x, p)) |
94 | | - return dx[2] = first(ann([t, t], p)) |
95 | | - end |
96 | | - x0 = [-4.0f0, 0.0f0] |
97 | | - ts = Float32.(collect(0.0:0.01:tspan[2])) |
98 | | - @test !FunctionProperties.hasbranching(dxdt_, copy(x0), x0, p, tspan[1]) |
99 | | - |
100 | | -end |
| 1 | +using SciMLTesting |
| 2 | +run_tests() |
0 commit comments