diff --git a/Project.toml b/Project.toml index 1e376220..0653ebd9 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ DifferentiationInterface = "0.6, 0.7" ExplicitImports = "1" ForwardDiff = "0.10" IncompleteLU = "0.2" -JET = "0.11.2" +JET = "0.9, 0.10, 0.11.2" Libdl = "1" LinearAlgebra = "1" LinearSolve = "3.40.0" diff --git a/src/common_interface/function_types.jl b/src/common_interface/function_types.jl index e3b86607..40b38e77 100644 --- a/src/common_interface/function_types.jl +++ b/src/common_interface/function_types.jl @@ -14,12 +14,17 @@ mutable struct FunJac{ u::Array{Float64, N} du::Array{Float64, N} resid::TResid + # Cached pointers for allocation-free array wrapping + cached_u_ptr::Ptr{Float64} + cached_du_ptr::Ptr{Float64} + cached_resid_ptr::Ptr{Float64} end function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du) return FunJac( fun, nothing, jac, p, m, jac_prototype, prec, - psetup, u, du, nothing + psetup, u, du, nothing, + Ptr{Float64}(0), Ptr{Float64}(0), Ptr{Float64}(0) ) end function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du, resid) @@ -28,31 +33,50 @@ function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du, resid) jac, p, m, jac_prototype, prec, psetup, u, - du, resid + du, resid, + Ptr{Float64}(0), Ptr{Float64}(0), Ptr{Float64}(0) + ) +end +function FunJac(fun, fun2, jac, p, m, jac_prototype, prec, psetup, u, du, resid) + return FunJac( + fun, fun2, + jac, p, m, + jac_prototype, + prec, psetup, u, + du, resid, + Ptr{Float64}(0), Ptr{Float64}(0), Ptr{Float64}(0) ) end function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N} - funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) - funjac.du = unsafe_wrap( - Array{Float64, N}, N_VGetArrayPointer_Serial(du), - size(funjac.du) - ) - _du = funjac.du - _u = funjac.u - funjac.fun(_du, _u, funjac.p, t) + u_ptr = N_VGetArrayPointer_Serial(u) + du_ptr = N_VGetArrayPointer_Serial(du) + # Only create new wrapper if pointer changed (avoids allocation on cache hit) + if u_ptr != funjac.cached_u_ptr + funjac.cached_u_ptr = u_ptr + funjac.u = unsafe_wrap(Array{Float64, N}, u_ptr, size(funjac.u)) + end + if du_ptr != funjac.cached_du_ptr + funjac.cached_du_ptr = du_ptr + funjac.du = unsafe_wrap(Array{Float64, N}, du_ptr, size(funjac.du)) + end + funjac.fun(funjac.du, funjac.u, funjac.p, t) return CV_SUCCESS end function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N} - funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) - funjac.du = unsafe_wrap( - Array{Float64, N}, N_VGetArrayPointer_Serial(du), - size(funjac.du) - ) - _du = funjac.du - _u = funjac.u - funjac.fun2(_du, _u, funjac.p, t) + u_ptr = N_VGetArrayPointer_Serial(u) + du_ptr = N_VGetArrayPointer_Serial(du) + # Only create new wrapper if pointer changed (avoids allocation on cache hit) + if u_ptr != funjac.cached_u_ptr + funjac.cached_u_ptr = u_ptr + funjac.u = unsafe_wrap(Array{Float64, N}, u_ptr, size(funjac.u)) + end + if du_ptr != funjac.cached_du_ptr + funjac.cached_du_ptr = du_ptr + funjac.du = unsafe_wrap(Array{Float64, N}, du_ptr, size(funjac.du)) + end + funjac.fun2(funjac.du, funjac.u, funjac.p, t) return CV_SUCCESS end @@ -98,19 +122,23 @@ function idasolfun( t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector, funjac::FunJac{N} ) where {N} - funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) - _u = funjac.u - funjac.du = unsafe_wrap( - Array{Float64, N}, N_VGetArrayPointer_Serial(du), - size(funjac.du) - ) - _du = funjac.du - funjac.resid = unsafe_wrap( - Array{Float64, N}, N_VGetArrayPointer_Serial(resid), - size(funjac.resid) - ) - _resid = funjac.resid - funjac.fun(_resid, _du, _u, funjac.p, t) + u_ptr = N_VGetArrayPointer_Serial(u) + du_ptr = N_VGetArrayPointer_Serial(du) + resid_ptr = N_VGetArrayPointer_Serial(resid) + # Only create new wrapper if pointer changed (avoids allocation on cache hit) + if u_ptr != funjac.cached_u_ptr + funjac.cached_u_ptr = u_ptr + funjac.u = unsafe_wrap(Array{Float64, N}, u_ptr, size(funjac.u)) + end + if du_ptr != funjac.cached_du_ptr + funjac.cached_du_ptr = du_ptr + funjac.du = unsafe_wrap(Array{Float64, N}, du_ptr, size(funjac.du)) + end + if resid_ptr != funjac.cached_resid_ptr + funjac.cached_resid_ptr = resid_ptr + funjac.resid = unsafe_wrap(Array{Float64, N}, resid_ptr, size(funjac.resid)) + end + funjac.fun(funjac.resid, funjac.du, funjac.u, funjac.p, t) return IDA_SUCCESS end diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index 3610bd86..03c90378 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -149,19 +149,9 @@ mutable struct ARKODEIntegrator{ ctx_handle::ContextHandle end -function ( - integrator::ARKODEIntegrator{ - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA, - } - )( - t::Number, - deriv::Type{Val{T}} = Val{0}; - idxs = nothing - ) where { - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T, - } +# runic: off +# Callable struct syntax - Runic formatting breaks these definitions +function (integrator::ARKODEIntegrator{N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA})(t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T} out = similar(integrator.u) out_nvec = NVector(vec(out), integrator.ctx_handle.ctx) integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out_nvec) @@ -169,19 +159,7 @@ function ( return idxs === nothing ? out : out[idxs] end -function ( - integrator::ARKODEIntegrator{ - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA, - } - )( - t::Number, - deriv::Type{Val{T}} = Val{0}; - idxs = nothing - ) where { - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T, - } +function (integrator::ARKODEIntegrator{N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA})(t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T} out = similar(integrator.u) out_nvec = NVector(vec(out), integrator.ctx_handle.ctx) integrator.flag = @checkflag ERKStepGetDky(integrator.mem, t, Cint(T), out_nvec) @@ -189,45 +167,20 @@ function ( return idxs === nothing ? out : out[idxs] end -function ( - integrator::ARKODEIntegrator{ - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA, - } - )( - out, - t::Number, - deriv::Type{Val{T}} = Val{0}; - idxs = nothing - ) where { - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T, - } +function (integrator::ARKODEIntegrator{N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, ARKStepMem, IA})(out, t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T} out_nvec = NVector(vec(out), integrator.ctx_handle.ctx) integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out_nvec) copyto!(out, out_nvec.v) return idxs === nothing ? out : @view out[idxs] end -function ( - integrator::ARKODEIntegrator{ - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA, - } - )( - out, - t::Number, - deriv::Type{Val{T}} = Val{0}; - idxs = nothing - ) where { - N, pType, solType, algType, fType, UFType, JType, oType, - LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T, - } +function (integrator::ARKODEIntegrator{N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, ERKStepMem, IA})(out, t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {N, pType, solType, algType, fType, UFType, JType, oType, LStype, Atype, MLStype, Mtype, CallbackCacheType, IA, T} out_nvec = NVector(vec(out), integrator.ctx_handle.ctx) integrator.flag = @checkflag ERKStepGetDky(integrator.mem, t, Cint(T), out_nvec) copyto!(out, out_nvec.v) return idxs === nothing ? out : @view out[idxs] end +# runic: on mutable struct IDAIntegrator{ N,