Skip to content

Add NURBS #24

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@ The interpolation types are given by the corresponding interpolation dimension t

- `LinearInterpolationDimension(t)`: Linear interpolation in the sense of bilinear, trilinear interpolation etc.
- `ConstantInterpolationDimension(t)`: An interpolation with a constant value in each interval between `t` points. The Boolean option `left` (default `true`) can be used to indicate which side of the interval in which the input lies determines the output value.
- `BSplineInterpolationDimension(t, degree)`: Interpolation using BSpline basis functions. The input values `t` are interpreted as knots, and optionally knot multiplicities can be supplied. Per dimension a degree can be specified. Note that for an `NDInterpolation` of this type, the size of `u` for a certain dimension is equal to `sum(multiplicities) - degree - 1`.
- `BSplineInterpolationDimension(t, degree)`: Interpolation using BSpline basis functions. The input values `t` are interpreted as knots, and optionally knot multiplicities can be supplied. Per dimension a degree can be specified. Note that for an `NDInterpolation` of this type, the size of `u` for a certain dimension is equal to `sum(multiplicities) - degree - 1`. This interpolation dimension type can also be used to define NURBS, by passing `global_cache = NURBSWeights(weights)` to the `NDInterpolation` constructor.
1 change: 1 addition & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ NDInterpolation
LinearInterpolationDimension
ConstantInterpolationDimension
BSplineInterpolationDimension
NURBSWeights
```

## Multi-point evaluation
Expand Down
2 changes: 1 addition & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,4 @@ The interpolation types are given by the corresponding interpolation dimension t

- `LinearInterpolationDimension(t)`: Linear interpolation in the sense of bilinear, trilinear interpolation etc.
- `ConstantInterpolationDimension(t)`: An interpolation with a constant value in each interval between `t` points. The Boolean option `left` (default `true`) can be used to indicate which side of the interval in which the input lies determines the output value.
- `BSplineInterpolationDimension(t, degree)`: Interpolation using BSpline basis functions. The input values `t` are interpreted as knots, and optionally knot multiplicities can be supplied. Per dimension a degree can be specified. Note that for an `NDInterpolation` of this type, the size of `u` for a certain dimension is equal to `sum(multiplicities) - degree - 1`.
- `BSplineInterpolationDimension(t, degree)`: Interpolation using BSpline basis functions. The input values `t` are interpreted as knots, and optionally knot multiplicities can be supplied. Per dimension a degree can be specified. Note that for an `NDInterpolation` of this type, the size of `u` for a certain dimension is equal to `sum(multiplicities) - degree - 1`. This interpolation dimension type can also be used to define NURBS, by passing `global_cache = NURBSWeights(weights)` to the `NDInterpolation` constructor.
9 changes: 9 additions & 0 deletions docs/src/interpolation_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,12 @@ interp = NDInterpolation(u_bspline, interp_dims)
eval_grid!(out, interp)
heatmap(out)
```

## NURBS Interpolation

```@example tutorial
weights = rand(11, 11)
interp = NDInterpolation(u_bspline, interp_dims; global_cache = NURBSWeights(weights))
eval_grid!(out, interp)
heatmap(out)
```
28 changes: 22 additions & 6 deletions src/NDInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ using EllipsisNotation
using RecipesBase

abstract type AbstractInterpolationDimension end
abstract type AbstractGlobalCache end

struct TrivialGlobalCache <: AbstractGlobalCache end

"""
NDInterpolation(interp_dims, u)
Expand All @@ -19,26 +22,39 @@ the size of `u` along that dimension must match the length of `t` of the corresp
- `u`: The array to be interpolated.
"""
struct NDInterpolation{
N_in, N_out, ID <: AbstractInterpolationDimension, uType <: AbstractArray}
N_in, N_out,
ID <: AbstractInterpolationDimension,
gType <: AbstractGlobalCache,
uType <: AbstractArray
}
u::uType
interp_dims::NTuple{N_in, ID}
function NDInterpolation(u, interp_dims)
global_cache::gType
function NDInterpolation(u, interp_dims, global_cache)
if interp_dims isa AbstractInterpolationDimension
interp_dims = (interp_dims,)
end
N_in = length(interp_dims)
N_out = ndims(u) - N_in
@assert N_out≥0 "The number of dimensions of u must be at least the number of interpolation dimensions."
validate_size_u(interp_dims, u)
new{N_in, N_out, eltype(interp_dims), typeof(u)}(u, interp_dims)
validate_global_cache(global_cache, interp_dims, u)
new{N_in, N_out, eltype(interp_dims), typeof(global_cache), typeof(u)}(
u, interp_dims, global_cache
)
end
end

# Constructor with optional global cache
function NDInterpolation(u, interp_dims; global_cache = TrivialGlobalCache())
NDInterpolation(u, interp_dims, global_cache)
end

@adapt_structure NDInterpolation

include("interpolation_dimensions.jl")
include("interpolation_utils.jl")
include("spline_utils.jl")
include("interpolation_utils.jl")
include("interpolation_methods.jl")
include("interpolation_parallel.jl")
include("plot_rec.jl")
Expand All @@ -59,7 +75,7 @@ function (interp::NDInterpolation{N_in})(
t::Tuple{Vararg{Number, N_in}};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp.interp_dims)
validate_derivative_orders(derivative_orders, interp)
idx = get_idx(interp.interp_dims, t)
@assert size(out)==size(interp.u)[(N_in + 1):end] "The size of out must match the size of the last N_out dimensions of u."
_interpolate!(out, interp, t, idx, derivative_orders, nothing)
Expand All @@ -72,7 +88,7 @@ function (interp::NDInterpolation)(t::Tuple{Vararg{Number}}; kwargs...)
end

export NDInterpolation, LinearInterpolationDimension, ConstantInterpolationDimension,
BSplineInterpolationDimension,
BSplineInterpolationDimension, NURBSWeights,
eval_unstructured, eval_unstructured!, eval_grid, eval_grid!

end # module NDInterpolations
64 changes: 47 additions & 17 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ function _interpolate!(
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: LinearInterpolationDimension}
if iszero(N_out)
out = zero(out)
else
out .= 0
end
out = make_zero(out)
any(>(1), derivative_orders) && return out

tᵢ = ntuple(i -> A.interp_dims[i].t[idx[i]], N_in)
Expand Down Expand Up @@ -67,6 +63,7 @@ function _interpolate!(
return out
end

# BSpline evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
Expand All @@ -77,19 +74,10 @@ function _interpolate!(
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims) = A

if iszero(N_out)
out = zero(out)
else
out .= 0
end

out = make_zero(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)

basis_function_vals = ntuple(
dim_in -> get_basis_function_values(
interp_dims[dim_in], t[dim_in], idx[dim_in], derivative_orders[dim_in], multi_point_index, dim_in
),
N_in
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
)

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
Expand All @@ -105,3 +93,45 @@ function _interpolate!(

return out
end

# NURBS evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID, <:NURBSWeights},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims, global_cache) = A

out = make_zero(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
)

denom = zero(eltype(t))

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
B_product = prod(dim_in -> basis_function_vals[dim_in][I[dim_in]], 1:N_in)
cp_index = ntuple(
dim_in -> idx[dim_in] + I[dim_in] - degrees[dim_in] - 1, N_in)
weight = global_cache.weights[cp_index...]
product = weight * B_product
denom += product
if iszero(N_out)
out += product * A.u[cp_index...]
else
out .+= product * view(A.u, cp_index..., ..)
end
end

if iszero(N_out)
out /= denom
else
out ./= denom
end

return out
end
4 changes: 2 additions & 2 deletions src/interpolation_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function eval_unstructured!(
interp::NDInterpolation{N_in};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp.interp_dims; multi_point = true)
validate_derivative_orders(derivative_orders, interp; multi_point = true)
backend = get_backend(out)
@assert all(i -> length(interp.interp_dims[i].t_eval) == size(out, 1), N_in) "The t_eval of all interpolation dimensions must have the same length as the first dimension of out."
@assert size(out)[2:end]==get_output_size(interp) "The size of the last N_out dimensions of out must be the same as the output size of the interpolation."
Expand Down Expand Up @@ -87,7 +87,7 @@ function eval_grid!(
interp::NDInterpolation{N_in};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp.interp_dims; multi_point = true)
validate_derivative_orders(derivative_orders, interp; multi_point = true)
backend = get_backend(out)
@assert all(i -> size(out, i) == length(interp.interp_dims[i].t_eval), N_in) "For the first N_in dimensions of out the length must match the t_eval of the corresponding interpolation dimension."
@assert size(out)[(N_in + 1):end]==get_output_size(interp) "The size of the last N_out dimensions of out must be the same as the output size of the interpolation."
Expand Down
41 changes: 36 additions & 5 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,31 @@ Base.length(itp_dim::AbstractInterpolationDimension) = length(itp_dim.t)

function validate_derivative_orders(
derivative_orders::NTuple{N_in, <:Integer},
::NTuple{N_in, <:AbstractInterpolationDimension};
multi_point::Bool = false
::NDInterpolation{N_in};
kwargs...
) where {N_in}
@assert all(≥(0), derivative_orders) "Derivative orders must me non-negative."
end

function validate_derivative_orders(
derivative_orders::NTuple{N_in, <:Integer},
interp_dims::NTuple{N_in, <:BSplineInterpolationDimension};
A::NDInterpolation{N_in, N_out, <:BSplineInterpolationDimension};
multi_point::Bool = false
) where {N_in}
) where {N_in, N_out}
@assert all(≥(0), derivative_orders) "Derivative orders must me non-negative."

if multi_point
@assert all(
i -> derivative_orders[i] ≤ interp_dims[i].max_derivative_order_eval, 1:N_in
i -> derivative_orders[i] ≤ A.interp_dims[i].max_derivative_order_eval, 1:N_in
) "For BSpline interpolation, when using multi-point evaluation the derivative orders cannot be \
larger than the `max_derivative_order_eval` eval of of the `BSplineInterpolationDimension`. If you want \
to compute higher order multi-point derivatives, pass a larger `max_derivative_order_eval` to the \
`BSplineInterpolationDimension` constructor(s)."
end

if A.global_cache isa NURBSWeights
@assert all(==(0), derivative_orders) "Currently partial derivatives of NURBS are not supported."
end
end

function validate_t(t)
Expand All @@ -47,6 +51,26 @@ function validate_size_u(
@assert expected_size==size(u)[1:N_in] "Expected the size of the first N_in dimensions of u to be $expected_size based on the BSplineInterpolation properties."
end

function validate_global_cache(
::TrivialGlobalCache, ::NTuple{N_in, ID}, ::AbstractArray
) where {N_in, ID}
nothing
end

function validate_global_cache(
nurbs_weights::NURBSWeights,
::NTuple{N_in, BSplineInterpolationDimension},
u::AbstractArray
) where {N_in}
size_expected = size(u)[1:N_in]
@assert size(nurbs_weights.weights)==size_expected "The size of the weights array must match the length of the first N_in dimensions of u ($size_expected)."
end

function validate_global_cache(
::gType, ::NTuple{N_in, ID}, ::AbstractArray) where {gType, N_in, ID}
@error("Interpolation dimension type $ID is not compatible with global cache type $gType.")
end

function get_ts(interp_dims::NTuple{
N_in, AbstractInterpolationDimension}) where {N_in}
ntuple(i -> interp_dims[i].t, N_in)
Expand All @@ -56,6 +80,13 @@ function get_output_size(interp::NDInterpolation{N_in}) where {N_in}
size(interp.u)[(N_in + 1):end]
end

make_zero(::T) where {T <: Number} = zero(T)

function make_zero(v::T) where {T <: AbstractArray}
v .= 0
v
end

function make_out(
interp::NDInterpolation{N_in, 0},
t::NTuple{N_in, >:Number}
Expand Down
29 changes: 27 additions & 2 deletions src/spline_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,22 @@ function get_basis_function_values(
multi_point_index[dim_in], :, derivative_order + 1)
end

# Get all basis function values to evaluate a BSpline interpolation in t
function get_basis_function_values_all(
A::NDInterpolation{N_in, N_out, <:BSplineInterpolationDimension},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out}
ntuple(
dim_in -> get_basis_function_values(
A.interp_dims[dim_in], t[dim_in], idx[dim_in], derivative_orders[dim_in], multi_point_index, dim_in
),
N_in
)
end

function set_basis_function_eval!(itp_dim::BSplineInterpolationDimension)::Nothing
backend = get_backend(itp_dim.t_eval)
basis_function_eval_kernel(backend)(
Expand All @@ -137,8 +153,8 @@ end
i, derivative_order_plus_1 = @index(Global, NTuple)

itp_dim.basis_function_eval[i,
:,
derivative_order_plus_1] .= get_basis_function_values(
:,
derivative_order_plus_1] .= get_basis_function_values(
itp_dim,
itp_dim.t_eval[i],
itp_dim.idx_eval[i],
Expand Down Expand Up @@ -181,3 +197,12 @@ end
function get_n_basis_functions(itp_dim::BSplineInterpolationDimension)
length(itp_dim.knots_all) - itp_dim.degree - 1
end

"""
NURBSWeights(weights::AbstractArray)

Weights associated with the control points to define a NURBS geometry.
"""
struct NURBSWeights{W <: AbstractArray} <: AbstractGlobalCache
weights::W
end
Loading
Loading