forked from stan-dev/math
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbernoulli_logit_glm_lpmf.hpp
174 lines (159 loc) · 6.81 KB
/
bernoulli_logit_glm_lpmf.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#ifndef STAN_MATH_PRIM_PROB_BERNOULLI_LOGIT_GLM_LPMF_HPP
#define STAN_MATH_PRIM_PROB_BERNOULLI_LOGIT_GLM_LPMF_HPP
#include <stan/math/prim/meta.hpp>
#include <stan/math/prim/err.hpp>
#include <stan/math/prim/fun/constants.hpp>
#include <stan/math/prim/fun/Eigen.hpp>
#include <stan/math/prim/fun/exp.hpp>
#include <stan/math/prim/fun/size.hpp>
#include <stan/math/prim/fun/size_zero.hpp>
#include <stan/math/prim/fun/to_ref.hpp>
#include <stan/math/prim/fun/value_of_rec.hpp>
#include <cmath>
namespace stan {
namespace math {
/** \ingroup multivar_dists
* Returns the log PMF of the Generalized Linear Model (GLM)
* with Bernoulli distribution and logit link function.
* The idea is that bernoulli_logit_glm_lpmf(y, x, alpha, beta) should
* compute a more efficient version of bernoulli_logit_lpmf(y, alpha + x * beta)
* by using analytically simplified gradients.
* If containers are supplied, returns the log sum of the probabilities.
*
* @tparam T_y type of binary vector of dependent variables (labels);
* this can also be a single binary value;
* @tparam T_x_scalar type of a scalar in the matrix of independent variables
* (features)
* @tparam T_x_rows compile-time number of rows of `x`. It can be either
* `Eigen::Dynamic` or 1.
* @tparam T_alpha type of the intercept(s);
* this can be a vector (of the same length as y) of intercepts or a single
* value (for models with constant intercept);
* @tparam T_beta type of the weight vector
*
* @param y binary scalar or vector parameter. If it is a scalar it will be
* broadcast - used for all instances.
* @param x design matrix or row vector. If it is a row vector it will be
* broadcast - used for all instances.
* @param alpha intercept (in log odds)
* @param beta weight vector
* @return log probability or log sum of probabilities
* @throw std::domain_error if x, beta or alpha is infinite.
* @throw std::domain_error if y is not binary.
* @throw std::invalid_argument if container sizes mismatch.
*/
template <bool propto, typename T_y, typename T_x_scalar, int T_x_rows,
typename T_alpha, typename T_beta>
return_type_t<T_x_scalar, T_alpha, T_beta> bernoulli_logit_glm_lpmf(
const T_y &y, const Eigen::Matrix<T_x_scalar, T_x_rows, Eigen::Dynamic> &x,
const T_alpha &alpha, const T_beta &beta) {
using Eigen::Array;
using Eigen::Dynamic;
using Eigen::log1p;
using Eigen::Matrix;
using std::exp;
using T_partials_return = partials_return_t<T_y, T_x_scalar, T_alpha, T_beta>;
using T_y_val =
typename std::conditional_t<is_vector<T_y>::value,
Eigen::Matrix<partials_return_t<T_y>, -1, 1>,
partials_return_t<T_y>>;
using T_ytheta_tmp =
typename std::conditional_t<T_x_rows == 1, T_partials_return,
Array<T_partials_return, Dynamic, 1>>;
const size_t N_instances = T_x_rows == 1 ? stan::math::size(y) : x.rows();
const size_t N_attributes = x.cols();
static const char *function = "bernoulli_logit_glm_lpmf";
check_consistent_size(function, "Vector of dependent variables", y,
N_instances);
check_consistent_size(function, "Weight vector", beta, N_attributes);
check_consistent_size(function, "Vector of intercepts", alpha, N_instances);
check_bounded(function, "Vector of dependent variables", y, 0, 1);
if (size_zero(y)) {
return 0;
}
if (!include_summand<propto, T_x_scalar, T_alpha, T_beta>::value) {
return 0;
}
T_partials_return logp(0);
const auto &x_val = to_ref(value_of_rec(x));
const auto &y_val = value_of_rec(y);
const auto &beta_val = value_of_rec(beta);
const auto &alpha_val = value_of_rec(alpha);
const auto &y_val_vec = as_column_vector_or_scalar(y_val);
const auto &beta_val_vec = to_ref(as_column_vector_or_scalar(beta_val));
const auto &alpha_val_vec = as_column_vector_or_scalar(alpha_val);
T_y_val signs = 2 * as_array_or_scalar(y_val_vec) - 1;
Array<T_partials_return, Dynamic, 1> ytheta(N_instances);
if (T_x_rows == 1) {
T_ytheta_tmp ytheta_tmp
= forward_as<T_ytheta_tmp>((x_val * beta_val_vec)(0, 0));
ytheta = as_array_or_scalar(signs)
* (ytheta_tmp + as_array_or_scalar(alpha_val_vec));
} else {
ytheta = (x_val * beta_val_vec).array();
ytheta = as_array_or_scalar(signs)
* (ytheta + as_array_or_scalar(alpha_val_vec));
}
// Compute the log-density and handle extreme values gracefully
// using Taylor approximations.
// And compute the derivatives wrt theta.
static const double cutoff = 20.0;
Eigen::Array<T_partials_return, Dynamic, 1> exp_m_ytheta = exp(-ytheta);
logp += sum(
(ytheta > cutoff)
.select(-exp_m_ytheta,
(ytheta < -cutoff).select(ytheta, -log1p(exp_m_ytheta))));
if (!std::isfinite(logp)) {
check_finite(function, "Weight vector", beta);
check_finite(function, "Intercept", alpha);
check_finite(function, "Matrix of independent variables", ytheta);
}
operands_and_partials<Eigen::Matrix<T_x_scalar, T_x_rows, Eigen::Dynamic>,
T_alpha, T_beta>
ops_partials(x, alpha, beta);
// Compute the necessary derivatives.
if (!is_constant_all<T_beta, T_x_scalar, T_alpha>::value) {
Matrix<T_partials_return, Dynamic, 1> theta_derivative
= (ytheta > cutoff)
.select(-exp_m_ytheta,
(ytheta < -cutoff)
.select(as_array_or_scalar(signs),
as_array_or_scalar(signs) * exp_m_ytheta
/ (exp_m_ytheta + 1)));
if (!is_constant_all<T_beta>::value) {
if (T_x_rows == 1) {
ops_partials.edge3_.partials_
= forward_as<Matrix<T_partials_return, 1, Dynamic>>(
theta_derivative.sum() * x_val);
} else {
ops_partials.edge3_.partials_ = x_val.transpose() * theta_derivative;
}
}
if (!is_constant_all<T_x_scalar>::value) {
if (T_x_rows == 1) {
ops_partials.edge1_.partials_
= forward_as<Array<T_partials_return, Dynamic, T_x_rows>>(
beta_val_vec * theta_derivative.sum());
} else {
ops_partials.edge1_.partials_
= (beta_val_vec * theta_derivative.transpose()).transpose();
}
}
if (!is_constant_all<T_alpha>::value) {
if (is_vector<T_alpha>::value) {
ops_partials.edge2_.partials_ = theta_derivative;
} else {
ops_partials.edge2_.partials_[0] = sum(theta_derivative);
}
}
}
return ops_partials.build(logp);
}
template <typename T_y, typename T_x, typename T_alpha, typename T_beta>
inline return_type_t<T_x, T_beta, T_alpha> bernoulli_logit_glm_lpmf(
const T_y &y, const T_x &x, const T_alpha &alpha, const T_beta &beta) {
return bernoulli_logit_glm_lpmf<false>(y, x, alpha, beta);
}
} // namespace math
} // namespace stan
#endif