Skip to content

Commit d06f5fd

Browse files
t4c1stan-buildbot
andauthored
Let value_of and value_of_rec return expressions (#1872)
* let value_of and value_of_rec return expressions * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) * added missing include * added missing include and generalized hmm_marginal_lpdf_val * added more missing includes * addressed review comments * [Jenkins] auto-formatting by clang-format version 6.0.0-1ubuntu2~16.04.1 (tags/RELEASE_600/final) Co-authored-by: Stan Jenkins <[email protected]>
1 parent 3b2dbbd commit d06f5fd

20 files changed

+250
-56
lines changed

stan/math/opencl/prim/categorical_logit_glm_lpmf.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,6 @@ return_type_t<T_alpha_scalar, T_beta_scalar> categorical_logit_glm_lpmf(
7272
const auto& beta_val = value_of_rec(beta);
7373
const auto& alpha_val = value_of_rec(alpha);
7474

75-
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();
76-
7775
const int local_size
7876
= opencl_kernels::categorical_logit_glm.get_option("LOCAL_SIZE_");
7977
const int wgs = (N_instances + local_size - 1) / local_size;

stan/math/opencl/prim/neg_binomial_2_log_glm_lpmf.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,6 @@ return_type_t<T_alpha, T_beta, T_precision> neg_binomial_2_log_glm_lpmf(
100100
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
101101
const auto& phi_val_vec = as_column_vector_or_scalar(phi_val);
102102

103-
const auto& phi_arr = as_array_or_scalar(phi_val_vec);
104-
105103
const int local_size
106104
= opencl_kernels::neg_binomial_2_log_glm.get_option("LOCAL_SIZE_");
107105
const int wgs = (N + local_size - 1) / local_size;

stan/math/opencl/prim/normal_id_glm_lpdf.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/size.hpp>
99
#include <stan/math/prim/fun/size_zero.hpp>
1010
#include <stan/math/prim/fun/sum.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1112
#include <stan/math/prim/fun/value_of_rec.hpp>
1213
#include <stan/math/prim/prob/normal_id_glm_lpdf.hpp>
1314
#include <stan/math/opencl/copy.hpp>
@@ -96,7 +97,7 @@ return_type_t<T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
9697

9798
const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
9899
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
99-
const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val);
100+
const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val));
100101

101102
T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec);
102103
Matrix<T_partials_return, Dynamic, 1> y_minus_mu_over_sigma_mat(N);

stan/math/opencl/prim/ordered_logistic_glm_lpmf.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ return_type_t<T_beta_scalar, T_cuts_scalar> ordered_logistic_glm_lpmf(
8383
const auto& beta_val = value_of_rec(beta);
8484
const auto& cuts_val = value_of_rec(cuts);
8585

86-
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
87-
const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val);
88-
8986
operands_and_partials<Eigen::Matrix<T_beta_scalar, Eigen::Dynamic, 1>,
9087
Eigen::Matrix<T_cuts_scalar, Eigen::Dynamic, 1>>
9188
ops_partials(beta, cuts);

stan/math/prim/fun.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@
306306
#include <stan/math/prim/fun/to_array_1d.hpp>
307307
#include <stan/math/prim/fun/to_array_2d.hpp>
308308
#include <stan/math/prim/fun/to_matrix.hpp>
309+
#include <stan/math/prim/fun/to_ref.hpp>
309310
#include <stan/math/prim/fun/to_row_vector.hpp>
310311
#include <stan/math/prim/fun/to_vector.hpp>
311312
#include <stan/math/prim/fun/trace.hpp>

stan/math/prim/fun/log_mix.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/log1m.hpp>
99
#include <stan/math/prim/fun/log_sum_exp.hpp>
1010
#include <stan/math/prim/fun/size.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1112
#include <stan/math/prim/fun/value_of.hpp>
1213
#include <vector>
1314
#include <cmath>
@@ -86,8 +87,8 @@ return_type_t<T_theta, T_lam> log_mix(const T_theta& theta,
8687
check_finite(function, "theta", theta);
8788
check_consistent_sizes(function, "theta", theta, "lambda", lambda);
8889

89-
const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta));
90-
const auto& lam_dbl = value_of(as_column_vector_or_scalar(lambda));
90+
const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta)));
91+
const auto& lam_dbl = to_ref(value_of(as_column_vector_or_scalar(lambda)));
9192

9293
T_partials_return logp = log_sum_exp(log(theta_dbl) + lam_dbl);
9394

@@ -158,7 +159,7 @@ return_type_t<T_theta, std::vector<T_lam>> log_mix(
158159
check_consistent_sizes(function, "theta", theta, "lambda", lambda[n]);
159160
}
160161

161-
const auto& theta_dbl = value_of(as_column_vector_or_scalar(theta));
162+
const auto& theta_dbl = to_ref(value_of(as_column_vector_or_scalar(theta)));
162163

163164
T_partials_mat lam_dbl(M, N);
164165
for (int n = 0; n < N; ++n) {

stan/math/prim/fun/to_ref.hpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#ifndef STAN_MATH_PRIM_FUN_TO_REF_HPP
2+
#define STAN_MATH_PRIM_FUN_TO_REF_HPP
3+
4+
#include <stan/math/prim/meta.hpp>
5+
6+
namespace stan {
7+
namespace math {
8+
9+
/**
10+
* No-op that should be optimized away.
11+
* @tparam T non-Eigen argument type
12+
* @param a argument
13+
* @return argument
14+
*/
15+
template <typename T, require_not_eigen_t<T>* = nullptr>
16+
inline T to_ref(T&& a) {
17+
return std::forward<T>(a);
18+
}
19+
20+
/**
21+
* Converts Eigen argument into `Eigen::Ref`. This evaluate expensive
22+
* expressions.
23+
* @tparam T argument type (Eigen expression)
24+
* @param a argument
25+
* @return argument converted to `Eigen::Ref`
26+
*/
27+
template <typename T, require_eigen_t<T>* = nullptr>
28+
inline Eigen::Ref<const plain_type_t<T>> to_ref(T&& a) {
29+
return std::forward<T>(a);
30+
}
31+
32+
} // namespace math
33+
} // namespace stan
34+
#endif

stan/math/prim/fun/value_of.hpp

+4-12
Original file line numberDiff line numberDiff line change
@@ -99,22 +99,15 @@ inline Vec value_of(Vec&& x) {
9999
* T must implement value_of. See
100100
* test/math/fwd/fun/value_of.cpp for fvar and var usage.
101101
*
102-
* @tparam T type of elements in the matrix
103-
* @tparam R number of rows in the matrix, can be Eigen::Dynamic
104-
* @tparam C number of columns in the matrix, can be Eigen::Dynamic
102+
* @tparam EigMat type of the matrix
105103
*
106104
* @param[in] M Matrix to be converted
107105
* @return Matrix of values
108106
**/
109107
template <typename EigMat, require_eigen_t<EigMat>* = nullptr,
110108
require_not_vt_double_or_int<EigMat>* = nullptr>
111-
inline Eigen::Matrix<typename child_type<value_type_t<EigMat>>::type,
112-
EigMat::RowsAtCompileTime, EigMat::ColsAtCompileTime>
113-
value_of(const EigMat& M) {
114-
return M.array()
115-
.unaryExpr([](const auto& scal) { return value_of(scal); })
116-
.matrix()
117-
.eval();
109+
inline auto value_of(const EigMat& M) {
110+
return M.unaryExpr([](const auto& scal) { return value_of(scal); });
118111
}
119112

120113
/**
@@ -125,8 +118,7 @@ value_of(const EigMat& M) {
125118
*
126119
* <p>This inline pass-through no-op should be compiled away.
127120
*
128-
* @tparam R number of rows in the matrix, can be Eigen::Dynamic
129-
* @tparam C number of columns in the matrix, can be Eigen::Dynamic
121+
* @tparam EigMat type of the matrix
130122
*
131123
* @param x Specified matrix.
132124
* @return Specified matrix.

stan/math/prim/fun/value_of_rec.hpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ inline std::complex<double> value_of_rec(const std::complex<T>& x) {
6565
* @param[in] x std::vector to be converted
6666
* @return std::vector of values
6767
**/
68-
template <typename T>
68+
template <typename T, require_not_same_t<double, T>* = nullptr>
6969
inline std::vector<double> value_of_rec(const std::vector<T>& x) {
7070
size_t x_size = x.size();
7171
std::vector<double> result(x_size);
@@ -86,8 +86,10 @@ inline std::vector<double> value_of_rec(const std::vector<T>& x) {
8686
* @param x Specified std::vector.
8787
* @return Specified std::vector.
8888
*/
89-
inline const std::vector<double>& value_of_rec(const std::vector<double>& x) {
90-
return x;
89+
template <typename T, require_std_vector_t<T>* = nullptr,
90+
require_vt_same<double, T>* = nullptr>
91+
inline T value_of_rec(T&& x) {
92+
return std::forward<T>(x);
9193
}
9294

9395
/**
@@ -120,8 +122,8 @@ inline auto value_of_rec(const T& M) {
120122
*/
121123
template <typename T, typename = require_st_same<T, double>,
122124
typename = require_eigen_t<T>>
123-
inline const T& value_of_rec(const T& x) {
124-
return x;
125+
inline T value_of_rec(T&& x) {
126+
return std::forward<T>(x);
125127
}
126128
} // namespace math
127129
} // namespace stan

stan/math/prim/prob/bernoulli_logit_glm_lpmf.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/exp.hpp>
99
#include <stan/math/prim/fun/size.hpp>
1010
#include <stan/math/prim/fun/size_zero.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1112
#include <stan/math/prim/fun/value_of_rec.hpp>
1213
#include <cmath>
1314

@@ -82,13 +83,13 @@ return_type_t<T_x_scalar, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
8283
}
8384

8485
T_partials_return logp(0);
85-
const auto &x_val = value_of_rec(x);
86+
const auto &x_val = to_ref(value_of_rec(x));
8687
const auto &y_val = value_of_rec(y);
8788
const auto &beta_val = value_of_rec(beta);
8889
const auto &alpha_val = value_of_rec(alpha);
8990

9091
const auto &y_val_vec = as_column_vector_or_scalar(y_val);
91-
const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
92+
const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
9293
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
9394

9495
T_y_val signs = 2 * as_array_or_scalar(y_val_vec) - 1;

stan/math/prim/prob/categorical_logit_glm_lpmf.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <stan/math/prim/fun/log.hpp>
88
#include <stan/math/prim/fun/size.hpp>
99
#include <stan/math/prim/fun/size_zero.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
11+
#include <stan/math/prim/fun/value_of_rec.hpp>
1012
#include <Eigen/Core>
1113
#include <cmath>
1214

@@ -73,8 +75,8 @@ categorical_logit_glm_lpmf(
7375
return 0;
7476
}
7577

76-
const auto& x_val = value_of_rec(x);
77-
const auto& beta_val = value_of_rec(beta);
78+
const auto& x_val = to_ref(value_of_rec(x));
79+
const auto& beta_val = to_ref(value_of_rec(beta));
7880
const auto& alpha_val = value_of_rec(alpha);
7981

8082
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val).transpose();

stan/math/prim/prob/hmm_marginal_lpdf.hpp

+14-10
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@
77
#include <stan/math/prim/fun/col.hpp>
88
#include <stan/math/prim/fun/transpose.hpp>
99
#include <stan/math/prim/fun/exp.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
1011
#include <stan/math/prim/fun/value_of.hpp>
1112
#include <stan/math/prim/core.hpp>
1213
#include <vector>
1314

1415
namespace stan {
1516
namespace math {
1617

17-
template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alpha>
18-
inline auto hmm_marginal_lpdf_val(
19-
const Eigen::Matrix<T_omega, Eigen::Dynamic, Eigen::Dynamic>& omegas,
20-
const Eigen::Matrix<T_Gamma, Eigen::Dynamic, Eigen::Dynamic>& Gamma_val,
21-
const Eigen::Matrix<T_rho, Eigen::Dynamic, 1>& rho_val,
22-
Eigen::Matrix<T_alpha, Eigen::Dynamic, Eigen::Dynamic>& alphas,
23-
Eigen::Matrix<T_alpha, Eigen::Dynamic, 1>& alpha_log_norms,
24-
T_alpha& norm_norm) {
18+
template <typename T_omega, typename T_Gamma, typename T_rho, typename T_alphas,
19+
typename T_alpha_log_norm, typename T_norm,
20+
require_all_eigen_matrix_t<T_omega, T_Gamma, T_alphas>* = nullptr,
21+
require_all_eigen_col_vector_t<T_rho, T_alpha_log_norm>* = nullptr,
22+
require_stan_scalar_t<T_norm>* = nullptr,
23+
require_all_vt_same<T_alphas, T_alpha_log_norm, T_norm>* = nullptr>
24+
inline auto hmm_marginal_lpdf_val(const T_omega& omegas,
25+
const T_Gamma& Gamma_val,
26+
const T_rho& rho_val, T_alphas& alphas,
27+
T_alpha_log_norm& alpha_log_norms,
28+
T_norm& norm_norm) {
2529
const int n_states = omegas.rows();
2630
const int n_transitions = omegas.cols() - 1;
2731
alphas.col(0) = omegas.col(0).cwiseProduct(rho_val);
@@ -100,10 +104,10 @@ inline auto hmm_marginal_lpdf(
100104

101105
eig_matrix_partial alphas(n_states, n_transitions + 1);
102106
eig_vector_partial alpha_log_norms(n_transitions + 1);
103-
auto Gamma_val = value_of(Gamma);
107+
const auto& Gamma_val = to_ref(value_of(Gamma));
104108

105109
// compute the density using the forward algorithm.
106-
auto rho_val = value_of(rho);
110+
const auto& rho_val = to_ref(value_of(rho));
107111
eig_matrix_partial omegas = value_of(log_omegas).array().exp();
108112
T_partial_type norm_norm;
109113
auto log_marginal_density = hmm_marginal_lpdf_val(

stan/math/prim/prob/neg_binomial_2_log_glm_lpmf.hpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <stan/math/prim/fun/multiply_log.hpp>
1212
#include <stan/math/prim/fun/size.hpp>
1313
#include <stan/math/prim/fun/sum.hpp>
14+
#include <stan/math/prim/fun/to_ref.hpp>
1415
#include <stan/math/prim/fun/value_of_rec.hpp>
1516
#include <vector>
1617
#include <cmath>
@@ -105,16 +106,16 @@ neg_binomial_2_log_glm_lpmf(
105106
}
106107

107108
T_partials_return logp(0);
108-
const auto& x_val = value_of_rec(x);
109+
const auto& x_val = to_ref(value_of_rec(x));
109110
const auto& y_val = value_of_rec(y);
110111
const auto& beta_val = value_of_rec(beta);
111112
const auto& alpha_val = value_of_rec(alpha);
112113
const auto& phi_val = value_of_rec(phi);
113114

114-
const auto& y_val_vec = as_column_vector_or_scalar(y_val);
115-
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
115+
const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val));
116+
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
116117
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
117-
const auto& phi_val_vec = as_column_vector_or_scalar(phi_val);
118+
const auto& phi_val_vec = to_ref(as_column_vector_or_scalar(phi_val));
118119

119120
const auto& y_arr = as_array_or_scalar(y_val_vec);
120121
const auto& phi_arr = as_array_or_scalar(phi_val_vec);
@@ -147,7 +148,7 @@ neg_binomial_2_log_glm_lpmf(
147148
}
148149
if (include_summand<propto, T_precision>::value) {
149150
if (is_vector<T_precision>::value) {
150-
scalar_seq_view<decltype(phi_val)> phi_vec(phi_val);
151+
scalar_seq_view<decltype(phi_val_vec)> phi_vec(phi_val_vec);
151152
for (size_t n = 0; n < N_instances; ++n) {
152153
logp += multiply_log(phi_vec[n], phi_vec[n]) - lgamma(phi_vec[n]);
153154
}

stan/math/prim/prob/normal_id_glm_lpdf.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/size.hpp>
99
#include <stan/math/prim/fun/size_zero.hpp>
1010
#include <stan/math/prim/fun/sum.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1112
#include <stan/math/prim/fun/value_of_rec.hpp>
1213
#include <cmath>
1314

@@ -88,15 +89,15 @@ return_type_t<T_y, T_x_scalar, T_alpha, T_beta, T_scale> normal_id_glm_lpdf(
8889
return 0;
8990
}
9091

91-
const auto &x_val = value_of_rec(x);
92+
const auto &x_val = to_ref(value_of_rec(x));
9293
const auto &beta_val = value_of_rec(beta);
9394
const auto &alpha_val = value_of_rec(alpha);
9495
const auto &sigma_val = value_of_rec(sigma);
9596
const auto &y_val = value_of_rec(y);
9697

97-
const auto &beta_val_vec = as_column_vector_or_scalar(beta_val);
98+
const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
9899
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
99-
const auto &sigma_val_vec = as_column_vector_or_scalar(sigma_val);
100+
const auto &sigma_val_vec = to_ref(as_column_vector_or_scalar(sigma_val));
100101
const auto &y_val_vec = as_column_vector_or_scalar(y_val);
101102

102103
T_scale_val inv_sigma = 1 / as_array_or_scalar(sigma_val_vec);

stan/math/prim/prob/ordered_logistic_glm_lpmf.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <stan/math/prim/fun/log1m_exp.hpp>
88
#include <stan/math/prim/fun/size.hpp>
99
#include <stan/math/prim/fun/size_zero.hpp>
10+
#include <stan/math/prim/fun/to_ref.hpp>
1011
#include <stan/math/prim/fun/value_of_rec.hpp>
1112
#include <cmath>
1213

@@ -81,12 +82,12 @@ ordered_logistic_glm_lpmf(
8182
if (!include_summand<propto, T_x_scalar, T_beta_scalar, T_cuts_scalar>::value)
8283
return 0;
8384

84-
const auto& x_val = value_of_rec(x);
85+
const auto& x_val = to_ref(value_of_rec(x));
8586
const auto& beta_val = value_of_rec(beta);
8687
const auto& cuts_val = value_of_rec(cuts);
8788

88-
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
89-
const auto& cuts_val_vec = as_column_vector_or_scalar(cuts_val);
89+
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
90+
const auto& cuts_val_vec = to_ref(as_column_vector_or_scalar(cuts_val));
9091

9192
scalar_seq_view<T_y> y_seq(y);
9293
Array<double, Dynamic, 1> cuts_y1(N_instances), cuts_y2(N_instances);

stan/math/prim/prob/poisson_log_glm_lpmf.hpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stan/math/prim/fun/lgamma.hpp>
99
#include <stan/math/prim/fun/size.hpp>
1010
#include <stan/math/prim/fun/size_zero.hpp>
11+
#include <stan/math/prim/fun/to_ref.hpp>
1112
#include <stan/math/prim/fun/value_of_rec.hpp>
1213
#include <cmath>
1314

@@ -82,13 +83,13 @@ return_type_t<T_x_scalar, T_alpha, T_beta> poisson_log_glm_lpmf(
8283

8384
T_partials_return logp(0);
8485

85-
const auto& x_val = value_of_rec(x);
86+
const auto& x_val = to_ref(value_of_rec(x));
8687
const auto& y_val = value_of_rec(y);
8788
const auto& beta_val = value_of_rec(beta);
8889
const auto& alpha_val = value_of_rec(alpha);
8990

90-
const auto& y_val_vec = as_column_vector_or_scalar(y_val);
91-
const auto& beta_val_vec = as_column_vector_or_scalar(beta_val);
91+
const auto& y_val_vec = to_ref(as_column_vector_or_scalar(y_val));
92+
const auto& beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
9293
const auto& alpha_val_vec = as_column_vector_or_scalar(alpha_val);
9394

9495
Array<T_partials_return, Dynamic, 1> theta(N_instances);

0 commit comments

Comments
 (0)