Skip to content

Commit 84a7a33

Browse files
Merge pull request #54 from ChrisRackauckas-Claude/static-improvements-20260112-064852
Improve type stability in BSpline basis function evaluation
2 parents e8704a6 + dcd01de commit 84a7a33

File tree

9 files changed

+55
-20
lines changed

9 files changed

+55
-20
lines changed

.github/workflows/Tests.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ jobs:
3535
- "ubuntu-latest"
3636
- "macos-latest"
3737
- "windows-latest"
38+
exclude:
39+
- group: QA
40+
version: "pre"
3841
uses: "SciML/.github/.github/workflows/tests.yml@v1"
3942
with:
4043
julia-version: "${{ matrix.version }}"

Project.toml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,9 @@ DataInterpolationsNDSymbolicsExt = "Symbolics"
1717

1818
[compat]
1919
Adapt = "4.3.0"
20-
Aqua = "0.8"
2120
DataInterpolations = "8"
2221
EllipsisNotation = "1.8.0"
23-
ExplicitImports = "1.14.0"
2422
ForwardDiff = "0"
25-
JET = "0.9, 0.10, 0.11.2"
2623
KernelAbstractions = "0.9.34"
2724
PrecompileTools = "1.0"
2825
Random = "1"
@@ -33,11 +30,8 @@ Test = "1"
3330
julia = "1"
3431

3532
[extras]
36-
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3733
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
38-
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
3934
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
40-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
4135
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
4236
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
4337
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
@@ -46,4 +40,4 @@ Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
4640
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4741

4842
[targets]
49-
test = ["Aqua", "DataInterpolations", "ExplicitImports", "ForwardDiff", "JET", "Pkg", "Random", "SafeTestsets", "Symbolics", "Test", "SymbolicUtils"]
43+
test = ["DataInterpolations", "ForwardDiff", "Pkg", "Random", "SafeTestsets", "Symbolics", "Test", "SymbolicUtils"]

src/spline_utils.jl

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,40 +63,54 @@ end
6363
# The trailing zero is just a convenience for the algorithm and is removed in the output
6464
function get_basis_function_values(
6565
itp_dim::BSplineInterpolationDimension,
66-
t::Number,
66+
t::T_t,
6767
idx::Integer,
6868
derivative_order::Integer,
6969
multi_point_index::Nothing,
7070
dim_in::Integer
71-
)
71+
) where {T_t <: Number}
7272
(; degree, knots_all) = itp_dim
73-
T = promote_type(typeof(t), eltype(itp_dim.basis_function_eval))
73+
T_eval = eltype(itp_dim.basis_function_eval)
74+
# Use a concrete zero value for type stability
75+
_zero = zero(promote_type(T_t, T_eval))
76+
_one = one(promote_type(T_t, T_eval))
7477
degree_plus_1 = degree + 1
7578

7679
if derivative_order > degree
77-
return ntuple(_ -> zero(T), degree_plus_1)
80+
return ntuple(_ -> _zero, degree_plus_1)
7881
end
7982

8083
degree_plus_2 = degree + 2
8184

8285
# Degree 0 basis function values
8386
basis_function_values = ntuple(
84-
k -> (k == degree_plus_1) ? one(T) : zero(T),
87+
k -> (k == degree_plus_1) ? _one : _zero,
8588
degree_plus_2
8689
)
8790

8891
# Higher order basis function values
92+
# Use a helper function to avoid capturing mutable variables in closures
93+
basis_function_values = _compute_basis_function_values(
94+
basis_function_values, knots_all, t, idx, degree, derivative_order, degree_plus_2
95+
)
96+
97+
return basis_function_values[1:degree_plus_1]
98+
end
99+
100+
# Helper function to avoid type instability from captured variables in closures
101+
function _compute_basis_function_values(
102+
basis_function_values, knots_all, t, idx, degree, derivative_order, degree_plus_2
103+
)
89104
for d in 1:degree
90105
deriv = d > degree - derivative_order
106+
# Create a local copy for the closure to capture
107+
bfv = basis_function_values
91108
basis_function_values = ntuple(
92-
k -> cox_de_boor(
93-
basis_function_values, knots_all, t, idx, degree, d, k, deriv
94-
),
109+
k -> cox_de_boor(bfv, knots_all, t, idx, degree, d, k, deriv),
95110
degree_plus_2
96111
)
97112
end
98-
99-
return basis_function_values[1:degree_plus_1]
113+
return basis_function_values
100114
end
101115

102116
# Get the basis function values for one point in an

test/qa/Project.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[deps]
2+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
DataInterpolationsND = "4f1ef021-621a-47a5-ac8a-95402a2d1ea8"
4+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
5+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
6+
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
7+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8+
9+
[compat]
10+
Aqua = "0.8"
11+
DataInterpolationsND = "0.1"
12+
ExplicitImports = "1.14.0"
13+
JET = "0.9, 0.10, 0.11.2"
14+
SafeTestsets = "0.1"
File renamed without changes.
File renamed without changes.

test/qa/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using SafeTestsets
2+
3+
@safetestset "Aqua" include("aqua.jl")
4+
@safetestset "ExplicitImports" include("explicit_imports.jl")
5+
@safetestset "JET" include("jet.jl")

test/runtests.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@ function activate_gpu_env()
77
return Pkg.instantiate()
88
end
99

10+
function activate_qa_env()
11+
Pkg.activate("qa")
12+
Pkg.develop(PackageSpec(path = dirname(@__DIR__)))
13+
return Pkg.instantiate()
14+
end
15+
1016
if GROUP == "All" || GROUP == "Core"
1117
@safetestset "Interpolations" include("test_interpolations.jl")
1218
@safetestset "Derivatives" include("test_derivatives.jl")
@@ -15,9 +21,8 @@ if GROUP == "All" || GROUP == "Core"
1521
elseif GROUP == "Extensions"
1622
@safetestset "Symbolics Extension" include("test_symbolics_ext.jl")
1723
elseif GROUP == "QA"
18-
@safetestset "Aqua" include("aqua.jl")
19-
@safetestset "ExplicitImports" include("explicit_imports.jl")
20-
@safetestset "JET" include("jet.jl")
24+
activate_qa_env()
25+
include("qa/runtests.jl")
2126
elseif GROUP == "GPU"
2227
activate_gpu_env()
2328
# TODO: Add GPU tests

0 commit comments

Comments
 (0)