|
7 | 7 | #include <stan/math/prim/fun/col.hpp>
|
8 | 8 | #include <stan/math/prim/fun/transpose.hpp>
|
9 | 9 | #include <stan/math/prim/fun/exp.hpp>
|
| 10 | +#include <stan/math/prim/fun/to_ref.hpp> |
10 | 11 | #include <stan/math/prim/fun/value_of.hpp>
|
11 | 12 | #include <stan/math/prim/core.hpp>
|
12 | 13 | #include <vector>
|
13 | 14 |
|
14 | 15 | namespace stan {
|
15 | 16 | namespace math {
|
16 | 17 |
|
17 |
| -template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alpha> |
18 |
| -inline auto hmm_marginal_lpdf_val( |
19 |
| - const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& omegas, |
20 |
| - const Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>& Gamma_val, |
21 |
| - const Eigen::Matrix<T_rho, Eigen::Dynamic, 1>& rho_val, |
22 |
| - Eigen::Matrix<T_alpha, Eigen::Dynamic, Eigen::Dynamic>& alphas, |
23 |
| - Eigen::Matrix<T_alpha, Eigen::Dynamic, 1>& alpha_log_norms, |
24 |
| - T_alpha& norm_norm) { |
| 18 | +template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alphas, |
| 19 | + typename T_alpha_log_norm, typename T_norm, |
| 20 | + require_all_eigen_matrix_t<T_omega, T_Gamma, T_alphas>* = nullptr, |
| 21 | + require_all_eigen_col_vector_t<T_rho, T_alpha_log_norm>* = nullptr, |
| 22 | + require_stan_scalar_t<T_norm>* = nullptr, |
| 23 | + require_all_vt_same<T_alphas, T_alpha_log_norm, T_norm>* = nullptr> |
| 24 | +inline auto hmm_marginal_lpdf_val(const T_omega& omegas, |
| 25 | + const T_Gamma& Gamma_val, |
| 26 | + const T_rho& rho_val, T_alphas& alphas, |
| 27 | + T_alpha_log_norm& alpha_log_norms, |
| 28 | + T_norm& norm_norm) { |
25 | 29 | const int n_states = omegas.rows();
|
26 | 30 | const int n_transitions = omegas.cols() - 1;
|
27 | 31 | alphas.col(0) = omegas.col(0).cwiseProduct(rho_val);
|
@@ -100,10 +104,10 @@ inline auto hmm_marginal_lpdf(
|
100 | 104 |
|
101 | 105 | eig_matrix_partial alphas(n_states, n_transitions + 1);
|
102 | 106 | eig_vector_partial alpha_log_norms(n_transitions + 1);
|
103 |
| - auto Gamma_val = value_of(Gamma); |
| 107 | + const auto& Gamma_val = to_ref(value_of(Gamma)); |
104 | 108 |
|
105 | 109 | // compute the density using the forward algorithm.
|
106 |
| - auto rho_val = value_of(rho); |
| 110 | + const auto& rho_val = to_ref(value_of(rho)); |
107 | 111 | eig_matrix_partial omegas = value_of(log_omegas).array().exp();
|
108 | 112 | T_partial_type norm_norm;
|
109 | 113 | auto log_marginal_density = hmm_marginal_lpdf_val(
|
|
0 commit comments