diff --git a/stan/math/prim/err/elementwise_check.hpp b/stan/math/prim/err/elementwise_check.hpp index de2e192d41d..4c4386c1f06 100644 --- a/stan/math/prim/err/elementwise_check.hpp +++ b/stan/math/prim/err/elementwise_check.hpp @@ -124,6 +124,13 @@ inline void elementwise_check(const F& is_good, const char* function, }(); } } +template * = nullptr> +inline void elementwise_check(const F& is_good, const char* function, + const char* name, const T& x, const char* must_be, + const Indexings&... indexings) { + // XXX skip closures +} /** * Check that the predicate holds for all elements of the value of `x`. This * overload works on Eigen types that support linear indexing. diff --git a/stan/math/prim/fun/value_of.hpp b/stan/math/prim/fun/value_of.hpp index 7cc37ac9b2c..f575d032fcd 100644 --- a/stan/math/prim/fun/value_of.hpp +++ b/stan/math/prim/fun/value_of.hpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include @@ -77,6 +79,23 @@ inline auto value_of(EigMat&& M) { std::forward(M)); } +/** + * Closures that capture non-arithmetic types have value_of__() method. + * + * @tparam F Input element type + * @param[in] f Input closure + * @return closure + **/ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline auto value_of(const F& f) { + return apply( + [&f](const auto&... s) { + return typename F::partials_closure_t_(f.f_, eval(value_of(s))...); + }, + f.captures_); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor.hpp b/stan/math/prim/functor.hpp index 0ec5c343ff7..01bb580684a 100644 --- a/stan/math/prim/functor.hpp +++ b/stan/math/prim/functor.hpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/functor/closure_adapter.hpp b/stan/math/prim/functor/closure_adapter.hpp new file mode 100644 index 00000000000..a708d27c16c --- /dev/null +++ b/stan/math/prim/functor/closure_adapter.hpp @@ -0,0 +1,220 @@ +#ifndef STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP +#define STAN_MATH_PRIM_FUNCTOR_CLOSURE_ADAPTER_HPP + +#include +#include +#include + +namespace stan { +namespace math { +namespace internal { + +/** + * A closure that wraps a C++ lambda and captures values. + * + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values + */ +template +struct base_closure { + using return_scalar_t_ = return_type_t; + /*The base closure with `Ts` as the non-expression partials of `Ts`*/ + using partials_closure_t_ + = base_closure())))...>; + using Base_ = base_closure; + std::decay_t f_; + std::tuple...> captures_; + template * = nullptr, typename... Args> + explicit base_closure(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} + + template + auto operator()(std::ostream* msgs, const Args&... args) const { + return apply( + [this, msgs, &args...](const auto&... s) { + return this->f_(s..., args..., msgs); + }, + captures_); + } +}; + +/** + * A closure that takes rng argument. + * + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values + */ +template +struct closure_rng { + using return_scalar_t_ = double; + using partials_closure_t_ = closure_rng; + using Base_ = closure_rng; + std::decay_t f_; + std::tuple...> captures_; + + template * = nullptr, typename... Args> + explicit closure_rng(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} + + template + auto operator()(Rng& rng, std::ostream* msgs, const Args&... args) const { + return apply( + [this, &rng, msgs, &args...](const auto&... s) { + return this->f_(s..., args..., rng, msgs); + }, + captures_); + } +}; + +/** + * A closure that may compute an unnormalized propability density. + * + * @tparam Propto if true the function is unnormalized + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values + */ +template +struct closure_lpdf { + using return_scalar_t_ = return_type_t; + using partials_closure_t_ = closure_lpdf; + using Base_ = closure_lpdf; + std::decay_t f_; + std::tuple...> captures_; + + template * = nullptr, typename... Args> + explicit closure_lpdf(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} + + template + auto with_propto() { + return apply( + [this](const auto&... args) { + return closure_lpdf < Propto && propto, true, F, + Ts... > (this->f_, args...); + }, + captures_); + } + + template + auto operator()(std::ostream* msgs, const Args&... args) const { + return apply( + [this, msgs, &args...](const auto&... s) { + return this->f_.template operator()(s..., args..., msgs); + }, + captures_); + } +}; + +/** + * A closure that accesses logprob accumulator. + * + * @tparam Propto if true the logprob is unnormalized + * @tparam Ref if true values are captured by reference + * @tparam F the lambda functor type + * @tparam Ts types of the captured values + */ +template +struct closure_lp { + using return_scalar_t_ = return_type_t; + using partials_closure_t_ = closure_lp; + using Base_ = closure_lp; + std::decay_t f_; + std::tuple...> captures_; + + template * = nullptr, typename... Args> + explicit closure_lp(FF&& f, Args&&... args) + : f_(std::forward(f)), captures_(std::forward(args)...) {} + + template + auto operator()(T_lp& lp, T_lp_accum& lp_accum, std::ostream* msgs, + const Args&... args) const { + return apply( + [this, &lp, &lp_accum, msgs, &args...](const auto&... s) { + return this->f_.template operator()(s..., args..., lp, + lp_accum, msgs); + }, + captures_); + } +}; + +} // namespace internal + +/** + * Higher-order functor suitable for calling a closure inside variadic ODE + * solvers. + */ +struct ode_closure_adapter { + template + auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, + Args&&... args) const { + return std::forward(f)(msgs, t, y, std::forward(args)...); + } +}; + +struct integrate_ode_closure_adapter { + template + auto operator()(const T0& t, const T1& y, std::ostream* msgs, F&& f, + Args&&... args) const { + return to_vector(std::forward(f)(msgs, t, to_array_1d(y), + std::forward(args)...)); + } +}; + +/** + * Create a closure from a C++ lambda and captures. + */ +template +auto from_lambda(F&& f, Args&&... args) { + return internal::base_closure(std::forward(f), + std::forward(args)...); +} + +/** + * Create a closure from an rng functor. + */ +template +auto rng_from_lambda(F&& f, Args&&... args) { + return internal::closure_rng(std::forward(f), + std::forward(args)...); +} + +/** + * Create a closure from an lpdf functor. + */ +template +auto lpdf_from_lambda(F&& f, Args&&... args) { + return internal::closure_lpdf( + std::forward(f), std::forward(args)...); +} + +/** + * Create a closure from a functor that needs access to logprob accumulator. + */ +template +auto lp_from_lambda(F&& f, Args&&... args) { + return internal::closure_lp( + std::forward(f), std::forward(args)...); +} + +/** + * Higher-order functor that invokes a closure inside a reduce_sum call. + */ +struct reduce_sum_closure_adapter { + template + auto operator()(const std::vector& sub_slice, std::size_t start, + std::size_t end, std::ostream* msgs, F&& f, + Args&&... args) const { + return std::forward(f)(msgs, sub_slice, start + error_index::value, + end + error_index::value, + std::forward(args)...); + } +}; + +} // namespace math +} // namespace stan + +#endif diff --git a/stan/math/prim/functor/integrate_1d.hpp b/stan/math/prim/functor/integrate_1d.hpp index e4494c81ed8..78288a89f17 100644 --- a/stan/math/prim/functor/integrate_1d.hpp +++ b/stan/math/prim/functor/integrate_1d.hpp @@ -236,7 +236,7 @@ inline double integrate_1d_impl(const F& f, double a, double b, * @param relative_tolerance tolerance passed to Boost quadrature * @return numeric integral of function f */ -template +template * = nullptr> inline double integrate_1d(const F& f, double a, double b, const std::vector& theta, const std::vector& x_r, @@ -247,6 +247,18 @@ inline double integrate_1d(const F& f, double a, double b, msgs, theta, x_r, x_i); } +template * = nullptr, + require_arithmetic_t>* = nullptr> +inline double integrate_1d(const F& f, double a, double b, + const std::vector& theta, + const std::vector& x_r, + const std::vector& x_i, std::ostream* msgs, + const double relative_tolerance + = std::sqrt(EPSILON)) { + return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, + relative_tolerance, msgs, f, theta, x_r, x_i); +} + } // namespace math } // namespace stan diff --git a/stan/math/prim/functor/integrate_1d_adapter.hpp b/stan/math/prim/functor/integrate_1d_adapter.hpp index ecfcaaaa9a6..1bf065f809f 100644 --- a/stan/math/prim/functor/integrate_1d_adapter.hpp +++ b/stan/math/prim/functor/integrate_1d_adapter.hpp @@ -25,4 +25,19 @@ struct integrate_1d_adapter { } }; +/** + * Call a closure object from integrate_1d + */ +struct integrate_1d_closure_adapter { + integrate_1d_closure_adapter() {} + + template + auto operator()(const T_a& x, const T_b& xc, std::ostream* msgs, const F& f, + const std::vector& theta, + const std::vector& x_r, + const std::vector& x_i) const { + return f(msgs, x, xc, theta, x_r, x_i); + } +}; + #endif diff --git a/stan/math/prim/functor/integrate_ode_rk45.hpp b/stan/math/prim/functor/integrate_ode_rk45.hpp index 1e568d6b1ea..47fac5b4377 100644 --- a/stan/math/prim/functor/integrate_ode_rk45.hpp +++ b/stan/math/prim/functor/integrate_ode_rk45.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_RK45_HPP #include +#include #include #include #include @@ -10,6 +11,38 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +inline auto integrate_ode_rk45_impl( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs, double relative_tolerance, double absolute_tolerance, + int max_num_steps) { + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, theta, x, x_int); +} + +template * = nullptr> +inline auto integrate_ode_rk45_impl( + const F& f, const std::vector& y0, const T_t0& t0, + const std::vector& ts, const std::vector& theta, + const std::vector& x, const std::vector& x_int, + std::ostream* msgs, double relative_tolerance, double absolute_tolerance, + int max_num_steps) { + return ode_rk45_tol_impl("integrate_ode_rk45", + integrate_ode_closure_adapter(), to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, theta, x, x_int); +} + +} // namespace internal + /** * @deprecated use ode_rk45 */ @@ -21,12 +54,11 @@ inline auto integrate_ode_rk45( const std::vector& x, const std::vector& x_int, std::ostream* msgs = nullptr, double relative_tolerance = 1e-6, double absolute_tolerance = 1e-6, int max_num_steps = 1e6) { - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_rk45_tol_impl("integrate_ode_rk45", f_adapted, to_vector(y0), t0, - ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + auto y = internal::integrate_ode_rk45_impl(f, y0, t0, ts, theta, x, x_int, + msgs, relative_tolerance, + absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; y_converted.reserve(y.size()); for (size_t i = 0; i < y.size(); ++i) diff --git a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp index 9b0178e27b5..b61d9a9bfa6 100644 --- a/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp +++ b/stan/math/prim/functor/integrate_ode_std_vector_interface_adapter.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_STD_VECTOR_INTERFACE_ADAPTER_HPP #define STAN_MATH_PRIM_FUNCTOR_INTEGRATE_ODE_STD_VECTOR_INTERFACE_ADAPTER_HPP +#include #include #include #include @@ -21,8 +22,7 @@ namespace internal { */ template struct integrate_ode_std_vector_interface_adapter { - const F f_; - + const F& f_; explicit integrate_ode_std_vector_interface_adapter(const F& f) : f_(f) {} template diff --git a/stan/math/prim/functor/ode_ckrk.hpp b/stan/math/prim/functor/ode_ckrk.hpp index 3f8fd0a0972..3693aecefa2 100644 --- a/stan/math/prim/functor/ode_ckrk.hpp +++ b/stan/math/prim/functor/ode_ckrk.hpp @@ -52,7 +52,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... Args, require_eigen_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_ckrk_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, @@ -194,7 +195,7 @@ ode_ckrk_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_ckrk_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -240,7 +241,7 @@ ode_ckrk_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_ckrk(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/functor/ode_rk45.hpp b/stan/math/prim/functor/ode_rk45.hpp index 8a90bd1e2c7..521e66cface 100644 --- a/stan/math/prim/functor/ode_rk45.hpp +++ b/stan/math/prim/functor/ode_rk45.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -53,7 +54,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... Args, require_eigen_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, @@ -196,7 +198,7 @@ ode_rk45_tol_impl(const char* function_name, const F& f, const T_y0& y0_arg, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, const std::vector& ts, double relative_tolerance, @@ -242,7 +244,7 @@ ode_rk45_tol(const F& f, const T_y0& y0_arg, T_t0 t0, */ template * = nullptr> -std::vector, +std::vector, Eigen::Dynamic, 1>> ode_rk45(const F& f, const T_y0& y0, T_t0 t0, const std::vector& ts, std::ostream* msgs, const Args&... args) { diff --git a/stan/math/prim/meta.hpp b/stan/math/prim/meta.hpp index 2fa62a1e55e..564b2b8c56c 100644 --- a/stan/math/prim/meta.hpp +++ b/stan/math/prim/meta.hpp @@ -214,6 +214,7 @@ #include #include #include +#include #include #include #include diff --git a/stan/math/prim/meta/is_stan_closure.hpp b/stan/math/prim/meta/is_stan_closure.hpp new file mode 100644 index 00000000000..3ff256c3de0 --- /dev/null +++ b/stan/math/prim/meta/is_stan_closure.hpp @@ -0,0 +1,60 @@ +#ifndef STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP +#define STAN_MATH_PRIM_META_IS_STAN_CLOSURE_HPP + +#include +#include +#include + +#include + +namespace stan { + +/** + * Checks if type is a closure object. + * @tparam The type to check + * @ingroup type_trait + */ +template +struct is_stan_closure : std::false_type {}; + +template +struct is_stan_closure::return_scalar_t_>> + : std::true_type {}; + +STAN_ADD_REQUIRE_UNARY(stan_closure, is_stan_closure, general_types); + +template +struct scalar_type> { + using type = typename std::decay_t::return_scalar_t_; +}; + +template +struct closure_return_type; + +template +struct closure_return_type { + using type = const std::decay_t&; +}; + +template +struct closure_return_type> { + using type = std::remove_reference_t; +}; + +template +struct closure_return_type> { + using type = typename std::remove_reference_t::Base_; +}; + +/** + * Type for things captured either by const reference or by copy. + * + * @tparam T type of object being captured + * @tparam Ref true if reference, false if copy + */ +template +using closure_return_type_t = typename closure_return_type::type; + +} // namespace stan + +#endif diff --git a/stan/math/prim/meta/promote_scalar_type.hpp b/stan/math/prim/meta/promote_scalar_type.hpp index 4c03569903a..a1574d81149 100644 --- a/stan/math/prim/meta/promote_scalar_type.hpp +++ b/stan/math/prim/meta/promote_scalar_type.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace stan { @@ -93,6 +94,23 @@ struct promote_scalar_type> { S::RowsAtCompileTime, S::ColsAtCompileTime>>::type; }; +/** + * Template metaprogram to calculate a type for a closure whose + * underlying scalar is converted from the second template + * parameter type to the first. + * + * @tparam T result scalar type. + * @tparam S input closure type + */ +template +struct promote_scalar_type> { + /** + * The promoted type. + */ + using type = typename std::conditional::value, F, + typename F::partials_closure_t_>::type; +}; + template using promote_scalar_t = typename promote_scalar_type, std::decay_t>::type; diff --git a/stan/math/prim/meta/return_type.hpp b/stan/math/prim/meta/return_type.hpp index 79ef5816a9a..0f41544ef26 100644 --- a/stan/math/prim/meta/return_type.hpp +++ b/stan/math/prim/meta/return_type.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -116,7 +117,12 @@ using row_vector_return_t = Eigen::Matrix, 1, -1>; */ template struct scalar_lub { - using type = promote_args_t; + using type = std::conditional_t< + is_stan_scalar::value && is_stan_scalar::value, + promote_args_t, + std::conditional_t< + is_stan_scalar::value, T1, + std::conditional_t::value, T2, double>>>; }; template @@ -190,8 +196,8 @@ struct return_type { template struct return_type { - using type - = scalar_lub_t, typename return_type::type>; + using type = scalar_lub_t, + typename return_type...>::type>; }; /** diff --git a/stan/math/rev/core/accumulate_adjoints.hpp b/stan/math/rev/core/accumulate_adjoints.hpp index e5b27354ebd..ddcd9ae4d00 100644 --- a/stan/math/rev/core/accumulate_adjoints.hpp +++ b/stan/math/rev/core/accumulate_adjoints.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_ACCUMULATE_ADJOINTS_HPP #include +#include #include #include @@ -29,6 +30,10 @@ template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline double* accumulate_adjoints(double* dest, Arith&& x, Pargs&&... args); @@ -121,6 +126,29 @@ inline double* accumulate_adjoints(double* dest, EigT&& x, Pargs&&... args) { return accumulate_adjoints(dest + x.size(), std::forward(args)...); } +/** + * Accumulate adjoints from f (a closure type containing vars) + * into storage pointed to by dest, + * increment the adjoint storage pointer, + * recursively accumulate the adjoints of the rest of the arguments, + * and return final position of storage pointer. + * + * @tparam F A closure type capturing vars. + * @tparam Pargs Types of remaining arguments + * @param dest Pointer to where adjoints are to be accumulated + * @param f A closure holding vars to accumulate over + * @param args Further args to accumulate over + * @return Final position of adjoint storage pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline double* accumulate_adjoints(double* dest, F& f, Pargs&&... args) { + return accumulate_adjoints( + apply([dest](auto... s) { return accumulate_adjoints(dest, s...); }, + f.captures_), + std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/count_vars.hpp b/stan/math/rev/core/count_vars.hpp index b0b536a27ab..466d8609214 100644 --- a/stan/math/rev/core/count_vars.hpp +++ b/stan/math/rev/core/count_vars.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_CORE_COUNT_VARS_HPP #include +#include #include #include @@ -29,6 +30,10 @@ inline size_t count_vars_impl(size_t count, EigT&& x, Pargs&&... args); template inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args); + template >* = nullptr, typename... Pargs> inline size_t count_vars_impl(size_t count, Arith& x, Pargs&&... args); @@ -110,6 +115,28 @@ inline size_t count_vars_impl(size_t count, const var& x, Pargs&&... args) { return count_vars_impl(count + 1, std::forward(args)...); } +/** + * Count the number of vars in f (a closure capturing vars), + * add it to the running total, + * count the number of vars in the remaining arguments + * and return the result. + * + * @tparam F A closure type + * @tparam Pargs Types of remaining arguments + * @param[in] count The current count of the number of vars + * @param[in] f A closure holding vars + * @param[in] args objects to be forwarded to recursive call of + * `count_vars_impl` + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline size_t count_vars_impl(size_t count, const F& f, Pargs&&... args) { + return count_vars_impl( + apply([count](auto... s) { return count_vars_impl(count, s...); }, + f.captures_), + std::forward(args)...); +} + /** * Arguments without vars contribute zero to the total number of vars. * diff --git a/stan/math/rev/core/deep_copy_vars.hpp b/stan/math/rev/core/deep_copy_vars.hpp index 06561d1a9e0..b8c3dd077e8 100644 --- a/stan/math/rev/core/deep_copy_vars.hpp +++ b/stan/math/rev/core/deep_copy_vars.hpp @@ -2,6 +2,8 @@ #define STAN_MATH_REV_CORE_DEEP_COPY_VARS_HPP #include +#include +#include #include #include @@ -19,7 +21,8 @@ namespace math { * @param arg For lvalue references this will be passed by reference. * Otherwise it will be moved. */ -template >> +template >, + typename = require_not_stan_closure_t> inline Arith deep_copy_vars(Arith&& arg) { return std::forward(arg); } @@ -81,6 +84,22 @@ inline auto deep_copy_vars(EigT&& arg) { .eval(); } +/** + * Copy the vars in f but reallocate new varis for them + * + * @tparam F A closure type + * @param f A closure containing vars + * @return A new closure containing vars + */ +template * = nullptr> +inline auto deep_copy_vars(const F& f) { + return apply( + [&f](const auto&... s) { + return typename F::Base_(f.f_, eval(deep_copy_vars(s))...); + }, + f.captures_); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/core/save_varis.hpp b/stan/math/rev/core/save_varis.hpp index c53a5390539..37d246d81dc 100644 --- a/stan/math/rev/core/save_varis.hpp +++ b/stan/math/rev/core/save_varis.hpp @@ -1,6 +1,7 @@ #ifndef STAN_MATH_REV_CORE_SAVE_VARIS_HPP #define STAN_MATH_REV_CORE_SAVE_VARIS_HPP +#include #include #include #include @@ -29,6 +30,10 @@ template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args); +template * = nullptr, + require_not_st_arithmetic* = nullptr, typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args); + template * = nullptr, typename... Pargs> inline vari** save_varis(vari** dest, Arith&& x, Pargs&&... args); @@ -118,6 +123,27 @@ inline vari** save_varis(vari** dest, EigT&& x, Pargs&&... args) { return save_varis(dest + x.size(), std::forward(args)...); } +/** + * Save the vari pointers in f into the memory pointed to by dest, + * increment the dest storage pointer, + * recursively call save_varis on the rest of the arguments, + * and return the final value of the dest storage pointer. + * + * @tparam F A closure type with var value type + * @tparam Pargs Types of remaining arguments + * @param[in, out] dest Pointer to where vari pointers are saved + * @param[in] f A closure capturing vars + * @param[in] args Additional arguments to have their varis saved + * @return Final position of dest pointer + */ +template *, require_not_st_arithmetic*, + typename... Pargs> +inline vari** save_varis(vari** dest, F& f, Pargs&&... args) { + return save_varis( + apply([dest](auto... s) { return save_varis(dest, s...); }, f.captures_), + std::forward(args)...); +} + /** * Ignore arithmetic types. * diff --git a/stan/math/rev/core/zero_adjoints.hpp b/stan/math/rev/core/zero_adjoints.hpp index a05362fc79d..82e49cf3d62 100644 --- a/stan/math/rev/core/zero_adjoints.hpp +++ b/stan/math/rev/core/zero_adjoints.hpp @@ -3,6 +3,7 @@ #include #include +#include #include namespace stan { @@ -54,6 +55,23 @@ inline void zero_adjoints(EigMat& x) { x.coeffRef(i).adj() = 0; } +/** + * Zero the adjoints of the varis of every var in a closure. + * Recursively call zero_adjoints on the rest of the arguments. + * + * @tparam F type of current argument + * @tparam Pargs type of rest of arguments + * + * @param f current argument + * @param args rest of arguments to zero + */ +template * = nullptr, + require_not_st_arithmetic* = nullptr> +inline void zero_adjoints(F& f, Pargs&... args) { + apply([](auto... s) { zero_adjoints(s...); }, f.captures_); + zero_adjoints(args...); +} + /** * Zero the adjoints of every element in a vector. Recursively call * zero_adjoints on the rest of the arguments. diff --git a/stan/math/rev/functor/integrate_1d.hpp b/stan/math/rev/functor/integrate_1d.hpp index be8ff106393..73dc3437b1d 100644 --- a/stan/math/rev/functor/integrate_1d.hpp +++ b/stan/math/rev/functor/integrate_1d.hpp @@ -211,6 +211,7 @@ inline return_type_t integrate_1d_impl( * @return numeric integral of function f */ template , typename = require_any_var_t> inline return_type_t integrate_1d( const F &f, const T_a &a, const T_b &b, const std::vector &theta, @@ -220,6 +221,17 @@ inline return_type_t integrate_1d( msgs, theta, x_r, x_i); } +template , + typename = require_any_var_t, T_a, T_b, T_theta>> +inline return_type_t integrate_1d( + const F &f, const T_a &a, const T_b &b, const std::vector &theta, + const std::vector &x_r, const std::vector &x_i, + std::ostream *msgs, const double relative_tolerance = std::sqrt(EPSILON)) { + return integrate_1d_impl(integrate_1d_closure_adapter(), a, b, + relative_tolerance, msgs, f, theta, x_r, x_i); +} + } // namespace math } // namespace stan diff --git a/stan/math/rev/functor/integrate_ode_adams.hpp b/stan/math/rev/functor/integrate_ode_adams.hpp index 0cba70a321e..1f2fee91cb0 100644 --- a/stan/math/rev/functor/integrate_ode_adams.hpp +++ b/stan/math/rev/functor/integrate_ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -10,12 +11,48 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +auto integrate_ode_adams_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_adams_tol_impl("integrate_ode_adams", f_adapted, to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, theta, x, x_int); +} + +template * = nullptr> +auto integrate_ode_adams_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + return ode_adams_tol_impl("integrate_ode_adams", + integrate_ode_closure_adapter(), to_vector(y0), t0, + ts, relative_tolerance, absolute_tolerance, + max_num_steps, msgs, f, theta, x, x_int); +} + +} // namespace internal + /** * @deprecated use ode_adams */ template -std::vector>> +std::vector>> integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -24,12 +61,11 @@ integrate_ode_adams(const F& f, const std::vector& y0, const T_t0& t0, double relative_tolerance = 1e-10, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_adams_tol_impl("integrate_ode_adams", f_adapted, to_vector(y0), - t0, ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + auto y = internal::integrate_ode_adams_impl( + f, y0, t0, ts, theta, x, x_int, msgs, relative_tolerance, + absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/integrate_ode_bdf.hpp b/stan/math/rev/functor/integrate_ode_bdf.hpp index c3877bdb875..4e25abd1b1b 100644 --- a/stan/math/rev/functor/integrate_ode_bdf.hpp +++ b/stan/math/rev/functor/integrate_ode_bdf.hpp @@ -2,6 +2,7 @@ #define STAN_MATH_REV_FUNCTOR_INTEGRATE_ODE_BDF_HPP #include +#include #include #include #include @@ -10,12 +11,48 @@ namespace stan { namespace math { +namespace internal { + +template * = nullptr> +auto integrate_ode_bdf_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + internal::integrate_ode_std_vector_interface_adapter f_adapted(f); + return ode_bdf_tol_impl("integrate_ode_bdf", f_adapted, to_vector(y0), t0, ts, + relative_tolerance, absolute_tolerance, max_num_steps, + msgs, theta, x, x_int); +} + +template * = nullptr> +auto integrate_ode_bdf_impl(const F& f, const std::vector& y0, + const T_t0& t0, const std::vector& ts, + const std::vector& theta, + const std::vector& x, + const std::vector& x_int, std::ostream* msgs, + double relative_tolerance, + double absolute_tolerance, + long int max_num_steps) { // NOLINT(runtime/int) + return ode_bdf_tol_impl("integrate_ode_bdf", integrate_ode_closure_adapter(), + to_vector(y0), t0, ts, relative_tolerance, + absolute_tolerance, max_num_steps, msgs, f, theta, x, + x_int); +} + +} // namespace internal + /** * @deprecated use ode_bdf */ template -std::vector>> + typename T_ts, typename = require_not_stan_closure_t> +std::vector>> integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, const std::vector& ts, const std::vector& theta, @@ -24,12 +61,11 @@ integrate_ode_bdf(const F& f, const std::vector& y0, const T_t0& t0, double relative_tolerance = 1e-10, double absolute_tolerance = 1e-10, long int max_num_steps = 1e8) { // NOLINT(runtime/int) - internal::integrate_ode_std_vector_interface_adapter f_adapted(f); - auto y = ode_bdf_tol_impl("integrate_ode_bdf", f_adapted, to_vector(y0), t0, - ts, relative_tolerance, absolute_tolerance, - max_num_steps, msgs, theta, x, x_int); + auto y = internal::integrate_ode_bdf_impl(f, y0, t0, ts, theta, x, x_int, + msgs, relative_tolerance, + absolute_tolerance, max_num_steps); - std::vector>> + std::vector>> y_converted; for (size_t i = 0; i < y.size(); ++i) y_converted.push_back(to_array_1d(y[i])); diff --git a/stan/math/rev/functor/ode_adams.hpp b/stan/math/rev/functor/ode_adams.hpp index ee0bdafbbd5..d6b64eeb45a 100644 --- a/stan/math/rev/functor/ode_adams.hpp +++ b/stan/math/rev/functor/ode_adams.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -45,7 +46,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_adams_tol_impl(const char* function_name, const F& f, const T_y0& y0, diff --git a/stan/math/rev/functor/ode_bdf.hpp b/stan/math/rev/functor/ode_bdf.hpp index a07af2e3339..cf53b1f54e9 100644 --- a/stan/math/rev/functor/ode_bdf.hpp +++ b/stan/math/rev/functor/ode_bdf.hpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -46,7 +47,8 @@ namespace math { * @return Solution to ODE at times \p ts */ template * = nullptr> + typename... T_Args, require_eigen_col_vector_t* = nullptr, + require_not_stan_closure_t* = nullptr> std::vector, Eigen::Dynamic, 1>> ode_bdf_tol_impl(const char* function_name, const F& f, const T_y0& y0, diff --git a/test/unit/math/prim/meta/is_stan_closure_test.cpp b/test/unit/math/prim/meta/is_stan_closure_test.cpp new file mode 100644 index 00000000000..7138e31f448 --- /dev/null +++ b/test/unit/math/prim/meta/is_stan_closure_test.cpp @@ -0,0 +1,18 @@ +#include +#include +#include +#include + +TEST(MathMetaPrim, IsStanClosure) { + auto lambda = [](auto msg) { return 0.0; }; + auto cl = stan::math::from_lambda(lambda); + EXPECT_FALSE((stan::is_stan_closure::value)); + EXPECT_TRUE((stan::is_stan_closure::value)); +} + +TEST(MathMetaPrim, ClosureReturnType) { + EXPECT_SAME_TYPE(const std::vector&, + stan::closure_return_type, true>::type); + EXPECT_SAME_TYPE(std::vector, + stan::closure_return_type, false>::type); +} diff --git a/test/unit/math/rev/functor/closure_ode_typed_test.cpp b/test/unit/math/rev/functor/closure_ode_typed_test.cpp new file mode 100644 index 00000000000..dcd4fecc8e1 --- /dev/null +++ b/test/unit/math/rev/functor/closure_ode_typed_test.cpp @@ -0,0 +1,80 @@ +#include +#include +#include +#include +#include +#include + +/** + * + * Use same solver functor type for both w & w/o tolerance control + */ +template +using ode_test_tuple = std::tuple; + +/** + * Outer product of test types + */ +using closure_test_types = boost::mp11::mp_product< + ode_test_tuple, ::testing::Types >; + +TYPED_TEST_SUITE_P(closure_test); +TYPED_TEST_P(closure_test, y0_error) { + this->test_y0_error(); + this->test_y0_error_with_tol(); +} +TYPED_TEST_P(closure_test, t0_error) { + this->test_t0_error(); + this->test_t0_error_with_tol(); +} +TYPED_TEST_P(closure_test, ts_error) { + this->test_ts_error(); + this->test_ts_error_with_tol(); +} +TYPED_TEST_P(closure_test, two_arg_error) { + this->test_two_arg_error(); + this->test_two_arg_error_with_tol(); +} +TYPED_TEST_P(closure_test, tol_error) { + this->test_rtol_error(); + this->test_atol_error(); + this->test_max_num_step_error(); + this->test_too_much_work(); +} +TYPED_TEST_P(closure_test, value) { this->test_value(); } +TYPED_TEST_P(closure_test, grad) { + this->test_grad_t0(); + this->test_grad_ts(); + this->test_grad_ts_repeat(); + this->test_scalar_arg(); + this->test_std_vector_arg(); + this->test_vector_arg(); + this->test_row_vector_arg(); + this->test_matrix_arg(); + this->test_scalar_std_vector_args(); + this->test_std_vector_std_vector_args(); + this->test_std_vector_vector_args(); + this->test_std_vector_row_vector_args(); + this->test_std_vector_matrix_args(); + this->test_arg_combos_test(); +} +TYPED_TEST_P(closure_test, tol_grad) { + this->test_tol_t0(); + this->test_tol_ts(); + this->test_tol_ts_repeat(); + this->test_tol_scalar_arg(); + this->test_tol_scalar_arg_multi_time(); + this->test_tol_std_vector_arg(); + this->test_tol_vector_arg(); + this->test_tol_row_vector_arg(); + this->test_tol_matrix_arg(); + this->test_tol_scalar_std_vector_args(); + this->test_tol_std_vector_std_vector_args(); + this->test_tol_std_vector_vector_args(); + this->test_tol_std_vector_row_vector_args(); + this->test_tol_std_vector_matrix_args(); +} +REGISTER_TYPED_TEST_SUITE_P(closure_test, y0_error, t0_error, ts_error, + two_arg_error, tol_error, value, grad, tol_grad); +INSTANTIATE_TYPED_TEST_SUITE_P(StanOde, closure_test, closure_test_types); diff --git a/test/unit/math/rev/functor/reduce_sum_closure_test.cpp b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp new file mode 100644 index 00000000000..f31711acf8b --- /dev/null +++ b/test/unit/math/rev/functor/reduce_sum_closure_test.cpp @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +TEST(StanMathRev_reduce_sum, grouped_gradient_closure) { + using stan::math::from_lambda; + using stan::math::reduce_sum_closure_adapter; + using stan::math::var; + using stan::math::test::get_new_msg; + + double lambda_d = 10.0; + const std::size_t groups = 10; + const std::size_t elems_per_group = 1000; + const std::size_t elems = groups * elems_per_group; + + std::vector data(elems); + std::vector gidx(elems); + + for (std::size_t i = 0; i != elems; ++i) { + data[i] = i; + gidx[i] = i / elems_per_group; + } + + std::vector vlambda_v; + + for (std::size_t i = 0; i != groups; ++i) + vlambda_v.push_back(i + 0.2); + + var lambda_v = vlambda_v[0]; + + auto functor = from_lambda( + [](auto& lambda, auto& slice, std::size_t start, std::size_t end, + auto& gidx, std::ostream* msgs) { + const std::size_t num_terms = end - start + 1; + std::decay_t lambda_slice(num_terms); + for (std::size_t i = 0; i != num_terms; ++i) + lambda_slice[i] = lambda[gidx[start - 1 + i]]; + return stan::math::poisson_lpmf(slice, lambda_slice); + }, + vlambda_v); + + var poisson_lpdf = stan::math::reduce_sum( + data, 5, get_new_msg(), functor, gidx); + + std::vector vref_lambda_v; + for (std::size_t i = 0; i != elems; ++i) { + vref_lambda_v.push_back(vlambda_v[gidx[i]]); + } + var lambda_ref = vlambda_v[0]; + var poisson_lpdf_ref = stan::math::poisson_lpmf(data, vref_lambda_v); + + EXPECT_FLOAT_EQ(value_of(poisson_lpdf), value_of(poisson_lpdf_ref)); + + stan::math::grad(poisson_lpdf_ref.vi_); + const double lambda_ref_adj = lambda_ref.adj(); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf.vi_); + const double lambda_adj = lambda_v.adj(); + + EXPECT_FLOAT_EQ(lambda_adj, lambda_ref_adj) + << "ref value of poisson lpdf : " << poisson_lpdf_ref.val() << std::endl + << "ref gradient wrt to lambda: " << lambda_ref_adj << std::endl + << "value of poisson lpdf : " << poisson_lpdf.val() << std::endl + << "gradient wrt to lambda: " << lambda_adj << std::endl; + + var poisson_lpdf_static + = stan::math::reduce_sum_static( + data, 5, get_new_msg(), functor, gidx); + + stan::math::set_zero_all_adjoints(); + stan::math::grad(poisson_lpdf_static.vi_); + const double lambda_adj_static = lambda_v.adj(); + EXPECT_FLOAT_EQ(lambda_adj_static, lambda_ref_adj); + stan::math::recover_memory(); +} diff --git a/test/unit/math/rev/functor/test_fixture_ode_closure.hpp b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp new file mode 100644 index 00000000000..9b2a70b040a --- /dev/null +++ b/test/unit/math/rev/functor/test_fixture_ode_closure.hpp @@ -0,0 +1,1198 @@ +#ifndef STAN_MATH_TEST_FIXTURE_ODE_CLOSURE_HPP +#define STAN_MATH_TEST_FIXTURE_ODE_CLOSURE_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct closure_ode_base { + Eigen::VectorXd y0; + double t0; + std::vector ts; + double a; + double rtol; + double atol; + int max_num_step; + + closure_ode_base() + : y0(1), + t0(0.0), + ts{0.45, 1.1}, + a(1.5), + rtol(1.e-10), + atol(1.e-10), + max_num_step(100000) { + y0[0] = 0.0; + } +}; + +/** + * Inheriting base type, various fixtures differs by the type of ODE + * functor used in apply_solver calls, intended for + * different kind of tests. + * + */ +template +struct closure_test : public closure_ode_base, + public ODETestFixture> { + closure_test() : closure_ode_base() {} + + Eigen::VectorXd init() { return y0; } + std::vector param() { return {a}; } + + auto apply_solver() { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); + } + + template + auto apply_solver(Eigen::Matrix& init, std::vector& va) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), init, t0, ts, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + va)); + } + + template + auto apply_solver_ts(const std::vector& ts_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); + } + + template + auto apply_solver_ts(const std::vector& ts_, const a_type& arg) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + arg)); + } + + template + auto apply_solver_ts_tol(const std::vector& ts_, double rtol, double atol, + int max_num_steps, const a_type& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts_, rtol, atol, + max_num_steps, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); + } + + template + auto apply_solver_t0(const T0& t0_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); + } + + template + auto apply_solver_t0_tol(const T0& t0_, double rtol, double atol, + int max_num_steps, const a_type& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0_, ts, rtol, atol, + max_num_steps, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); + } + + auto apply_solver_tol() { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a_, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a_); + }, + a)); + } + + template + auto apply_solver_arg(a_type const& a_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); + } + + template + auto apply_solver_arg_tol(a_type const& a_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, std::ostream* msgs) { + return stan::test::CosArg1()(t, y, msgs, a); + }, + a_)); + } + + template + auto apply_solver_arg(a_type const& a_, b_type const& b_) { + std::tuple_element_t<0, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a_), + b_); + } + + template + auto apply_solver_arg_tol(a_type const& a_, b_type const& b_) { + std::tuple_element_t<1, T> sol; + return sol(stan::math::ode_closure_adapter(), y0, t0, ts, rtol, atol, + max_num_step, nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a_), + b_); + } + + void test_y0_error() { + y0 = Eigen::VectorXd::Zero(1); + ASSERT_NO_THROW(apply_solver()); + + y0[0] = stan::math::INFTY; + EXPECT_THROW(apply_solver(), std::domain_error); + + y0[0] = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver(), std::domain_error); + + y0 = Eigen::VectorXd(); + EXPECT_THROW(apply_solver(), std::invalid_argument); + } + + void test_y0_error_with_tol() { + y0 = Eigen::VectorXd::Zero(1); + ASSERT_NO_THROW(apply_solver_tol()); + + y0[0] = stan::math::INFTY; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + y0[0] = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + y0 = Eigen::VectorXd(); + EXPECT_THROW(apply_solver_tol(), std::invalid_argument); + } + + void test_t0_error() { + t0 = stan::math::INFTY; + EXPECT_THROW(apply_solver(), std::domain_error); + + t0 = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver(), std::domain_error); + } + + void test_t0_error_with_tol() { + t0 = stan::math::INFTY; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + t0 = stan::math::NOT_A_NUMBER; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_ts_error() { + std::vector ts_repeat = {0.45, 0.45}; + std::vector ts_lots = {0.45, 0.45, 1.1, 1.1, 2.0}; + std::vector ts_empty = {}; + std::vector ts_early = {-0.45, 0.2}; + std::vector ts_decreasing = {0.45, 0.2}; + std::vector tsinf = {stan::math::INFTY, 1.1}; + std::vector tsNaN = {0.45, stan::math::NOT_A_NUMBER}; + + std::vector out; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts.size()); + + ts = ts_repeat; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts_repeat.size()); + EXPECT_MATRIX_FLOAT_EQ(out[0], out[1]); + + ts = ts_lots; + EXPECT_NO_THROW(out = apply_solver()); + EXPECT_EQ(out.size(), ts_lots.size()); + + ts = ts_empty; + EXPECT_THROW(apply_solver(), std::invalid_argument); + + ts = ts_early; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = ts_decreasing; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = tsinf; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = tsNaN; + EXPECT_THROW(apply_solver(), std::domain_error); + + ts = {0.45, 1.1}; + } + + void test_ts_error_with_tol() { + std::vector ts_repeat = {0.45, 0.45}; + std::vector ts_lots = {0.45, 0.45, 1.1, 1.1, 2.0}; + std::vector ts_empty = {}; + std::vector ts_early = {-0.45, 0.2}; + std::vector ts_decreasing = {0.45, 0.2}; + std::vector tsinf = {stan::math::INFTY, 1.1}; + std::vector tsNaN = {0.45, stan::math::NOT_A_NUMBER}; + + std::vector out; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts.size()); + + ts = ts_repeat; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts_repeat.size()); + EXPECT_MATRIX_FLOAT_EQ(out[0], out[1]); + + ts = ts_lots; + EXPECT_NO_THROW(out = apply_solver_tol()); + EXPECT_EQ(out.size(), ts_lots.size()); + + ts = ts_empty; + EXPECT_THROW(apply_solver_tol(), std::invalid_argument); + + ts = ts_early; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = ts_decreasing; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = tsinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = tsNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + ts = {0.45, 1.1}; + } + + void test_two_arg_error() { + a = 1.5; + double ainf = stan::math::INFTY; + double aNaN = stan::math::NOT_A_NUMBER; + + std::vector va = {a}; + std::vector vainf = {ainf}; + std::vector vaNaN = {aNaN}; + + Eigen::VectorXd ea(1); + ea << a; + Eigen::VectorXd eainf(1); + eainf << ainf; + Eigen::VectorXd eaNaN(1); + eaNaN << aNaN; + + std::vector> vva = {va}; + std::vector> vvainf = {vainf}; + std::vector> vvaNaN = {vaNaN}; + + std::vector vea = {ea}; + std::vector veainf = {eainf}; + std::vector veaNaN = {eaNaN}; + + EXPECT_NO_THROW(apply_solver_arg(a, a)); + + EXPECT_THROW(apply_solver_arg(a, ainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, aNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, va)); + + EXPECT_THROW(apply_solver_arg(a, vainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, vaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, ea)); + + EXPECT_THROW(apply_solver_arg(a, eainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, eaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, vva)); + + EXPECT_THROW(apply_solver_arg(a, vvainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, vvaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg(a, vea)); + + EXPECT_THROW(apply_solver_arg(a, veainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg(a, veaNaN), std::domain_error); + } + + void test_two_arg_error_with_tol() { + a = 1.5; + double ainf = stan::math::INFTY; + double aNaN = stan::math::NOT_A_NUMBER; + + std::vector va = {a}; + std::vector vainf = {ainf}; + std::vector vaNaN = {aNaN}; + + Eigen::VectorXd ea(1); + ea << a; + Eigen::VectorXd eainf(1); + eainf << ainf; + Eigen::VectorXd eaNaN(1); + eaNaN << aNaN; + + std::vector> vva = {va}; + std::vector> vvainf = {vainf}; + std::vector> vvaNaN = {vaNaN}; + + std::vector vea = {ea}; + std::vector veainf = {eainf}; + std::vector veaNaN = {eaNaN}; + + EXPECT_NO_THROW(apply_solver_arg_tol(a, a)); + + EXPECT_THROW(apply_solver_arg_tol(a, ainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, aNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, va)); + + EXPECT_THROW(apply_solver_arg_tol(a, vainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, vaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, ea)); + + EXPECT_THROW(apply_solver_arg_tol(a, eainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, eaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, vva)); + + EXPECT_THROW(apply_solver_arg_tol(a, vvainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, vvaNaN), std::domain_error); + + EXPECT_NO_THROW(apply_solver_arg_tol(a, vea)); + + EXPECT_THROW(apply_solver_arg_tol(a, veainf), std::domain_error); + + EXPECT_THROW(apply_solver_arg_tol(a, veaNaN), std::domain_error); + } + + void test_rtol_error() { + y0 = Eigen::VectorXd::Zero(1); + t0 = 0; + ts = {0.45, 1.1}; + a = 1.5; + + rtol = 1e-6; + atol = 1e-6; + double rtol_negative = -1e-6; + double rtolinf = stan::math::INFTY; + double rtolNaN = stan::math::NOT_A_NUMBER; + + EXPECT_NO_THROW(apply_solver_tol()); + + rtol = rtol_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + rtol = rtolinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + rtol = rtolNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_atol_error() { + y0 = Eigen::VectorXd::Zero(1); + t0 = 0; + ts = {0.45, 1.1}; + a = 1.5; + + rtol = 1e-6; + atol = 1e-6; + double atol_negative = -1e-6; + double atolinf = stan::math::INFTY; + double atolNaN = stan::math::NOT_A_NUMBER; + + EXPECT_NO_THROW(apply_solver_tol()); + + atol = atol_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + atol = atolinf; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + atol = atolNaN; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_max_num_step_error() { + rtol = 1e-6; + atol = 1e-6; + max_num_step = 500; + int max_num_steps_negative = -500; + int max_num_steps_zero = 0; + + EXPECT_NO_THROW(apply_solver_tol()); + + max_num_step = max_num_steps_negative; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + + max_num_step = max_num_steps_zero; + EXPECT_THROW(apply_solver_tol(), std::domain_error); + } + + void test_too_much_work() { + ts[1] = 1e4; + max_num_step = 10; + EXPECT_THROW_MSG(apply_solver_tol(), std::domain_error, + "Failed to integrate to next output time"); + } + + void test_value() { + std::vector res = apply_solver(); + EXPECT_NEAR(res[0][0], 0.4165982112, 1e-5); + EXPECT_NEAR(res[1][0], 0.66457668563, 1e-5); + + std::vector ts_i = {1, 2}; + std::tuple_element_t<0, T> sol; + res = apply_solver_ts(ts_i); + EXPECT_NEAR(res[0][0], 0.6649966577, 1e-5); + EXPECT_NEAR(res[1][0], 0.09408000537, 1e-5); + + int t0_i = 0; + res = apply_solver_t0(t0_i); + EXPECT_NEAR(res[0][0], 0.4165982112, 1e-5); + EXPECT_NEAR(res[1][0], 0.66457668563, 1e-5); + } + + void test_grad_t0() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + stan::math::var t0v = 0.0; + auto res = apply_solver_t0(t0v); + + res[0][0].grad(); + + EXPECT_NEAR(res[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(t0v.adj(), -1.0, 1e-5); + + nested.set_zero_all_adjoints(); + + res[1][0].grad(); + + EXPECT_NEAR(res[1][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(t0v.adj(), -1.0, 1e-5); + } + + void test_grad_ts() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector tsv = {0.45, 1.1}; + auto res = apply_solver_ts(tsv); + + res[0][0].grad(); + + EXPECT_NEAR(res[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[0].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + res[1][0].grad(); + + EXPECT_NEAR(res[1][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[1].adj(), -0.0791208888, 1e-5); + } + + void test_grad_ts_repeat() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector tsv = {0.45, 0.45, 1.1, 1.1}; + auto output = apply_solver_ts(tsv); + + EXPECT_EQ(output.size(), tsv.size()); + + output[0][0].grad(); + + EXPECT_NEAR(output[0][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[0].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_NEAR(output[1][0].val(), 0.4165982112, 1e-5); + EXPECT_NEAR(tsv[1].adj(), 0.78070695113, 1e-5); + nested.set_zero_all_adjoints(); + + output[2][0].grad(); + + EXPECT_NEAR(output[2][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[2].adj(), -0.0791208888, 1e-5); + nested.set_zero_all_adjoints(); + + output[3][0].grad(); + EXPECT_NEAR(output[3][0].val(), 0.66457668563, 1e-5); + EXPECT_NEAR(tsv[3].adj(), -0.0791208888, 1e-5); + } + + void test_scalar_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + stan::math::var av = 1.5; + + { + std::vector ts1{1.1}; + auto output = apply_solver_ts(ts1, av)[0][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av.adj(), -0.50107310888, 1e-5); + nested.set_zero_all_adjoints(); + } + + { + auto output = apply_solver_arg(av); + + output[0](0).grad(); + + EXPECT_NEAR(output[0](0).val(), 0.4165982112, 1e-5); + EXPECT_NEAR(av.adj(), -0.04352005542, 1e-5); + nested.set_zero_all_adjoints(); + + output[1](0).grad(); + + EXPECT_NEAR(output[1](0).val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av.adj(), -0.50107310888, 1e-5); + nested.set_zero_all_adjoints(); + } + } + + void test_std_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + std::vector av = {1.5}; + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av[0].adj(), -0.50107310888, 1e-5); + } + + void test_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_row_vector_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_matrix_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + Eigen::Matrix av(1, 1); + av << 1.5; + + var output = apply_solver_arg(av)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(av(0).adj(), -0.50107310888, 1e-5); + } + + void test_scalar_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 0.75; + std::vector a1 = {0.75}; + + var output = apply_solver_arg(a0, a1)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a0.adj(), -0.50107310888, 1e-5); + EXPECT_NEAR(a1[0].adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + std::vector a1(1, a0); + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0][0].adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_row_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_std_vector_matrix_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + var a0 = 1.5; + Eigen::Matrix a1(1, 1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_arg(a2)[1][0]; + + output.grad(); + + EXPECT_NEAR(output.val(), 0.66457668563, 1e-5); + EXPECT_NEAR(a2[0](0).adj(), -0.50107310888, 1e-5); + } + + void test_arg_combos_test() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<0, T> sol; + + var t0 = 0.5; + var a = 0.2; + std::vector ts = {1.25}; + Eigen::Matrix y0(1); + y0 << 0.75; + + double t0d = stan::math::value_of(t0); + double ad = stan::math::value_of(a); + std::vector tsd = stan::math::value_of(ts); + Eigen::VectorXd y0d = stan::math::value_of(y0); + + auto check_yT = [&](auto yT) { + EXPECT_NEAR(stan::math::value_of(yT), + y0d(0) * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_t0 = [&](var t0) { + EXPECT_NEAR( + t0.adj(), + ad * t0d * y0d(0) * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_a = [&](var a) { + EXPECT_NEAR(a.adj(), + -0.5 * (tsd[0] * tsd[0] - t0d * t0d) * y0d(0) + * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_ts = [&](std::vector ts) { + EXPECT_NEAR(ts[0].adj(), + -ad * tsd[0] * y0d(0) + * exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + auto check_y0 = [&](Eigen::Matrix y0) { + EXPECT_NEAR(y0(0).adj(), exp(-0.5 * ad * (tsd[0] * tsd[0] - t0d * t0d)), + 1e-5); + }; + + double yT1 = sol(stan::test::ayt(), y0d, t0d, tsd, nullptr, ad)[0](0); + check_yT(yT1); + + var yT2 = sol(stan::test::ayt(), y0d, t0d, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT2.grad(); + check_yT(yT2); + check_a(a); + + var yT3 = sol(stan::test::ayt(), y0d, t0d, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT3.grad(); + check_yT(yT3); + check_ts(ts); + + var yT4 = sol(stan::test::ayt(), y0d, t0d, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT4.grad(); + check_yT(yT4); + check_ts(ts); + check_a(a); + + var yT5 = sol(stan::test::ayt(), y0d, t0, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT5.grad(); + check_yT(yT5); + check_t0(t0); + + var yT6 = sol(stan::test::ayt(), y0d, t0, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT6.grad(); + check_yT(yT6); + check_t0(t0); + check_a(a); + + var yT7 = sol(stan::test::ayt(), y0d, t0, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT7.grad(); + check_yT(yT7); + check_t0(t0); + check_ts(ts); + + var yT8 = sol(stan::test::ayt(), y0d, t0, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT8.grad(); + check_yT(yT8); + check_t0(t0); + check_ts(ts); + check_a(a); + + var yT9 = sol(stan::test::ayt(), y0, t0d, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT9.grad(); + check_yT(yT9); + check_y0(y0); + + var yT10 = sol(stan::test::ayt(), y0, t0d, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT10.grad(); + check_yT(yT10); + check_y0(y0); + check_a(a); + + var yT11 = sol(stan::test::ayt(), y0, t0d, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT11.grad(); + check_yT(yT11); + check_y0(y0); + check_ts(ts); + + var yT12 = sol(stan::test::ayt(), y0, t0d, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT12.grad(); + check_yT(yT12); + check_y0(y0); + check_ts(ts); + check_a(a); + + var yT13 = sol(stan::test::ayt(), y0, t0, tsd, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT13.grad(); + check_yT(yT13); + check_y0(y0); + check_t0(t0); + + var yT14 = sol(stan::test::ayt(), y0, t0, tsd, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT14.grad(); + check_yT(yT14); + check_y0(y0); + check_t0(t0); + check_a(a); + + var yT15 = sol(stan::test::ayt(), y0, t0, ts, nullptr, ad)[0](0); + nested.set_zero_all_adjoints(); + yT15.grad(); + check_yT(yT15); + check_y0(y0); + check_t0(t0); + check_ts(ts); + + var yT16 = sol(stan::test::ayt(), y0, t0, ts, nullptr, a)[0](0); + nested.set_zero_all_adjoints(); + yT16.grad(); + check_yT(yT16); + check_y0(y0); + check_t0(t0); + check_ts(ts); + check_a(a); + } + + void test_tol_int_ts() { + std::vector ts = {1, 2}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + EXPECT_FLOAT_EQ(output[0][0], 0.6649966577); + EXPECT_FLOAT_EQ(output[1][0], 0.09408000537); + } + + void test_tol_t0() { + stan::math::nested_rev_autodiff nested; + + var t0 = 0.0; + + double a = 1.5; + + std::vector> output + = apply_solver_t0_tol(t0, 1e-10, 1e-10, 1e6, a); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(t0.adj(), -1.0); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(t0.adj(), -1.0); + } + + void test_tol_ts() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 1.1}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[0].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[1].adj(), -0.0791208888); + } + + void test_tol_ts_repeat() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 0.45, 1.1, 1.1}; + + double a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + EXPECT_EQ(output.size(), ts.size()); + + output[0][0].grad(); + + EXPECT_FLOAT_EQ(output[0][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[0].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[1][0].grad(); + + EXPECT_FLOAT_EQ(output[1][0].val(), 0.4165982112); + EXPECT_FLOAT_EQ(ts[1].adj(), 0.78070695113); + + nested.set_zero_all_adjoints(); + + output[2][0].grad(); + + EXPECT_FLOAT_EQ(output[2][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[2].adj(), -0.0791208888); + + nested.set_zero_all_adjoints(); + + output[3][0].grad(); + + EXPECT_FLOAT_EQ(output[3][0].val(), 0.66457668563); + EXPECT_FLOAT_EQ(ts[3].adj(), -0.0791208888); + } + + void test_tol_scalar_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a = 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a.adj(), -0.50107310888); + } + + void test_tol_scalar_arg_multi_time() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {0.45, 1.1}; + + var a = 1.5; + + std::vector> output + = apply_solver_ts_tol(ts, 1e-10, 1e-10, 1e6, a); + + output[0](0).grad(); + + EXPECT_FLOAT_EQ(output[0](0).val(), 0.4165982112); + EXPECT_FLOAT_EQ(a.adj(), -0.04352005542); + + nested.set_zero_all_adjoints(); + + output[1](0).grad(); + + EXPECT_FLOAT_EQ(output[1](0).val(), 0.66457668563); + EXPECT_FLOAT_EQ(a.adj(), -0.50107310888); + } + + void test_tol_std_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + std::vector a = {1.5}; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a[0].adj(), -0.50107310888); + } + + void test_tol_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1); + a << 1.5; + + var output = apply_solver_t0_tol(t0, 1e-8, 1e-10, 1e6, a)[1][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0).adj(), -0.50107310888); + } + + void test_tol_row_vector_arg() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1); + a << 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0).adj(), -0.50107310888); + } + + void test_tol_matrix_arg() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<1, T> sol; + + std::vector ts = {1.1}; + + Eigen::Matrix a(1, 1); + a << 1.5; + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a(0, 0).adj(), -0.50107310888); + } + + void test_tol_scalar_std_vector_args() { + stan::math::nested_rev_autodiff nested; + std::tuple_element_t<1, T> sol; + + std::vector ts = {1.1}; + + var a0 = 0.75; + std::vector a1 = {0.75}; + + var output + = sol(stan::math::ode_closure_adapter(), y0, t0, ts, 1e-8, 1e-10, 1e6, + nullptr, + stan::math::from_lambda( + [](auto& a, auto& t, auto& y, auto& b, std::ostream* msgs) { + return stan::test::Cos2Arg()(t, y, msgs, a, b); + }, + a0), + a1)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a0.adj(), -0.50107310888); + EXPECT_FLOAT_EQ(a1[0].adj(), -0.50107310888); + } + + void test_tol_std_vector_std_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + std::vector a1(1, a0); + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0][0].adj(), -0.50107310888); + } + + void test_tol_std_vector_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } + + void test_tol_std_vector_row_vector_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } + + void test_tol_std_vector_matrix_args() { + stan::math::nested_rev_autodiff nested; + + std::vector ts = {1.1}; + + var a0 = 1.5; + Eigen::Matrix a1(1, 1); + a1 << a0; + std::vector> a2(1, a1); + + var output = apply_solver_ts_tol(ts, 1e-8, 1e-10, 1e6, a2)[0][0]; + + output.grad(); + + EXPECT_FLOAT_EQ(output.val(), 0.66457668563); + EXPECT_FLOAT_EQ(a2[0](0).adj(), -0.50107310888); + } +}; + +#endif