forked from rsagroup/rsaModelComparison
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrsa_marglRidgeEB.m
More file actions
47 lines (43 loc) · 2.02 KB
/
rsa_marglRidgeEB.m
File metadata and controls
47 lines (43 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
function [nlmlSum,dnlmlSum,wN,VN,nlml] = rsa_marglRidgeEB(logtheta, X, Y, Sigma);
% function [nlmlSum,dnlmlSum,wN,VN,nlml] = marglRidgeEB(logtheta, X, Y, Sigma);
% Calculates the marginal log-likelihood and the derivates in respect to
% the hyperparameters, summed over the different data sets (subjects)
%
% nlml is the returned value of the negative log marginal likelihood
% dnlml is a (column) vector of partial derivatives of the negative
% log marginal likelihood wrt each log hyperparameter
% wN is the posterior Model of the regression coefficients
% VN is the posterior Variance of the regression coefficients
%
% Joern Diedrichsen
[N, numReg,depthX] = size(X);
[N, numSubj] = size(Y);
if (depthX ==1)
X=repmat(X,1,1,numSubj);
elseif (depthX~=numSubj)
error('X must needs to be a matrix or have a size in the 3rd dimension of numSubj');
end;
V0 = eye(numReg)*exp(logtheta); % Prior covariance of the regression coefficients
for s=1:numSubj
S = Sigma(:,:,s) + X(:,:,s)*V0*X(:,:,s)' ; % compute training set covariance matrix
L = chol(S)'; % cholesky factorization of the covariance
alpha = solve_chol(L',Y(:,s)); % Convenience function
nlml(s) = 0.5*sum(sum(alpha.*Y(:,s),2)) + sum(log(diag(L))) + 0.5*N*log(2*pi); % Negative log-likihood
if (nargout>1)
invS = (L'\(L\eye(N)));
W = alpha*alpha'-invS; % this is (alpha*alpha' - inv(S))
dnlml(s) = -1/2*sum(sum(W.*(X(:,:,s)*V0*X(:,:,s)'))); % Derivative of L
end;
if (nargout>2)
wN(:,s) = V0*X(:,:,s)'*alpha; % regression coefficients by Matrix inversion
end;
if (nargout>3)
VN(:,:,s) = V0 - V0*X(:,:,s)'*invS*X(:,:,s)*V0; % Posterior variance
VN(:,:,s) = (VN(:,:,s)+VN(:,:,s)')/2; % Prevent asymmrety through roundoff.
end;
end;
% Sum marginal likelihoods for optimization
nlmlSum = sum(nlml);
if (nargout>1)
dnlmlSum = sum(dnlml);
end;