Skip to content

Commit 84869a0

Browse files
author
Pavan Ramkumar
authored
Merge pull request #224 from jasmainak/logexp
Bring back logexp trick
2 parents 09e1cba + 50f6cc6 commit 84869a0

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

pyglmnet/pyglmnet.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,21 @@ def _grad_mu(distr, z, eta):
5454
return grad_mu
5555

5656

57-
def _logL(distr, y, y_hat):
57+
def _logL(distr, y, y_hat, z=None):
5858
"""The log likelihood."""
5959
if distr in ['softplus', 'poisson']:
6060
eps = np.spacing(1)
6161
logL = np.sum(y * np.log(y_hat + eps) - y_hat)
6262
elif distr == 'gaussian':
6363
logL = -0.5 * np.sum((y - y_hat)**2)
6464
elif distr == 'binomial':
65-
# analytical formula
66-
logL = np.sum(y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat))
6765

68-
# but this prevents underflow
69-
# z = beta0 + np.dot(X, beta)
70-
# logL = np.sum(y * z - np.log(1 + np.exp(z)))
66+
# prevents underflow
67+
if z is not None:
68+
logL = np.sum(y * z - np.log(1 + np.exp(z)))
69+
# for scoring
70+
else:
71+
logL = np.sum(y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat))
7172
elif distr == 'probit':
7273
logL = np.sum(y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat))
7374
elif distr == 'gamma':
@@ -123,8 +124,9 @@ def _L1penalty(beta, group=None):
123124
def _loss(distr, alpha, Tau, reg_lambda, X, y, eta, group, beta):
124125
"""Define the objective function for elastic net."""
125126
n_samples = X.shape[0]
126-
y_hat = _mu(distr, beta[0] + np.dot(X, beta[1:]), eta)
127-
L = 1. / n_samples * _logL(distr, y, y_hat)
127+
z = beta[0] + np.dot(X, beta[1:])
128+
y_hat = _mu(distr, z, eta)
129+
L = 1. / n_samples * _logL(distr, y, y_hat, z)
128130
P = _penalty(alpha, beta[1:], Tau, group)
129131
J = -L + reg_lambda * P
130132
return J
@@ -133,8 +135,9 @@ def _loss(distr, alpha, Tau, reg_lambda, X, y, eta, group, beta):
133135
def _L2loss(distr, alpha, Tau, reg_lambda, X, y, eta, group, beta):
134136
"""Define the objective function for elastic net."""
135137
n_samples = X.shape[0]
136-
y_hat = _mu(distr, beta[0] + np.dot(X, beta[1:]), eta)
137-
L = 1. / n_samples * _logL(distr, y, y_hat)
138+
z = beta[0] + np.dot(X, beta[1:])
139+
y_hat = _mu(distr, z, eta)
140+
L = 1. / n_samples * _logL(distr, y, y_hat, z)
138141
P = 0.5 * (1 - alpha) * _L2penalty(beta[1:], Tau)
139142
J = -L + reg_lambda * P
140143
return J

0 commit comments

Comments
 (0)