Skip to content

Commit b2fbe2b

Browse files
ChrisRackauckas-ClaudeChrisRackauckasclaude
authored
Fix Julia 1.10 tests: replace Lux-internal NN tracer test with broadcast MLP (#49)
The test suite errored on Julia 1.10 with `UndefVarError: apply_activation not defined`: `test/runtests.jl` called `Lux.apply_activation`, an internal that no longer exists in modern Lux. Renaming it does not fix the suite — modern Lux layer dispatch routes through device-detection / type-introspection helpers (`fieldcount`, `can_setindex`, `__is_immutable_array_or_dual_val`, `_reshape_uncolon`, ...) that contain genuine, but value-independent, compile-time `GotoIfNot` branches. FunctionProperties' syntactic IR scan cannot distinguish those from value-dependent branches, so tracing a real Lux layer now reports `hasbranching == true`, which would make the `@test !hasbranching(ann, ...)` and `@test !hasbranching(dxdt_, ...)` assertions fail rather than pass. There is no Lux version on 1.10 where the suite passes as written (#46). Rework the neural-network portion of the test to exercise the same value flow the tool actually inspects (SciML/SciMLSensitivity.jl#997: a neural-network- shaped ODE right-hand side must trace branch-free so a tracing AD can compile a tape) using explicit affine transforms plus broadcast activations instead of a Lux layer. This drops the Lux test dependency entirely and removes the silent resolution to incompatible Lux versions noted in the issue. Also add `[compat]` bounds for the test dependencies (ComponentArrays, Random, SafeTestsets, Test) and bump the patch version. Co-authored-by: Chris Rackauckas <accounts@chrisrackauckas.com> Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
1 parent 257dfc2 commit b2fbe2b

2 files changed

Lines changed: 38 additions & 21 deletions

File tree

Project.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
name = "FunctionProperties"
22
uuid = "f62d2435-5019-4c03-9749-2d4c77af0cbc"
33
authors = ["SciML"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
88
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
99

1010
[compat]
1111
Cassette = "0.3.12"
12+
ComponentArrays = "0.15"
1213
DiffRules = "1.15"
14+
Random = "1.10"
15+
SafeTestsets = "0.1"
16+
Test = "1.10"
1317
julia = "1.10"
1418

1519
[extras]
1620
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
17-
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
1821
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1922
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
2023
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2124

2225
[targets]
23-
test = ["Test", "SafeTestsets", "ComponentArrays", "Lux", "Random"]
26+
test = ["Test", "SafeTestsets", "ComponentArrays", "Random"]

test/runtests.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,41 +35,55 @@ x = [1.0]
3535
@test !FunctionProperties.hasbranching(f, x)
3636

3737
# Neural networks
38-
using Lux, ComponentArrays, Random
38+
#
39+
# The relevant scenario is a neural-network-shaped ODE right-hand side (SciML/SciMLSensitivity.jl#997):
40+
# `hasbranching` must report it as branch-free so a tracing AD like ReverseDiff can compile a tape.
41+
# The forward pass is expressed here as explicit affine transforms plus broadcast activations, which
42+
# is the value flow `hasbranching` actually inspects. We deliberately do not trace a real Lux layer:
43+
# modern Lux layer dispatch routes through device-detection / type-introspection helpers that contain
44+
# genuine (but value-independent, compile-time) `GotoIfNot` branches, which this syntactic IR scan
45+
# cannot distinguish from value-dependent branches (SciML/FunctionProperties.jl#46).
46+
using ComponentArrays, Random
3947
rng = Random.default_rng()
40-
ann = Dense(1, 1, identity)
41-
ps, st = Lux.setup(rng, ann)
42-
p = ComponentArray(ps)
43-
x0 = [-4.0f0, 0.0f0]
48+
W = randn(rng, Float32, 1, 1)
49+
b = randn(rng, Float32, 1)
50+
p = ComponentArray(; weight = W, bias = b)
4451
t = [0.0]
4552

46-
function f(x, ps, st)
53+
function f(x, ps)
4754
return ps.weight * x
4855
end
49-
@test !FunctionProperties.hasbranching(f, t, p, st)
56+
@test !FunctionProperties.hasbranching(f, t, p)
5057

51-
function f(x, ps, st)
58+
function f(x, ps)
5259
return x .+ x
5360
end
54-
@test !FunctionProperties.hasbranching(f, t, p, st)
61+
@test !FunctionProperties.hasbranching(f, t, p)
5562

56-
function f2(x, ps, st)
57-
return Lux.apply_activation(identity, ps.weight * x .+ vec(ps.bias)), st
63+
# Affine transform followed by a broadcast activation (the original `apply_activation` intent).
64+
function f2(x, ps)
65+
return identity.(ps.weight * x .+ vec(ps.bias))
5866
end
59-
@test !FunctionProperties.hasbranching(f2, t, p, st)
60-
@test !FunctionProperties.hasbranching(ann, t, p, st)
67+
@test !FunctionProperties.hasbranching(f2, t, p)
6168

69+
# A multi-layer perceptron forward pass built from broadcast `tanh` activations.
6270
rng = Random.default_rng()
6371
tspan = (0.0f0, 8.0f0)
64-
ann = Chain(Dense(2, 32, tanh), Dense(32, 32, tanh), Dense(32, 1))
65-
ps, st = Lux.setup(rng, ann)
66-
p = ComponentArray(ps)
72+
W1 = randn(rng, Float32, 32, 2)
73+
b1 = randn(rng, Float32, 32)
74+
W2 = randn(rng, Float32, 32, 32)
75+
b2 = randn(rng, Float32, 32)
76+
W3 = randn(rng, Float32, 1, 32)
77+
b3 = randn(rng, Float32, 1)
78+
p = ComponentArray(; W1, b1, W2, b2, W3, b3)
6779
θ, ax = getdata(p), getaxes(p)
6880

81+
ann(x, p) = p.W3 * tanh.(p.W2 * tanh.(p.W1 * x .+ p.b1) .+ p.b2) .+ p.b3
82+
6983
function dxdt_(dx, x, p, t)
7084
x1, x2 = x
71-
dx[1] = x[2] + first(ann(x, p, st))[1]
72-
return dx[2] = first(ann([t, t], p, st))[1]
85+
dx[1] = x[2] + first(ann(x, p))
86+
return dx[2] = first(ann([t, t], p))
7387
end
7488
x0 = [-4.0f0, 0.0f0]
7589
ts = Float32.(collect(0.0:0.01:tspan[2]))

0 commit comments

Comments
 (0)