Skip to content

Commit 2895f1e

Browse files
committed
[BUG] Fix StudentT import/param, log_marginal_likelihood; add comprehensive tests for BayesianConjugateGLMRegressor
1 parent 6131211 commit 2895f1e

File tree

2 files changed

+455
-12
lines changed

2 files changed

+455
-12
lines changed

skpro/regression/bayesian/_glm_conjugate.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -346,17 +346,18 @@ def _predict_proba(self, X):
346346
# Student-t predictive if noise_prior_shape/rate are set
347347
if self.noise_prior_shape is not None and self.noise_prior_rate is not None:
348348
nu = 2 * self._noise_posterior_shape
349-
scale = (
349+
# predictive scale: sqrt(bN/aN * (1 + x^T Sigma_N x))
350+
pred_scale = np.sqrt(
350351
self._noise_posterior_rate
351352
/ self._noise_posterior_shape
352353
* (1 + pred_var_all_x_i)
353354
)
354-
from skpro.distributions.student_t import StudentT
355+
from skpro.distributions.t import TDistribution
355356

356357
mus = pred_mu.reshape(-1, 1).tolist()
357-
scales = scale.reshape(-1, 1).tolist()
358-
return StudentT(
359-
mu=mus, scale=scales, df=nu, columns=self._y_cols, index=idx
358+
sigmas = pred_scale.reshape(-1, 1).tolist()
359+
return TDistribution(
360+
mu=mus, sigma=sigmas, df=nu, columns=self._y_cols, index=idx
360361
)
361362
else:
362363
pred_sigma = np.sqrt(pred_var_all_x_i + 1 / self.noise_precision)
@@ -379,18 +380,27 @@ def log_marginal_likelihood(self, X, y):
379380
float
380381
Log marginal likelihood (evidence).
381382
"""
382-
# Convert to numpy arrays
383-
if isinstance(X, (np.ndarray, np.generic)):
384-
X_arr = X
383+
import pandas as pd
384+
385+
# Apply the same intercept logic used in _fit / _predict_proba
386+
if isinstance(X, pd.DataFrame):
387+
X_df = X.copy()
388+
if self.add_constant:
389+
X_df = self._add_intercept(X_df)
390+
X_arr = X_df.to_numpy(dtype=float)
385391
else:
386-
X_arr = X.to_numpy(dtype=float)
392+
X_arr = np.array(X, dtype=float)
393+
if self.add_constant:
394+
X_arr = np.column_stack([np.ones(X_arr.shape[0]), X_arr])
395+
387396
if isinstance(y, (np.ndarray, np.generic)):
388397
y_arr = y
389398
else:
390399
y_arr = y.to_numpy(dtype=float)
400+
391401
N = X_arr.shape[0]
392-
S0 = self.coefs_prior_cov
393-
m0 = self.coefs_prior_mu
402+
S0 = self._coefs_prior_cov
403+
m0 = self._coefs_prior_mu
394404
tau = self.noise_precision
395405
SN_inv = np.linalg.inv(S0) + tau * (X_arr.T @ X_arr)
396406
SN = np.linalg.inv(SN_inv)
@@ -399,7 +409,7 @@ def log_marginal_likelihood(self, X, y):
399409
term1 = -0.5 * N * np.log(2 * np.pi)
400410
term2 = 0.5 * np.log(np.linalg.det(SN) / np.linalg.det(S0))
401411
term3 = -0.5 * tau * np.sum((y_arr - X_arr @ mN) ** 2)
402-
term4 = -0.5 * (mN - m0).T @ np.linalg.inv(S0) @ (mN - m0)
412+
term4 = -0.5 * ((mN - m0).T @ np.linalg.inv(S0) @ (mN - m0)).item()
403413
log_ml = term1 + term2 + term3 + term4
404414
return float(log_ml)
405415

0 commit comments

Comments
 (0)