Skip to content
1 change: 1 addition & 0 deletions stan/math/fwd/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
#include <stan/math/fwd/fun/log_rising_factorial.hpp>
#include <stan/math/fwd/fun/log_softmax.hpp>
#include <stan/math/fwd/fun/log_sum_exp.hpp>
#include <stan/math/fwd/fun/log_add_exp.hpp>
#include <stan/math/fwd/fun/logit.hpp>
#include <stan/math/fwd/fun/mdivide_left.hpp>
#include <stan/math/fwd/fun/mdivide_left_ldlt.hpp>
Expand Down
162 changes: 162 additions & 0 deletions stan/math/fwd/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
#ifndef STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP
#define STAN_MATH_FWD_FUN_LOG_ADD_EXP_HPP

#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/fwd/meta.hpp>
#include <stan/math/fwd/core.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/log_add_exp.hpp>
#include <cmath>
#include <vector>

namespace stan {
namespace math {

// Overload for fvar and fvar
template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, const fvar<T>& x2) {
auto val = stan::math::log_add_exp(x1.val_, x2.val_);

auto exp_x1 = stan::math::exp(x1.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto sum_exp = exp_x1 + exp_x2;

auto grad1 = exp_x1 / sum_exp;
auto grad2 = exp_x2 / sum_exp;

return fvar<T>(val, x1.d_ * grad1 + x2.d_ * grad2);
}

template <typename T>
inline fvar<T> log_add_exp(const fvar<T>& x1, double x2) {
if (x1.val_ == NEGATIVE_INFTY) {
return fvar<T>(x2, 0.0); // log_add_exp(-∞, b) = b
}
return log_add_exp(x2, x1);
}

template <typename T>
inline fvar<T> log_add_exp(double x1, const fvar<T>& x2) {
if (x2.val_ == NEGATIVE_INFTY) {
return fvar<T>(x1, 0.0); // log_add_exp(a, -∞) = a
}
auto val = stan::math::log_add_exp(x1, x2.val_);
auto exp_x2 = stan::math::exp(x2.val_);
auto grad = exp_x2 / (stan::math::exp(x1) + exp_x2);
return fvar<T>(val, x2.d_ * grad);
}

// Overload for matrices of fvar
template <typename T>
inline Eigen::Matrix<fvar<T>, -1, -1> log_add_exp(
const Eigen::Matrix<fvar<T>, -1, -1>& a,
const Eigen::Matrix<fvar<T>, -1, -1>& b) {
using fvar_mat_type = Eigen::Matrix<fvar<T>, -1, -1>;
fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Apply the log_add_exp operation directly
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
result(i, j) = stan::math::log_add_exp(a(i, j), b(i, j));
}
}

return result; // Return the result matrix
}

// Overload for Eigen vectors
template <typename T>
inline Eigen::Matrix<fvar<T>, -1, 1> log_add_exp(
const Eigen::Matrix<fvar<T>, -1, 1>& a,
const Eigen::Matrix<fvar<T>, -1, 1>& b) {
using fvar_vec_type = Eigen::Matrix<fvar<T>, -1, 1>;
fvar_vec_type result(a.rows());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(fvar<T>(std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Apply the log_add_exp operation directly
for (int i = 0; i < a.rows(); ++i) {
result(i) = stan::math::log_add_exp(a(i), b(i));
}

return result; // Return the result vector
}

// Specialization for nested fvar types
template <typename T>
inline auto log_add_exp(
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>& a,
const Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>&
b) {
using nested_fvar_mat_type
= Eigen::Matrix<stan::math::fvar<stan::math::fvar<double>>, -1, -1>;
nested_fvar_mat_type result(a.rows(), a.cols());

// Check for empty inputs
if (a.size() == 0 || b.size() == 0) {
throw std::invalid_argument("Input containers must not be empty.");
}

// Check for NaN
if (a.array().isNaN().any() || b.array().isNaN().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Check for infinity
if (a.array().isInf().any() || b.array().isInf().any()) {
result.setConstant(stan::math::fvar<stan::math::fvar<double>>(
std::numeric_limits<double>::quiet_NaN()));
return result;
}

// Implement the logic for log_add_exp for nested fvar types
for (int i = 0; i < a.rows(); ++i) {
for (int j = 0; j < a.cols(); ++j) {
auto inner_a = a(i, j);
auto inner_b = b(i, j);
result(i, j) = stan::math::log_add_exp(inner_a, inner_b);
}
}

return result; // Return the result matrix
}

} // namespace math
} // namespace stan

#endif
1 change: 1 addition & 0 deletions stan/math/prim/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@
#include <stan/math/prim/fun/log_softmax.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp_signed.hpp>
#include <stan/math/prim/fun/log_add_exp.hpp>
#include <stan/math/prim/fun/logical_and.hpp>
#include <stan/math/prim/fun/logical_eq.hpp>
#include <stan/math/prim/fun/logical_gt.hpp>
Expand Down
159 changes: 159 additions & 0 deletions stan/math/prim/fun/log_add_exp.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#ifndef STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP
#define STAN_MATH_PRIM_FUN_LOG_ADD_EXP_HPP

#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/log1p_exp.hpp>
#include <stan/math/prim/fun/log_sum_exp.hpp>
#include <stan/math/prim/functor/apply_scalar_binary.hpp>
#include <cmath>
#include <vector>
#include <algorithm>
#include <stan/math/prim/err/check_matching_dims.hpp>
#include <stan/math/prim/meta/is_eigen.hpp>

namespace stan {
namespace math {

/**
* Calculates the elementwise sum of exponentials without overflow.
*
* \f$\log (\exp(a) + \exp(b)) = m + \log(\exp(a-m) + \exp(b-m))\f$,
*
* where \f$m = max(a, b)\f$.
*
* @tparam T1 type of the first variable
* @tparam T2 type of the second variable
* @param a the first variable
* @param b the second variable
*/

template <typename T1, typename T2, require_all_not_st_var<T1, T2>* = nullptr,
require_all_stan_scalar_t<T1, T2>* = nullptr>
inline return_type_t<T1, T2> log_add_exp(const T2& a, const T1& b) {
if (a == NEGATIVE_INFTY) {
return b;
}
if (b == NEGATIVE_INFTY) {
return a;
}
if (a == INFTY || b == INFTY) {
return INFTY;
}

const double max_val = std::max(a, b);
return max_val + std::log(std::exp(a - max_val) + std::exp(b - max_val));
}

/**
* Calculates the element-wise log sum of exponentials for two containers.
* For vectors a and b, computes log(exp(a[i]) + exp(b[i])) for each element i.
* If sizes don't match, uses the smaller size.
*
* @tparam T1 type of first container
* @tparam T2 type of second container
* @param a First input container
* @param b Second input container
* @return Container with element-wise log_add_exp results
*/
template <typename T, require_container_st<std::is_arithmetic, T>* = nullptr>
inline auto log_add_exp(const T& a, const T& b) {
// Check if sizes are compatible
if constexpr (stan::is_eigen<T>::value) {
// Check if both matrices/vectors have the same dimensions
stan::math::check_matching_dims("log_add_exp", "a", a, "b", b);

// Determine the number of rows and columns for the result
size_t rows = a.rows();
size_t cols = b.cols();
using return_t = return_type_t<T>;

Eigen::Matrix<return_t, Eigen::Dynamic, Eigen::Dynamic> result(rows, cols);

// Iterate over each element
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
double a_val = (a.cols() == 1)
? a(i, 0)
: a(i, j); // Handle column vector or matrix
double b_val = (b.rows() == 1)
? b(0, j)
: b(i, j); // Handle row vector or matrix

if (a_val == NEGATIVE_INFTY) {
result(i, j) = b_val;
} else if (b_val == NEGATIVE_INFTY) {
result(i, j) = a_val;
} else if (a_val == INFTY || b_val == INFTY) {
result(i, j) = INFTY;
} else {
result(i, j) = log_sum_exp(a_val, b_val);
}
}
}

return result;
} else if constexpr (std::is_same_v<T, std::vector<typename T::value_type>>) {
// Handle std::vector
if (a.size() != b.size()) {
throw std::invalid_argument("Sizes of x and y must match.");
}

using return_t = return_type_t<T>;
std::vector<return_t> result(a.size());

for (size_t i = 0; i < a.size(); ++i) {
double a_val = a[i];
double b_val = b[i];

if (a_val == NEGATIVE_INFTY) {
result[i] = b_val;
} else if (b_val == NEGATIVE_INFTY) {
result[i] = a_val;
} else if (a_val == INFTY || b_val == INFTY) {
result[i] = INFTY;
} else {
result[i] = log_sum_exp(a_val, b_val);
}
}

return result;
} else {
throw std::invalid_argument("Unsupported container type.");
}
}

/**
* Enables the vectorized application of the log_add_exp function,
* when the first and/or second arguments are containers.
*
* @tparam T1
* @tparam T2
* @param a
* @param b
* @return auto
*/
template <typename T1, typename T2, require_any_container_t<T1, T2>* = nullptr>
inline auto log_add_exp(const T1& a, const T2& b) {
// Check if both are Eigen/vectors
if constexpr (stan::is_eigen<T1>::value && stan::is_eigen<T2>::value) {
// Check if both matrices/vectors have the same dimensions
stan::math::check_matching_dims("log_add_exp", "a", a, "b", b);
} else {
// Check if sizes are compatible for other types
if (a.size() != b.size()) {
throw std::invalid_argument(
"Sizes of x and y must match or be compatible.");
}
}

// If dimensions are verified to match, apply the operation
return apply_scalar_binary(
a, b, [](const auto& c, const auto& d) { return log_add_exp(c, d); });
}

} // namespace math
} // namespace stan

#endif
1 change: 1 addition & 0 deletions stan/math/rev/fun.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
#include <stan/math/rev/fun/log_rising_factorial.hpp>
#include <stan/math/rev/fun/log_softmax.hpp>
#include <stan/math/rev/fun/log_sum_exp.hpp>
#include <stan/math/rev/fun/log_add_exp.hpp>
#include <stan/math/rev/fun/logit.hpp>
#include <stan/math/rev/fun/matrix_exp_multiply.hpp>
#include <stan/math/rev/fun/matrix_power.hpp>
Expand Down
Loading