Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@ version = "0.1.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[weakdeps]
LinearMaps = "7a12625a-238d-50fd-b39a-03d52299707e"

[extensions]
NDInterpolationsLinearMapsExt = "LinearMaps"

[compat]
Adapt = "4.3.0"
Aqua = "0.8"
Atomix = "1.1.1"
DataInterpolations = "8"
EllipsisNotation = "1.8.0"
ForwardDiff = "0"
KernelAbstractions = "0.9.34"
LinearMaps = "3.11.4"
Random = "1"
RecipesBase = "1.3.4"
SafeTestsets = "0.1"
Expand Down
69 changes: 69 additions & 0 deletions ext/NDInterpolationsLinearMapsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module NDInterpolationsLinearMapsExt
using NDInterpolations
using NDInterpolations: validate_derivative_orders, get_output_size
using LinearMaps

# A linear map interp.u -> grid evaluation
function LinearMaps.LinearMap(
interp::NDInterpolation{N_in}, ::Val{:grid};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp)

T = NDInterpolations.output_type(interp)

grid_size = map(itp_dim -> length(itp_dim.t_eval), interp.interp_dims)
size_out = (grid_size..., get_output_size(interp)...)
N_input = length(interp.u)
N_output = prod(size_out)

function map!(out_flat, u_flat)
u_reshaped = reshape(u_flat, size(interp.u))
out_reshaped = reshape(out_flat, size_out)

interp_ = NDInterpolation(u_reshaped, interp.interp_dims, interp.cache)
eval_grid!(out_reshaped, interp_; derivative_orders)
return out_flat
end

function map_adjoint!(u_flat, out_flat)
u_reshaped = reshape(u_flat, size(interp.u))
out_reshaped = reshape(out_flat, size_out)

interp_ = NDInterpolation(u_reshaped, interp.interp_dims, interp.cache)
eval_grid!(out_reshaped, interp_; derivative_orders, adjoint = true)
return u_flat
end

return FunctionMap{T}(map!, map_adjoint!, N_output, N_input)
end

# A linear map interp.u -> unstructured evaluation
function LinearMaps.LinearMap(
interp::NDInterpolation{N_in}, ::Val{:unstructured};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp)

T = NDInterpolations.output_type(interp)

N_input = length(interp.u)
size_out = (length(first(interp.interp_dims).t_eval), get_output_size(interp)...)
N_output = prod(size_out)

function map!(out_flat, u_flat)
u_reshaped = reshape(u_flat, size(interp.u))
out_reshaped = reshape(out_flat, size_out)

interp_ = NDInterpolation(u_reshaped, interp.interp_dims, interp.cache)
eval_unstructured!(out_reshaped, interp_; derivative_orders)
return out_flat
end

# function map_adjoint!()
# end

return FunctionMap{T}(map!, N_output, N_input)
end

end # module NDInterpolationsLinearMapsExt
4 changes: 3 additions & 1 deletion src/NDInterpolations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module NDInterpolations
using KernelAbstractions # Keep as dependency or make extension?
import Atomix
using KernelAbstractions
using Adapt: @adapt_structure
using EllipsisNotation
using RecipesBase
Expand Down Expand Up @@ -56,6 +57,7 @@ include("interpolation_dimensions.jl")
include("spline_utils.jl")
include("interpolation_utils.jl")
include("interpolation_methods.jl")
include("interpolation_methods_adjoint.jl")
include("interpolation_parallel.jl")
include("plot_rec.jl")

Expand Down
39 changes: 39 additions & 0 deletions src/interpolation_methods_adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function _interpolate_adjoint!(
A::NDInterpolation{N_in, N_out, ID},
out,
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
)::Nothing where {N_in, N_out, ID <: LinearInterpolationDimension}
any(>(1), derivative_orders) && return out

tᵢ = ntuple(i -> A.interp_dims[i].t[idx[i]], N_in)
tᵢ₊₁ = ntuple(i -> A.interp_dims[i].t[idx[i] + 1], N_in)

# Size of the (hyper)rectangle `t` is in
t_vol = one(eltype(tᵢ))
for (t₁, t₂) in zip(tᵢ, tᵢ₊₁)
t_vol *= t₂ - t₁
end

# Loop over the corners of the (hyper)rectangle `t` is in
for I in Iterators.product(ntuple(i -> (false, true), N_in)...)
c = eltype(out)(inv(t_vol))
for (t_, right_point, d, t₁, t₂) in zip(t, I, derivative_orders, tᵢ, tᵢ₊₁)
c *= if right_point
iszero(d) ? t_ - t₁ : one(t_)
else
iszero(d) ? t₂ - t_ : -one(t_)
end
end
J = (ntuple(i -> idx[i] + I[i], N_in)..., ..)

if iszero(N_out)
Atomix.@atomic A.u[J...] += c * out
else
Atomix.@atomic A.u[J...] += c * out
end
end
return nothing
end
71 changes: 52 additions & 19 deletions src/interpolation_parallel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,34 +85,37 @@ in place.
function eval_grid!(
out::AbstractArray,
interp::NDInterpolation{N_in};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in),
adjoint::Bool = false
) where {N_in}
validate_derivative_orders(derivative_orders, interp; multi_point = true)
backend = get_backend(out)
@assert all(i -> size(out, i) == length(interp.interp_dims[i].t_eval), N_in) "For the first N_in dimensions of out the length must match the t_eval of the corresponding interpolation dimension."
@assert size(out)[(N_in + 1):end]==get_output_size(interp) "The size of the last N_out dimensions of out must be the same as the output size of the interpolation."
eval_kernel(backend)(
out,
interp,
derivative_orders,
true,
ndrange = size(out)[1:N_in]
)
if adjoint
eval_kernel_adjoint(backend)(
interp,
out,
derivative_orders,
true,
ndrange = size(out)[1:N_in]
)
else
eval_kernel(backend)(
out,
interp,
derivative_orders,
true,
ndrange = size(out)[1:N_in]
)
end
synchronize(backend)
return out
end

@kernel function eval_kernel(
out,
@Const(A),
derivative_orders,
eval_grid
)
N_in = length(A.interp_dims)
N_out = ndims(A.u) - N_in

k = @index(Global, NTuple)

function get_eval_params(
A::NDInterpolation{N_in, N_out}, eval_grid::Bool, k::NTuple{N_in, Int}
) where {N_in, N_out}
if eval_grid
t_eval = ntuple(i -> A.interp_dims[i].t_eval[k[i]], N_in)
idx_eval = ntuple(i -> A.interp_dims[i].idx_eval[k[i]], N_in)
Expand All @@ -121,6 +124,18 @@ end
idx_eval = ntuple(i -> A.interp_dims[i].idx_eval[only(k)], N_in)
end

return N_out, t_eval, idx_eval
end

@kernel function eval_kernel(
out,
@Const(A),
derivative_orders,
eval_grid
)
k = @index(Global, NTuple)
N_out, t_eval, idx_eval = get_eval_params(A, eval_grid, k)

if iszero(N_out)
out[k...] = _interpolate!(
make_out(A, t_eval), A, t_eval, idx_eval, derivative_orders, k)
Expand All @@ -130,3 +145,21 @@ end
A, t_eval, idx_eval, derivative_orders, k)
end
end

@kernel function eval_kernel_adjoint(
A,
@Const(out),
derivative_orders,
eval_grid
)
k = @index(Global, NTuple)
N_out, t_eval, idx_eval = get_eval_params(A, eval_grid, k)

if iszero(N_out)
_interpolate_adjoint!(
A, out[k...], t_eval, idx_eval, derivative_orders, k)
else
_interpolate_adjoint!(
A, view(out, k..., ..), t_eval, idx_eval, derivative_orders, k)
end
end
16 changes: 13 additions & 3 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,29 @@ function make_zero!!(v::T) where {T <: AbstractArray}
v
end

function output_type(interp::NDInterpolation)
promote_type(eltype(interp.u), map(itp_dim -> eltype(itp_dim.t), interp.interp_dims)...)
end

function output_type(
interp::NDInterpolation{N_in},
t::NTuple{N_in, >:Number}
) where {N_in}
promote_type(output_type(interp), map(typeof, t)...)
end

function make_out(
interp::NDInterpolation{N_in, 0},
t::NTuple{N_in, >:Number}
) where {N_in}
zero(promote_type(eltype(interp.u), map(typeof, t)...))
zero(output_type(interp, t))
end

function make_out(
interp::NDInterpolation{N_in},
t::NTuple{N_in, >:Number}
) where {N_in}
similar(
interp.u, promote_type(eltype(interp.u), map(eltype, t)...), get_output_size(interp))
similar(interp.u, ouput_type(interp, t))
end

get_left(::AbstractInterpolationDimension) = false
Expand Down
Loading